WebMvcTest to Swagger Yaml (1)

19 July 2020

spring boot에서 @WebMvcTest를 실행하면서 Swagger Yaml 파일을 자동 생성해주는 예제를 공유하도록 하겠습니다.
실제 소스코드를 보면서 동작 방식을 설명하도록 하겠습니다. 먼저 아래와 같이 dependency를 추가해주도록 합니다. Junit Test를 위한 라이브러리와 Yaml 파일 생성을 위한 라이브러리로 구성되어 있습니다.

plugins {
    id 'org.springframework.boot' version '2.2.2.RELEASE'
    id 'io.spring.dependency-management' version '1.0.8.RELEASE'
    id 'java'
}

group = 'com.example'
version = '0.0.1-SNAPSHOT'
sourceCompatibility = '1.8'

repositories {
    mavenCentral()
}

configurations {
    compileOnly {
        extendsFrom annotationProcessor
    }
}

dependencies {
    compileOnly 'org.projectlombok:lombok:1.18.12'
    annotationProcessor 'org.projectlombok:lombok:1.18.12'

    testCompileOnly 'org.projectlombok:lombok:1.18.12'
    testAnnotationProcessor 'org.projectlombok:lombok:1.18.12'

    implementation 'org.springframework.boot:spring-boot-starter-web'
    testImplementation('org.springframework.boot:spring-boot-starter-test') {
        exclude group: 'org.junit.vintage', module: 'junit-vintage-engine'
    }

    compile group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-yaml', version: '2.11.1'
}

test {
    useJUnitPlatform()
}
그 이후, 아래와 같이 간단한 RestController를 생성해줍니다. 간단히 path, query, body를 가진 POST API 입니다.
package org.shashaka.io.controller;

import org.shashaka.io.model.PostRequest;
import org.shashaka.io.model.PostResponse;
import org.springframework.web.bind.annotation.*;

@RestController
@RequestMapping("/swagger")
public class SwaggerController {

    @PostMapping(value = "/post/{id}")
    public PostResponse postMethod(@PathVariable(required = false) Integer id,
                                   @RequestParam(required = false) String latest,
                                   @RequestParam(required = false) String first,
                                   @RequestBody PostRequest request) {
        return new PostResponse(request.getName());
    }
}
그리고 아래와 같이 MockMvc 사용시 중간에서 request와 response 값을 체크할 filter를 만들어줍니다.
package org.shashaka.io.swagger;

import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.shashaka.io.swagger.utils.SwaggerGenerator;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.TestComponent;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;

@Slf4j
@TestComponent
public class SwaggerDocFilter implements Filter {

    @Autowired
    private SwaggerGenerator swaggerGenerator;

    @SneakyThrows
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) {

        MockHttpServletRequest mockHttpServletRequest = (MockHttpServletRequest) request;

        chain.doFilter(request, response);

        MockHttpServletResponse mockHttpServletResponse = (MockHttpServletResponse) response;

        swaggerGenerator.writeYaml(mockHttpServletRequest);
    }
}
그리고 위 필터를 이용하여 새로운 MockMvc Bean을 생성해줍니다. 사용자가 새로이 MockMvc를 생성하였으므로, autowired시에는 자동으로 사용자가 생성한 MockMvc가 호출되게 됩니다.
package org.shashaka.io.swagger;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.DefaultMockMvcBuilder;

@Slf4j
@TestConfiguration
@ComponentScan
public class SwaggerDocConfig {

    @Autowired
    private SwaggerDocFilter swaggerDocFilter;

    @Bean
    public MockMvc mockMvc(DefaultMockMvcBuilder builder) {
        return builder
                .addFilters(swaggerDocFilter)
                .build();
    }
}

아래와 같이 WebMvcTest를 작성해주도록 합니다. 여기서 주의할 점은 SwaggerDocConfig를 생성하여 주입해준다는 점입니다.
package org.shashaka.io;

