Skip to content

Commit 9a7d518

Browse files
committed
WIP
1 parent 13be162 commit 9a7d518

File tree

6 files changed

+20
-40
lines changed

6 files changed

+20
-40
lines changed

apps/nccl/src/nccl.cu

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -258,28 +258,6 @@ static void registerCustomizedAlgo(ncclComm* commPtr) {
258258
collectionBuilder->addDefaultNativeAlgorithmBuilder("default_allreduce_packet",
259259
reinterpret_cast<uintptr_t>(commPtr->scratchBuffer_.get()),
260260
commPtr->scratchBufferSize_);
261-
262-
// std::shared_ptr<AllgatherAlgo6> allgatherAlgo6 = std::make_shared<AllgatherAlgo6>();
263-
// std::shared_ptr<AllgatherAlgo8> allgatherAlgo8 =
264-
// std::make_shared<AllgatherAlgo8>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
265-
// collectionBuilder->addAlgorithmBuilder(allgatherAlgo6);
266-
// // TODO(binyli): remove allgather8 algo, use nccl by default
267-
// collectionBuilder->addAlgorithmBuilder(allgatherAlgo8);
268-
269-
// std::shared_ptr<AllreducePacket> allreduceAllpairAlgo =
270-
// std::make_shared<AllreducePacket>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
271-
// std::shared_ptr<AllreduceNvls> allreduceNvlsAlgo = std::make_shared<AllreduceNvls>();
272-
// std::shared_ptr<AllreduceNvlsWithCopy> allreduceNvlsWithCopyAlgo =
273-
// std::make_shared<AllreduceNvlsWithCopy>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
274-
// std::shared_ptr<Allreduce8> allreduceAllreduce8Algo =
275-
// std::make_shared<Allreduce8>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
276-
// std::shared_ptr<AllreduceNvlsPacket> allreduceNvlsPacketAlgo =
277-
// std::make_shared<AllreduceNvlsPacket>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
278-
// collectionBuilder->addAlgorithmBuilder(allreduceAllpairAlgo);
279-
// collectionBuilder->addAlgorithmBuilder(allreduceNvlsAlgo);
280-
// collectionBuilder->addAlgorithmBuilder(allreduceNvlsWithCopyAlgo);
281-
// collectionBuilder->addAlgorithmBuilder(allreduceAllreduce8Algo);
282-
// collectionBuilder->addAlgorithmBuilder(allreduceNvlsPacketAlgo);
283261
}
284262

