@@ -140,8 +140,8 @@ struct Executor::Impl {
140140
141141 ExecutionContext setupExecutionContext (int rank, void * sendbuff, void * recvbuff, size_t inputMessageSize,
142142 size_t outputMessageSize, size_t constSrcOffset, size_t constDstOffset,
143- size_t sendBufferSize , size_t recvBufferSize , const ExecutionPlan& plan) {
144- ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize , plan.impl_ ->name };
143+ size_t sendMemRange , size_t recvMemRange , const ExecutionPlan& plan) {
144+ ExecutionContextKey key = {sendbuff, recvbuff, sendMemRange, recvMemRange , plan.impl_ ->name };
145145 DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset};
146146 if (this ->contexts .find (key) != this ->contexts .end ()) {
147147 auto & devicePlans = this ->contexts [key].deviceExecutionPlans ;
@@ -167,7 +167,9 @@ struct Executor::Impl {
167167 plan.impl_ ->loadExecutionPlan (inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);
168168
169169 ExecutionContext context;
170- size_t scratchBufferSize = plan.impl_ ->getScratchBufferSize (rank, sendBufferSize, recvBufferSize);
170+ size_t maxScratchBufferSize = plan.impl_ ->getMaxScratchBufferSize (rank);
171+ size_t scratchBufferSize =
172+ std::min (plan.impl_ ->getScratchBufferSize (rank, sendMemRange, recvMemRange), maxScratchBufferSize);
171173 std::shared_ptr<char > scratchBuffer;
172174 if (isNvlsSupported ()) {
173175 scratchBuffer = allocSharedPhysicalCuda<char >(scratchBufferSize);
@@ -179,8 +181,8 @@ struct Executor::Impl {
179181 context.proxyService = std::make_shared<ProxyService>();
180182 context.nthreadsPerBlock = plan.impl_ ->getNThreadsPerBlock ();
181183 this ->setupConnections (context, rank, plan);
182- this ->setupRegisteredMemories (context, sendbuff, recvbuff, sendBufferSize, recvBufferSize , rank, plan);
183- this ->setupChannels (context, sendbuff, recvbuff, sendBufferSize, recvBufferSize , rank, plan);
184+ this ->setupRegisteredMemories (context, sendbuff, recvbuff, sendMemRange, recvMemRange , rank, plan);
185+ this ->setupChannels (context, sendbuff, recvbuff, sendMemRange, recvMemRange , rank, plan);
184186 this ->setupNvlsChannels (context, sendbuff, recvbuff, rank, plan);
185187 this ->setupDeviceExecutionPlan (context, devicePlanKey, rank, plan);
186188 context.deviceExecutionPlansBuffers [devicePlanKey] =
@@ -438,16 +440,16 @@ Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<
438440void Executor::execute (int rank, void * sendbuff, void * recvbuff, size_t sendBuffSize,
439441 [[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
440442 cudaStream_t stream, PacketType packetType) {
441- size_t sendBytes, recvBytes ;
443+ size_t sendMemRange, recvMemRange ;
442444 CUdeviceptr sendBasePtr, recvBasePtr;
443- MSCCLPP_CUTHROW (cuMemGetAddressRange (&sendBasePtr, &sendBytes , (CUdeviceptr)sendbuff));
444- MSCCLPP_CUTHROW (cuMemGetAddressRange (&recvBasePtr, &recvBytes , (CUdeviceptr)recvbuff));
445+ MSCCLPP_CUTHROW (cuMemGetAddressRange (&sendBasePtr, &sendMemRange , (CUdeviceptr)sendbuff));
446+ MSCCLPP_CUTHROW (cuMemGetAddressRange (&recvBasePtr, &recvMemRange , (CUdeviceptr)recvbuff));
445447 size_t offsetIn = (char *)sendbuff - (char *)sendBasePtr;
446448 size_t offsetOut = (char *)recvbuff - (char *)recvBasePtr;
447449
448450 ExecutionContext context =
449451 this ->impl_ ->setupExecutionContext (rank, (void *)sendBasePtr, (void *)recvBasePtr, sendBuffSize, recvBuffSize,
450- offsetIn, offsetOut, sendBytes, recvBytes , plan);
452+ offsetIn, offsetOut, sendMemRange, recvMemRange , plan);
451453 this ->impl_ ->launchKernel (context, rank, sendbuff, recvbuff, dataType, stream, packetType);
452454}
453455
0 commit comments