import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.shashaka.io.model.PostRequest;
import org.shashaka.io.swagger.SwaggerDocConfig;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest;
import org.springframework.context.annotation.Import;
import org.springframework.http.MediaType;
import org.springframework.test.web.servlet.MockMvc;

import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;


@Slf4j
@Import(SwaggerDocConfig.class)
@WebMvcTest
public class SwaggerControllerTest {

    @Autowired
    private ObjectMapper objectMapper;

    @Autowired
    private MockMvc mockMvc;

    @Test
    public void post_test() throws Exception {

        PostRequest request = new PostRequest();
        request.setName("name");
        String body = objectMapper.writeValueAsString(request);

        mockMvc.perform(
                post("/swagger/post/1")
                        .contentType(MediaType.APPLICATION_JSON)
                        .content(body.getBytes()))
                .andDo(print())
                .andExpect(content().json("{\"name\":\"name\"}"))
                .andExpect(status().isOk());
    }
}
실제 Swagger Document를 생성할 SwaggerGenerator는 아래와 같습니다. 각 Test가 실행될 때마다 Test API에 해당하는 handlerMapping을 가져오고, API가 선언된 controller의 소스코드 값을 reflection을 이용해 읽어옵니다. 그래서 각각 path, query, body, responseBody class를 읽어와 swagger.yaml 파일에 write해주는 역할을 합니다.
package org.shashaka.io.swagger.utils;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import lombok.extern.slf4j.Slf4j;
import org.shashaka.io.swagger.model.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.TestComponent;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.core.MethodParameter;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;

import java.io.File;
import java.lang.reflect.Field;
import java.lang.reflect.Parameter;
import java.util.*;

@TestConfiguration
@TestComponent
@Slf4j
public class SwaggerGenerator {

    private static final String SWAGGER_IN_PATH = "path";
    private static final String SWAGGER_IN_QUERY = "query";
    private static final String SWAGGER_IN_HEADER = "header";
    private static final String SWAGGER_IN_BODY = "body";

    @Autowired
    private RequestMappingHandlerMapping handlerMapping;

    SwaggerDocument document = new SwaggerDocument();
    Map definitions = new HashMap<>();

    public void writeYaml(MockHttpServletRequest mockHttpServletRequest) throws Exception {

        SwaggerInfo info = new SwaggerInfo();
        info.setVersion("1.0.0");
        info.setDescription("test API Doc");
        info.setTitle("Swagger API Doc from TC");
        document.setInfo(info);

        document.setDefinitions(definitions);
        addModel(mockHttpServletRequest);

        ObjectMapper om = new ObjectMapper(new YAMLFactory());

        om.writeValue(new File("swagger.yaml"), document);
    }

    private String addPath(HandlerMethod handler) {
        String path = getPath(handler);
        if (document.getPaths() == null) {
            document.setPaths(new HashMap<>());
        }
        document.getPaths().computeIfAbsent(path, k -> new HashMap<>());
        return path;
    }

    public void addModel(MockHttpServletRequest mockHttpServletRequest) throws Exception {

        HandlerMethod handler = (HandlerMethod) handlerMapping.getHandler(mockHttpServletRequest).getHandler();

        addRequestBodyModel(handler);
        addResponseBodyModel(handler);
        String path = addPath(handler);
        String httpMethod = addHttpMethod(path, mockHttpServletRequest.getMethod().toLowerCase());

        SwaggerPath swaggerPath = document.getPaths().get(path).get(httpMethod);
        swaggerPath.setParameters(getRequestParameters(mockHttpServletRequest, handler));
        swaggerPath.setResponses(getResponses(handler));

    }

    private Map getResponses(HandlerMethod handler) {
        Class responseBodyClass = getResponseBodyClass(handler);
        Map swaggerResponseMap = new HashMap<>();
        SwaggerResponse successResponseValue = new SwaggerResponse();

        SwaggerSchema successResponseSchema = new SwaggerSchema();
        successResponseSchema.setRef("#/definitions/" + responseBodyClass.getSimpleName());
        successResponseValue.setSchema(successResponseSchema);
        successResponseValue.setDescription("success");

        swaggerResponseMap.put("200", successResponseValue);
        return swaggerResponseMap;
    }

