Vova Bilyachat

Melbourne, Australia

How to create Spring Boot Reactive(WebFlux) Webfilter and protect your service with Active Directory

09 May 2021

Reactive and Servlet have a bit different approaches to the same functionality. In this post I want to show how to implement simple filter in reactive way. WebFilter is contract to intercept web request which can be used for security, rewriting header and so on.

What is this post about?

  • Register Active Directory App
  • Implement Reactive Webfilter in spring boot
  • Validate Active Directory JWT token

Register AD App

For this example, we will need to create two AD apps one will be our protected microservice and consumer service. To do so:

  1. Go to Active Directory Active directory
  2. Register new app App registrations -> New registration App registrations
  3. Service A App registrations
  4. Register App Url. Click on Add an Application ID URI then click “Set” on next screen App registrations App registrations
  5. Then create Consumer app and Create secret (Save it since you will need it to make request) App registrations

Get Active Directory token

To get AD JWT token we need to run following curl(Note: that if you convert)

Replace TENANT_ID, CLIENT_SECRET, CLIENT_ID, APP_ID_URI

wget --no-check-certificate --quiet \
  --method GET \
  --timeout=0 \
  --header 'Content-Type: application/x-www-form-urlencoded' \
  --body-data 'client_secret=CLIENT_SECRET&grant_type=client_credentials%20&client_id=CLIENT_ID&resource=APP_ID_URI' \
   'https://login.microsoftonline.com/TENANT_ID/oauth2/token'

App registrations

Spring boot

Generate project with webflux then add following dependency. We need this to decode and validate JWT Token

        <dependency>
            <groupId>org.springframework.security</groupId>
            <artifactId>spring-security-oauth2-jose</artifactId>
            <version>5.4.6</version>
        </dependency>

Web filter

package com.vob.webflux.webfilter.filter;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.jwt.*;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;

@Component
public class AdAuthNFilter implements WebFilter {
    private final String AuthHeader = "X-Server-Authorization";
    public static final String HEADER_PREFIX = "Bearer ";
    private final ReactiveJwtDecoder jwtDecoder;

    public AdAuthNFilter(@Value("${jwt.iss}")  String issuer, @Value("${jwt.aud}") String aud, @Value("${jwt.jwk-uri}") String jwkUrl) {
        jwtDecoder = NimbusReactiveJwtDecoder.withJwkSetUri(jwkUrl).build();
        ((NimbusReactiveJwtDecoder) jwtDecoder).setJwtValidator(new DelegatingOAuth2TokenValidator<>(
                new JwtAudValidator(aud),
                new JwtIssuerValidator(issuer),
                new JwtTimestampValidator()));
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        return Mono
                .defer(() -> {
                    var token = resolveToken(exchange.getRequest());
                    if (!StringUtils.hasText(token)) {
                        throw new BadJwtException("Authorisation token is invalid");
                    }
                    return jwtDecoder.decode(token);
                })
                .flatMap(tokenJwt -> chain.filter(exchange))
                .onErrorResume(JwtValidationException.class, err -> handleError(exchange, err))
                .onErrorResume(err -> handleError(exchange, err));
    }

    private Mono<Void> handleError(ServerWebExchange exchange, JwtValidationException ex) {
        return writeResponse(exchange, ex.getErrors().stream().map(e->e.getDescription()).collect(Collectors.joining(", ")));
    }
    private Mono<Void> handleError(ServerWebExchange exchange, Throwable ex) {
       return writeResponse(exchange, ex.getMessage());
    }

    private Mono<Void> writeResponse(ServerWebExchange exchange, String message) {
        exchange.getResponse().setRawStatusCode(HttpStatus.UNAUTHORIZED.value());
        exchange.getResponse().getHeaders().add("Content-Type", "application/json");
        return exchange
                .getResponse()
                .writeWith(
                        Flux.just(
                                exchange.getResponse().bufferFactory().wrap(message.getBytes(StandardCharsets.UTF_8))));
    }

    private String resolveToken(ServerHttpRequest request) {
        String bearerToken = request.getHeaders().getFirst(AuthHeader);
        if (StringUtils.hasText(bearerToken) && bearerToken.startsWith(HEADER_PREFIX)) {
            return bearerToken.substring(7).trim();
        }
        return "";
    }
}

What does it do?

  • Creates JWT decoder which will also validate Audience, Expiry time, and Issuer
  • Expect header X-Server-Authorization with JWT token if token is empty or invalid 401 HTTP status will be returned with error message.
    • If token is valid then we will step into flatMap and return chain.filter(exchange)
    • All JwtValidationException will fall into .onErrorResume(JwtValidationException.class, err -> handleError(exchange, err)) as it can contains multiple errors and I want to return them all.
    • Other errors will go into .onErrorResume(err -> handleError(exchange, err))

How Does webfilter works?

So how does it work? Well if you want to go to the next webfilter the return chain.filter(exchange) but if you want to stop and return responsethen you will need to complete request and setting response code

exchange.getResponse().setStatusCode(HttpStatus.BAD_REQUEST);
return exchange.getResponse().setComplete();

Or if you want write custom response message

 return exchange
                .getResponse()
                .writeWith(Flux.just(exchange.getResponse().bufferFactory().wrap("String message".getBytes(StandardCharsets.UTF_8))));

In case you want to return json then you will need to serialize your object

Audience validator

I havent found Audience validator in spring-security-oauth2-jose so I wrote my own.

package com.vob.webflux.webfilter.filter;

import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.security.oauth2.jwt.Jwt;

public class JwtAudValidator implements OAuth2TokenValidator<Jwt> {
    private final String aud;
    private final OAuth2Error error;

    public JwtAudValidator(String aud) {
        this.aud = aud;
        this.error = new OAuth2Error("invalid_request", "The aud claim is not valid", "https://tools.ietf.org/html/rfc6750#section-3.1");

    }

    @Override
    public OAuth2TokenValidatorResult validate(Jwt jwt) {
        if (jwt.getAudience().contains(aud)) {
            return OAuth2TokenValidatorResult.success();
        } else {
            return OAuth2TokenValidatorResult.failure(this.error);
        }
    }
}

Properties

jwt.jwk-uri=https://login.microsoftonline.com/common/discovery/keys
jwt.iss=https://sts.windows.net/TENANT_ID/
jwt.aud=api://fb05b788-bea9-45ee-831d-886271460ed0

Demo

Error

Error service call

Success

Successfull service call

Github

https://github.com/vovikdrg/webflux-webfilter-example