1- // #include "algorithm_utils .hpp"
2- // #include "allreduce_common .hpp"
3- // #include "allreduce_nvls_packet .hpp"
4- // #include "debug.h"
1+ #include " algorithms/allreduce/allreduce_nvls_packet .hpp"
2+ #include " algorithms/allreduce/common .hpp"
3+ #include " algorithms/utils .hpp"
4+ #include " debug.h"
55
6- // namespace mscclpp {
6+ namespace mscclpp {
7+ namespace algorithm {
78
8- // inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize) {
9- // int blockNum = 8;
10- // int threadNum = 1024;
11- // if (inputSize <= (1 << 13)) {
12- // blockNum = 4;
13- // threadNum = 512;
9+ __device__ uint32_t deviceFlag = 1 ;
10+ template <Algorithm::Op OpType, typename T>
11+ __global__ void __launch_bounds__ (1024 , 1 )
12+ allreduceNvlsPacket([[maybe_unused]] const T* input, [[maybe_unused]] T* scratch, [[maybe_unused]] T* output,
13+ [[maybe_unused]] mscclpp::DeviceHandle<mscclpp::SwitchChannel>* multicast,
14+ [[maybe_unused]] size_t nelems, [[maybe_unused]] size_t scratchBufferSize,
15+ [[maybe_unused]] int rank, [[maybe_unused]] int worldSize, [[maybe_unused]] LL8Packet* flags) {
16+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
17+ uint32_t flag = deviceFlag;
18+ // __syncthreads();
19+ // if (threadIdx.x == 0) {
20+ // flags[blockIdx.x].write(0, flag);
1421// }
15- // return {blockNum, threadNum};
16- // }
1722
18- // template <Op OpType, typename T>
19- // struct AllreduceNvlsPacketAdapter {
20- // static cudaError_t call(const void* input, void* scratch, void* output, void*, void*,
21- // mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsChannels,
22- // mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t, size_t, size_t scratchBufferSize,
23- // int rank, int, int worldSize, size_t inputSize, cudaStream_t stream, uint32_t* deviceFlag,
24- // uint32_t*, uint32_t*, uint32_t, int nBlocks, int nThreadsPerBlock) {
25- // allreduceNvlsPacket<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
26- // (const T*)input, (T*)scratch, (T*)output, nvlsChannels, inputSize / sizeof(T), scratchBufferSize, rank,
27- // worldSize, deviceFlag);
28- // return cudaGetLastError();
23+ size_t scratchBaseOffset = (flag % 2 ) ? scratchBufferSize / 2 : 0 ;
24+ uint32_t tid = threadIdx .x + blockIdx .x * blockDim .x ;
25+ uint32_t nPktPerRank = nelems / worldSize / (sizeof (mscclpp::LL8Packet::Payload) / sizeof (T));
26+ mscclpp::LL8Packet* multiPkt =
27+ (mscclpp::LL8Packet*)((char *)multicast->mcPtr + scratchBaseOffset) + rank * worldSize * nPktPerRank;
28+ uint* src = (uint*)(input);
29+ uint* dst = (uint*)(output);
30+ mscclpp::LL8Packet* scratchPkt = (mscclpp::LL8Packet*)((char *)scratch + scratchBaseOffset);
31+ for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim .x * gridDim .x ) {
32+ mscclpp::LL8Packet pkt (src[i], flag);
33+ mscclpp::SwitchChannelDeviceHandle::multimemStore (*(mscclpp::f32x2*)(&pkt), multiPkt + i);
34+ }
35+ for (uint32_t i = tid; i < nPktPerRank * worldSize; i += blockDim .x * gridDim .x ) {
36+ uint data = src[i];
37+ for (int peer = 0 ; peer < worldSize; peer++) {
38+ if (peer == rank) {
39+ continue ;
40+ }
41+ uint val = scratchPkt[peer * worldSize * nPktPerRank + i].read (flag);
42+ data = cal_vectors<T, OpType>(data, val);
43+ }
44+ dst[i] = data;
45+ }
46+ // if (blockIdx.x == 0 && threadIdx.x < gridDim.x) {
47+ // flags[threadIdx.x].read(flag, -1);
2948// }
30- // };
31-
32- // void AllreduceNvlsPacket::initialize(std::shared_ptr<mscclpp::Communicator>) {
33- // deviceFlag_ = mscclpp::detail::gpuCallocShared<uint32_t>(16);
34- // std::vector<uint32_t> initFlag(16);
35- // for (int i = 0; i < 16; ++i) {
36- // initFlag[i] = 1;
49+ // if (blockIdx.x == 0) {
50+ // __syncthreads();
3751// }
38- // mscclpp::gpuMemcpy<uint32_t>(deviceFlag_.get(), initFlag.data(), 16, cudaMemcpyHostToDevice);
39- // }
52+ if (threadIdx .x == 0 && blockIdx .x == 0 ) {
53+ deviceFlag++;
54+ }
55+ #endif
56+ }
4057
41- // mscclpp::AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t,
42- // mscclpp::DataType) {
43- // return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
44- // }
58+ inline std::pair<int , int > getDefaultBlockNumAndThreadNum (size_t inputSize) {
59+ int blockNum = 8 ;
60+ int threadNum = 1024 ;
61+ if (inputSize <= (1 << 13 )) {
62+ blockNum = 4 ;
63+ threadNum = 512 ;
64+ }
65+ return {blockNum, threadNum};
66+ }
4567
46- // std::shared_ptr<mscclpp::AlgorithmCtx> AllreduceNvlsPacket::initAllreduceContext(
47- // std::shared_ptr<mscclpp::Communicator> comm, const void*, void*, size_t, mscclpp::DataType) {
48- // auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
49- // ctx->rank = comm->bootstrap()->getRank();
50- // ctx->workSize = comm->bootstrap()->getNranks();
51- // ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
68+ template <Op OpType, typename T>
69+ struct AllreduceNvlsPacketAdapter {
70+ static cudaError_t call (const void * input, void * scratch, void * output, void *, void *,
71+ mscclpp::DeviceHandle<mscclpp::SwitchChannel>* nvlsChannels,
72+ mscclpp::DeviceHandle<mscclpp::SwitchChannel>*, size_t , size_t , size_t scratchBufferSize,
73+ int rank, int , int worldSize, size_t inputSize, cudaStream_t stream, LL8Packet* flags,
74+ uint32_t , int nBlocks, int nThreadsPerBlock) {
75+ allreduceNvlsPacket<OpType, T><<<nBlocks, nThreadsPerBlock, 0 , stream>>> ((const T*)input, (T*)scratch, (T*)output,
76+ nvlsChannels, inputSize / sizeof (T),
77+ scratchBufferSize, rank, worldSize, flags);
78+ return cudaGetLastError ();
79+ }
80+ };
5281
53- // // setup channels
54- // int nSwitchChannels = 1;
55- // ctx->nvlsConnections = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels);
56- // ctx->switchChannels = setupNvlsChannels(ctx->nvlsConnections, this->scratchBuffer_.lock().get(),
57- // this->scratchBufferSize_, nSwitchChannels);
58- // ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels);
59- // return ctx;
60- // }
82+ void AllreduceNvlsPacket::initialize (std::shared_ptr<mscclpp::Communicator>) {
83+ flags_ = mscclpp::detail::gpuCallocShared<LL8Packet>(16 );
84+ }
6185
62- // CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
63- // void* output, size_t inputSize, mscclpp::DataType dtype,
64- // cudaStream_t stream,
65- // std::unordered_map<std::string, uintptr_t>& extra) {
66- // int op = *reinterpret_cast<int*>(extra.at("op"));
67- // std::pair<int, int> blockAndThreadNum = getBlockNumAndThreadNum(extra);
68- // if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
69- // blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize);
70- // }
71- // if (blockAndThreadNum.first > maxBlockNum_) {
72- // WARN("Block number %d exceeds the maximum limit %d", blockAndThreadNum.first, maxBlockNum_);
73- // return CommResult::commInvalidArgument;
74- // }
75- // AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(static_cast<Algorithm::Op>(op), dtype);
76- // if (!allreduce) {
77- // WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast<int>(dtype));
78- // return CommResult::commInvalidArgument;
79- // }
80- // cudaError_t error = allreduce(
81- // input, this->scratchBuffer_.lock().get(), output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(),
82- // nullptr, 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream,
83- // this->deviceFlag_.get(), nullptr, nullptr, 0, blockAndThreadNum.first, blockAndThreadNum.second);
84- // if (error != cudaSuccess) {
85- // WARN("AllreduceNvlsPacket failed with error: %s", cudaGetErrorString(error));
86- // return CommResult::commUnhandledCudaError;
87- // }
88- // return CommResult::commSuccess;
89- // }
86+ mscclpp::AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey (const void *, void *, size_t ,
87+ mscclpp::DataType) {
88+ return mscclpp::AlgorithmCtxKey{nullptr , nullptr , 0 , 0 , 0 };
89+ }
90+
91+ std::shared_ptr<mscclpp::AlgorithmCtx> AllreduceNvlsPacket::initAllreduceContext (
92+ std::shared_ptr<mscclpp::Communicator> comm, const void *, void *, size_t , mscclpp::DataType) {
93+ auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
94+ ctx->rank = comm->bootstrap ()->getRank ();
95+ ctx->workSize = comm->bootstrap ()->getNranks ();
96+ ctx->nRanksPerNode = comm->bootstrap ()->getNranksPerNode ();
97+
98+ // setup channels
99+ int nSwitchChannels = 1 ;
100+ ctx->nvlsConnections = setupNvlsConnections (comm, nvlsBufferSize_, nSwitchChannels);
101+ ctx->switchChannels =
102+ setupNvlsChannels (ctx->nvlsConnections , this ->scratchBuffer_ , this ->scratchBufferSize_ , nSwitchChannels);
103+ ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles (ctx->switchChannels );
104+ return ctx;
105+ }
106+
107+ CommResult AllreduceNvlsPacket::allreduceKernelFunc (const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input,
108+ void * output, size_t inputSize, mscclpp::DataType dtype,
109+ cudaStream_t stream,
110+ std::unordered_map<std::string, uintptr_t >& extra) {
111+ int op = *reinterpret_cast <int *>(extra.at (" op" ));
112+ std::pair<int , int > blockAndThreadNum = getBlockNumAndThreadNum (extra);
113+ if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0 ) {
114+ blockAndThreadNum = getDefaultBlockNumAndThreadNum (inputSize);
115+ }
116+ if (blockAndThreadNum.first > maxBlockNum_) {
117+ WARN (" Block number %d exceeds the maximum limit %d" , blockAndThreadNum.first , maxBlockNum_);
118+ return CommResult::commInvalidArgument;
119+ }
120+ AllreduceFunc allreduce = dispatch<AllreduceNvlsPacketAdapter>(static_cast <Algorithm::Op>(op), dtype);
121+ if (!allreduce) {
122+ WARN (" Unsupported operation or data type for allreduce, dtype=%d" , static_cast <int >(dtype));
123+ return CommResult::commInvalidArgument;
124+ }
125+ cudaError_t error = allreduce (
126+ input, this ->scratchBuffer_ , output, nullptr , nullptr , ctx->switchChannelDeviceHandles .get (),
127+ nullptr , 0 , 0 , this ->scratchBufferSize_ , ctx->rank , ctx->nRanksPerNode , ctx->workSize , inputSize, stream,
128+ this ->flags_ .get (), 0 , blockAndThreadNum.first , blockAndThreadNum.second );
129+ if (error != cudaSuccess) {
130+ WARN (" AllreduceNvlsPacket failed with error: %s" , cudaGetErrorString (error));
131+ return CommResult::commUnhandledCudaError;
132+ }
133+ return CommResult::commSuccess;
134+ }
90135
91- // std::shared_ptr<mscclpp::Algorithm> AllreduceNvlsPacket::build() {
92- // auto self = std::make_shared<AllreduceNvlsPacket>(scratchBuffer_.lock(), scratchBufferSize_);
93- // return std::make_shared<mscclpp::NativeAlgorithm>(
94- // "default_allreduce_nvls_packet", "allreduce",
95- // [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize(comm); },
96- // [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t inputSize,
97- // [[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, cudaStream_t stream,
98- // std::unordered_map<std::string, uintptr_t>& extras) {
99- // return self->allreduceKernelFunc(ctx, input, output, inputSize, dtype, stream, extras);
100- // },
101- // [self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t inputSize,
102- // [[maybe_unused]] size_t outputSize,
103- // mscclpp::DataType dtype) { return self->initAllreduceContext(comm, input, output, inputSize, dtype); },
104- // [self](const void* input, void* output, size_t inputSize, [[maybe_unused]] size_t outputSize,
105- // mscclpp::DataType dtype) { return self->generateAllreduceContextKey(input, output, inputSize, dtype); });
106- // }
107- // } // namespace mscclpp
136+ std::shared_ptr<mscclpp::Algorithm> AllreduceNvlsPacket::build () {
137+ auto self = std::make_shared<AllreduceNvlsPacket>((uintptr_t )scratchBuffer_, scratchBufferSize_);
138+ return std::make_shared<mscclpp::NativeAlgorithm>(
139+ " default_allreduce_nvls_packet" , " allreduce" ,
140+ [self](std::shared_ptr<mscclpp::Communicator> comm) { self->initialize (comm); },
141+ [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t inputSize,
142+ [[maybe_unused]] size_t outputSize, mscclpp::DataType dtype, cudaStream_t stream,
143+ std::unordered_map<std::string, uintptr_t >& extras) {
144+ return self->allreduceKernelFunc (ctx, input, output, inputSize, dtype, stream, extras);
145+ },
146+ [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t inputSize,
147+ [[maybe_unused]] size_t outputSize,
148+ mscclpp::DataType dtype) { return self->initAllreduceContext (comm, input, output, inputSize, dtype); },
149+ [self](const void * input, void * output, size_t inputSize, [[maybe_unused]] size_t outputSize,
150+ mscclpp::DataType dtype) { return self->generateAllreduceContextKey (input, output, inputSize, dtype); });
151+ }
152+ } // namespace algorithm
153+ } // namespace mscclpp
0 commit comments