Skip to content

Commit 13be162

Browse files
committed
WIP
1 parent 98dec89 commit 13be162

File tree

4 files changed

+34
-19
lines changed

4 files changed

+34
-19
lines changed

src/algorithms/algorithm_utils.cc renamed to src/algorithms/utils.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
#include "algorithms/utils.hpp"
2+
13
#include <algorithm>
4+
#include <mscclpp/algorithm.hpp>
25
#include <mscclpp/core.hpp>
36
#include <mscclpp/memory_channel.hpp>
47
#include <mscclpp/switch_channel.hpp>
58

69
#include "algorithms/allreduce/allreduce_allpair_packet.hpp"
7-
#include "algorithms/utils.hpp"
10+
#include "algorithms/allreduce/allreduce_packet.hpp"
811

912
namespace mscclpp {
1013
namespace algorithm {
@@ -148,8 +151,12 @@ std::shared_ptr<mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>> setupBaseMemo
148151
return ptr;
149152
}
150153

151-
std::vector<std::shared_ptr<AlgorithmBuilder>> loadNativeAlgorithmBuilders() {
152-
return std::vector<std::shared_ptr<AlgorithmBuilder>>();
154+
std::vector<std::shared_ptr<AlgorithmBuilder>> loadNativeAlgorithmBuilders(uintptr_t scratchBuffer,
155+
size_t scratchBufferSize) {
156+
std::vector<std::shared_ptr<AlgorithmBuilder>> builders;
157+
builders.push_back(std::make_shared<AllreducePacket>(scratchBuffer, scratchBufferSize));
158+
builders.push_back(std::make_shared<AllreduceAllpairPacket>(scratchBuffer, scratchBufferSize));
159+
return builders;
153160
}
154161

155162
std::shared_ptr<AlgorithmBuilder> getDefaultNativeAlgorithmBuilder(std::string algorithmName, uintptr_t scratchBuffer,

src/include/algorithms/allreduce/allreduce_packet.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class AllreducePacket : public AlgorithmBuilder {
2222
size_t, DataType);
2323
AlgorithmCtxKey generateAllreduceContextKey(const void*, void*, size_t, DataType);
2424

25-
size_t scratchBufferSize_;
2625
void* scratchBuffer_;
26+
size_t scratchBufferSize_;
2727
const int nSegmentsForScratchBuffer_ = 2;
2828
const int maxBlockNum_ = 28;
2929
std::vector<Connection> conns_;

src/include/algorithms/allreduce/common.hpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,17 @@ __forceinline__ __device__ __fp8x2_e4m3 min_elements(__fp8x2_e4m3 a, __fp8x2_e4m
281281
// FP8 E4M3 vectorized min for 4 elements
282282
__forceinline__ __device__ __fp8x4_e4m3 min_elements(__fp8x4_e4m3 a, __fp8x4_e4m3 b) {
283283
// Process as two __fp8x2_e4m3 using min_elements for 2 elements
284-
__fp8x2_e4m3* a_pair = reinterpret_cast<__fp8x2_e4m3*>(&a);
285-
__fp8x2_e4m3* b_pair = reinterpret_cast<__fp8x2_e4m3*>(&b);
284+
union {
285+
__fp8x4_e4m3 vec4;
286+
__fp8x2_e4m3 vec2[2];
287+
} ua, ub, uresult;
288+
ua.vec4 = a;
289+
ub.vec4 = b;
286290

287-
__fp8x2_e4m3 result[2];
288-
result[0] = min_elements(a_pair[0], b_pair[0]);
289-
result[1] = min_elements(a_pair[1], b_pair[1]);
291+
uresult.vec2[0] = min_elements(ua.vec2[0], ub.vec2[0]);
292+
uresult.vec2[1] = min_elements(ua.vec2[1], ub.vec2[1]);
290293

291-
return *reinterpret_cast<__fp8x4_e4m3*>(result);
294+
return uresult.vec4;
292295
}
293296

294297
// FP8 E5M2 min operation (single element)
@@ -310,14 +313,17 @@ __forceinline__ __device__ __fp8x2_e5m2 min_elements(__fp8x2_e5m2 a, __fp8x2_e5m
310313
// FP8 E5M2 vectorized min for 4 elements (CUDA only)
311314
__forceinline__ __device__ __fp8x4_e5m2 min_elements(__fp8x4_e5m2 a, __fp8x4_e5m2 b) {
312315
// Process as two __fp8x2_e5m2 using min_elements for 2 elements
313-
__fp8x2_e5m2* a_pair = reinterpret_cast<__fp8x2_e5m2*>(&a);
314-
__fp8x2_e5m2* b_pair = reinterpret_cast<__fp8x2_e5m2*>(&b);
316+
union {
317+
__fp8x4_e5m2 vec4;
318+
__fp8x2_e5m2 vec2[2];
319+
} ua, ub, uresult;
320+
ua.vec4 = a;
321+
ub.vec4 = b;
315322

316-
__fp8x2_e5m2 result[2];
317-
result[0] = min_elements(a_pair[0], b_pair[0]);
318-
result[1] = min_elements(a_pair[1], b_pair[1]);
323+
uresult.vec2[0] = min_elements(ua.vec2[0], ub.vec2[0]);
324+
uresult.vec2[1] = min_elements(ua.vec2[1], ub.vec2[1]);
319325

320-
return *reinterpret_cast<__fp8x4_e5m2*>(result);
326+
return uresult.vec4;
321327
}
322328
#endif // !defined(__HIP_PLATFORM_AMD__)
323329
#endif // __FP8_TYPES_EXIST__

src/include/algorithms/utils.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#ifndef ALGORITHM_UTILS_HPP
22
#define ALGORITHM_UTILS_HPP
33

4-
#include <mscclpp/algorithm.hpp>
54
#include <mscclpp/concurrency_device.hpp>
65
#include <mscclpp/core.hpp>
76
#include <mscclpp/env.hpp>
@@ -17,6 +16,8 @@
1716
#endif
1817

1918
namespace mscclpp {
19+
20+
class AlgorithmBuilder;
2021
namespace algorithm {
2122
constexpr int NUM_NVLS_CONNECTION = 8;
2223
constexpr int NUM_SEMAPHORES = 64;
@@ -60,10 +61,11 @@ std::vector<BaseMemoryChannel> setupBaseMemoryChannels(
6061
std::shared_ptr<DeviceHandle<BaseMemoryChannel>> setupBaseMemoryChannelDeviceHandles(
6162
const std::vector<BaseMemoryChannel>& baseMemoryChannels);
6263

63-
std::vector<std::shared_ptr<AlgorithmBuilder>> loadNativeAlgorithmBuilders();
64+
std::vector<std::shared_ptr<AlgorithmBuilder>> loadNativeAlgorithmBuilders(uintptr_t scratchBuffer,
65+
size_t scratchBufferSize);
6466

6567
std::shared_ptr<AlgorithmBuilder> getDefaultNativeAlgorithmBuilder(std::string algorithmName, uintptr_t scratchBuffer,
66-
size_t scratchBufferSize);
68+
size_t scratchBufferSize);
6769
} // namespace algorithm
6870
} // namespace mscclpp
6971
#endif // ALGORITHM_UTILS_HPP

0 commit comments

Comments
 (0)