Skip to content

Commit ee75caf

Browse files
authored
Reduce memory usage for scratch buffer (#403)
In the executor, we allocate the scratch buffer based on `sendMemRange`. However, for certain execution plans, this allocation may be unsuitable, as the plan does not support messages of this size. To avoid allocating to much data and cause OOM error, set scratch buffer size to `min(scratchBufferSize(maxMessageSizeSupportedForPlan), scratchBufferSize(sendMemRange))`
1 parent 01fd813 commit ee75caf

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

.azure-pipelines/nccl-api-test.yaml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,24 @@ jobs:
156156
mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/apps/nccl/libmscclpp_nccl.so -x NCCL_DEBUG=WARN -x MSCCLPP_EXECUTION_PLAN_DIR=/root/mscclpp/msccl-users/execution-files /root/nccl-tests/build/all_reduce_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20"'
157157
workingDirectory: '$(System.DefaultWorkingDirectory)'
158158

159+
- task: Bash@3
160+
name: RunNcclGatherTest
161+
displayName: Run NCCL Allreduce Test
162+
inputs:
163+
targetType: 'inline'
164+
script: |
165+
set -e
166+
HOSTFILE=$(System.DefaultWorkingDirectory)/mscclpp/test/deploy/hostfile_ci
167+
ROOT_DIR=$(System.DefaultWorkingDirectory)/mscclpp
168+
SSH_OPTION="StrictHostKeyChecking=no"
169+
KeyFilePath=${SSHKEYFILE_SECUREFILEPATH}
170+
parallel-ssh -i -t 0 -h ${HOSTFILE} -x "-i ${KeyFilePath}" \
171+
-O $SSH_OPTION 'sudo docker exec -t mscclpp-test bash -c "\
172+
cd /root/mscclpp; \
173+
mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/apps/nccl/libmscclpp_nccl.so -x NCCL_DEBUG=WARN -x MSCCLPP_EXECUTION_PLAN_DIR=/root/mscclpp/msccl-users/execution-files /root/nccl-tests/build/all_gather_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20"'
174+
workingDirectory: '$(System.DefaultWorkingDirectory)'
175+
176+
159177
- task: AzureCLI@2
160178
name: StopVMSS
161179
displayName: Deallocate VMSS

src/executor/execution_plan.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c
165165
}
166166
return std::vector<BufferType>(bufferTypes.begin(), bufferTypes.end());
167167
}
168+
168169
size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const {
169-
size_t sizePerRank;
170+
size_t sizePerRank = 0;
170171
if (this->inputChunks.at(rank) != 0)
171172
sizePerRank = inputSize / this->inputChunks.at(rank);
172173
else if (this->outputChunks.at(rank) != 0)
@@ -179,6 +180,23 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
179180
}
180181
return sizePerRank * this->scratchChunks.at(rank);
181182
}
183+
184+
size_t ExecutionPlan::Impl::getMaxScratchBufferSize(int rank) const {
185+
if (this->maxMessageSize == std::numeric_limits<uint64_t>::max()) {
186+
return std::numeric_limits<size_t>::max();
187+
}
188+
size_t sizePerChunk = 0;
189+
if (this->inputChunks.at(rank) != 0)
190+
sizePerChunk = maxMessageSize / this->inputChunks.at(rank);
191+
else if (this->outputChunks.at(rank) != 0)
192+
sizePerChunk = maxMessageSize / this->outputChunks.at(rank);
193+
else
194+
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);
195+
196+
return this->getScratchBufferSize(rank, sizePerChunk * this->inputChunks.at(rank),
197+
sizePerChunk * this->outputChunks.at(rank));
198+
}
199+
182200
std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
183201
return this->operations.at(rank)[threadblock];
184202
}