    private List getRequestParameters(MockHttpServletRequest request, HandlerMethod handler) {
        List swaggerPrameters = new ArrayList<>();

        Set pathVariables = getPathVariable(handler);
        for (Parameter pathVariable : pathVariables) {
            SwaggerPrameter pathVariableParameter = new SwaggerPrameter();
            pathVariableParameter.setIn(SWAGGER_IN_PATH);
            pathVariableParameter.setName(pathVariable.getName());
            pathVariableParameter.setRequired(true);
            pathVariableParameter.setType(pathVariable.getType().getSimpleName().toLowerCase());
            swaggerPrameters.add(pathVariableParameter);
        }

        Set requestQueries = getRequestQuery(handler);
        for (Parameter requestQuery : requestQueries) {
            SwaggerPrameter requestQueryParameter = new SwaggerPrameter();
            requestQueryParameter.setIn(SWAGGER_IN_QUERY);
            requestQueryParameter.setName(requestQuery.getName());
            requestQueryParameter.setType(requestQuery.getType().getSimpleName().toLowerCase());
            swaggerPrameters.add(requestQueryParameter);
        }

        Enumeration headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String headerName = headerNames.nextElement();
            SwaggerPrameter requestHeaderParameter = new SwaggerPrameter();
            requestHeaderParameter.setIn(SWAGGER_IN_HEADER);
            requestHeaderParameter.setName(headerName);
            requestHeaderParameter.setType(request.getHeader(headerName).getClass().getSimpleName().toLowerCase());
            swaggerPrameters.add(requestHeaderParameter);
        }

        Parameter requestBodyClass = getRequestBodyClass(handler);
        if (requestBodyClass != null) {
            SwaggerPrameter requestBodyParameter = new SwaggerPrameter();
            requestBodyParameter.setIn(SWAGGER_IN_BODY);
            requestBodyParameter.setName(SWAGGER_IN_BODY);
            requestBodyParameter.setRequired(true);
            SwaggerSchema requestBodySchema = new SwaggerSchema();
            requestBodySchema.setRef("#/definitions/" + requestBodyClass.getType().getSimpleName());
            requestBodyParameter.setSchema(requestBodySchema);
            swaggerPrameters.add(requestBodyParameter);
        }

