web/Spring

Spring Interceptor와 Filter에서 POST 방식으로 전달된 JSON 데이터 처리하기

반응형

이번 회사 프로젝트에서 진행할 때 parameter값이 아니라 Josn 데이터가 필요할 때가 있었다.

이를 위해서는 HttpServletRequest에서 InputStream으로 데이터를 추출해야한다.
하지만 HttpServletRequest에서 InputStream을 한번 추출하게되면, Controller에서 parameter를 매핑하려고 데이터를 바인딩할 때 다음과 같은 오류가 발생한다.

이는 톰캣에서 막아놓았기 때문이다.


[에러내용]


1
2
java.lang.IllegalStateException: getReader() has already been called for this request
org.springframework.http.converter.HttpMessageNotReadableException: Could not read JSON: Stream closed; nested exception is java.io.IOException: Stream closed
cs




이를 방지하고 JSON 데이터를 추출해서 사용하기 위해서는 HttpServletRequest를 wrapping 해서 사용해야 하는데 이를 HttpServletRequestWrapper을 확장하여 정의하면 된다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public class RereadableRequestWrapper extends HttpServletRequestWrapper {
 
    private final Charset encoding;
    private byte[] rawData;
 
    public RereadableRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
 
        String characterEncoding = request.getCharacterEncoding();
        if (StringUtils.isBlank(characterEncoding)) {
            characterEncoding = StandardCharsets.UTF_8.name();
        }
        this.encoding = Charset.forName(characterEncoding);
 
        // Convert InputStream data to byte array and store it to this wrapper instance.
        try {
            InputStream inputStream = request.getInputStream();
            this.rawData = IOUtils.toByteArray(inputStream);
        } catch (IOException e) {
            throw e;
        }
    }
 
    @Override
    public ServletInputStream getInputStream() throws IOException {
        final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(this.rawData);
        ServletInputStream servletInputStream = new ServletInputStream() {
            public int read() throws IOException {
                return byteArrayInputStream.read();
            }
        };
        return servletInputStream;
    }
 
    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(this.getInputStream(), this.encoding));
    }
 
    @Override
    public ServletRequest getRequest() {
        return super.getRequest();
    }
}
cs



재정의 해서 만든 HttpServletRequestWrapper를 Filter로 통해서 사용하도록 지정해야한다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
// filter 클래스 정의
  @Override
  public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
     RereadableRequestWrapper rereadableRequestWrapper = new RereadableRequestWrapper((HttpServletRequest)request);
      ...
      chain.doFilter(rereadableRequestWrapper , response);
      ...
 
// web.xml에 정의
filter>
        <filter-name>requestFilter</filter-name>
        <filter-class>com.wedul.wedulpos.RequestFilter</filter-class>
    </filter>
    <filter-mapping>
        <filter-name>requestFilter</filter-name>
        <url-pattern>/*</url-pattern>
    </filter-mapping>
cs

 



하지만 이렇게 정의할 경우 

ResponseBody만 정의 하였기 때문에, RequestParam에 대한 처리를 해주지 않아서 문제가 발생한다. 


그렇기 때문에 RequestParam("application/x-www-form-urlencoded")에 대한 처리를 다음과 같이 해주어야 한다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
public class RereadableRequestWrapper extends HttpServletRequestWrapper {
 
     private boolean parametersParsed = false;
 
    private final Charset encoding;
    private final byte[] rawData;
    private final Map<String, ArrayList<String>> parameters = new LinkedHashMap<String, ArrayList<String>>();
    ByteChunk tmpName = new ByteChunk();
    ByteChunk tmpValue = new ByteChunk();
 
    private class ByteChunk {
 
        private byte[] buff;
        private int start = 0;
        private int end;
 
        public void setByteChunk(byte[] b, int off, int len) {
            buff = b;
            start = off;
            end = start + len;
        }
 
        public byte[] getBytes() {
            return buff;
        }
 
        public int getStart() {
            return start;
        }
 
        public int getEnd() {
            return end;
        }
 
        public void recycle() {
            buff = null;
            start = 0;
            end = 0;
        }
    }
 
    ...
 
    @Override
    public String getParameter(String name) {
        if (!parametersParsed) {
            parseParameters();
        }
        ArrayList<String> values = this.parameters.get(name);
        if (values == null || values.size() == 0)
            return null;
        return values.get(0);
    }
 
    public HashMap<StringString[]> getParameters() {
        if (!parametersParsed) {
            parseParameters();
        }
        HashMap<StringString[]> map = new HashMap<StringString[]>(this.parameters.size() * 2);
        for (String name : this.parameters.keySet()) {
            ArrayList<String> values = this.parameters.get(name);
            map.put(name, values.toArray(new String[values.size()]));
        }
        return map;
    }
 
    @SuppressWarnings("rawtypes")
    @Override
    public Map getParameterMap() {
        return getParameters();
    }
 
    @SuppressWarnings("rawtypes")
    @Override
    public Enumeration getParameterNames() {
        return new Enumeration<String>() {
            @SuppressWarnings("unchecked")
            private String[] arr = (String[])(getParameterMap().keySet().toArray(new String[0]));
            private int index = 0;
 
            @Override
            public boolean hasMoreElements() {
                return index < arr.length;
            }
 
            @Override
            public String nextElement() {
                return arr[index++];
            }
        };
    }
 
    @Override
    public String[] getParameterValues(String name) {
        if (!parametersParsed) {
            parseParameters();
        }
        ArrayList<String> values = this.parameters.get(name);
        String[] arr = values.toArray(new String[values.size()]);
        if (arr == null) {
            return null;
        }
        return arr;
    }
 
    private void parseParameters() {
        parametersParsed = true;
 
        if (!("application/x-www-form-urlencoded".equalsIgnoreCase(super.getContentType()))) {
            return;
        }
 
        int pos = 0;
        int end = this.rawData.length;
 
        while (pos < end) {
            int nameStart = pos;
            int nameEnd = -1;
            int valueStart = -1;
            int valueEnd = -1;
 
            boolean parsingName = true;
            boolean decodeName = false;
            boolean decodeValue = false;
            boolean parameterComplete = false;
 
            do {
                switch (this.rawData[pos]) {
                    case '=':
                        if (parsingName) {
                            // Name finished. Value starts from next character
                            nameEnd = pos;
                            parsingName = false;
                            valueStart = ++pos;
                        } else {
                            // Equals character in value
                            pos++;
                        }
                        break;
                    case '&':
                        if (parsingName) {
                            // Name finished. No value.
                            nameEnd = pos;
                        } else {
                            // Value finished
                            valueEnd = pos;
                        }
                        parameterComplete = true;
                        pos++;
                        break;
             }
 
                if (StringUtils.isNotBlank(name)) {
                    ArrayList<String> values = this.parameters.get(name);
                    if (values == null) {
                        values = new ArrayList<String>(1);
                        this.parameters.put(name, values);
                    }
                    if (StringUtils.isNotBlank(value)) {
                        values.add(value);
                    }
                }
            } catch (DecoderException e) {
                // ignore invalid chunk
            }
 
            tmpName.recycle();
            tmpValue.recycle();
        }
    }
}
cs



이 포스트를 보면서 회사에서 CSRF 공격 방어에 잘 사용하였다.
아주 감사한 포스트였다.

출저 : http://meetup.toast.com/posts/44

반응형