Skip to content
24 changes: 11 additions & 13 deletions src/executor/execution_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ namespace mscclpp {

template <typename PacketType>
void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag) {
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag) {
switch (dataType) {
case DataType::INT32:
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -23,7 +23,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::UINT32:
executionKernel<uint32_t><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -33,7 +33,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::FLOAT16:
executionKernel<half><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -43,7 +43,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::FLOAT32:
executionKernel<float><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
rank, (float*)src, (float*)dst, (float*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -53,7 +53,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -65,12 +65,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
}

template void ExecutionKernel::launchKernel<LL16Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, size_t scratchSize, DataType dataType,
DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag);
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
template void ExecutionKernel::launchKernel<LL8Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, size_t scratchSize, DataType dataType,
DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag);
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
} // namespace mscclpp
#endif
65 changes: 49 additions & 16 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ auto getOpType = [](const std::string& str) {
return mscclpp::OperationType::WAIT;
} else if (str == "flush") {
return mscclpp::OperationType::FLUSH;
} else if (str == "re") {
} else if (str == "reduce") {
return mscclpp::OperationType::REDUCE;
} else if (str == "rs") {
return mscclpp::OperationType::REDUCE_SEND;
Expand Down Expand Up @@ -100,7 +100,7 @@ std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfos(int rank, BufferTy
}

std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfosByDstRank(int rank, BufferType bufferType) const {
auto pred = [rank, bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; };
auto pred = [bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; };
return filter(this->channelInfosByDstRank.at(rank), pred);
}

Expand Down Expand Up @@ -148,7 +148,8 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c
}
return std::vector<BufferType>(bufferTypes.begin(), bufferTypes.end());
}
size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const {

void ExecutionPlan::Impl::calcScratchBufferSizeAndOffset(int rank, size_t inputSize, size_t outputSize, int flag) {
size_t sizePerRank;
if (this->inputChunks.at(rank) != 0)
sizePerRank = inputSize / this->inputChunks.at(rank);
Expand All @@ -157,11 +158,18 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);

this->scratchBufferSize = sizePerRank * this->scratchChunks.at(rank);
if (this->isUsingPacket) {
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
this->scratchBufferSize *= 2; /* data + flag */
}
if (this->isUsingDoubleScratchBuffer) {
this->scratchBufferSize *= 2; /* double buffer */
}
return sizePerRank * this->scratchChunks.at(rank);
this->scratchBufferOffset = (this->isUsingDoubleScratchBuffer && (flag % 2) == 0) ? (this->scratchBufferSize / 2) : 0;
}

size_t ExecutionPlan::Impl::getScratchBufferSize() const { return this->scratchBufferSize; }

std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
return this->operations.at(rank)[threadblock];
}
Expand All @@ -170,8 +178,9 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper

int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }

void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset,
size_t constDstOffset, int selfRank, size_t inputBufferSize,
size_t outputBufferSize, int flag) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -182,6 +191,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
this->isUsingPacket = true;
}
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand All @@ -195,11 +205,13 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,

this->inputSize = inputSize;
this->outputSize = outputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag);
this->setupOperations(gpus, constSrcOffset, constDstOffset);
}

void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset,
size_t constDstOffset, int selfRank, size_t inputBufferSize,
size_t outputBufferSize, int flag) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -209,6 +221,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output
if (protocol == "LL") {
this->isUsingPacket = true;
}
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand All @@ -221,7 +234,8 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output

this->inputSize = inputSize;
this->outputSize = outputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag);
this->setupOperations(gpus, constSrcOffset, constDstOffset);
}

// Construct the channel info. Step 1. Flatten SM and PROXY channels into separate vectors.
Expand Down Expand Up @@ -291,7 +305,16 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) {
}
}

void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffset, size_t constDstOffset) {
void ExecutionPlan::Impl::checkChannelsPerOperation(int channels) {
if (channels > MAX_CHANNEL_PER_OPERATION) {
throw Error("Executor plan has " + std::to_string(channels) +
" channels per operation, exceeding executor support (" +
std::to_string(MAX_CHANNEL_PER_OPERATION) + ")",
ErrorCode::ExecutorError);
}
}

void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffset, size_t constDstOffset) {
// setup threadblocks and operations
for (const auto& gpu : gpus) {
int rank = gpu["id"];
Expand All @@ -318,6 +341,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
}
if (op.contains("i_cids")) {
operation.nInputs = op["i_cids"].size();
checkChannelsPerOperation(operation.nInputs);
for (int i = 0; i < operation.nInputs; i++) {
BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]);
BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]);
Expand All @@ -326,42 +350,45 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]];
operation.inputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) +
(srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
(srcBufferType != BufferType::SCRATCH ? constSrcOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]);
}
}
// will have either srcs or i_cids
if (op.contains("srcs")) {
operation.nInputs = op["srcs"].size();
checkChannelsPerOperation(operation.nInputs);
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
for (int i = 0; i < operation.nInputs; i++) {
operation.inputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcs"][i]["off"]) +
(operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
(operation.inputBufferType != BufferType::SCRATCH ? constSrcOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]);
}
}
if (op.contains("o_cids")) {
operation.nOutputs = op["o_cids"].size();
checkChannelsPerOperation(operation.nOutputs);
for (int i = 0; i < operation.nOutputs; i++) {
BufferType srcBufferType = convertToBufferType(op["o_buff"]["src"]);
BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]);
operation.outputChannelIndexes[i] =
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]];
operation.outputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) +
(dstBufferType != BufferType::SCRATCH ? constDstOffset : 0);
(dstBufferType != BufferType::SCRATCH ? constDstOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]);
}
}
// will have either dsts or o_cids
if (op.contains("dsts")) {
operation.nOutputs = op["dsts"].size();
checkChannelsPerOperation(operation.nOutputs);
operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]);
for (int i = 0; i < operation.nOutputs; i++) {
operation.outputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dsts"][i]["off"]) +
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0);
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]);
}
}
Expand All @@ -370,13 +397,19 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
}
if (op.contains("srcoff")) {
operation.srcOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcoff"]);
if (operation.srcBufferType == BufferType::SCRATCH) {
operation.srcOffset += this->scratchBufferOffset;
}
chunkIndexes.push_back((uint32_t)op["srcoff"]);
}
if (op.contains("dstbuff")) {
operation.dstBufferType = convertToBufferType(op["dstbuff"]);
}
if (op.contains("dstoff")) {
operation.dstOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dstoff"]);
if (operation.dstBufferType == BufferType::SCRATCH) {
operation.dstOffset += this->scratchBufferOffset;
}
chunkIndexes.push_back((uint32_t)op["dstoff"]);
}
if (op.contains("cnt")) {
Expand Down
Loading
Loading