        return swaggerPrameters;
    }

    private String addHttpMethod(String path, String method) {
        document.getPaths().get(path).computeIfAbsent(method, k -> new SwaggerPath());
        return method;
    }

    private void addResponseBodyModel(HandlerMethod handler) {
        Class requestBody = getResponseBodyClass(handler);

        SwaggerDefinition swaggerDefinition = new SwaggerDefinition();
        Map properties = new HashMap<>();

        String bodyClassName = requestBody.getSimpleName();
        for (Field declaredField : requestBody.getDeclaredFields()) {
            if (!java.lang.reflect.Modifier.isStatic(declaredField.getModifiers())) {
                PropertyMap propertyMap = new PropertyMap();
                propertyMap.setType(declaredField.getType().getSimpleName().toLowerCase());
                properties.put(declaredField.getName(), propertyMap);
                swaggerDefinition.setProperties(properties);
            }
        }
        definitions.put(bodyClassName, swaggerDefinition);
    }

    private void addRequestBodyModel(HandlerMethod handler) {
        Parameter requestBody = getRequestBodyClass(handler);

        SwaggerDefinition swaggerDefinition = new SwaggerDefinition();
        Map properties = new HashMap<>();

        String bodyClassName = requestBody.getType().getSimpleName();
        for (Field declaredField : requestBody.getType().getDeclaredFields()) {
            if (!java.lang.reflect.Modifier.isStatic(declaredField.getModifiers())) {
                PropertyMap propertyMap = new PropertyMap();
                propertyMap.setType(declaredField.getType().getSimpleName().toLowerCase());
                properties.put(declaredField.getName(), propertyMap);
                swaggerDefinition.setProperties(properties);
            }
        }
        definitions.put(bodyClassName, swaggerDefinition);
    }

    private String getPath(HandlerMethod handler) {
        String prefix = "";
        String path = "";

        RequestMapping classRequestMapping = handler.getBeanType().getAnnotation(RequestMapping.class);
        if (classRequestMapping != null) {
            prefix = classRequestMapping.value()[0];
        }

        String[] pathValues = getRequestMappingAnnotation(handler);

        if (!Arrays.asList(pathValues).isEmpty()) {
            path = pathValues[0];
        }

        return prefix + path;
    }

    private String[] getRequestMappingAnnotation(HandlerMethod handler) {
        String[] value = {};
        for (MethodParameter methodParameter : handler.getMethodParameters()) {
            if (methodParameter.getMethodAnnotation(RequestMapping.class) != null) {
                return methodParameter.getMethodAnnotation(RequestMapping.class).value();
            } else if (methodParameter.getMethodAnnotation(PostMapping.class) != null) {
                return methodParameter.getMethodAnnotation(PostMapping.class).value();
            } else if (methodParameter.getMethodAnnotation(GetMapping.class) != null) {
                return methodParameter.getMethodAnnotation(GetMapping.class).value();
            } else if (methodParameter.getMethodAnnotation(PutMapping.class) != null) {
                return methodParameter.getMethodAnnotation(PutMapping.class).value();
            } else if (methodParameter.getMethodAnnotation(DeleteMapping.class) != null) {
                return methodParameter.getMethodAnnotation(DeleteMapping.class).value();
            }
        }
        return value;
    }

    private Set getPathVariable(HandlerMethod handler) {
        Set pathVariables = new HashSet<>();
        for (MethodParameter methodParameter : handler.getMethodParameters()) {
            if (methodParameter.getParameterAnnotation(PathVariable.class) != null) {
                pathVariables.add(methodParameter.getParameter());
            }
        }
        return pathVariables;
    }

    private Set getRequestQuery(HandlerMethod handler) {
        Set requestQuery = new HashSet<>();
        for (MethodParameter methodParameter : handler.getMethodParameters()) {
            if (methodParameter.getParameterAnnotation(RequestParam.class) != null) {
                requestQuery.add(methodParameter.getParameter());
            }
        }
        return requestQuery;
    }

    private Parameter getRequestBodyClass(HandlerMethod handler) {
        for (MethodParameter methodParameter : handler.getMethodParameters()) {
            if (methodParameter.getParameterAnnotation(RequestBody.class) != null) {
                return methodParameter.getParameter();
            }
        }
        return null;
    }

    private Class getResponseBodyClass(HandlerMethod handler) {
        return handler.getMethod().getReturnType();
    }
}
Test를 실행하였을 때, 생성된 Swagger.yaml 파일입니다. Swagger Editor를 통해 해당 yaml 파일을 렌더링할 수 있습니다.
---
swagger: "2.0"
info:
  description: "test API Doc"
  version: "1.0.0"
  title: "Swagger API Doc from TC"
paths:
  /swagger/post/{id}:
    post:
      parameters:
      - name: "id"
        in: "path"
        type: "integer"
        required: true
      - name: "first"
        in: "query"
        type: "string"
      - name: "latest"
        in: "query"
        type: "string"
      - name: "Content-Type"
        in: "header"
        type: "string"
      - name: "Content-Length"
        in: "header"
        type: "string"
      - name: "body"
        in: "body"
        required: true
        schema:
          $ref: "#/definitions/PostRequest"
      responses:
        "200":
          description: "success"
          schema:
            $ref: "#/definitions/PostResponse"
definitions:
  PostRequest:
    type: "object"
    properties:
      name:
        type: "string"
  PostResponse:
    type: "object"
    properties:
      name:
        type: "string"
전체 소스코드는 아래에서 확인할 수 있습니다.
https://gitlab.com/shashaka/tc2swagger