spring webclient circuit breaker

14 December 2019

spring webclient를 이용해 간단한 circuit breaker를 만들어보도록 하겠습니다.
먼저 아래와 같이 concurrentMap을 통해, webclient에서 다른 url로 호출하는 통계를 만들고, 해당 값을 통해 circuit을 open할 url을 선정하는 service를 만들어줍니다.

package org.shashaka.io;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Component;

import javax.annotation.PostConstruct;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;

@Component
@Slf4j
public class CircuitBreakerUtils {

    private static ConcurrentHashMap hostAccumulator;

    private static ConcurrentHashMap.KeySetView openedHost;

    @Value("${circuit.open.threshold}")
    private Long threshold;

    @Value("${circuit.open.rate}")
    private Double rate;

    @PostConstruct
    public void init() {
        hostAccumulator = new ConcurrentHashMap<>();
        openedHost = ConcurrentHashMap.newKeySet();
    }

    public void addHostStatus(String url, int statusCode) {
        String key = url + "_" + statusCode;
        hostAccumulator.compute(key, (s, atomicLong) -> {
            if (atomicLong == null) {
                atomicLong = new AtomicLong(1);
            } else {
                atomicLong.incrementAndGet();
            }
            return atomicLong;
        });
    }

    public List getTargetHostForOpen() {

        Map> hostResponseStatusMap = new HashMap<>();

        Enumeration keys = hostAccumulator.keys();
        while (keys.hasMoreElements()) {
            String key = keys.nextElement();


            String[] keyArr = key.split("_");
            String url = keyArr[0];
            String status = keyArr[1];
            if (!hostResponseStatusMap.containsKey(url)) {
                hostResponseStatusMap.put(url, new HashMap() );
            }
        }

        List targetUrl = new ArrayList<>();

        for (String url : hostResponseStatusMap.keySet()) {
            Double errorCount = 0.0;
            Double totalCount = 0.0;
            Map statusMap = hostResponseStatusMap.get(url);
            for (String status : statusMap.keySet()) {
                Long count = statusMap.get(status);
                if (HttpStatus.BAD_GATEWAY.value() == Integer.parseInt(status) ||
                        HttpStatus.SERVICE_UNAVAILABLE.value() == Integer.parseInt(status) ||
                        HttpStatus.GATEWAY_TIMEOUT.value() == Integer.parseInt(status)) {
                    errorCount += count;
                }
                totalCount += count;
            }

            double errorRate = errorCount / totalCount;

            if (errorCount > threshold && errorRate > rate) {
                targetUrl.add(url);
            }
        }

        return targetUrl;
    }

    public Boolean isDownStatus(String status) {
        return HttpStatus.BAD_GATEWAY.value() == Integer.parseInt(status) ||
                HttpStatus.SERVICE_UNAVAILABLE.value() == Integer.parseInt(status) ||
                HttpStatus.GATEWAY_TIMEOUT.value() == Integer.parseInt(status);
    }

    public void clearHostAccumulator() {
        hostAccumulator.clear();
    }

    public boolean isOpendHost(String url) {
        return openedHost.contains(url);
    }

    public void addOpendHost(List urls) {
        openedHost.addAll(urls);

        List removeKeys = new ArrayList<>();

        Enumeration keys = hostAccumulator.keys();
        while (keys.hasMoreElements()) {
            String key = keys.nextElement();
            for (String url : urls) {
                if (key.contains(url)) {
                    removeKeys.add(key);
                }
            }
        }
        for (String removeKey : removeKeys) {
            hostAccumulator.remove(removeKey);
        }
    }

    public void removeOpenedHost(String url) {
        openedHost.remove(url);
    }

    public ConcurrentHashMap.KeySetView getOpenedHost() {
        return openedHost;
    }

}
그 다음, 아래와 같이 webClient filter와 JettyClientHttpConnector를 통해 circuitBreakerUtils을 호출하여 통계를 만들 수 있게 해줍니다.
package org.shashaka.io;

import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.reactive.ClientHttpConnector;
import org.springframework.http.client.reactive.JettyClientHttpConnector;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

import java.net.URI;

@Configuration
public class WebClientConfig {

    @Autowired
    private CircuitBreakerUtils circuitBreakerUtils;

    private ClientHttpConnector clientHttpConnector() {
        return new JettyClientHttpConnector(new HttpClient(new SslContextFactory.Client()) {
            @Override
            public Request newRequest(URI uri) {
                Request request = super.newRequest(uri);
                return circuiltApplied(request);
            }
        });
    }

    private Request circuiltApplied(Request request) {
        request.onResponseSuccess(response -> {
            URI uri = request.getURI();
            String url = uri.getScheme() + "://" + uri.getRawAuthority();
            circuitBreakerUtils.addHostStatus(url, response.getStatus());
        });
        return request;
    }

    @Bean
    public WebClient webClient() {
        return WebClient.builder()
                .clientConnector(clientHttpConnector())
                .filter(onRequest())
                .build();
    }

    private ExchangeFilterFunction onRequest() {
        return ExchangeFilterFunction.ofRequestProcessor(clientRequest -> {
            String url = clientRequest.url().getScheme() + "://" + clientRequest.url().getAuthority();
            if (circuitBreakerUtils.isOpendHost(url)) {
                throw new RuntimeException("circuit open for " + url);
            }
            return Mono.just(clientRequest);
        });
    }

}
마지막으로 아래와같이 scheduling을 통해, response status 통계를 체크하고, 해당 통계를 통해 circuit breaker를 열고 닫도록 설정해줍니다.
package org.shashaka.io;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableScheduling;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.web.reactive.function.client.WebClient;

import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

@Configuration
@EnableScheduling
@Slf4j
public class CircuitBreakerConfig {

    public WebClient circuitBreakerWebClient() {
        return WebClient.create();
    }

    @Autowired
    private CircuitBreakerUtils circuitBreakerUtils;

    @Scheduled(fixedRateString = "${circuit.open.interval}")
    public void clearHostAccumulator() {
        circuitBreakerUtils.clearHostAccumulator();
    }

    @Scheduled(fixedRateString = "${circuit.open.sample_interval}")
    public void addOpenedHost() {

        ConcurrentHashMap.KeySetView opendHost = circuitBreakerUtils.getOpenedHost();

        for (Object url : opendHost) {

            circuitBreakerWebClient()
                    .get()
                    .uri(url.toString())
                    .exchange()
                    .subscribe(clientResponse -> {
                        if (!circuitBreakerUtils.isDownStatus(String.valueOf(clientResponse.rawStatusCode()))) {
                            circuitBreakerUtils.removeOpenedHost(url.toString());
                        }
                    });
        }

        List targetHost = circuitBreakerUtils.getTargetHostForOpen();

        if (!targetHost.isEmpty()) {
            circuitBreakerUtils.addOpendHost(targetHost);
        }
    }
}
mocky.io를 통해 테스트를 하면 아래와 같이 circuit이 open되고 close되는 것을 확인할 수 있습니다.
$ http :8080/test
HTTP/1.1 200
Connection: keep-alive
Content-Length: 80
Content-Type: text/plain;charset=UTF-8
Date: Sat, 14 Dec 2019 13:07:10 GMT
Keep-Alive: timeout=60

503 Service Unavailable from GET http://www.mocky.io/v2/5df4be9f3000004d00111bba

$ http :8080/test
HTTP/1.1 200
Connection: keep-alive
Content-Length: 36
Content-Type: text/plain;charset=UTF-8
Date: Sat, 14 Dec 2019 13:07:13 GMT
Keep-Alive: timeout=60

circuit open for http://www.mocky.io