View Javadoc
1   /**
2    * Copyright By Grandsoft Company Limited.  
3    * 2013-9-25 上午10:10:38
4    */
5   package gboat2.base.plugin.servlet.filter;
6   
7   import java.io.IOException;
8   import java.io.StringWriter;
9   import java.util.Collections;
10  import java.util.Iterator;
11  import java.util.Map;
12  import java.util.Set;
13  
14  import javax.servlet.Filter;
15  import javax.servlet.FilterChain;
16  import javax.servlet.FilterConfig;
17  import javax.servlet.ServletException;
18  import javax.servlet.ServletRequest;
19  import javax.servlet.ServletResponse;
20  import javax.servlet.http.HttpServletRequest;
21  import javax.servlet.http.HttpServletRequestWrapper;
22  
23  import org.apache.commons.io.FilenameUtils;
24  import org.apache.commons.lang3.StringUtils;
25  import org.slf4j.Logger;
26  import org.slf4j.LoggerFactory;
27  
28  import com.opensymphony.xwork2.util.TextParseUtil;
29  
30  /**
31   * 防止 SQL 注入的监听器。<br>
32   * 配置示例:<pre></code>
33   * &lt;filter&gt;
34   *     &lt;!-- 过滤需要防止xss、sql攻击的请求 ,此配置统一解决不了的需求,需自己有针对性的过滤  --&gt;
35   *     &lt;filter-name&gt;antiXssSqlInjectFilter&lt;/filter-name&gt;
36   *     &lt;filter-class&gt;gboat2.base.plugin.servlet.filter.AntiXssSqlInjectFilter&lt;/filter-class&gt;
37   *     &lt;init-param&gt;
38   *         &lt;!-- 以逗号分隔的需要过滤的请求列表,支持通配符 --&gt;
39   *         &lt;param-name&gt;xssIncludes&lt;/param-name&gt;
40   *         &lt;param-value&gt;
41   *             /register!registerSave.do,
42   *             &#42;/login!login.do,
43   *             &#42;/login!rolesBeforeLogin.do
44   *         &lt;/param-value&gt;
45   *     &lt;/init-param&gt;
46   *     &lt;init-param&gt;
47   *         &lt;!-- 以逗号分隔的需要排除的请求列表,支持通配符,优先级高于xssIncludes  --&gt;
48   *         &lt;param-name&gt;xssExcludes&lt;/param-name&gt;
49   *         &lt;param-value&gt;&lt;/param-value&gt;
50   *     &lt;/init-param&gt;
51   *     &lt;init-param&gt;
52   *         &lt;param-name&gt;sqlIncludes&lt;/param-name&gt;
53   *         &lt;param-value&gt;
54   *             /register!registerSave.do,
55   *             &#42;/login!login.do
56   *         &lt;/param-value&gt;
57   *     &lt;/init-param&gt;
58   *     &lt;init-param&gt;
59   *         &lt;param-name&gt;sqlExcludes&lt;/param-name&gt;
60   *         &lt;param-value&gt;&lt;/param-value&gt;
61   *     &lt;/init-param&gt;
62   * &lt;/filter&gt;
63   * &lt;filter-mapping&gt;
64   *     &lt;filter-name&gt;antiXssSqlInjectFilter&lt;/filter-name&gt;
65   *     &lt;url-pattern&gt;*.do&lt;/url-pattern&gt;
66   * &lt;/filter-mapping&gt;</code></pre>
67   * @date 2013-9-25
68   * @author tanxw
69   * @since 1.0
70   */
71  public class AntiXssSqlInjectFilter implements Filter {
72  	
73  	private static Logger logger = LoggerFactory.getLogger(AntiXssSqlInjectFilter.class);
74  	
75  	protected Set<String> xssExcludes = Collections.emptySet();
76  	
77  	protected Set<String> xssIncludes = Collections.emptySet();
78  	
79  	protected Set<String> sqlExcludes = Collections.emptySet();
80  	
81  	protected Set<String> sqlIncludes = Collections.emptySet();
82  	
83  	/*
84  	 * {@inheritDoc}   
85  	 * @see javax.servlet.Filter#init(javax.servlet.FilterConfig) 
86  	 */
87  	@Override
88  	public void init(FilterConfig filterConfig) throws ServletException {
89  		String xssExcludes = filterConfig.getInitParameter("xssExcludes");
90  		if (StringUtils.isNotEmpty(xssExcludes)) {
91  			this.xssExcludes = TextParseUtil.commaDelimitedStringToSet(xssExcludes);
92  		}
93  		
94  		String xssIncludes = filterConfig.getInitParameter("xssIncludes");
95  		if (StringUtils.isNotEmpty(xssIncludes)) {
96  			this.xssIncludes = TextParseUtil.commaDelimitedStringToSet(xssIncludes);
97  		}
98  		
99  		String sqlExcludes = filterConfig.getInitParameter("sqlExcludes");
100 		if (StringUtils.isNotEmpty(sqlExcludes)) {
101 			this.sqlExcludes = TextParseUtil.commaDelimitedStringToSet(sqlExcludes);
102 		}
103 		
104 		String sqlIncludes = filterConfig.getInitParameter("sqlIncludes");
105 		if (StringUtils.isNotEmpty(sqlIncludes)) {
106 			this.sqlIncludes = TextParseUtil.commaDelimitedStringToSet(sqlIncludes);
107 		}
108 	}
109 	
110 	/*
111 	 * {@inheritDoc}   
112 	 * @see javax.servlet.Filter#destroy() 
113 	 */
114 	@Override
115 	public void destroy() {
116 	}
117 	
118 	/*
119 	 * {@inheritDoc}   
120 	 * @see javax.servlet.Filter#doFilter(javax.servlet.ServletRequest, javax.servlet.ServletResponse, javax.servlet.FilterChain) 
121 	 */
122 	@Override
123 	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
124 		HttpServletRequest req = (HttpServletRequest) request;
125 		String uri = req.getRequestURI().replaceFirst(req.getContextPath(), "");
126 		boolean applyXss = isMatched(xssExcludes, xssIncludes, uri);
127 		if (logger.isDebugEnabled()) {
128 			if (applyXss) {
129 				logger.debug("anti xss in [" + uri + "] 's parameters.");
130 			}
131 		}
132 		
133 		boolean applySql = isMatched(sqlExcludes, sqlIncludes, uri);
134 		if (logger.isDebugEnabled()) {
135 			if (applySql) {
136 				logger.debug("anti sql inject in [" + uri + "] 's parameters.");
137 			}
138 		}
139 		
140 		boolean needClean = applyXss || applySql;
141 		if (needClean) {
142 			chain.doFilter(new RequestWrapper(req, applyXss, applySql), response);
143 		} else {
144 			chain.doFilter(request, response);
145 		}
146 	}
147 	
148 	private boolean isMatched(Set<String> excludes, Set<String> includes, String uri) {
149 		boolean incMatched = false;
150 		for (String inc : includes) {
151 			if (FilenameUtils.wildcardMatch(uri, inc)) {
152 				incMatched = true;
153 				break;
154 			}
155 		}
156 		
157 		if (!incMatched) {//不在包含名单里
158 			return false;
159 		}
160 		
161 		boolean excMatched = false;
162 		for (String ex : excludes) {
163 			if (FilenameUtils.wildcardMatch(uri, ex)) {
164 				excMatched = true;
165 				break;
166 			}
167 		}
168 		
169 		if (excMatched) {//在排除名单里,排除优先级高
170 			return false;
171 		}
172 		
173 		return true;
174 	}
175 	
176 	public static class RequestWrapper extends HttpServletRequestWrapper {
177 		
178 		private boolean needAntiXss;
179 		
180 		private boolean needAntiSql;
181 		
182 		//private HttpServletRequest orgRequest = null;
183 		
184 		private boolean mapEscaped = false;
185 		
186 		/**
187 		* The constructor
188 		* @param req The request
189 		*/
190 		public RequestWrapper(HttpServletRequest req, boolean needAntiXss, boolean needAntiSql) {
191 			super(req);
192 			this.needAntiXss = needAntiXss;
193 			this.needAntiSql = needAntiSql;
194 			//this.orgRequest = req;
195 		}
196 		
197 		private String escape(String value) {
198 			String escapedValue = value;
199 			if (needAntiXss) {
200 				//escapedValue = StringEscapeUtils.escapeHtml(StringEscapeUtils.escapeJavaScript(value));
201 				escapedValue = escapeHtml(value);
202 			}
203 			
204 			if (needAntiSql) {
205 				escapedValue = escapeSql(escapedValue);
206 			}
207 			return escapedValue;
208 		}
209 		
210 		private String escapeHtml(String str) {
211 			if (str == null) {
212 				return null;
213 			}
214 			StringWriter writer = new StringWriter((int) (str.length() * 1.5));
215 			int len = str.length();
216 			for (int i = 0; i < len; i++) {
217 				char c = str.charAt(i);
218 				switch (c) {
219 					case '>':
220 						writer.write("&gt;");
221 						break;
222 					case '<':
223 						writer.write("&lt;");
224 						break;
225 					case '&':
226 						writer.write("&amp;");
227 						break;
228 					default:
229 						writer.write(c);
230 						break;
231 				}
232 			}
233 			return writer.toString();
234 		}
235 		
236 	    private static String escapeSql(String str) {
237 	        if (str == null) {
238 	            return null;
239 	        }
240 	        return StringUtils.replace(str, "'", "''");
241 	    }
242 	    
243 		@Override
244 		public String getHeader(String name) {
245 			String value = super.getHeader(name);
246 			if (value != null) {
247 				value = escape(value);
248 			}
249 			return value;
250 		}
251 		
252 		@Override
253 		public String getParameter(String name) {
254 			String value = super.getParameter(name);
255 			if (value != null) {
256 				value = escape(value);
257 			}
258 			return value;
259 		}
260 		
261 		@Override
262 		public String[] getParameterValues(String name) {
263 			String[] values = super.getParameterValues(name);
264 			if (values == null) {
265 				return null;
266 			}
267 			int len = values.length;
268 			String[] escapedValues = new String[len];
269 			for (int i = 0; i < len; i++) {
270 				escapedValues[i] = escape(values[i]);
271 			}
272 			return escapedValues;
273 		}
274 		
275 		@SuppressWarnings("rawtypes")
276         @Override
277 		public Map getParameterMap() {
278 			Map<?, ?> map = super.getParameterMap();
279 			if (mapEscaped) {//have already clean
280 				return map;
281 			}
282 			
283 			Iterator<?> it = (map.keySet() != null) ? map.keySet().iterator() : null;
284 			String key = null;
285 			String[] values = null;
286 			if (it != null) {
287 				while (it.hasNext()) {
288 					key = (String) it.next();
289 					if (key != null) {
290 						values = (String[]) map.get(key);
291 						for (int i = 0; i < values.length; i++) {
292 							values[i] = escape(values[i]);
293 						}
294 					}
295 				}
296 			}
297 			
298 			mapEscaped = true;
299 			return map;
300 		}
301 	}
302 }