Skip to content

Commit 4bd1633

Browse files
committed
WIP
1 parent 9a7d518 commit 4bd1633

File tree

5 files changed

+189
-169
lines changed

5 files changed

+189
-169
lines changed

apps/nccl/src/nccl.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ static void registerCustomizedAlgo(ncclComm* commPtr) {
258258
collectionBuilder->addDefaultNativeAlgorithmBuilder("default_allreduce_packet",
259259
reinterpret_cast<uintptr_t>(commPtr->scratchBuffer_.get()),
260260
commPtr->scratchBufferSize_);
261+
collectionBuilder->addDefaultNativeAlgorithmBuilder("default_allreduce_nvls_packet",
262+
reinterpret_cast<uintptr_t>(commPtr->scratchBuffer_.get()),
263+
commPtr->scratchBufferSize_);
261264
}
262265

263266
static std::pair<int, int> getDeviceComputeCapability() {
@@ -330,12 +333,12 @@ static std::shared_ptr<mscclpp::Algorithm> algoSelector(
330333
useNvls = false;
331334
}
332335
#endif
333-
if (messageSize <= (1 << 15)) {
334-
return algoMapByCollective.at(collective).at("default_allreduce_allpair_packet");
335-
}
336336
if (messageSize <= (1 << 15) && useNvls) {
337337
return algoMapByCollective.at(collective).at("default_allreduce_nvls_packet");
338338
}
339+
if (messageSize <= (1 << 14)) {
340+
return algoMapByCollective.at(collective).at("default_allreduce_allpair_packet");
341+
}
339342
if (messageSize <= (1 << 16) || (messageSize <= (1 << 20) && !useNvlsWithZeroCopy)) {
340343
return algoMapByCollective.at(collective).at("default_allreduce_packet");
341344
}
Lines changed: 142 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,153 @@
1-
// #include "algorithm_utils.hpp"
2-
// #include "allreduce_common.hpp"
3-
// #include "allreduce_nvls_packet.hpp"
4-
// #include "debug.h"
1+
#include "algorithms/allreduce/allreduce_nvls_packet.hpp"
2+
#include "algorithms/allreduce/common.hpp"
3+
#include "algorithms/utils.hpp"
4+
#include "debug.h"
55

6-
// namespace mscclpp {
6+
namespace mscclpp {
7+
namespace algorithm {
78

8-
// inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize) {
9-
// int blockNum = 8;
10-
// int threadNum = 1024;
11-
// if (inputSize <= (1 << 13)) {
12-
// blockNum = 4;
13-
// threadNum = 512;
9+
__device__ uint32_t deviceFlag = 1;
10+
template <Algorithm::Op OpType, typename T>
11+
__global__ void __launch_bounds__(1024, 1)
12+
allreduceNvlsPacket([[maybe_unused]] const T* input, [[maybe_unused]] T* scratch, [[maybe_unused]] T* output,
13+
[[maybe_unused]] mscclpp::DeviceHandle<mscclpp::SwitchChannel>* multicast,
14+
[[maybe_unused]] size_t nelems, [[maybe_unused]] size_t scratchBufferSize,
15+
[[maybe_unused]] int rank, [[maybe_unused]] int worldSize, [[maybe_unused]] LL8Packet* flags) {
16+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
17+
uint32_t flag = deviceFlag;
18+
// __syncthreads();
19+
// if (threadIdx.x == 0) {
20+
// flags[blockIdx.x].write(0, flag);
1421
// }
15-
// return {blockNum, threadNum};
16-
// }
1722

18-
// template <Op OpType, typename T>
19-
// struct AllreduceNvlsPacketAdapter {
20-
// static cudaError_t call(const void* input, void* scratch, void* output, void*, void*,
21-
// mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsChannels,
22-
// mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t scratchBufferSize,
23-
// int rank, int, int worldSize, size_t inputSize, cudaStream_t stream, uint32_t* deviceFlag,
24-
// uint32_t*, uint32_t*, uint32_t, int nBlocks, int nThreadsPerBlock) {
25-
// allreduceNvlsPacket<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
26-
// (const T*)input, (T*)scratch, (T*)output, nvlsChannels, inputSize / sizeof(T), scratchBufferSize, rank,
27-
// worldSize, deviceFlag);
28-
// return cudaGetLastError();
23+
size_t scratchBaseOffset = (flag % 2) ? scratchBufferSize / 2 : 0;
24+
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
25+
uint32_t nPktPerRank = nelems / worldSize / (sizeof(mscclpp::LL8Packet::Payload) / sizeof(T));
26+
mscclpp::LL8Packet* multiPkt =
27+
(mscclpp::LL8Packet*)((char*)multicast->mcPtr + scratchBaseOffset) + rank * worldSize * nPktPerRank;
28+
uint* src = (uint*)(input);
29+
uint* dst = (uint*)(output);
30+
mscclpp::LL8Packet* scratchPkt = (mscclpp::LL8Packet*)((char*)scratch + scratchBaseOffset);
31+
for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) {
32+
mscclpp::LL8Packet pkt(src[i], flag);
33+
mscclpp::SwitchChannelDeviceHandle::multimemStore(*(mscclpp::f32x2*)(&pkt), multiPkt + i);
34+
}
35+
for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim.x * gridDim.x) {
36+
uint data = src[i];
37+
for (int peer = 0; peer < worldSize; peer++) {
38+
if (peer == rank) {
39+
continue;
40+
}
41+
uint val = scratchPkt[peer * worldSize * nPktPerRank + i].read(flag);
42+
data = cal_vectors<T, OpType>(data, val);
43+
}
44+
dst[i] = data;
45+
}
46+
// if (blockIdx.x == 0 && threadIdx.x < gridDim.x) {
47+
// flags[threadIdx.x].read(flag, -1);
2948
// }
30-
// };
31-
32-
// void AllreduceNvlsPacket::initialize(std::shared_ptr<mscclpp::Communicator>) {
33-
// deviceFlag_ = mscclpp::detail::gpuCallocShared<uint32_t>(16);
34-
// std::vector<uint32_t> initFlag(16);
35-
// for (int i = 0; i < 16; ++i) {
36-
// initFlag[i] = 1;
49+
// if (blockIdx.x == 0) {
50+
// __syncthreads();
3751
// }
38-
// mscclpp::gpuMemcpy<uint32_t>(deviceFlag_.get(), initFlag.data(), 16, cudaMemcpyHostToDevice);
39-
// }
52+
if (threadIdx.x == 0 && blockIdx.x == 0) {
53+
deviceFlag++;
54+
}
55+
#endif
56+
}
4057

41-
// mscclpp::AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t,
42-
// mscclpp::DataType) {
43-
// return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
44-
// }
58+
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize) {
59+
int blockNum = 8;
60+
int threadNum = 1024;
61+
if (inputSize <= (1 << 13)) {
62+
blockNum = 4;
63+
threadNum = 512;
64+
}
65+
return {blockNum, threadNum};
66+
}
4567

