`

防止重复提交和esponse重写

阅读更多

1.为了统一防止重复提交的filter重写

 

package filter;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CopyOnWriteArraySet;


public class PreSubmitFilter implements Filter{
    private String encoding;
    private boolean forceEncoding = false;
    Map<String, Set<String>> currentTokens = new HashMap<String, Set<String>>();
    private static final Logger logger = LoggerFactory.getLogger(PreSubmitFilter.class);

    private Set<String> getTokens(String url) {
        Set<String> tokens= currentTokens.get(url);
        if(tokens==null){
            synchronized (this){
                tokens= currentTokens.get(url);
                if(tokens==null){
                    currentTokens.put(url, new CopyOnWriteArraySet<String>());
                    tokens = currentTokens.get(url);
                }
            }
        }
        return tokens;
    }
    private boolean isValid(String url, String token) {
        Set<String> tokens = getTokens(url);
        if (tokens == null) {
            return true;
        }
        if (tokens.contains(token)) {
            return true;
        }
        return false;
    }
    private void putToken(String url,String token) {
        if(token!=null){
            getTokens(url).add(token);
        }


    }
    private void resetToken(String url, String token) {
        if (token != null) {
            Set<String> tokens = getTokens(url);
            try {
                tokens.remove(token);
            } catch (Exception e) {
            }
        }
    }

    private String formTag = "</form>";



    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletResponse response =(HttpServletResponse) servletResponse;
        HttpServletRequest request =(HttpServletRequest) servletRequest;
        MyResponseWrapper responseWrapper = new MyResponseWrapper(response);
        String contentType = responseWrapper.getResponse().getContentType();
        String url = request.getRequestURI();
        HttpSession session = request.getSession();
        logger.info("url = " + url + ",contentType:" + contentType);
        String method = request.getMethod();
        if ("POST".equals(method)) {

            String tokenInSession=(String)session.getAttribute("submitToken");
            if(tokenInSession!=null){
                session.removeAttribute("submitToken");
                putToken(url,tokenInSession);
            }
            String tokenInput = request.getParameter("submitToken");
            if (isValid(url, tokenInput)) {
                resetToken(url, tokenInput);
                filterChain.doFilter(request,response);
                logger.info("验证通过");
                return;
            } else {
                logger.info("验证不通过");
                response.setContentType("text/html; charset=UTF-8");
                response.getWriter().write("请不要重复提交页面。");
                return;
            }

        }
        filterChain.doFilter(request, responseWrapper);
        char[] chars = responseWrapper.getMyWriter().toCharArray();
        if (chars != null && chars.length > 0) {
            String content = new String(chars);
            if (content.indexOf(formTag) > 0 && content.indexOf("name='submitToken'") < 0) {
                String uuid = UUID.randomUUID().toString();
                session.setAttribute("submitToken",uuid);
                String tokenInput = "<input type='hidden' name='submitToken' value='" + uuid + "'/>";
                content = content.replace(formTag, tokenInput + formTag);
            }
            response.getWriter().write(content);
        }
        byte[] bytes = responseWrapper.getMyByte();
        if (bytes != null && bytes.length > 0) {
            response.getOutputStream().write(bytes);
        }
    }

    @Override
    public void destroy() {

    }
}

 

 

2.ServletOutputStream重写,用于写二进制内容

 

package filter;



import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import java.io.ByteArrayOutputStream;
import java.io.IOException;

public class MyOutputStream extends ServletOutputStream {
    private HttpServletResponse outputStream;
    ByteArrayOutputStream byteArrayOutputStream;
    public MyOutputStream(HttpServletResponse outputStream) {
        this.outputStream = outputStream;
        byteArrayOutputStream = new ByteArrayOutputStream();
    }

    @Override
    public void write(int b) throws IOException {
        byteArrayOutputStream.write(b);
    }

    @Override
    public void write(byte[] b, int off, int len) throws IOException {
        byteArrayOutputStream.write(b, off, len);
    }

    @Override
    public void write(byte[] b) throws IOException {
        byteArrayOutputStream.write(b);
    }

    @Override
    public void flush() throws IOException {
        byteArrayOutputStream.flush();
    }

    @Override
    public void close() throws IOException {
        byte[] content = byteArrayOutputStream.toByteArray();
        ServletOutputStream out = outputStream.getOutputStream();
        out.write(content);
        out.close();
    }

    public byte[] getByte(){
        return byteArrayOutputStream.toByteArray();
    }
}

 3.用于复写HttpServletResponseWrapper

 

 

package filter;

import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.CharArrayWriter;
import java.io.IOException;
import java.io.PrintWriter;

public class MyResponseWrapper extends HttpServletResponseWrapper {



    private MyOutputStream bout;

    private CharArrayWriter charArrayWriter = new CharArrayWriter();
    private HttpServletResponse httpServletResponse;

    public MyResponseWrapper(HttpServletResponse response) {
        super(response);
        this.httpServletResponse =response;
    }
//字符内容来调用
    @Override
    public PrintWriter getWriter() throws IOException {
       return new PrintWriter(charArrayWriter);
    }

//二进制内容调用
    @Override
    public ServletOutputStream getOutputStream() throws IOException {
        if(bout==null){
            bout = new MyOutputStream(httpServletResponse);
        }
        return bout;
    }

    public CharArrayWriter getMyWriter(){
        return charArrayWriter;
    }

    public byte[] getMyByte(){
        return bout!=null?bout.getByte():null;
    }
}

 

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics