diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index b413a7728..453024c17 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -25,6 +25,7 @@ class MemoryChannel: """ _channel_counts = defaultdict(int) + _channel_peer_counts = defaultdict(int) def __init__(self, dst_rank: int, src_rank: int): """Initialize a new MemoryChannel. @@ -47,6 +48,8 @@ def __init__(self, dst_rank: int, src_rank: int): self.channel_id = MemoryChannel._channel_counts[src_rank] MemoryChannel._channel_counts[src_rank] += 1 + self.channel_peer_id = MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] + MemoryChannel._channel_peer_counts[(src_rank, dst_rank)] += 1 self.dst_rank = dst_rank self.src_rank = src_rank @@ -71,7 +74,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = F >>> channel.signal(tb=0, data_sync=SyncType.before) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) + op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False): @@ -92,7 +95,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = Fal >>> channel.wait(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) + op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): @@ -133,21 +136,29 @@ def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb, self) op = GetOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[RemoteChunk(src_chunk.buffer, src_chunk.index, src_chunk.size, tb_chunk_id)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Send data from local memory to remote memory. @@ -192,21 +203,29 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: Thre "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Transfer data in packet format from local to remote scratch buffer. @@ -256,23 +275,31 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, t "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), from_packet=True, to_packet=True, ) - get_program().add_operation(self.src_rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Transfer data from local buffer to remote scratch buffer in packet format. @@ -320,24 +347,31 @@ def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_gro "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb_id, remote_chunk, self.channel_type) tb_channel_ids = get_program().setup_channel(tb_id, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), from_packet=False, to_packet=True, ) + operations.append(op) - get_program().add_operation(self.src_rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def reduce( self, @@ -400,6 +434,7 @@ def reduce( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: remote_chunks = [ RemoteChunk( @@ -418,21 +453,27 @@ def reduce( tb_channel_ids = get_program().setup_channel(tb_id, self) op = ReduceOperation( + rank=self.src_rank, + threadblock=tb_id, local_src_buff=[LocalChunk(local_src_chunk.buffer, local_src_chunk.index, local_src_chunk.size)], local_dst_buff=[LocalChunk(local_dst_chunk.buffer, local_dst_chunk.index, local_dst_chunk.size)], remote_src_buff=remote_chunks, remote_dst_buff=[], channel_ids=tb_channel_ids, channel_type=self.channel_type, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), reduce_operation=reduce_op, ) + operations.apend(op) - get_program().add_operation(self.src_rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.src_rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) @dataclass @@ -452,6 +493,7 @@ class PortChannel: """ _channel_counts = defaultdict(int) + _channel_peer_counts = defaultdict(int) def __init__(self, dst_rank: int, src_rank: int): """Initialize a new PortChannel. @@ -474,6 +516,8 @@ def __init__(self, dst_rank: int, src_rank: int): self.channel_id = PortChannel._channel_counts[src_rank] PortChannel._channel_counts[src_rank] += 1 + self.channel_peer_id = PortChannel._channel_peer_counts[(src_rank, dst_rank)] + PortChannel._channel_peer_counts[(src_rank, dst_rank)] += 1 self.dst_rank = dst_rank self.src_rank = src_rank @@ -496,7 +540,7 @@ def signal(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.signal(tb=0, data_sync=SyncType.before) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = SignalOperation(tb_channel_ids, self.channel_type, data_sync) + op = SignalOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def wait(self, tb: int, data_sync: SyncType = SyncType.both): @@ -515,7 +559,7 @@ def wait(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.wait(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = WaitOperation(tb_channel_ids, self.channel_type, data_sync) + op = WaitOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def flush(self, tb: int, data_sync: SyncType = SyncType.both): @@ -534,7 +578,7 @@ def flush(self, tb: int, data_sync: SyncType = SyncType.both): >>> channel.flush(tb=0, data_sync=SyncType.after) """ tb_channel_ids = get_program().setup_channel(tb, self) - op = FlushOperation(tb_channel_ids, self.channel_type, data_sync) + op = FlushOperation(self.src_rank, tb, tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): @@ -573,6 +617,8 @@ def put(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -618,6 +664,8 @@ def put_with_signal(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -663,6 +711,8 @@ def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int) tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -713,6 +763,8 @@ def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): tb_channel_ids = get_program().setup_channel(tb, self) op = PutOperation( + rank=self.src_rank, + threadblock=tb, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)], channel_ids=tb_channel_ids, @@ -836,13 +888,15 @@ def reduce(self, rank, buffer_offset, size, dst_chunk: Chunk, tb, reduce_op=Redu tb_channel_ids = get_program().setup_channel(tb, self) op = GroupLoadReduce( - self.buffer_type, - buffer_offset, - size, - dst_chunk, - tb_channel_ids, - self.channel_type, - reduce_op, + rank=self.src_rank, + threadblock=tb, + buffer_type=self.buffer_type, + buffer_offset=buffer_offset, + size=size, + dst_chunk=dst_chunk, + tb_channel_ids=tb_channel_ids, + channel_type=self.channel_type, + reduce_operation=reduce_op, ) get_program().add_operation(self.src_rank, tb, op) @@ -886,7 +940,7 @@ def broadcast(self, rank, src_chunk: Chunk, buffer_offset, size, tb): ) tb_channel_ids = get_program().setup_channel(tb, self) - op = GroupStore(src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type) + op = GroupStore(self.src_rank, tb, src_chunk, self.buffer_type, buffer_offset, size, tb_channel_ids, self.channel_type) get_program().add_operation(self.src_rank, tb, op) class SwitchChannelRankView: diff --git a/python/mscclpp/language/internal/buffer_access.py b/python/mscclpp/language/internal/buffer_access.py index ab8a7bdcc..7c5f3d023 100644 --- a/python/mscclpp/language/internal/buffer_access.py +++ b/python/mscclpp/language/internal/buffer_access.py @@ -3,92 +3,138 @@ from sortedcontainers import SortedDict from typing import List -from mscclpp.language.internal.types import BufferType, DataAccessType +from mscclpp.language.internal.types import * from mscclpp.language.internal.operations import * from enum import Enum class BuffersAccess: - def __init__(self): - self.intervals = { - BufferType.input: SortedDict(), - BufferType.output: SortedDict(), - BufferType.scratch: SortedDict(), - } + def __init__(self, num_ranks): + self.rank_intervals = [ + { + BufferType.input: SortedDict(), + BufferType.output: SortedDict(), + BufferType.scratch: SortedDict(), + } + for _ in range(num_ranks) + ] + self.track_sync = {} + self.track_barrier = {} def process_operations(self, operations): result_operations = [] - for operation in operations: + for i in range(len(operations)): + operation = operations[i] if operation.name == Instruction.nop or operation.name == Instruction.barrier: - self.clear_data_access() + self.track_sync[operation.rank, operation.threadblock] = i + if operation.name == Instruction.barrier: + self.update_barrier(operation, i) + if operation.name == Instruction.sem_acquire: + self.update_semaphore(operation, i) else: if operation.name == Instruction.pipeline: pipeline_buffer_access = BuffersAccess() pipeline_result_operations = pipeline_buffer_access.process_operations(operation.operations) operation.operations = pipeline_result_operations - data_access = operation.local_data_access() - sync_added = False + data_access = operation.local_data_access(i) + data_access_conflict = DataAccessConflict(operation.rank) for data_access_element in data_access: - if self.compute_data_access(data_access_element) and not sync_added: - result_operations.append(SyncOperation()) - sync_added = True + data_access_conflict = data_access_conflict + self.compute_data_access(data_access_element) + fix_operations = self.resolve_conflicts(operation.rank, operation.threadblock, i, data_access_conflict) + result_operations.extend(fix_operations) result_operations.append(operation) return result_operations + def update_barrier(self, operation, order_id): + for tb in operation.barrier_info.tb_list: + if operation.threadblock != tb: + self.track_barrier[operation.rank, operation.threadblock, tb] = order_id + self.track_sync[operation.rank, operation.threadblock] = order_id + + def update_semaphore(self, operation, order_id): + for tb in operation.tb_sync: + if operation.threadblock != tb: + self.track_barrier[operation.rank, operation.threadblock, tb] = order_id + def compute_data_access(self, data_access: DataAccess) -> bool: - keys = self.intervals[data_access.buffer_type].keys() + intervals = self.rank_intervals[data_access.rank] + keys = intervals[data_access.buffer_type].keys() idx = self.lower_bound(0, len(keys) - 1, keys, data_access) - conflict = False + conflict = DataAccessConflict(data_access.rank) while len(keys) > 0 and data_access.overlaps(keys[idx]): conflict_data_access = keys[idx] - conflict_operation_type = self.intervals[data_access.buffer_type][conflict_data_access] - if data_access.check_conflict(conflict_data_access): - self.clear_data_access() - conflict = True - break + conflict_operation_type = intervals[data_access.buffer_type][conflict_data_access] + conflict = conflict + data_access.check_conflict(conflict_data_access) - self.intervals[data_access.buffer_type].pop(conflict_data_access) + intervals[data_access.buffer_type].pop(conflict_data_access) if conflict_data_access.end > data_access.end: - self.intervals[data_access.buffer_type][ + intervals[data_access.buffer_type][ DataAccess( - conflict_data_access.operation_id, - data_access.end + 1, + conflict_data_access.rank, + conflict_data_access.threadblock, + conflict_data_access.operation_global_id, + conflict_data_access.operation_order_id, + data_access.end, conflict_data_access.end, conflict_data_access.buffer_type, conflict_operation_type, + conflict_data_access.tb_group + ) ] = conflict_operation_type if conflict_data_access.start < data_access.start: - self.intervals[data_access.buffer_type][ + intervals[data_access.buffer_type][ DataAccess( - conflict_data_access.operation_id, + conflict_data_access.rank, + conflict_data_access.threadblock, + conflict_data_access.operation_global_id, + conflict_data_access.operation_order_id, conflict_data_access.start, - data_access.start - 1, + data_access.start, conflict_data_access.buffer_type, conflict_operation_type, + conflict_data_access.tb_group ) ] = conflict_operation_type - keys = self.intervals[data_access.buffer_type].keys() + keys = intervals[data_access.buffer_type].keys() idx = self.lower_bound(0, len(keys) - 1, keys, data_access) - self.intervals[data_access.buffer_type][data_access] = data_access.data_access_type + intervals[data_access.buffer_type][data_access] = data_access.data_access_type return conflict + + def resolve_conflicts(self, rank, threadblock, order_id, data_access_conflict: DataAccessConflict): + fix_operations = [] + if data_access_conflict.conflict_type == DataAccessConflictType.intra_threadblock: + for tb in data_access_conflict.threadblocks: + if (rank, threadblock) not in self.track_sync or tb[1] > self.track_sync[(rank, threadblock)]: + fix_operations.append(SyncOperation(rank, threadblock)) + self.track_sync[(rank, threadblock)] = order_id + break + if data_access_conflict.conflict_type == DataAccessConflictType.inter_threadblock: + conflict_tb = set([threadblock]) + for tb in data_access_conflict.threadblocks: + if threadblock != tb[0] and ((rank, threadblock, tb[0]) not in self.track_barrier or self.track_barrier[(rank, threadblock, tb[0])] < tb[1]): + if not tb[2]: + raise RuntimeError("Operations order not defined.") + conflict_tb.add(tb[0]) + if len(conflict_tb) > 1: + for tb in conflict_tb: + op = BarrierOperation(rank, tb, conflict_tb) + self.update_barrier(op, order_id) + fix_operations.append(op) - def clear_data_access(self): - self.intervals[BufferType.input].clear() - self.intervals[BufferType.output].clear() - self.intervals[BufferType.scratch].clear() + return fix_operations def lower_bound(self, init_pos, final_pos, data_access_list, data_access): if init_pos >= final_pos: return init_pos mid_pos = (init_pos + final_pos) // 2 - if data_access.start <= data_access_list[mid_pos].end: + if data_access.lower_overlaps(data_access_list[mid_pos]): final_pos = mid_pos else: init_pos = mid_pos + 1 diff --git a/python/mscclpp/language/internal/op_dep_graph.py b/python/mscclpp/language/internal/op_dep_graph.py new file mode 100644 index 000000000..1f5cb6f0b --- /dev/null +++ b/python/mscclpp/language/internal/op_dep_graph.py @@ -0,0 +1,324 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from mscclpp.language.internal.operations import * +from mscclpp.language.internal.types import * +from mscclpp.language.internal.register import ChannelRegister, SemaphoreRegister +from queue import Queue +from typing import Set, Dict, Tuple + + +class OperationDependencyGraph: + """ + A DAG structure to enforce correct execution order of collective communication operations. + Supports topological sorting based on rank/threadblock execution and signal/wait synchronization. + """ + + def __init__(self): + self.root_nodes: Set[OperationDependencyGraph.Node] = set() + self.last_node: Dict[Tuple[int, int], int] = {} + self.signalling: Dict[Tuple[int, int, int], Queue] = {} + self.waiting: Dict[Tuple[int, int, int], Queue] = {} + + self.barrier_nodes: Dict[Tuple[int, int], List[OperationDependencyGraph.Node]] = {} + self.tb_barriers: Dict[Tuple[int, int, int], int] = {} + self.node_list = [] + + def add_operation(self, operation, agg_node = None): + """ + Inserts an operation into the DAG, adding edges based on dependencies. + """ + rank = operation.rank + threadblock = operation.threadblock + node = self.Node(operation) + if agg_node is not None: + agg_node.add_node(node) + + if isinstance(operation, BarrierOperation): + if (rank, threadblock, operation.barrier_id) not in self.tb_barriers: + self.tb_barriers[(rank, threadblock, operation.barrier_id)] = 0 + if (rank, operation.barrier_id) not in self.barrier_nodes: + self.barrier_nodes[(rank, operation.barrier_id)] = [] + + barrier_count = self.tb_barriers[(rank, threadblock, operation.barrier_id)] + if barrier_count > len(self.barrier_nodes[(rank, operation.barrier_id)]): + raise RuntimeError(f"Barrier node not create correctly for rank {rank}, threadblock {threadblock}, barrier_id {operation.barrier_id}.") + elif barrier_count == len(self.barrier_nodes[(rank, operation.barrier_id)]): + agg_node = self.AggregateNode() + self.barrier_nodes[(rank, operation.barrier_id)].append(agg_node) + else: + agg_node = self.barrier_nodes[(rank, operation.barrier_id)][barrier_count] + + self.tb_barriers[(rank, threadblock, operation.barrier_id)] += 1 + agg_node.add_node(node) + node = agg_node + + self.node_list.append(node) + if (rank, threadblock) not in self.last_node: + self.last_node[(rank, threadblock)] = node + if node.get_input() == 0: + self.root_nodes.add(node) + else: + prev_node = self.last_node[(rank, threadblock)] + if prev_node is not node: + prev_node.next_nodes.append(node) + node.previous_nodes.append(prev_node) + node.add_input() + self.last_node[(rank, threadblock)] = node + if node in self.root_nodes: + self.root_nodes.remove(node) + + if isinstance(operation, SignalOperation) or (isinstance(operation, PutOperation) and (operation.with_signal or operation.with_signal_and_flush)): + for tb_channel_id in operation.channel_ids: + channel = ChannelRegister.get_channel(rank, threadblock, tb_channel_id) + op_info = (channel.src_rank, channel.dst_rank, channel.channel_peer_id) + if op_info not in self.waiting or self.waiting[op_info].empty(): + if op_info not in self.signalling: + self.signalling[op_info] = Queue() + self.signalling[op_info].put(node) + else: + waiting_node = self.waiting[op_info].get() + node.next_nodes.append(waiting_node) + waiting_node.previous_nodes.append(node) + waiting_node.add_input() + + if isinstance(operation, WaitOperation): + for tb_channel_id in operation.channel_ids: + channel = ChannelRegister.get_channel(rank, threadblock, tb_channel_id) + op_info = (channel.dst_rank, channel.src_rank, channel.channel_peer_id) + if op_info not in self.signalling or self.signalling[op_info].empty(): + if op_info not in self.waiting: + self.waiting[op_info] = Queue() + self.waiting[op_info].put(node) + else: + signalling_node = self.signalling[op_info].get() + signalling_node.next_nodes.append(node) + node.previous_nodes.append(signalling_node) + node.add_input() + + return node + + def add_tbg_operation(self, operations): + agg_node = self.AggregateNode() + for operation in operations: + self.add_operation(operation, agg_node) + + def add_semaphore_dependency(self): + queue = Queue() + sem_rel = {} + sem_acq = {} + sem_val = {} + + self.reset() + + def compute_sem_op(sem_op, node): + for id in node.operation.semaphore_ids: + if (node.operation.rank, id) not in sem_op: + sem_op[(node.operation.rank, id)] = [] + sem_val[(node.operation.rank, id)] = SemaphoreRegister.get_semaphore(node.operation.rank, id).initial_value + sem_op[(node.operation.rank, id)].append(node) + + def process_node(node): + for next_node in node.next_nodes: + next_node.add_reach() + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + + + for node in self.root_nodes: + queue.put(node) + + while True: + while not queue.empty(): + node = queue.get() + if isinstance(node, self.Node) and isinstance(node.operation, SemaphoreReleaseOperation): + compute_sem_op(sem_rel, node) + elif isinstance(node, self.Node) and isinstance(node.operation, SemaphoreAcquireOperation): + compute_sem_op(sem_acq, node) + else: + process_node(node) + + if not sem_rel and not sem_acq: + break + else: + removed_keys = [] + for key in sem_acq.keys(): + if key in sem_rel: + if len(sem_acq[key]) > 1 or sem_val[key] != len(sem_rel[key]) - len(sem_acq[key]): + raise RuntimeError(f"Undefined Behaviour Semaphore Id {key[1]}.") + else: + sem_acq_node = sem_acq[key][0] + sem_val[key] = 0 + if sem_acq_node in self.root_nodes: + self.root_nodes.remove(sem_acq_node) + process_node(sem_acq_node) + for sem_rel_node in sem_rel[key]: + process_node(sem_rel_node) + sem_rel_node.next_nodes.append(sem_acq_node) + sem_acq_node.operation.add_tb_sync(sem_rel_node.operation.threadblock) + sem_acq_node.previous_nodes.append(sem_rel_node) + sem_acq_node.add_input() + + removed_keys.append(key) + + for key in removed_keys: + sem_rel.pop(key) + sem_acq.pop(key) + + if len(sem_rel.keys()) > 0 or len(sem_acq.keys()): + raise RuntimeError(f"Undefined Semaphore Behaviour.") + + def reset(self): + for node in self.node_list: + node.reset() + + def print(self): + """ + Returns the order of operations in the DAG. + """ + self.reset() + self.check() + + queue = Queue() + for node in self.root_nodes: + queue.put(node) + + while not queue.empty(): + node = queue.get() + print(f"node {node.print()}") + for next_node in node.next_nodes: + next_node.add_reach() + print(f"next_node {next_node.print()}") + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + print() + + def check(self): + """ + Validates the DAG structure, ensuring all nodes are reachable and dependencies are correctly set. + """ + if len(self.signalling) > 0: + for key, queue in self.signalling.items(): + if not queue.empty(): + raise RuntimeError(f"Signalling from {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent wait operation.") + if len(self.waiting) > 0: + for key, queue in self.waiting.items(): + if not queue.empty(): + raise RuntimeError(f"Waiting for {key[0]} to {key[1]} on channel {key[2]} hasn't equivalent signal operation.") + + def get_execution_order(self): + """ + Returns the order of operations in the DAG. + """ + self.reset() + self.check() + + order = [] + queue = Queue() + for node in self.root_nodes: + queue.put(node) + + while not queue.empty(): + node = queue.get() + order.extend(node.get_operations()) + for next_node in node.next_nodes: + next_node.add_reach() + if next_node.get_reach() == next_node.get_input(): + if isinstance(next_node, self.Node) and next_node.agg_node is not None: + for sub_node in next_node.agg_node.nodes: + queue.put(sub_node) + else: + queue.put(next_node) + + return order + + class BaseNode(): + def __init__(self): + self.previous_nodes = [] + self.next_nodes = [] + self.input = 0 + self.reach = 0 + + def add_input(self): + self.input += 1 + + def add_reach(self): + self.reach += 1 + + def get_input(self): + return self.input + + def get_reach(self): + return self.reach + + def reset(self): + self.reach = 0 + + class Node(BaseNode): + def __init__(self, operation): + self.operation = operation + self.agg_node = None + super().__init__() + + def get_operations(self): + return [self.operation] + + def add_input(self): + if self.agg_node is not None: + self.agg_node.input += 1 + else: + self.input += 1 + + def add_reach(self): + if self.agg_node is not None: + self.agg_node.reach += 1 + else: + self.reach += 1 + + def get_input(self): + if self.agg_node is not None: + return self.agg_node.input + else: + return self.input + + def get_reach(self): + if self.agg_node is not None: + return self.agg_node.reach + else: + return self.reach + + def reset(self): + if self.agg_node is not None: + self.agg_node.reset() + else: + self.reach = 0 + + def print(self): + return f"rank {self.operation.rank} tb {self.operation.threadblock} {self.operation.name}" + + + class AggregateNode(BaseNode): + def __init__(self): + self.nodes = [] + super().__init__() + + def add_node(self, node): + self.nodes.append(node) + node.agg_node = self + + def get_operations(self): + operations = [] + for node in self.nodes: + operations.append(node) + return operations + + def print(self): + return f"rank {self.operations[0].rank} tb {self.operations[0].threadblock} {self.operations[0].name}" \ No newline at end of file diff --git a/python/mscclpp/language/internal/operations.py b/python/mscclpp/language/internal/operations.py index 1083c45bd..c7b731374 100644 --- a/python/mscclpp/language/internal/operations.py +++ b/python/mscclpp/language/internal/operations.py @@ -11,6 +11,7 @@ DataAccess, DataAccessType, ) +from mscclpp.language.thread_block_group import ThreadBlockGroup from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import List @@ -32,6 +33,8 @@ class BaseOperation(ABC): """ id: uuid.UUID = field(default_factory=uuid.uuid4, init=False) + rank: int + threadblock: int name: str def local_data_access(self, sync_purpose=True): @@ -112,18 +115,9 @@ def to_dict(self): return {"buffer_id": self.buffer_id, "index": self.index, "size": self.size} -@dataclass -class ThreadBlockGroupInfo: - tb_id: int - tbg_size: int - - def to_dict(self): - return {"tb_id": self.tb_id, "tbg_size": self.tbg_size} - - class SyncOperation(BaseOperation): - def __init__(self): - super().__init__(Instruction.nop) + def __init__(self, rank: int, threadblock: int): + super().__init__(rank, threadblock, Instruction.nop) def __add__(self, other): fused_operation = None @@ -146,36 +140,58 @@ def to_dict(self): class CopyOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[LocalChunk], dst_buff: List[LocalChunk], - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, from_packet: bool = False, to_packet: bool = False, ): if from_packet and to_packet: raise RuntimeError(f"Copy Operation from Packet to Packet is not Supported.") elif from_packet: - super().__init__(Instruction.unpack_packet) + super().__init__(rank, threadblock, Instruction.copy_packet) elif to_packet: - super().__init__(Instruction.copy_packet) + super().__init__(rank, threadblock, Instruction.transform_to_packet) else: - super().__init__(Instruction.copy) + super().__init__(rank, threadblock, Instruction.copy) self.src_buff = src_buff self.dst_buff = dst_buff - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] if self.name != Instruction.unpack_packet or not sync_purpose: for chunk in self.src_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.read, + self.tbg + ) ) if self.name != Instruction.copy_packet or not sync_purpose: for chunk in self.dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.write, + self.tbg + ) ) return data_access @@ -193,16 +209,20 @@ def to_dict(self): result["dst_buff"] = [] for chunk in self.dst_buff: result["dst_buff"].append(chunk.to_dict()) - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result class SemaphoreAcquireOperation(BaseOperation): - def __init__(self, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): - super().__init__(Instruction.sem_acquire) + def __init__(self, rank: int, threadblock: int, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): + super().__init__(rank, threadblock, Instruction.sem_acquire) self.semaphore_ids = semaphore_ids self.data_sync = data_sync + self.tb_sync = set() + + def add_tb_sync(self, tb): + self.tb_sync.add(tb) def shift_ids(self, instance, num_instances, replication_function): for i in range(len(self.semaphore_ids)): @@ -232,8 +252,8 @@ def to_dict(self): class SemaphoreReleaseOperation(BaseOperation): - def __init__(self, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): - super().__init__(Instruction.sem_release) + def __init__(self, rank: int, threadblock: int, semaphore_ids: List[int], data_sync: SyncType = SyncType.none): + super().__init__(rank, threadblock, Instruction.sem_release) self.semaphore_ids = semaphore_ids self.data_sync = data_sync @@ -267,15 +287,17 @@ def to_dict(self): class SignalOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none, relaxed: bool = False, ): if relaxed: - super().__init__(Instruction.relaxed_signal) + super().__init__(rank, threadblock, Instruction.relaxed_signal) else: - super().__init__(Instruction.signal) + super().__init__(rank, threadblock, Instruction.signal) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync @@ -314,15 +336,17 @@ def to_dict(self): class WaitOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none, relaxed: bool = False, ): if relaxed: - super().__init__(Instruction.relaxed_wait) + super().__init__(rank, threadblock, Instruction.relaxed_wait) else: - super().__init__(Instruction.wait) + super().__init__(rank, threadblock, Instruction.wait) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync @@ -361,7 +385,7 @@ def to_dict(self): class BarrierOperation(BaseOperation): __current_barriers = [] - def __init__(self, rank: int, tb_list: List[int]): + def __init__(self, rank: int, threadblock: int, tb_list: List[int]): for _ in range(len(BarrierOperation.__current_barriers), rank + 1): BarrierOperation.__current_barriers.append({}) barrier_info = BarrierOperation.BarrierInfo(tb_list) @@ -372,7 +396,7 @@ def __init__(self, rank: int, tb_list: List[int]): else: self.barrier_id = BarrierOperation.__current_barriers[rank][barrier_info] - super().__init__(Instruction.barrier) + super().__init__(rank, threadblock, Instruction.barrier) self.barrier_info = barrier_info def shift_ids(self, instance, num_instances, replication_function): @@ -406,8 +430,15 @@ def __hash__(self): class FlushOperation(BaseOperation): - def __init__(self, channels_ids: List[int], channel_type: ChannelType, data_sync: SyncType = SyncType.none): - super().__init__(Instruction.flush) + def __init__( + self, + rank: int, + threadblock: int, + channels_ids: List[int], + channel_type: ChannelType, + data_sync: SyncType = SyncType.none, + ): + super().__init__(rank, threadblock, Instruction.flush) self.channel_ids = set(channels_ids) self.channel_type = channel_type self.data_sync = data_sync @@ -440,24 +471,36 @@ def to_dict(self): class GetOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[RemoteChunk], dst_buff: List[LocalChunk], channel_ids: List[int], channel_type: ChannelType, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, ): - super().__init__(Instruction.get) + super().__init__(rank, threadblock, Instruction.get) self.src_buff = src_buff self.dst_buff = dst_buff self.channel_ids = channel_ids self.channel_type = channel_type - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] for chunk in self.dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.write, + self.tbg + ) ) return data_access @@ -473,14 +516,14 @@ def __add__(self, other): isinstance(other, GetOperation) and self.src_buff[0].size == other.src_buff[0].size and self.channel_type == other.channel_type - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = GetOperation( src_buff=self.src_buff + other.src_buff, dst_buff=self.dst_buff + other.dst_buff, channel_ids=self.channel_ids + other.channel_ids, channel_type=self.channel_type, - tbg_info=self.tbg_info, + tbg=self.tbg, ) return fused_operation @@ -495,40 +538,42 @@ def to_dict(self): result["dst_buff"].append(chunk.to_dict()) result["channel_ids"] = self.channel_ids result["channel_type"] = self.channel_type.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result class PutOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, src_buff: List[LocalChunk], dst_buff: List[RemoteChunk], channel_ids: List[int], channel_type: ChannelType, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, from_packet: bool = False, to_packet: bool = False, with_signal: bool = False, with_signal_and_flush: bool = False, ): if from_packet and to_packet: - super().__init__(Instruction.read_put_packet) + super().__init__(rank, threadblock, Instruction.read_put_packet) elif to_packet: - super().__init__(Instruction.put_packet) + super().__init__(rank, threadblock, Instruction.put_packet) elif from_packet: raise RuntimeError(f"Put Operation from Packet is not Supported.") else: if with_signal: if with_signal_and_flush: - super().__init__(Instruction.put_with_signal_and_flush) + super().__init__(rank, threadblock, Instruction.put_with_signal_and_flush) else: - super().__init__(Instruction.put_with_signal) + super().__init__(rank, threadblock, Instruction.put_with_signal) elif with_signal_and_flush: - super().__init__(Instruction.put_with_signal_and_flush) + super().__init__(rank, threadblock, Instruction.put_with_signal_and_flush) else: - super().__init__(Instruction.put) + super().__init__(rank, threadblock, Instruction.put) self.src_buff = src_buff self.dst_buff = dst_buff @@ -537,14 +582,24 @@ def __init__( self.to_packet = to_packet self.with_signal = with_signal self.with_signal_and_flush = with_signal_and_flush - self.tbg_info = tbg_info + self.tbg = tbg - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] if self.name != Instruction.read_put_packet or not sync_purpose: for chunk in self.src_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.read, + self.tbg + ) ) return data_access @@ -567,14 +622,14 @@ def __add__(self, other): and self.name == other.name and self.src_buff[0].size == other.src_buff[0].size and self.channel_type == other.channel_type - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = PutOperation( src_buff=self.src_buff + other.src_buff, dst_buff=self.dst_buff + other.dst_buff, channel_ids=self.channel_ids + other.channel_ids, channel_type=self.channel_type, - tbg_info=self.tbg_info, + tbg=self.tbg, to_packet=self.to_packet, with_signal=self.with_signal, with_signal_and_flush=self.with_signal_and_flush, @@ -593,8 +648,8 @@ def to_dict(self): if self.channel_type == ChannelType.port: result["channel_ids"] = self.channel_ids result["channel_type"] = self.channel_type.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result @@ -602,6 +657,8 @@ def to_dict(self): class ReduceOperation(BaseOperation): def __init__( self, + rank, + threadblock: int, local_src_buff: List[LocalChunk], local_dst_buff: List[LocalChunk], remote_src_buff: List[RemoteChunk] = None, @@ -610,7 +667,7 @@ def __init__( put_channel_ids: List[int] = None, channel_type: ChannelType = ChannelType.none, reduce_operation: ReduceOperationType = ReduceOperationType.sum, - tbg_info: ThreadBlockGroupInfo = None, + tbg: ThreadBlockGroup = None, packet: bool = False, ): remote_src_buff = remote_src_buff if remote_src_buff is not None else [] @@ -620,18 +677,18 @@ def __init__( if len(remote_src_buff) == 0 and len(remote_dst_buff) == 0: if packet: - super().__init__(Instruction.reduce_packet) + super().__init__(rank, threadblock, Instruction.reduce_packet) else: - super().__init__(Instruction.reduce) + super().__init__(rank, threadblock, Instruction.reduce) elif len(remote_src_buff) == 0: if packet: - super().__init__(Instruction.reduce_send_packet) + super().__init__(rank, threadblock, Instruction.reduce_send_packet) else: - super().__init__(Instruction.reduce_send) + super().__init__(rank, threadblock, Instruction.reduce_send) elif len(remote_dst_buff) == 0 and not packet: - super().__init__(Instruction.read_reduce) + super().__init__(rank, threadblock, Instruction.read_reduce) elif not packet: - super().__init__(Instruction.read_reduce_send) + super().__init__(rank, threadblock, Instruction.read_reduce_send) else: raise RuntimeError(f"Reduce Operation invalid parameters.") @@ -643,20 +700,40 @@ def __init__( self.put_channel_ids = put_channel_ids self.channel_type = channel_type self.reduce_operation = reduce_operation - self.tbg_info = tbg_info + self.tbg = tbg self.packet = packet - def local_data_access(self, sync_purpose=True): + def local_data_access(self, order_id, sync_purpose=True): data_access = [] for i in range(len(self.local_src_buff)): chunk = self.local_src_buff[i] if not self.packet or i != 0 or not sync_purpose: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.read) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.read, + self.tbg + ) ) for chunk in self.local_dst_buff: data_access.append( - DataAccess(self.id, chunk.index, chunk.index + chunk.size - 1, chunk.type, DataAccessType.write) + DataAccess( + self.rank, + self.threadblock, + self.id, + order_id, + chunk.index + self.tbg.start_offset(self.threadblock, chunk.size) if self.tbg is not None else 0, + chunk.index + self.tbg.end_offset(self.threadblock, chunk.size) if self.tbg is not None else chunk.size, + chunk.type, + DataAccessType.write, + self.tbg + ) ) return data_access @@ -684,7 +761,7 @@ def __add__(self, other): and self.local_dst_buff == other.local_dst_buff and self.channel_type == other.channel_type and self.reduce_operation == other.reduce_operation - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = ReduceOperation( self.local_src_buff + other.local_src_buff[1:], @@ -693,7 +770,7 @@ def __add__(self, other): channel_ids=self.channel_ids + other.channel_ids, channel_type=self.channel_type, reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, + tbg=self.tbg, packet=self.packet, ) if ( @@ -707,7 +784,7 @@ def __add__(self, other): and other.name == Instruction.put and self.local_dst_buff[0] == other.src_buff[0] and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = ReduceOperation( self.local_src_buff, @@ -718,7 +795,7 @@ def __add__(self, other): put_channel_ids=self.put_channel_ids + other.channel_ids, channel_type=self.channel_type, reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, + tbg=self.tbg, packet=self.packet, ) if ( @@ -727,7 +804,7 @@ def __add__(self, other): and other.name == Instruction.put_packet and self.local_dst_buff[0] == other.src_buff[0] and other.channel_type == ChannelType.memory - and self.tbg_info == other.tbg_info + and self.tbg == other.tbg ): fused_operation = ReduceOperation( self.local_src_buff, @@ -738,7 +815,7 @@ def __add__(self, other): put_channel_ids=self.put_channel_ids + other.channel_ids, channel_type=other.channel_type, reduce_operation=self.reduce_operation, - tbg_info=self.tbg_info, + tbg=self.tbg, packet=self.packet, ) @@ -763,8 +840,8 @@ def to_dict(self): if self.channel_type != ChannelType.none: result["channel_type"] = self.channel_type.value result["reduce_op"] = self.reduce_operation.value - if self.tbg_info is not None: - result["tbg_info"] = self.tbg_info.to_dict() + if self.tbg is not None: + result["tbg"] = self.tbg.to_dict(self.threadblock) return result @@ -772,6 +849,8 @@ def to_dict(self): class GroupLoadReduce(BaseOperation): def __init__( self, + rank, + threadblock: int, buffer_type: BufferType, buffer_offset: int, size: int, @@ -780,7 +859,7 @@ def __init__( channel_type: ChannelType = ChannelType.switch, reduce_operation: ReduceOperationType = ReduceOperationType.sum, ): - super().__init__(Instruction.group_load_reduce) + super().__init__(rank, threadblock, Instruction.group_load_reduce) self.buffer_type = buffer_type self.buffer_offset = buffer_offset self.size = size @@ -831,6 +910,8 @@ def to_dict(self): class GroupStore(BaseOperation): def __init__( self, + rank, + threadblock: int, src_chunk: Chunk, buffer_type: BufferType, buffer_offset: int, @@ -838,7 +919,7 @@ def __init__( channel_ids: List[int], channel_type: ChannelType = ChannelType.switch, ): - super().__init__(Instruction.group_store) + super().__init__(rank, threadblock, Instruction.group_store) self.src_chunk = src_chunk self.buffer_type = buffer_type self.buffer_offset = buffer_offset @@ -865,6 +946,8 @@ def to_dict(self): class GroupLoadReduceStore(BaseOperation): def __init__( self, + rank, + threadblock: int, buffer_type: BufferType, size: int, src_index: List[int], @@ -873,7 +956,7 @@ def __init__( channel_type: ChannelType = ChannelType.switch, reduce_operation: ReduceOperationType = ReduceOperationType.sum, ): - super().__init__(Instruction.group_load_reduce_store) + super().__init__(rank, threadblock, Instruction.group_load_reduce_store) self.buffer_type = buffer_type self.size = size self.src_index = src_index @@ -907,8 +990,8 @@ def to_dict(self): @dataclass class PipelineOperation(BaseOperation): - def __init__(self, unit_size: int, num_chunks: int, operations=None): - super().__init__(Instruction.pipeline) + def __init__(self, rank: int, threadblock: int, unit_size: int, num_chunks: int, operations=None): + super().__init__(rank, threadblock, Instruction.pipeline) self.unit_size = unit_size self.num_chunks = num_chunks self.operations = operations if operations is not None else [] diff --git a/python/mscclpp/language/internal/register.py b/python/mscclpp/language/internal/register.py new file mode 100644 index 000000000..ed0737d9a --- /dev/null +++ b/python/mscclpp/language/internal/register.py @@ -0,0 +1,22 @@ + +class ChannelRegister: + channels = {} + + @staticmethod + def add_channel(rank, tb, tb_channel_id, channel): + ChannelRegister.channels[(rank, tb, tb_channel_id)] = channel + + @staticmethod + def get_channel(rank: int, threadblock: int, tb_channel_id: int): + return ChannelRegister.channels.get((rank, threadblock, tb_channel_id)) + +class SemaphoreRegister: + semaphores = {} + + @staticmethod + def add_semaphore(semaphore): + SemaphoreRegister.semaphores[(semaphore.rank, semaphore.id)] = semaphore + + @staticmethod + def get_semaphore(rank: int, semaphore_id: int): + return SemaphoreRegister.semaphores.get((rank, semaphore_id)) \ No newline at end of file diff --git a/python/mscclpp/language/internal/types.py b/python/mscclpp/language/internal/types.py index 411a46959..19000fdd3 100644 --- a/python/mscclpp/language/internal/types.py +++ b/python/mscclpp/language/internal/types.py @@ -1,11 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from dataclasses import dataclass +from mscclpp.language.thread_block_group import ThreadBlockGroup +from dataclasses import dataclass, field from enum import Enum from typing import List, Set from collections import defaultdict - +import uuid class SyncType(Enum): none = "none" @@ -172,33 +173,77 @@ def __str__(self): @dataclass class DataAccess: - operation_id: int - start: int - end: int + rank: int + threadblock: int + operation_global_id: uuid.UUID + operation_order_id: int + start: float + end: float buffer_type: BufferType data_access_type: DataAccessType - + tb_group: ThreadBlockGroup = None + def __lt__(self, other): if self.start != other.start: return self.start < other.start return self.end < other.end - def __eq__(self, other): - return self.start == other.start and self.end == other.end + def __eq__(self, other, tolerance=1e-5): + return (abs(self.start - other.start) < tolerance and + abs(self.end - other.end) < tolerance) def __hash__(self): return hash((self.start, self.end)) - def overlaps(self, other) -> bool: - return self.start <= other.end and other.start <= self.end + def lower_overlaps(self, other, tolerance=1e-5) -> bool: + return (self.start + tolerance < other.end) + + def overlaps(self, other, tolerance=1e-5) -> bool: + return (self.start + tolerance < other.end) and (other.start + tolerance < self.end) def check_conflict(self, other) -> bool: - return ( + if ( self.overlaps(other) - and self.operation_id != other.operation_id + and self.operation_global_id != other.operation_global_id and (self.data_access_type != DataAccessType.read or other.data_access_type != DataAccessType.read) - ) + ): + if self.threadblock == other.threadblock: + return DataAccessConflict(self.rank, {(other.threadblock, other.operation_order_id, True)}, DataAccessConflictType.intra_threadblock) + else: + is_order_defined = ((self.tb_group is not None and other.tb_group is not None and self.tb_group.tbg_overlap(other.tb_group)) + or (self.tb_group is not None and other.tb_group is None and self.tb_group.tb_overlap(other.threadblock)) + or (self.tb_group is None and other.tb_group is not None and other.tb_group.tb_overlap(self.threadblock))) + return DataAccessConflict(self.rank, {(self.threadblock, other.operation_order_id, True), (other.threadblock, other.operation_order_id, is_order_defined)}, DataAccessConflictType.inter_threadblock) + else: + return DataAccessConflict(self.rank) + +class DataAccessConflictType(Enum): + inter_threadblock = "inter_tb" + intra_threadblock = "intra_tb" + none = "none" + + def __add__(self, other): + if not isinstance(other, DataAccessConflictType): + return NotImplemented + + map_to_num = {DataAccessConflictType.none: 0, DataAccessConflictType.intra_threadblock: 1, DataAccessConflictType.inter_threadblock: 3} + map_to_dact = {0: DataAccessConflictType.none, 1: DataAccessConflictType.intra_threadblock, 3: DataAccessConflictType.inter_threadblock} + return map_to_dact[map_to_num[self] | map_to_num[other]] + + def __str__(self): + return self.value + +@dataclass +class DataAccessConflict(): + rank: int + threadblocks: Set[int] = field(default_factory=set) + conflict_type: DataAccessConflictType = DataAccessConflictType.none + + def __add__(self, other): + if not isinstance(other, DataAccessConflict): + return NotImplemented + return DataAccessConflict(self.rank, self.threadblocks | other.threadblocks, self.conflict_type + other.conflict_type) class ReplicationPolicy(Enum): interleaved = "interleaved" diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index 007a8bcf9..019063a79 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -5,6 +5,9 @@ from mscclpp.language.internal.globals import set_program from mscclpp.language.internal.types import BufferType, RemoteBuffer, ChannelType, ReplicationPolicy from mscclpp.language.internal.gpu import Gpu +from mscclpp.language.internal.register import ChannelRegister, SemaphoreRegister +from mscclpp.language.internal.op_dep_graph import OperationDependencyGraph +from mscclpp.language.internal.buffer_access import BuffersAccess from typing import List import json @@ -99,6 +102,8 @@ def __init__( self.min_message_size = min_message_size self.max_message_size = max_message_size assert protocol == "Simple" or protocol == "LL", f"Given protocol: {protocol}. Must be either Simple, LL" + self.op_dep_dag = OperationDependencyGraph() + self.buffers_access = BuffersAccess(num_ranks) self.buffers = collective.init_buffers() self.gpus: List[Gpu] = [] for rank in range(self.num_ranks): @@ -134,22 +139,34 @@ def add_channel(self, channel): def setup_channel(self, tb, channel): tb_channel_ids = [] tb_channel_ids.append(self.gpus[channel.src_rank].setup_channel(tb, channel)) + for tb_channel_id in tb_channel_ids: + ChannelRegister.add_channel(channel.src_rank, tb, tb_channel_id, channel) return tb_channel_ids def setup_remote_chunk(self, rank, tb, remote_chunk: RemoteBuffer, channel_access: ChannelType): return self.gpus[rank].add_remote_buffer(tb, remote_chunk, channel_access) def add_semaphore(self, semaphore): + SemaphoreRegister.add_semaphore(semaphore) self.gpus[semaphore.rank].add_semaphore(semaphore) def add_operation(self, rank, tb, operation): if self.loop_context != None: self.loop_context.add_operation(rank, tb, operation) else: - self.gpus[rank].add_operation(tb, operation) + self.op_dep_dag.add_operation(operation) + + def add_tbg_operation(self, operations): + self.op_dep_dag.add_tbg_operation(operations) def post_process_operations(self): - for gpu in self.gpus: + self.op_dep_dag.add_semaphore_dependency() + list_op = self.op_dep_dag.get_execution_order() + list_op = self.buffers_access.process_operations(list_op) + for op in list_op: + self.gpus[op.rank].add_operation(op.threadblock, op) + + """ for gpu in self.gpus: if self.instr_fusion: gpu.optimize_operations() gpu.adding_data_sync() @@ -159,7 +176,7 @@ def post_process_operations(self): self.instances, self.get_default_replication_policy_function(), self.get_buffer_replication_policy_function(), - ) + ) """ def get_default_replication_policy_function(self): return lambda value, instance, num_instances: value * num_instances + instance diff --git a/python/mscclpp/language/rank.py b/python/mscclpp/language/rank.py index 02374c83a..4dcaa6c7a 100644 --- a/python/mscclpp/language/rank.py +++ b/python/mscclpp/language/rank.py @@ -110,20 +110,27 @@ def _copy( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: op = CopyOperation( + rank=self.rank, + threadblock=tb_id, src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)], dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), from_packet=from_packet, to_packet=to_packet, ) - - get_program().add_operation(self.rank, tb_id, op) + operations.append(op) + + if tb_group is None: + get_program().add_operation(self.rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def copy(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Copy data from source chunk to destination chunk. @@ -240,21 +247,28 @@ def reduce( "Either 'tb' (thread block ID) or 'tb_group' (ThreadBlockGroup) must be provided, but both are None." ) + operations = [] for tb_id in tb_list: op = ReduceOperation( - [LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)] + rank=self.rank, + threadblock=tb_id, + local_src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)] + [LocalChunk(chunk.buffer, chunk.index, chunk.size) for chunk in other_chunks], - [LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], + local_dst_buff=[LocalChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size)], reduce_operation=reduce_op, - tbg_info=( - ThreadBlockGroupInfo(tb_group.get_internal_id(tb_id), tb_group.numtb()) + tbg=( + tb_group if tb_group is not None else None ), packet=packet, ) + operations.append(op) - get_program().add_operation(self.rank, tb_id, op) + if tb_group is None: + get_program().add_operation(self.rank, tb_id, operations[0]) + else: + get_program().add_tbg_operation(operations) def barrier(self, tb_list: List[int]): """Create a synchronization barrier between thread blocks. @@ -278,8 +292,8 @@ def barrier(self, tb_list: List[int]): op = SyncOperation() get_program().add_operation(self.rank, tb_list[0], op) else: - op = BarrierOperation(self.rank, tb_list) for tb in tb_list: + op = BarrierOperation(self.rank, tb, tb_list) get_program().add_operation(self.rank, tb, op) @@ -408,7 +422,7 @@ def acquire(self, tb: int, data_sync: SyncType = SyncType.both): Example: >>> sem.acquire(tb=0, data_sync=SyncType.before) """ - op = SemaphoreAcquireOperation([self.id], data_sync) + op = SemaphoreAcquireOperation(self.rank, tb, [self.id], data_sync) get_program().add_operation(self.rank, tb, op) def release(self, tb: int, data_sync: SyncType = SyncType.both): @@ -426,5 +440,5 @@ def release(self, tb: int, data_sync: SyncType = SyncType.both): Example: >>> sem.release(tb=0, data_sync=SyncType.after) """ - op = SemaphoreReleaseOperation([self.id], data_sync) + op = SemaphoreReleaseOperation(self.rank, tb, [self.id], data_sync) get_program().add_operation(self.rank, tb, op) diff --git a/python/mscclpp/language/thread_block_group.py b/python/mscclpp/language/thread_block_group.py index 9ffe4e2dd..a0d63c498 100644 --- a/python/mscclpp/language/thread_block_group.py +++ b/python/mscclpp/language/thread_block_group.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import List, Dict +from typing import List, Dict, Set class ThreadBlockGroup: @@ -20,10 +20,9 @@ def __init__(self, tb_list: List[int]): tb_list: List of thread block objects """ - self.tb_list: List[int] = tb_list + self.tb_list: Set[int] = set(tb_list) self._tb_id: Dict[int, int] = {} - # Check for duplicates and build ID mapping seen = set() for i, tb in enumerate(self.tb_list): if tb in seen: @@ -51,3 +50,25 @@ def get_internal_id(self, tb: int) -> int: def numtb(self) -> int: """Return the number of thread blocks in the group.""" return len(self.tb_list) + + def tbg_overlap(self, other): + for tb in self.tb_list: + if tb in other.tb_list: + return True + return False + + def tb_overlap(self, tb_id): + return tb_id in self.tb_list + + def to_dict(self, tb): + return {"tb_id": self.get_internal_id(tb), "tbg_size": self.numtb()} + + def start_offset(self, tb, size): + tb_id = self.get_internal_id(tb) + return (size / self.numtb()) * tb_id + + def end_offset(self, tb, size): + tb_id = self.get_internal_id(tb) + return (size / self.numtb()) * (tb_id + 1) + +