46-
// std::shared_ptr<mscclpp::AlgorithmCtx> AllreduceNvlsPacket::initAllreduceContext(
47-
// std::shared_ptr<mscclpp::Communicator> comm, const void*, void*, size_t, mscclpp::DataType) {
48-
// auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
49-
// ctx->rank = comm->bootstrap()->getRank();
50-
// ctx->workSize = comm->bootstrap()->getNranks();
51-
// ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
68+
template <Op OpType, typename T>
69+
struct AllreduceNvlsPacketAdapter {
70+
static cudaError_t call(const void* input, void* scratch, void* output, void*, void*,
71+
mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsChannels,
72+
mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t scratchBufferSize,
73+
int rank, int, int worldSize, size_t inputSize, cudaStream_t stream, LL8Packet* flags,
74+
uint32_t, int nBlocks, int nThreadsPerBlock) {
75+
allreduceNvlsPacket<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>((const T*)input, (T*)scratch, (T*)output,
76+
nvlsChannels, inputSize / sizeof(T),
77+
scratchBufferSize, rank, worldSize, flags);
78+
return cudaGetLastError();
79+
}
80+
};
5281

53-
// // setup channels
54-
// int nSwitchChannels = 1;
55-
// ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
56-
// ctx->switchChannels = setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_.lock().get(),
57-
// this->scratchBufferSize_, nSwitchChannels);
58-
// ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
59-
// return ctx;
60-
// }
82+
void AllreduceNvlsPacket::initialize(std::shared_ptr<mscclpp::Communicator>) {
83+
flags_ = mscclpp::detail::gpuCallocShared<LL8Packet>(16);
84+
}
6185

62-
// CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
63-
// void* output, size_t inputSize, mscclpp::DataType dtype,
64-
// cudaStream_t stream,
65-
// std::unordered_map<std::string, uintptr_t>& extra) {
66-
// int op = *reinterpret_cast<int*>(extra.at("op"));
67-
// std::pair<int, int> blockAndThreadNum = getBlockNumAndThreadNum(extra);
68-
// if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
69-
// blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize);
70-
// }
71-
// if (blockAndThreadNum.first > maxBlockNum_) {
72-
// WARN("Block number %d exceeds the maximum limit %d", blockAndThreadNum.first, maxBlockNum_);
73-
// return CommResult::commInvalidArgument;
74-
// }
75-
// AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(static_cast<Algorithm::Op>(op), dtype);
76-
// if (!allreduce) {
77-
// WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
78-
// return CommResult::commInvalidArgument;
79-
// }
80-
// cudaError_t error = allreduce(
81-
// input, this->scratchBuffer_.lock().get(), output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(),
82-
// nullptr, 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream,
83-
// this->deviceFlag_.get(), nullptr, nullptr, 0, blockAndThreadNum.first, blockAndThreadNum.second);
84-
// if (error != cudaSuccess) {
85-
// WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error));
86-
// return CommResult::commUnhandledCudaError;
87-
// }
88-
// return CommResult::commSuccess;
89-
// }
86+
mscclpp::AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t,
87+
mscclpp::DataType) {
88+
return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
89+
}
90+
91+
std::shared_ptr<mscclpp::AlgorithmCtx> AllreduceNvlsPacket::initAllreduceContext(
92+
std::shared_ptr<mscclpp::Communicator> comm, const void*, void*, size_t, mscclpp::DataType) {
93+
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
94+
ctx->rank = comm->bootstrap()->getRank();
95+
ctx->workSize = comm->bootstrap()->getNranks();
96+
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
97+
98+
// setup channels
99+
int nSwitchChannels = 1;
100+
ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
101+
ctx->switchChannels =
102+
setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels);
103+
ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
104+
return ctx;
105+
}
106+
107+
CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
108+
void* output, size_t inputSize, mscclpp::DataType dtype,
109+
cudaStream_t stream,
110+
std::unordered_map<std::string, uintptr_t>& extra) {
111+
int op = *reinterpret_cast<int*>(extra.at("op"));
112+
std::pair<int, int> blockAndThreadNum = getBlockNumAndThreadNum(extra);
113+
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
114+
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize);
115+
}
116+
if (blockAndThreadNum.first > maxBlockNum_) {
117+
WARN("Block number %d exceeds the maximum limit %d", blockAndThreadNum.first, maxBlockNum_);
118+
return CommResult::commInvalidArgument;
119+
}
120+
AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(static_cast<Algorithm::Op>(op), dtype);
121+
if (!allreduce) {
122+
WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
123+
return CommResult::commInvalidArgument;
124+
}
125+
cudaError_t error = allreduce(
126+
input, this->scratchBuffer_, output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(),
127+
nullptr, 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream,
128+
this->flags_.get(), 0, blockAndThreadNum.first, blockAndThreadNum.second);
129+
if (error != cudaSuccess) {
130+
WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error));
131+
return CommResult::commUnhandledCudaError;
132+
}
133+
return CommResult::commSuccess;
134+
}
90135