src/executor/executor.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ struct Executor::Impl {
140140

141141
ExecutionContext setupExecutionContext(int rank, void* sendbuff, void* recvbuff, size_t inputMessageSize,
142142
size_t outputMessageSize, size_t constSrcOffset, size_t constDstOffset,
143-
size_t sendBufferSize, size_t recvBufferSize, const ExecutionPlan& plan) {
144-
ExecutionContextKey key = {sendbuff, recvbuff, sendBufferSize, recvBufferSize, plan.impl_->name};
143+
size_t sendMemRange, size_t recvMemRange, const ExecutionPlan& plan) {
144+
ExecutionContextKey key = {sendbuff, recvbuff, sendMemRange, recvMemRange, plan.impl_->name};
145145
DeviceExecutionPlanKey devicePlanKey = {inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset};
146146
if (this->contexts.find(key) != this->contexts.end()) {
147147
auto& devicePlans = this->contexts[key].deviceExecutionPlans;
@@ -167,7 +167,9 @@ struct Executor::Impl {
167167
plan.impl_->loadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset);
168168

169169
ExecutionContext context;
170-
size_t scratchBufferSize = plan.impl_->getScratchBufferSize(rank, sendBufferSize, recvBufferSize);
170+
size_t maxScratchBufferSize = plan.impl_->getMaxScratchBufferSize(rank);
171+
size_t scratchBufferSize =
172+
std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange, recvMemRange), maxScratchBufferSize);
171173
std::shared_ptr<char> scratchBuffer;
172174
if (isNvlsSupported()) {
173175
scratchBuffer = allocSharedPhysicalCuda<char>(scratchBufferSize);
@@ -179,8 +181,8 @@ struct Executor::Impl {
179181
context.proxyService = std::make_shared<ProxyService>();
180182
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
181183
this->setupConnections(context, rank, plan);
182-
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
183-
this->setupChannels(context, sendbuff, recvbuff, sendBufferSize, recvBufferSize, rank, plan);
184+
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
185+
this->setupChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
184186
this->setupNvlsChannels(context, sendbuff, recvbuff, rank, plan);
185187
this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan);
186188
context.deviceExecutionPlansBuffers[devicePlanKey] =
@@ -438,16 +440,16 @@ Executor::Executor(std::shared_ptr<Communicator> comm) : impl_(std::make_unique<
438440
void Executor::execute(int rank, void* sendbuff, void* recvbuff, size_t sendBuffSize,
439441
[[maybe_unused]] size_t recvBuffSize, DataType dataType, const ExecutionPlan& plan,
440442
cudaStream_t stream, PacketType packetType) {
441-
size_t sendBytes, recvBytes;
443+
size_t sendMemRange, recvMemRange;
442444
CUdeviceptr sendBasePtr, recvBasePtr;
443-
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff));
444-
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
445+
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendMemRange, (CUdeviceptr)sendbuff));
446+
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvMemRange, (CUdeviceptr)recvbuff));
445447
size_t offsetIn = (char*)sendbuff - (char*)sendBasePtr;
446448
size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
447449

448450
ExecutionContext context =
449451
this->impl_->setupExecutionContext(rank, (void*)sendBasePtr, (void*)recvBasePtr, sendBuffSize, recvBuffSize,
450-
offsetIn, offsetOut, sendBytes, recvBytes, plan);
452+
offsetIn, offsetOut, sendMemRange, recvMemRange, plan);
451453
this->impl_->launchKernel(context, rank, sendbuff, recvbuff, dataType, stream, packetType);
452454
}
453455

src/include/execution_plan.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ struct ExecutionPlan::Impl {
7373
std::vector<int> getConnectedPeers(int rank) const;
7474
std::vector<BufferType> getConnectedBufferTypes(int rank) const;
7575
size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;
76+
size_t getMaxScratchBufferSize(int rank) const;
7677
std::vector<Operation> getOperations(int rank, int threadblock) const;
7778
int getThreadblockCount(int rank) const;
7879
int getNThreadsPerBlock() const;

0 commit comments

Comments
 (0)