@@ -258,28 +258,6 @@ static void registerCustomizedAlgo(ncclComm* commPtr) {
258258 collectionBuilder->addDefaultNativeAlgorithmBuilder (" default_allreduce_packet" ,
259259 reinterpret_cast <uintptr_t >(commPtr->scratchBuffer_ .get ()),
260260 commPtr->scratchBufferSize_ );
261-
262- // std::shared_ptr<AllgatherAlgo6> allgatherAlgo6 = std::make_shared<AllgatherAlgo6>();
263- // std::shared_ptr<AllgatherAlgo8> allgatherAlgo8 =
264- // std::make_shared<AllgatherAlgo8>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
265- // collectionBuilder->addAlgorithmBuilder(allgatherAlgo6);
266- // // TODO(binyli): remove allgather8 algo, use nccl by default
267- // collectionBuilder->addAlgorithmBuilder(allgatherAlgo8);
268-
269- // std::shared_ptr<AllreducePacket> allreduceAllpairAlgo =
270- // std::make_shared<AllreducePacket>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
271- // std::shared_ptr<AllreduceNvls> allreduceNvlsAlgo = std::make_shared<AllreduceNvls>();
272- // std::shared_ptr<AllreduceNvlsWithCopy> allreduceNvlsWithCopyAlgo =
273- // std::make_shared<AllreduceNvlsWithCopy>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
274- // std::shared_ptr<Allreduce8> allreduceAllreduce8Algo =
275- // std::make_shared<Allreduce8>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
276- // std::shared_ptr<AllreduceNvlsPacket> allreduceNvlsPacketAlgo =
277- // std::make_shared<AllreduceNvlsPacket>(commPtr->scratchBuffer_, commPtr->scratchBufferSize_);
278- // collectionBuilder->addAlgorithmBuilder(allreduceAllpairAlgo);
279- // collectionBuilder->addAlgorithmBuilder(allreduceNvlsAlgo);
280- // collectionBuilder->addAlgorithmBuilder(allreduceNvlsWithCopyAlgo);
281- // collectionBuilder->addAlgorithmBuilder(allreduceAllreduce8Algo);
282- // collectionBuilder->addAlgorithmBuilder(allreduceNvlsPacketAlgo);
283261}
284262
285263static std::pair<int , int > getDeviceComputeCapability () {
@@ -352,6 +330,9 @@ static std::shared_ptr<mscclpp::Algorithm> algoSelector(
352330 useNvls = false ;
353331 }
354332#endif
333+ if (messageSize <= (1 << 15 )) {
334+ return algoMapByCollective.at (collective).at (" default_allreduce_allpair_packet" );
335+ }
355336 if (messageSize <= (1 << 15 ) && useNvls) {
356337 return algoMapByCollective.at (collective).at (" default_allreduce_nvls_packet" );
357338 }
0 commit comments