`
wuhuajun
  • 浏览: 93838 次
  • 性别: Icon_minigender_1
  • 来自: 上海
社区版块
存档分类
最新评论

thift-transport层

    博客分类:
  • java
 
阅读更多

TTransport=>TIOStreamTransport=>TSocket

重要参数设置:

 

socket_.setSoLinger(false, 0);

socket_.setTcpNoDelay(true);

socket_.setSoTimeout(timeout_);//客户端读取超时时间

socket_.connect(new InetSocketAddress(host_, port_), timeout_);//客户端连接超时时间

 

 

 

public abstract class TTransport {
  //底层实现socket_.isConnected()
  public abstract boolean isOpen();
  
   //底层实现socket_.isConnected()
  public boolean peek() {
    return isOpen();
  }

  //socket_.connect(new InetSocketAddress(host_, port_), timeout_);
  public abstract void open()
    throws TTransportException;

  //socket_.close();
  public abstract void close();

  //读取len个字节到buf buf中起始位置off  返回实际读取的字节数
  public abstract int read(byte[] buf, int off, int len)
    throws TTransportException;
  
  //读满len个字节到buf 起始位置off
  public int readAll(byte[] buf, int off, int len)
    throws TTransportException {
    int got = 0;
    int ret = 0;
    while (got < len) {
      ret = read(buf, off+got, len-got);
      if (ret <= 0) {
        throw new TTransportException(
            "Cannot read. Remote side has closed. Tried to read "
                + len
                + " bytes, but only got "
                + got
                + " bytes. (This is often indicative of an internal error on the server side. Please check your server logs.)");
      }
      got += ret;
    }
    return got;
  }

  //写buf到流
  public void write(byte[] buf) throws TTransportException {
    write(buf, 0, buf.length);
  }

  //写buf到流起始位置off 长度len
  public abstract void write(byte[] buf, int off, int len)
    throws TTransportException;

  //flush刷新
  public void flush()
    throws TTransportException {}
   
  //下面基于NIO实现
  //获取到buffer
  public byte[] getBuffer() {
    return null;
  }

  //获取buffer位置
  public int getBufferPosition() {
    return 0;
  }

  //读取buffer中剩余的字节
  public int getBytesRemainingInBuffer() {
    return -1;
  }

  public void consumeBuffer(int len) {}
}

 TTransport=>TNonblockingTransport=>TNonblockingSocket 

 

重要参数:

Selector selector = SelectorProvider.provider().openSelector(); //创建选择器

SocketChannel socketChannel = SocketChannel.open();//创建ScoketChannel 同Socket 里面包装了Socket

socketChannel.configureBlocking(false);//设置为非阻塞

 

Socket socket = socketChannel.socket();//获取包装的Socket 设置底层的行为

socket.setSoLinger(false, 0);

socket.setTcpNoDelay(true);

setTimeout(timeout);

socketChannel_.register(selector,SelectionKey.OP_CONNECT);//注册连接事件

 

 

public abstract class TNonblockingTransport extends TTransport {

 

  //底层实现socketChannel_.connect(socketAddress_)

  public abstract boolean startConnect() throws IOException;

 

  //底层实现socketChannel_.finishConnect()

  public abstract boolean finishConnect() throws IOException;//完成后可以注册读.写事件

  

  public abstract SelectionKey registerSelector(Selector selector, int interests) throws IOException;

 

  public abstract int read(ByteBuffer buffer) throws IOException;

 

  public abstract int write(ByteBuffer buffer) throws IOException;

}

 

NIO客户端Demo

package sampleNio; 

import java.io.IOException; 
import java.net.InetAddress; 
import java.net.InetSocketAddress; 
import java.nio.ByteBuffer; 
import java.nio.channels.SelectionKey; 
import java.nio.channels.Selector; 
import java.nio.channels.SocketChannel; 
import java.nio.channels.spi.SelectorProvider; 
import java.util.Iterator; 

/** 
* @author jason 
* 
*/ 
public class NioClient implements Runnable { 
    private InetAddress hostAddress; 
    private int port; 
    private Selector selector; 
    private ByteBuffer readBuffer = ByteBuffer.allocate(8192); 
    private ByteBuffer outBuffer = ByteBuffer.wrap("nice to meet you" 
            .getBytes()); 

    public NioClient(InetAddress hostAddress, int port) throws IOException { 
        this.hostAddress = hostAddress; 
        this.port = port; 
        initSelector(); 
    } 

    public static void main(String[] args) { 
        try { 
            NioClient client = new NioClient( 
                    InetAddress.getByName("localhost"), 9090); 
            new Thread(client).start(); 

        } catch (IOException e) { 
            e.printStackTrace(); 
        } 
    } 

    @Override 
    public void run() { 
        while (true) { 
            try { 
                selector.select(); 

                Iterator<?> selectedKeys = selector.selectedKeys().iterator(); 
                while (selectedKeys.hasNext()) { 
                    SelectionKey key = (SelectionKey) selectedKeys.next(); 
                    selectedKeys.remove(); 

                    if (!key.isValid()) { 
                        continue; 
                    } 

                    if (key.isConnectable()) { 
                        finishConnection(key); 
                    } else if (key.isReadable()) { 
                        read(key); 
                    } else if (key.isWritable()) { 
                        write(key); 
                    } 

                } 

            } catch (Exception e) { 
                e.printStackTrace(); 
            } 
        } 

    } 

    private void initSelector() throws IOException { 
        // 创建一个selector 
        selector = SelectorProvider.provider().openSelector(); 
        // 打开SocketChannel 
        SocketChannel socketChannel = SocketChannel.open(); 
        // 设置为非阻塞 
        socketChannel.configureBlocking(false); 
        // 连接指定IP和端口的地址 
        socketChannel 
                .connect(new InetSocketAddress(this.hostAddress, this.port)); 
        // 用selector注册套接字,并返回对应的SelectionKey,同时设置Key的interest set为监听服务端已建立连接的事件 
        socketChannel.register(selector, SelectionKey.OP_CONNECT); 
    } 

    private void finishConnection(SelectionKey key) throws IOException { 
        SocketChannel socketChannel = (SocketChannel) key.channel(); 
        try { 
            // 判断连接是否建立成功,不成功会抛异常 
            socketChannel.finishConnect(); 
        } catch (IOException e) { 
            key.cancel(); 
            return; 
        } 
        // 设置Key的interest set为OP_WRITE事件 
        key.interestOps(SelectionKey.OP_WRITE); 
    } 

    /** 
     * 处理read 
     * 
     * @param key 
     * @throws IOException 
     */ 
    private void read(SelectionKey key) throws IOException { 
        SocketChannel socketChannel = (SocketChannel) key.channel(); 
        readBuffer.clear(); 
        int numRead; 
        try { 
            numRead = socketChannel.read(readBuffer); 
        } catch (Exception e) { 
            key.cancel(); 
            socketChannel.close(); 
            return; 
        } 
        if (numRead == 1) { 
            System.out.println("close connection"); 
            socketChannel.close(); 
            key.cancel(); 
            return; 
        } 
        // 处理响应 
        handleResponse(socketChannel, readBuffer.array(), numRead); 
    } 

    /** 
     * 处理响应 
     * 
     * @param socketChannel 
     * @param data 
     * @param numRead 
     * @throws IOException 
     */ 
    private void handleResponse(SocketChannel socketChannel, byte[] data, 
            int numRead) throws IOException { 
        byte[] rspData = new byte[numRead]; 
        System.arraycopy(data, 0, rspData, 0, numRead); 
        System.out.println(new String(rspData)); 
        socketChannel.close(); 
        socketChannel.keyFor(selector).cancel(); 
    } 

    /** 
     * 处理write 
     * 
     * @param key 
     * @throws IOException 
     */ 
    private void write(SelectionKey key) throws IOException { 
        SocketChannel socketChannel = (SocketChannel) key.channel(); 
        socketChannel.write(outBuffer); 
        if (outBuffer.remaining() > 0) { 
            return; 
        } 
        // 设置Key的interest set为OP_READ事件 
        key.interestOps(SelectionKey.OP_READ); 
    } 

} 

 TServerTransport=>TNonblockingServerTransport=>TNonblockingServerSocket

public abstract class TServerTransport {

  public abstract void listen() throws TTransportException;

  public final TTransport accept() throws TTransportException {
    TTransport transport = acceptImpl();
    if (transport == null) {
      throw new TTransportException("accept() may not return NULL");
    }
    return transport;
  }

  public abstract void close();

  protected abstract TTransport acceptImpl() throws TTransportException;

  /**
   * Optional method implementation. This signals to the server transport
   * that it should break out of any accept() or listen() that it is currently
   * blocked on. This method, if implemented, MUST be thread safe, as it may
   * be called from a different thread context than the other TServerTransport
   * methods.
   */
  public void interrupt() {}

}

 

public abstract class TNonblockingServerTransport extends TServerTransport {

  public abstract void registerSelector(Selector selector);
}

 

public class TNonblockingServerSocket extends TNonblockingServerTransport {
  private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingServerTransport.class.getName());

  /**
   * This channel is where all the nonblocking magic happens.
   */
  private ServerSocketChannel serverSocketChannel = null;

  /**
   * Underlying ServerSocket object
   */
  private ServerSocket serverSocket_ = null;

  /**
   * Timeout for client sockets from accept
   */
  private int clientTimeout_ = 0;

  /**
   * Creates just a port listening server socket
   */
  public TNonblockingServerSocket(int port) throws TTransportException {
    this(port, 0);
  }

  /**
   * Creates just a port listening server socket
   */
  public TNonblockingServerSocket(int port, int clientTimeout) throws TTransportException {
    this(new InetSocketAddress(port), clientTimeout);
  }

  public TNonblockingServerSocket(InetSocketAddress bindAddr) throws TTransportException {
    this(bindAddr, 0);
  }

  public TNonblockingServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTransportException {
    clientTimeout_ = clientTimeout;
    try {
      serverSocketChannel = ServerSocketChannel.open();
      serverSocketChannel.configureBlocking(false);

      // Make server socket
      serverSocket_ = serverSocketChannel.socket();
      // Prevent 2MSL delay problem on server restarts
      serverSocket_.setReuseAddress(true);
      // Bind to listening port
      serverSocket_.bind(bindAddr);
    } catch (IOException ioe) {
      serverSocket_ = null;
      throw new TTransportException("Could not create ServerSocket on address " + bindAddr.toString() + ".");
    }
  }

  public void listen() throws TTransportException {
    // Make sure not to block on accept
    if (serverSocket_ != null) {
      try {
        serverSocket_.setSoTimeout(0);
      } catch (SocketException sx) {
        sx.printStackTrace();
      }
    }
  }

  protected TNonblockingSocket acceptImpl() throws TTransportException {
    if (serverSocket_ == null) {
      throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket.");
    }
    try {
      SocketChannel socketChannel = serverSocketChannel.accept();
      if (socketChannel == null) {
        return null;
      }

      TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel);
      tsocket.setTimeout(clientTimeout_);
      return tsocket;
    } catch (IOException iox) {
      throw new TTransportException(iox);
    }
  }
  
  //注册连接接收事件
  public void registerSelector(Selector selector) {
    try {
      // Register the server socket channel, indicating an interest in
      // accepting new connections
      serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT);
    } catch (ClosedChannelException e) {
      // this shouldn't happen, ideally...
      // TODO: decide what to do with this.
    }
  }

  public void close() {
    if (serverSocket_ != null) {
      try {
        serverSocket_.close();
      } catch (IOException iox) {
        LOGGER.warn("WARNING: Could not close server socket: " + iox.getMessage());
      }
      serverSocket_ = null;
    }
  }
  
   //可能存在线程安全问题 虽然java文档声称线程安全的
  public void interrupt() {
    // The thread-safeness of this is dubious, but Java documentation suggests
    // that it is safe to do this from a different thread context
    close();
  }

}

 

 TTransport=>TFramedTransport

//封装消息体之前带的帧大小4个字节。

public class TFramedTransport extends TTransport {

 

protected static final int DEFAULT_MAX_LENGTH = 16384000;

 

private int maxLength_;

 

/**

* Underlying transport

*/

private TTransport transport_ = null;

 

/**

* Buffer for output

*/

private final TByteArrayOutputStream writeBuffer_ = new TByteArrayOutputStream(

1024);

 

/**

* Buffer for input

*/

private TMemoryInputTransport readBuffer_ = new TMemoryInputTransport(

new byte[0]);

 

public static class Factory extends TTransportFactory {

private int maxLength_;

 

public Factory() {

maxLength_ = TFramedTransport.DEFAULT_MAX_LENGTH;

}

 

public Factory(int maxLength) {

maxLength_ = maxLength;

}

 

@Override

public TTransport getTransport(TTransport base) {

return new TFramedTransport(base, maxLength_);

}

}

 

/**

* Constructor wraps around another transport

*/

public TFramedTransport(TTransport transport, int maxLength) {

transport_ = transport;

maxLength_ = maxLength;

}

 

public TFramedTransport(TTransport transport) {

transport_ = transport;

maxLength_ = TFramedTransport.DEFAULT_MAX_LENGTH;

}

 

public void open() throws TTransportException {

transport_.open();

}

 

public boolean isOpen() {

return transport_.isOpen();

}

 

public void close() {

transport_.close();

}

 

public int read(byte[] buf, int off, int len) throws TTransportException {

if (readBuffer_ != null) {

int got = readBuffer_.read(buf, off, len);

if (got > 0) {

return got;

}

}

 

// Read another frame of data

readFrame();

 

return readBuffer_.read(buf, off, len);

}

 

@Override

public byte[] getBuffer() {

return readBuffer_.getBuffer();

}

 

@Override

public int getBufferPosition() {

return readBuffer_.getBufferPosition();

}

 

@Override

public int getBytesRemainingInBuffer() {

return readBuffer_.getBytesRemainingInBuffer();

}

 

@Override

public void consumeBuffer(int len) {

readBuffer_.consumeBuffer(len);

}

 

private final byte[] i32buf = new byte[4];

 

private void readFrame() throws TTransportException {

transport_.readAll(i32buf, 0, 4);

int size = decodeFrameSize(i32buf);

 

if (size < 0) {

throw new TTransportException("Read a negative frame size (" + size

+ ")!");

}

 

if (size > maxLength_) {

throw new TTransportException("Frame size (" + size

+ ") larger than max length (" + maxLength_ + ")!");

}

 

byte[] buff = new byte[size];

transport_.readAll(buff, 0, size);

readBuffer_.reset(buff);

}

 

public void write(byte[] buf, int off, int len) throws TTransportException {

writeBuffer_.write(buf, off, len);

}

 

@Override

public void flush() throws TTransportException {

byte[] buf = writeBuffer_.get();

int len = writeBuffer_.len();

writeBuffer_.reset();

 

encodeFrameSize(len, i32buf);

transport_.write(i32buf, 0, 4);

transport_.write(buf, 0, len);

transport_.flush();

}

 

public static final void encodeFrameSize(final int frameSize,

final byte[] buf) {

buf[0] = (byte) (0xff & (frameSize >> 24));

buf[1] = (byte) (0xff & (frameSize >> 16));

buf[2] = (byte) (0xff & (frameSize >> 8));

buf[3] = (byte) (0xff & (frameSize));

}

 

public static final int decodeFrameSize(final byte[] buf) {

return ((buf[0] & 0xff) << 24) | ((buf[1] & 0xff) << 16)

| ((buf[2] & 0xff) << 8) | ((buf[3] & 0xff));

}

 

public static void main(String[] args) {

int number = 99999;

byte[] buf = new byte[4];

encodeFrameSize(number,buf);

System.out.println(decodeFrameSize(buf));

}

}

 

 

 

 

 

 

 

 

 

 

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics