Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 82 additions & 28 deletions python/mscclpp/language/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class MemoryChannel:
"""

_channel_counts = defaultdict(int)
_channel_peer_counts = defaultdict(int)

def __init__(self, dst_rank: int, src_rank: int):
"""Initialize a new MemoryChannel.
Expand All @@ -47,6 +48,8 @@ def __init__(self, dst_rank: int, src_rank: int):

self.channel_id = MemoryChannel._channel_counts[src_rank]
MemoryChannel._channel_counts[src_rank] += 1
self.channel_peer_id = MemoryChannel._channel_peer_counts[(src_rank, dst_rank)]
MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] += 1

self.dst_rank = dst_rank
self.src_rank = src_rank
Expand All @@ -71,7 +74,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = F
>>> channel.signal(tb=0, data_sync=SyncType.before)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)

def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
Expand All @@ -92,7 +95,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = Fal
>>> channel.wait(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)

def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
Expand Down Expand Up @@ -133,21 +136,29 @@ def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb, self)
op = GetOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)],
dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
tbg=(
tb_group
if tb_group is not None
else None
),
)
get_program().add_operation(self.src_rank, tb_id, op)
operations.append(op)

if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Send data from local memory to remote memory.
Expand Down Expand Up @@ -192,21 +203,29 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
tbg=(
tb_group
if tb_group is not None
else None
),
)
get_program().add_operation(self.src_rank, tb_id, op)
operations.append(op)

if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Transfer data in packet format from local to remote scratch buffer.
Expand Down Expand Up @@ -256,23 +275,31 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, t
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
tbg=(
tb_group
if tb_group is not None
else None
),
from_packet=True,
to_packet=True,
)
get_program().add_operation(self.src_rank, tb_id, op)
operations.append(op)

if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Transfer data from local buffer to remote scratch buffer in packet format.
Expand Down Expand Up @@ -320,24 +347,31 @@ def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_gro
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type)
tb_channel_ids = get_program().setup_channel(tb_id, self)
op = PutOperation(
rank=self.src_rank,
threadblock=tb_id,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
tbg=(
tb_group
if tb_group is not None
else None
),
from_packet=False,
to_packet=True,
)
operations.append(op)

get_program().add_operation(self.src_rank, tb_id, op)
if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)

def reduce(
self,
Expand Down Expand Up @@ -400,6 +434,7 @@ def reduce(
"Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None."
)

operations = []
for tb_id in tb_list:
remote_chunks = [
RemoteChunk(
Expand All @@ -418,21 +453,27 @@ def reduce(
tb_channel_ids = get_program().setup_channel(tb_id, self)

op = ReduceOperation(
rank=self.src_rank,
threadblock=tb_id,
local_src_buff=[LocalChunk(local_src_chunk.buffer, local_src_chunk.index, local_src_chunk.size)],
local_dst_buff=[LocalChunk(local_dst_chunk.buffer, local_dst_chunk.index, local_dst_chunk.size)],
remote_src_buff=remote_chunks,
remote_dst_buff=[],
channel_ids=tb_channel_ids,
channel_type=self.channel_type,
tbg_info=(
ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb())
tbg=(
tb_group
if tb_group is not None
else None
),
reduce_operation=reduce_op,
)
operations.apend(op)

get_program().add_operation(self.src_rank, tb_id, op)
if tb_group is None:
get_program().add_operation(self.src_rank, tb_id, operations[0])
else:
get_program().add_tbg_operation(operations)


@dataclass
Expand All @@ -452,6 +493,7 @@ class PortChannel:
"""

_channel_counts = defaultdict(int)
_channel_peer_counts = defaultdict(int)

def __init__(self, dst_rank: int, src_rank: int):
"""Initialize a new PortChannel.
Expand All @@ -474,6 +516,8 @@ def __init__(self, dst_rank: int, src_rank: int):

self.channel_id = PortChannel._channel_counts[src_rank]
PortChannel._channel_counts[src_rank] += 1
self.channel_peer_id = PortChannel._channel_peer_counts[(src_rank, dst_rank)]
PortChannel._channel_peer_counts[(src_rank, dst_rank)] += 1

self.dst_rank = dst_rank
self.src_rank = src_rank
Expand All @@ -496,7 +540,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both):
>>> channel.signal(tb=0, data_sync=SyncType.before)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync)
op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)

def wait(self, tb: int, data_sync: SyncType = SyncType.both):
Expand All @@ -515,7 +559,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both):
>>> channel.wait(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync)
op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)

def flush(self, tb: int, data_sync: SyncType = SyncType.both):
Expand All @@ -534,7 +578,7 @@ def flush(self, tb: int, data_sync: SyncType = SyncType.both):
>>> channel.flush(tb=0, data_sync=SyncType.after)
"""
tb_channel_ids = get_program().setup_channel(tb, self)
op = FlushOperation(tb_channel_ids, self.channel_type, data_sync)
op = FlushOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)

def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
Expand Down Expand Up @@ -573,6 +617,8 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -618,6 +664,8 @@ def put_with_signal(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -663,6 +711,8 @@ def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int)
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -713,6 +763,8 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
tb_channel_ids = get_program().setup_channel(tb, self)

op = PutOperation(
rank=self.src_rank,
threadblock=tb,
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
channel_ids=tb_channel_ids,
Expand Down Expand Up @@ -836,13 +888,15 @@ def reduce(self, rank, buffer_offset, size, dst_chunk: Chunk, tb, reduce_op=Redu

tb_channel_ids = get_program().setup_channel(tb, self)
op = GroupLoadReduce(
self.buffer_type,
buffer_offset,
size,
dst_chunk,
tb_channel_ids,
self.channel_type,
reduce_op,
rank=self.src_rank,
threadblock=tb,
buffer_type=self.buffer_type,
buffer_offset=buffer_offset,
size=size,
dst_chunk=dst_chunk,
tb_channel_ids=tb_channel_ids,
channel_type=self.channel_type,
reduce_operation=reduce_op,
)
get_program().add_operation(self.src_rank, tb, op)

Expand Down Expand Up @@ -886,7 +940,7 @@ def broadcast(self, rank, src_chunk: Chunk, buffer_offset, size, tb):
)

tb_channel_ids = get_program().setup_channel(tb, self)
op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
op = GroupStore(self.src_rank, tb, src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type)
get_program().add_operation(self.src_rank, tb, op)

class SwitchChannelRankView:
Expand Down
Loading
Loading