diff --git a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java index af2bc6321cd..e32ddfee63b 100644 --- a/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java +++ b/lib/java/src/main/java/org/apache/thrift/server/AbstractNonblockingServer.java @@ -290,10 +290,12 @@ public FrameBuffer( selectThread_ = selectThread; buffer_ = ByteBuffer.allocate(4); - frameTrans_ = new TMemoryInputTransport(); + frameTrans_ = new TMemoryInputTransport(trans_.getConfiguration()); response_ = new TByteArrayOutputStream(); inTrans_ = inputTransportFactory_.getTransport(frameTrans_); - outTrans_ = outputTransportFactory_.getTransport(new TIOStreamTransport(response_)); + outTrans_ = + outputTransportFactory_.getTransport( + new TIOStreamTransport(trans_.getConfiguration(), response_)); inProt_ = inputProtocolFactory_.getProtocol(inTrans_); outProt_ = outputProtocolFactory_.getProtocol(outTrans_); diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java index 6026390413c..36d7dd84ff0 100644 --- a/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java +++ b/lib/java/src/main/java/org/apache/thrift/transport/TEndpointTransport.java @@ -35,6 +35,10 @@ public void setMaxFrameSize(int maxFrameSize) { getConfiguration().setMaxFrameSize(maxFrameSize); } + public void setMaxMessageSize(int maxMessageSize) { + getConfiguration().setMaxMessageSize(maxMessageSize); + } + protected long knownMessageSize; protected long remainingMessageSize; diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java index 998379746e3..0bcf601f81f 100644 --- a/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java +++ b/lib/java/src/main/java/org/apache/thrift/transport/TNonblockingServerSocket.java @@ -49,6 +49,9 @@ public class TNonblockingServerSocket extends TNonblockingServerTransport { /** Limit for client sockets request size */ private int maxFrameSize_ = 0; + /** Max message size */ + private int maxMessageSize_ = 0; + public static class NonblockingAbstractServerSocketArgs extends AbstractServerTransportArgs {} @@ -93,6 +96,7 @@ public TNonblockingServerSocket(NonblockingAbstractServerSocketArgs args) throws TTransportException { clientTimeout_ = args.clientTimeout; maxFrameSize_ = args.maxFrameSize; + maxMessageSize_ = args.maxMessageSize; try { serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel.configureBlocking(false); @@ -135,6 +139,7 @@ public TNonblockingSocket accept() throws TTransportException { TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel); tsocket.setTimeout(clientTimeout_); tsocket.setMaxFrameSize(maxFrameSize_); + tsocket.setMaxMessageSize(maxMessageSize_); return tsocket; } catch (IOException iox) { throw new TTransportException(iox); diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java index e1056623e57..59cef201ede 100644 --- a/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java +++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerSocket.java @@ -38,6 +38,9 @@ public class TServerSocket extends TServerTransport { /** Timeout for client sockets from accept */ private int clientTimeout_ = 0; + /** Max message size */ + private int maxMessageSize_ = 0; + public static class ServerSocketTransportArgs extends AbstractServerTransportArgs { ServerSocket serverSocket; @@ -78,6 +81,7 @@ public TServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTran public TServerSocket(ServerSocketTransportArgs args) throws TTransportException { clientTimeout_ = args.clientTimeout; + maxMessageSize_ = args.maxMessageSize; if (args.serverSocket != null) { this.serverSocket_ = args.serverSocket; return; @@ -123,6 +127,7 @@ public TSocket accept() throws TTransportException { } TSocket socket = new TSocket(result); socket.setTimeout(clientTimeout_); + socket.setMaxMessageSize(maxMessageSize_); return socket; } diff --git a/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java index 47fa2513ca1..05a3f09b89f 100644 --- a/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java +++ b/lib/java/src/main/java/org/apache/thrift/transport/TServerTransport.java @@ -32,6 +32,7 @@ public abstract static class AbstractServerTransportArgs< int clientTimeout = 0; InetSocketAddress bindAddr; int maxFrameSize = TConfiguration.DEFAULT_MAX_FRAME_SIZE; + int maxMessageSize = TConfiguration.DEFAULT_MAX_MESSAGE_SIZE; public T backlog(int backlog) { this.backlog = backlog; @@ -57,6 +58,11 @@ public T maxFrameSize(int maxFrameSize) { this.maxFrameSize = maxFrameSize; return (T) this; } + + public T maxMessageSize(int maxMessageSize) { + this.maxMessageSize = maxMessageSize; + return (T) this; + } } public abstract void listen() throws TTransportException; diff --git a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java index 66a1e5f3b86..491f01b1d3a 100644 --- a/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java +++ b/lib/java/src/main/java/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java @@ -320,7 +320,8 @@ private void executeProcessing() { byte[] inputPayload = requestReader.getPayload(); requestReader.clear(); byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload; - TMemoryTransport memoryTransport = new TMemoryTransport(rawInput); + TMemoryTransport memoryTransport = + new TMemoryTransport(underlyingTransport.getConfiguration(), rawInput); TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport); TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport); diff --git a/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java b/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java index 74205c73592..c16f59cef75 100644 --- a/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java +++ b/lib/java/src/test/java/org/apache/thrift/server/TestThreadPoolServer.java @@ -23,10 +23,12 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.atomic.AtomicReference; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.transport.TServerSocket; import org.apache.thrift.transport.TServerTransport; import org.apache.thrift.transport.TSocket; +import org.apache.thrift.transport.TTransportException; import org.junit.jupiter.api.Test; import thrift.test.ThriftTest; @@ -35,7 +37,20 @@ public class TestThreadPoolServer { /** Test server is shut down properly even with some open clients. */ @Test public void testStopServerWithOpenClient() throws Exception { - TServerSocket serverSocket = new TServerSocket(0, 3000); + AtomicReference ref = new AtomicReference<>(); + TServerSocket serverSocket = + new TServerSocket( + new TServerSocket.ServerSocketTransportArgs() + .port(0) + .clientTimeout(3000) + .maxMessageSize(51200)) { + @Override + public TSocket accept() throws TTransportException { + TSocket socket = super.accept(); + ref.set(socket); + return socket; + } + }; TThreadPoolServer server = buildServer(serverSocket); Thread serverThread = new Thread(server::serve); serverThread.start(); @@ -44,6 +59,7 @@ public void testStopServerWithOpenClient() throws Exception { Thread.sleep(1000); // There is a thread listening to the client assertEquals(1, ((ThreadPoolExecutor) server.getExecutorService()).getActiveCount()); + assertEquals(51200, ref.get().getConfiguration().getMaxMessageSize()); // Trigger the server to stop, but it does not wait server.stop();