91-
// std::shared_ptr<mscclpp::Algorithm> AllreduceNvlsPacket::build() {
92-
// auto self = std::make_shared<AllreduceNvlsPacket>(scratchBuffer_.lock(), scratchBufferSize_);
93-
// return std::make_shared<mscclpp::NativeAlgorithm>(
94-
// "default_allreduce_nvls_packet", "allreduce",
95-
// [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
96-
// [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t inputSize,
97-
// [[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, cudaStream_t stream,
98-
// std::unordered_map<std::string, uintptr_t>& extras) {
99-
// return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, stream, extras);
100-
// },
101-
// [self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
102-
// [[maybe_unused]] size_t outputSize,
103-
// mscclpp::DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
104-
// [self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize,
105-
// mscclpp::DataType dtype) { return self->generateAllreduceContextKey(input, output, inputSize, dtype); });
106-
// }
107-
// } // namespace mscclpp
136+
std::shared_ptr<mscclpp::Algorithm> AllreduceNvlsPacket::build() {
137+
auto self = std::make_shared<AllreduceNvlsPacket>((uintptr_t)scratchBuffer_, scratchBufferSize_);
138+
return std::make_shared<mscclpp::NativeAlgorithm>(
139+
"default_allreduce_nvls_packet", "allreduce",
140+
[self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
141+
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t inputSize,
142+
[[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, cudaStream_t stream,
143+
std::unordered_map<std::string, uintptr_t>& extras) {
144+
return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, stream, extras);
145+
},
146+
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
147+
[[maybe_unused]] size_t outputSize,
148+
mscclpp::DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
149+
[self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize,
150+
mscclpp::DataType dtype) { return self->generateAllreduceContextKey(input, output, inputSize, dtype); });
151+
}
152+
} // namespace algorithm
153+
} // namespace mscclpp

src/algorithms/allreduce/allreduce_packet.cu

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,14 @@ struct PacketAdapter {
163163
}
164164
};
165165

166-
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int worldSize) {
167-
if (inputSize < worldSize * sizeof(int)) {
168-
return {worldSize - 1, 32};
166+
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int nRanksPerNode, int worldSize) {
167+
int nBlocks = (nRanksPerNode - 1) * 4;
168+
int nThreadsPerBlock = 1024;
169+
if (inputSize >= 16384) {
170+
nBlocks = (worldSize - 1) * 8;
171+
nThreadsPerBlock = (inputSize <= 153600) ? 512 : 1024;
169172
}
170-
return {(worldSize - 1) * 4, 512};
173+
return {nBlocks, nThreadsPerBlock};
171174
}
172175

173176
void AllreducePacket::initialize(std::shared_ptr<Communicator> comm) {
@@ -186,7 +189,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<AlgorithmC
186189
Algorithm::Op op = *reinterpret_cast<Algorithm::Op*>(extras.at("op"));
187190
std::pair<int, int> blockAndThreadNum = getBlockNumAndThreadNum(extras);
188191
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
189-
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->workSize);
192+
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->workSize, ctx->nRanksPerNode);
190193
}
191194

192195
size_t sendBytes;

src/algorithms/utils.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <mscclpp/switch_channel.hpp>
88

99
#include "algorithms/allreduce/allreduce_allpair_packet.hpp"
10+
#include "algorithms/allreduce/allreduce_nvls_packet.hpp"
1011
#include "algorithms/allreduce/allreduce_packet.hpp"
1112

1213
namespace mscclpp {
@@ -159,6 +160,9 @@ std::shared_ptr<AlgorithmBuilder> getDefaultNativeAlgorithmBuilder(std::string a
159160
if (algorithmName == "default_allreduce_packet") {
160161
return std::make_shared<AllreducePacket>(scratchBuffer, scratchBufferSize);
161162
}
163+
if (algorithmName == "default_allreduce_nvls_packet") {
164+
return std::make_shared<AllreduceNvlsPacket>(scratchBuffer, scratchBufferSize);
165+
}
162166
throw std::runtime_error("Unsupported default native algorithm: " + algorithmName);
163167
}
164168
} // namespace algorithm

0 commit comments

Comments
 (0)