285263
static std::pair<int, int> getDeviceComputeCapability() {
@@ -352,6 +330,9 @@ static std::shared_ptr<mscclpp::Algorithm> algoSelector(
352330
useNvls = false;
353331
}
354332
#endif
333+
if (messageSize <= (1 << 15)) {
334+
return algoMapByCollective.at(collective).at("default_allreduce_allpair_packet");
335+
}
355336
if (messageSize <= (1 << 15) && useNvls) {
356337
return algoMapByCollective.at(collective).at("default_allreduce_nvls_packet");
357338
}

src/algorithms/algorithm.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <mscclpp/algorithm.hpp>
55

6+
#include "algorithms/utils.hpp"
67
#include "logger.hpp"
78

89
namespace mscclpp {
@@ -111,9 +112,11 @@ void AlgorithmCollectionBuilder::addAlgorithmBuilder(std::shared_ptr<AlgorithmBu
111112
this->algoBuilders_.push_back(builder);
112113
}
113114

114-
// TODO (binyli) implement this
115115
void AlgorithmCollectionBuilder::addDefaultNativeAlgorithmBuilder(std::string algorithmName, uintptr_t scratchBuffer,
116-
size_t scratchBufferSize) {}
116+
size_t scratchBufferSize) {
117+
auto builder = algorithm::getDefaultNativeAlgorithmBuilder(algorithmName, scratchBuffer, scratchBufferSize);
118+
this->algoBuilders_.push_back(builder);
119+
}
117120

118121
void AlgorithmCollectionBuilder::setAlgorithmSelector(AlgoSelectFunc selector) { algoSelector_ = selector; }
119122

src/algorithms/allreduce/allreduce_allpair_packet.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
5959
if (blockIdx.x == 0 && threadIdx.x < gridDim.x) {
6060
flags[threadIdx.x].read(flag, -1);
6161
}
62-
__syncthreads();
62+
if (blockIdx.x == 0) {
63+
__syncthreads();
64+
}
6365
if (threadIdx.x == 0 && blockIdx.x == 0) {
6466
deviceFlag++;
6567
}

src/algorithms/allreduce/allreduce_packet.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ __global__ void __launch_bounds__(1024, 1)
129129
if (blockIdx.x == 0 && threadIdx.x < gridDim.x) {
130130
flags[threadIdx.x].read(flag, -1);
131131
}
132-
__syncthreads();
132+
if (blockIdx.x == 0) {
133+
__syncthreads();
134+
}
133135
#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY) && \
134136
defined(ENABLE_NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT)
135137
NpKit::CollectGpuEventShm(NPKIT_EVENT_KERNEL_ALLREDUCE_ENTRY, 0, 0, npkit_timestamp_entry, event_buffer,

src/algorithms/utils.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,15 @@ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>> setupBaseMemo
151151
return ptr;
152152
}
153153

154-
std::vector<std::shared_ptr<AlgorithmBuilder>> loadNativeAlgorithmBuilders(uintptr_t scratchBuffer,
155-
size_t scratchBufferSize) {
156-
std::vector<std::shared_ptr<AlgorithmBuilder>> builders;
157-
builders.push_back(std::make_shared<AllreducePacket>(scratchBuffer, scratchBufferSize));
158-
builders.push_back(std::make_shared<AllreduceAllpairPacket>(scratchBuffer, scratchBufferSize));
159-
return builders;
160-
}
161-
162154
std::shared_ptr<AlgorithmBuilder> getDefaultNativeAlgorithmBuilder(std::string algorithmName, uintptr_t scratchBuffer,
163155
size_t scratchBufferSize) {
164156
if (algorithmName == "default_allreduce_allpair_packet") {
165157
return std::make_shared<AllreduceAllpairPacket>(scratchBuffer, scratchBufferSize);
166158
}
167-
return nullptr;
159+
if (algorithmName == "default_allreduce_packet") {
160+
return std::make_shared<AllreducePacket>(scratchBuffer, scratchBufferSize);
161+
}
162+
throw std::runtime_error("Unsupported default native algorithm: " + algorithmName);
168163
}
169164
} // namespace algorithm
170165

src/include/algorithms/utils.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ constexpr int MAX_NRANKS_PER_NODE = 8;
2727
constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB
2828
static bool mscclppDisableChannelCache = env()->disableChannelCache;
2929

30-
__device__ DeviceSyncer deviceSyncer;
31-
__constant__ DeviceSemaphore deviceSemaphore[NUM_SEMAPHORES];
30+
static __device__ DeviceSyncer deviceSyncer;
31+
static __constant__ DeviceSemaphore deviceSemaphore[NUM_SEMAPHORES];
3232

3333
std::vector<RegisteredMemory> setupRemoteMemories(std::shared_ptr<Communicator> comm, int rank,
3434
RegisteredMemory localMemory);
@@ -61,9 +61,6 @@ std::vector<BaseMemoryChannel> setupBaseMemoryChannels(
6161
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> setupBaseMemoryChannelDeviceHandles(
6262
const std::vector<BaseMemoryChannel>& baseMemoryChannels);
6363

64-
std::vector<std::shared_ptr<AlgorithmBuilder>> loadNativeAlgorithmBuilders(uintptr_t scratchBuffer,
65-
size_t scratchBufferSize);
66-
6764
std::shared_ptr<AlgorithmBuilder> getDefaultNativeAlgorithmBuilder(std::string algorithmName, uintptr_t scratchBuffer,
6865
size_t scratchBufferSize);
6966
} // namespace algorithm

0 commit comments

Comments
 (0)