1
2
3
4
5 package gboat2.base.plugin.servlet.filter;
6
7 import java.io.IOException;
8 import java.io.StringWriter;
9 import java.util.Collections;
10 import java.util.Iterator;
11 import java.util.Map;
12 import java.util.Set;
13
14 import javax.servlet.Filter;
15 import javax.servlet.FilterChain;
16 import javax.servlet.FilterConfig;
17 import javax.servlet.ServletException;
18 import javax.servlet.ServletRequest;
19 import javax.servlet.ServletResponse;
20 import javax.servlet.http.HttpServletRequest;
21 import javax.servlet.http.HttpServletRequestWrapper;
22
23 import org.apache.commons.io.FilenameUtils;
24 import org.apache.commons.lang3.StringUtils;
25 import org.slf4j.Logger;
26 import org.slf4j.LoggerFactory;
27
28 import com.opensymphony.xwork2.util.TextParseUtil;
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 public class AntiXssSqlInjectFilter implements Filter {
72
73 private static Logger logger = LoggerFactory.getLogger(AntiXssSqlInjectFilter.class);
74
75 protected Set<String> xssExcludes = Collections.emptySet();
76
77 protected Set<String> xssIncludes = Collections.emptySet();
78
79 protected Set<String> sqlExcludes = Collections.emptySet();
80
81 protected Set<String> sqlIncludes = Collections.emptySet();
82
83
84
85
86
87 @Override
88 public void init(FilterConfig filterConfig) throws ServletException {
89 String xssExcludes = filterConfig.getInitParameter("xssExcludes");
90 if (StringUtils.isNotEmpty(xssExcludes)) {
91 this.xssExcludes = TextParseUtil.commaDelimitedStringToSet(xssExcludes);
92 }
93
94 String xssIncludes = filterConfig.getInitParameter("xssIncludes");
95 if (StringUtils.isNotEmpty(xssIncludes)) {
96 this.xssIncludes = TextParseUtil.commaDelimitedStringToSet(xssIncludes);
97 }
98
99 String sqlExcludes = filterConfig.getInitParameter("sqlExcludes");
100 if (StringUtils.isNotEmpty(sqlExcludes)) {
101 this.sqlExcludes = TextParseUtil.commaDelimitedStringToSet(sqlExcludes);
102 }
103
104 String sqlIncludes = filterConfig.getInitParameter("sqlIncludes");
105 if (StringUtils.isNotEmpty(sqlIncludes)) {
106 this.sqlIncludes = TextParseUtil.commaDelimitedStringToSet(sqlIncludes);
107 }
108 }
109
110
111
112
113
114 @Override
115 public void destroy() {
116 }
117
118
119
120
121
122 @Override
123 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
124 HttpServletRequest req = (HttpServletRequest) request;
125 String uri = req.getRequestURI().replaceFirst(req.getContextPath(), "");
126 boolean applyXss = isMatched(xssExcludes, xssIncludes, uri);
127 if (logger.isDebugEnabled()) {
128 if (applyXss) {
129 logger.debug("anti xss in [" + uri + "] 's parameters.");
130 }
131 }
132
133 boolean applySql = isMatched(sqlExcludes, sqlIncludes, uri);
134 if (logger.isDebugEnabled()) {
135 if (applySql) {
136 logger.debug("anti sql inject in [" + uri + "] 's parameters.");
137 }
138 }
139
140 boolean needClean = applyXss || applySql;
141 if (needClean) {
142 chain.doFilter(new RequestWrapper(req, applyXss, applySql), response);
143 } else {
144 chain.doFilter(request, response);
145 }
146 }
147
148 private boolean isMatched(Set<String> excludes, Set<String> includes, String uri) {
149 boolean incMatched = false;
150 for (String inc : includes) {
151 if (FilenameUtils.wildcardMatch(uri, inc)) {
152 incMatched = true;
153 break;
154 }
155 }
156
157 if (!incMatched) {
158 return false;
159 }
160
161 boolean excMatched = false;
162 for (String ex : excludes) {
163 if (FilenameUtils.wildcardMatch(uri, ex)) {
164 excMatched = true;
165 break;
166 }
167 }
168
169 if (excMatched) {
170 return false;
171 }
172
173 return true;
174 }
175
176 public static class RequestWrapper extends HttpServletRequestWrapper {
177
178 private boolean needAntiXss;
179
180 private boolean needAntiSql;
181
182
183
184 private boolean mapEscaped = false;
185
186
187
188
189
190 public RequestWrapper(HttpServletRequest req, boolean needAntiXss, boolean needAntiSql) {
191 super(req);
192 this.needAntiXss = needAntiXss;
193 this.needAntiSql = needAntiSql;
194
195 }
196
197 private String escape(String value) {
198 String escapedValue = value;
199 if (needAntiXss) {
200
201 escapedValue = escapeHtml(value);
202 }
203
204 if (needAntiSql) {
205 escapedValue = escapeSql(escapedValue);
206 }
207 return escapedValue;
208 }
209
210 private String escapeHtml(String str) {
211 if (str == null) {
212 return null;
213 }
214 StringWriter writer = new StringWriter((int) (str.length() * 1.5));
215 int len = str.length();
216 for (int i = 0; i < len; i++) {
217 char c = str.charAt(i);
218 switch (c) {
219 case '>':
220 writer.write(">");
221 break;
222 case '<':
223 writer.write("<");
224 break;
225 case '&':
226 writer.write("&");
227 break;
228 default:
229 writer.write(c);
230 break;
231 }
232 }
233 return writer.toString();
234 }
235
236 private static String escapeSql(String str) {
237 if (str == null) {
238 return null;
239 }
240 return StringUtils.replace(str, "'", "''");
241 }
242
243 @Override
244 public String getHeader(String name) {
245 String value = super.getHeader(name);
246 if (value != null) {
247 value = escape(value);
248 }
249 return value;
250 }
251
252 @Override
253 public String getParameter(String name) {
254 String value = super.getParameter(name);
255 if (value != null) {
256 value = escape(value);
257 }
258 return value;
259 }
260
261 @Override
262 public String[] getParameterValues(String name) {
263 String[] values = super.getParameterValues(name);
264 if (values == null) {
265 return null;
266 }
267 int len = values.length;
268 String[] escapedValues = new String[len];
269 for (int i = 0; i < len; i++) {
270 escapedValues[i] = escape(values[i]);
271 }
272 return escapedValues;
273 }
274
275 @SuppressWarnings("rawtypes")
276 @Override
277 public Map getParameterMap() {
278 Map<?, ?> map = super.getParameterMap();
279 if (mapEscaped) {
280 return map;
281 }
282
283 Iterator<?> it = (map.keySet() != null) ? map.keySet().iterator() : null;
284 String key = null;
285 String[] values = null;
286 if (it != null) {
287 while (it.hasNext()) {
288 key = (String) it.next();
289 if (key != null) {
290 values = (String[]) map.get(key);
291 for (int i = 0; i < values.length; i++) {
292 values[i] = escape(values[i]);
293 }
294 }
295 }
296 }
297
298 mapEscaped = true;
299 return map;
300 }
301 }
302 }