torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
monarch/simulator/ir.py
ADDED
@@ -0,0 +1,770 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import csv
|
9
|
+
import json
|
10
|
+
from collections import defaultdict
|
11
|
+
from dataclasses import dataclass, field
|
12
|
+
from itertools import count
|
13
|
+
from typing import (
|
14
|
+
Any,
|
15
|
+
DefaultDict,
|
16
|
+
Dict,
|
17
|
+
Iterator,
|
18
|
+
List,
|
19
|
+
NamedTuple,
|
20
|
+
Optional,
|
21
|
+
Set,
|
22
|
+
Tuple,
|
23
|
+
Union,
|
24
|
+
)
|
25
|
+
|
26
|
+
import torch
|
27
|
+
|
28
|
+
|
29
|
+
class Command(NamedTuple):
|
30
|
+
"""
|
31
|
+
Represents a node in the control flow DAG that tracks command execution on workers.
|
32
|
+
|
33
|
+
Each Command node captures an operation executed on a specific worker and stream,
|
34
|
+
including its control dependencies and associated devices.
|
35
|
+
|
36
|
+
Attributes:
|
37
|
+
worker_rank (int): Worker that executed the command
|
38
|
+
stream_name (str): Stream on which the command was executed
|
39
|
+
command_id (int): Unique identifier for the command
|
40
|
+
command_name (str): Name of command (CallFunction: aten:mm, SendTensor: 7, etc.)
|
41
|
+
devices (List[int]): Device IDs associated with this command
|
42
|
+
control_dependencies (List[int]): Command IDs this command depends on
|
43
|
+
traceback (List[str]): Python traceback at command execution
|
44
|
+
duration (int): Command execution duration in milliseconds
|
45
|
+
"""
|
46
|
+
|
47
|
+
worker_rank: int
|
48
|
+
stream_name: str
|
49
|
+
command_id: int
|
50
|
+
command_name: str
|
51
|
+
devices: List[int]
|
52
|
+
control_dependencies: List[int]
|
53
|
+
traceback: List[str]
|
54
|
+
duration: int = 0 # ms
|
55
|
+
|
56
|
+
|
57
|
+
class StorageCreationEvent(NamedTuple):
|
58
|
+
command_id: int
|
59
|
+
storage_id: int
|
60
|
+
dtype: Optional[torch.dtype]
|
61
|
+
dims: Optional[tuple]
|
62
|
+
size: Optional[int]
|
63
|
+
devices: List[int]
|
64
|
+
stream_name: str
|
65
|
+
|
66
|
+
|
67
|
+
class StorageDeletionEvent(NamedTuple):
|
68
|
+
command_id: int
|
69
|
+
storage_id: int
|
70
|
+
dtype: Optional[torch.dtype]
|
71
|
+
dims: Optional[tuple]
|
72
|
+
size: Optional[int]
|
73
|
+
devices: List[int]
|
74
|
+
stream_name: str
|
75
|
+
|
76
|
+
|
77
|
+
class TensorCreationEvent(NamedTuple):
|
78
|
+
command_id: int
|
79
|
+
DTensorRef: int
|
80
|
+
storage_id: int
|
81
|
+
dims: Optional[
|
82
|
+
tuple
|
83
|
+
] # TODO: make sure dims here reflect tensor's and not storages'
|
84
|
+
devices: List[int]
|
85
|
+
stream_name: str
|
86
|
+
|
87
|
+
|
88
|
+
class TensorAccessEvent(NamedTuple):
|
89
|
+
command_id: int
|
90
|
+
DTensorRef: int
|
91
|
+
storage_id: int
|
92
|
+
dims: Optional[tuple]
|
93
|
+
devices: List[int]
|
94
|
+
stream_name: str
|
95
|
+
|
96
|
+
|
97
|
+
class TensorMutationEvent(NamedTuple):
|
98
|
+
command_id: int
|
99
|
+
DTensorRef: int
|
100
|
+
storage_id: int
|
101
|
+
dims: Optional[tuple]
|
102
|
+
devices: List[int]
|
103
|
+
stream_name: str
|
104
|
+
|
105
|
+
|
106
|
+
class TensorDeletionEvent(NamedTuple):
|
107
|
+
command_id: int
|
108
|
+
DTensorRef: int
|
109
|
+
storage_id: int
|
110
|
+
dims: Optional[tuple]
|
111
|
+
devices: List[int]
|
112
|
+
stream_name: str
|
113
|
+
|
114
|
+
|
115
|
+
"""
|
116
|
+
Represents a node in the data flow DAG that tracks tensor and storage lifecycle events.
|
117
|
+
|
118
|
+
Each DataEvent captures a specific event in the lifecycle of tensors and storage objects,
|
119
|
+
including creation, access, mutation, and deletion operations across workers and devices.
|
120
|
+
"""
|
121
|
+
DataEvent = Union[
|
122
|
+
StorageCreationEvent,
|
123
|
+
StorageDeletionEvent,
|
124
|
+
TensorCreationEvent,
|
125
|
+
TensorAccessEvent,
|
126
|
+
TensorMutationEvent,
|
127
|
+
TensorDeletionEvent,
|
128
|
+
]
|
129
|
+
|
130
|
+
|
131
|
+
@dataclass
|
132
|
+
class BorrowInfo:
|
133
|
+
borrow_id: Optional[int] = None
|
134
|
+
devices: Set[int] = field(default_factory=set)
|
135
|
+
src_stream_name: Optional[str] = None
|
136
|
+
dst_stream_name: Optional[str] = None
|
137
|
+
create_id: Optional[int] = None
|
138
|
+
firstuse_id: Optional[int] = None
|
139
|
+
lastuse_id: Optional[int] = None
|
140
|
+
drop_id: Optional[int] = None
|
141
|
+
|
142
|
+
|
143
|
+
@dataclass
|
144
|
+
class SendTensorInfo:
|
145
|
+
result_tensor_id: Optional[int] = None
|
146
|
+
src_devices: Optional[List[int]] = None
|
147
|
+
src_stream_name: Optional[str] = None
|
148
|
+
dst_devices: Optional[List[int]] = None
|
149
|
+
dst_stream_name: Optional[str] = None
|
150
|
+
result_tensor_dims: Optional[Tuple[int, ...]] = None
|
151
|
+
|
152
|
+
|
153
|
+
@dataclass
|
154
|
+
class TensorInfo:
|
155
|
+
storage_id: Optional[int] = None
|
156
|
+
DTensorRefs: Set[int] = field(default_factory=set)
|
157
|
+
dtype: Optional[torch.dtype] = None
|
158
|
+
dims: Tuple[int, ...] = field(default_factory=tuple)
|
159
|
+
size: Optional[int] = None
|
160
|
+
devices: Set[int] = field(default_factory=set)
|
161
|
+
stream_name: Optional[str] = None
|
162
|
+
storage_create_id: Optional[int] = None
|
163
|
+
tensor_create_ids: Set[int] = field(default_factory=set)
|
164
|
+
access_ids: Set[int] = field(default_factory=set)
|
165
|
+
mutation_ids: Set[int] = field(default_factory=set)
|
166
|
+
lastuse_id: Optional[int] = None
|
167
|
+
tensor_deletion_ids: Set[int] = field(default_factory=set)
|
168
|
+
storage_deletion_id: Optional[int] = None
|
169
|
+
|
170
|
+
|
171
|
+
class IRGraph:
|
172
|
+
"""
|
173
|
+
Represents an intermediate representation (IR) graph for distributed tensor operations.
|
174
|
+
|
175
|
+
The IRGraph tracks both control flow (commands executed on workers) and data flow
|
176
|
+
(tensor/storage lifecycle events) in distributed tensor computations. It consists of:
|
177
|
+
|
178
|
+
1. Control DAG: Tracks command execution across workers and streams
|
179
|
+
2. Data DAG: Tracks tensor and storage lifecycle events (creation, access, mutation, deletion)
|
180
|
+
|
181
|
+
The graph can be exported to Chrome Trace format for visualization, and additional CSV
|
182
|
+
exports provide detailed information about borrows, tensor sends, and data dependencies.
|
183
|
+
|
184
|
+
Attributes:
|
185
|
+
control_dag (List[Command]): Command nodes representing operations executed on workers
|
186
|
+
data_dag (List[DataEvent]): Data events tracking tensor/storage lifecycle
|
187
|
+
_control: Internal manager for control flow information (borrows, sendtensor)
|
188
|
+
_data: Internal manager for data flow information (tensors, storage)
|
189
|
+
"""
|
190
|
+
|
191
|
+
def __init__(self) -> None:
|
192
|
+
self.control_dag: List[Command] = []
|
193
|
+
self.data_dag: List[DataEvent] = []
|
194
|
+
self._control: IRGraph._ControlManager = self._ControlManager()
|
195
|
+
self._data: IRGraph._DataManager = self._DataManager()
|
196
|
+
|
197
|
+
def insert_node(
|
198
|
+
self,
|
199
|
+
worker_rank: int,
|
200
|
+
stream_name: str,
|
201
|
+
command_id: int,
|
202
|
+
command_name: str,
|
203
|
+
devices: List[int],
|
204
|
+
control_dependencies: List[int],
|
205
|
+
traceback: List[str],
|
206
|
+
) -> None:
|
207
|
+
new_dag_node = Command(
|
208
|
+
worker_rank=worker_rank,
|
209
|
+
stream_name=stream_name,
|
210
|
+
command_id=command_id,
|
211
|
+
command_name=command_name,
|
212
|
+
devices=devices,
|
213
|
+
control_dependencies=control_dependencies,
|
214
|
+
traceback=traceback,
|
215
|
+
)
|
216
|
+
self.control_dag.append(new_dag_node)
|
217
|
+
|
218
|
+
def add_borrow(
|
219
|
+
self,
|
220
|
+
borrow_id: int,
|
221
|
+
device: int,
|
222
|
+
src_stream_name: str,
|
223
|
+
dst_stream_name: str,
|
224
|
+
create_id: int,
|
225
|
+
) -> None:
|
226
|
+
self._control.borrows_info[borrow_id].borrow_id = borrow_id
|
227
|
+
self._control.borrows_info[borrow_id].devices.add(device)
|
228
|
+
self._control.borrows_info[borrow_id].src_stream_name = src_stream_name
|
229
|
+
self._control.borrows_info[borrow_id].dst_stream_name = dst_stream_name
|
230
|
+
self._control.borrows_info[borrow_id].create_id = create_id
|
231
|
+
|
232
|
+
def update_tensor(
|
233
|
+
self,
|
234
|
+
temp_id: int,
|
235
|
+
ref: int,
|
236
|
+
dtype: torch.dtype,
|
237
|
+
dims: Tuple[int, ...],
|
238
|
+
worker_rank: int,
|
239
|
+
stream_name: str,
|
240
|
+
command_id: int,
|
241
|
+
mutate=False,
|
242
|
+
borrow_src_tensor_ref: Optional[int] = None,
|
243
|
+
tensor_size: Optional[int] = None,
|
244
|
+
) -> None:
|
245
|
+
new_tensor_event = new_storage_event = False
|
246
|
+
update_tensor_devices = update_storage_devices = False
|
247
|
+
|
248
|
+
if temp_id not in self._data.id_to_storageid:
|
249
|
+
if borrow_src_tensor_ref is None:
|
250
|
+
new_storage_event = True
|
251
|
+
storage_id = next(self._data.storageid_counter)
|
252
|
+
self._data.id_to_storageid[temp_id] = storage_id
|
253
|
+
self._data.data_dependency_info[storage_id].storage_id = storage_id
|
254
|
+
self._data.data_dependency_info[storage_id].dtype = dtype
|
255
|
+
self._data.data_dependency_info[storage_id].dims = dims
|
256
|
+
self._data.data_dependency_info[storage_id].size = tensor_size
|
257
|
+
self._data.data_dependency_info[storage_id].stream_name = stream_name
|
258
|
+
self._data.data_dependency_info[
|
259
|
+
storage_id
|
260
|
+
].storage_create_id = command_id
|
261
|
+
# borrow aliasing
|
262
|
+
else:
|
263
|
+
storage_id = self._data.tensorref_to_storageid[borrow_src_tensor_ref]
|
264
|
+
self._data.id_to_storageid[temp_id] = storage_id
|
265
|
+
else:
|
266
|
+
storage_id = self._data.id_to_storageid[temp_id]
|
267
|
+
if worker_rank not in self._data.data_dependency_info[storage_id].devices:
|
268
|
+
update_storage_devices = True
|
269
|
+
self._data.data_dependency_info[storage_id].devices.add(worker_rank)
|
270
|
+
|
271
|
+
if ref not in self._data.tensorref_to_stream:
|
272
|
+
new_tensor_event = True
|
273
|
+
self._data.tensorref_to_storageid[ref] = storage_id
|
274
|
+
self._data.tensorref_to_mesh[ref].add(worker_rank)
|
275
|
+
self._data.tensorref_to_stream[ref] = stream_name
|
276
|
+
self._data.storageid_to_tensorref[storage_id].add(ref)
|
277
|
+
|
278
|
+
self._data.data_dependency_info[storage_id].DTensorRefs.add(ref)
|
279
|
+
self._data.data_dependency_info[storage_id].tensor_create_ids.add(
|
280
|
+
command_id
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
if worker_rank not in self._data.tensorref_to_mesh[ref]:
|
284
|
+
update_tensor_devices = True
|
285
|
+
self._data.tensorref_to_mesh[ref].add(worker_rank)
|
286
|
+
|
287
|
+
self._data.data_dependency_info[storage_id].access_ids.add(command_id)
|
288
|
+
self._data.data_dependency_info[
|
289
|
+
storage_id
|
290
|
+
].lastuse_id = command_id # commands are processed in increasing command_id
|
291
|
+
if mutate:
|
292
|
+
self._data.data_dependency_info[storage_id].mutation_ids.add(command_id)
|
293
|
+
|
294
|
+
# Helper function to find or create events
|
295
|
+
def find_or_create_event(event_type):
|
296
|
+
# Look for existing event with same command_id and event_type
|
297
|
+
# Look backwards since events are processed in increasing command_id
|
298
|
+
for i in range(len(self.data_dag) - 1, -1, -1):
|
299
|
+
event = self.data_dag[i]
|
300
|
+
event_class_name = event.__class__.__name__
|
301
|
+
if (
|
302
|
+
event.command_id == command_id
|
303
|
+
and event_class_name == event_type
|
304
|
+
and (not hasattr(event, "DTensorRef") or event.DTensorRef == ref)
|
305
|
+
):
|
306
|
+
# If worker_rank already exists, just return True
|
307
|
+
if worker_rank in event.devices:
|
308
|
+
return True
|
309
|
+
|
310
|
+
# Update devices list
|
311
|
+
updated_devices = event.devices + [worker_rank]
|
312
|
+
updated_event = event._replace(devices=updated_devices)
|
313
|
+
self.data_dag[i] = updated_event
|
314
|
+
return True
|
315
|
+
return False
|
316
|
+
|
317
|
+
if new_storage_event and not find_or_create_event("StorageCreationEvent"):
|
318
|
+
self.data_dag.append(
|
319
|
+
StorageCreationEvent(
|
320
|
+
command_id=command_id,
|
321
|
+
storage_id=storage_id,
|
322
|
+
dtype=dtype,
|
323
|
+
dims=dims,
|
324
|
+
size=tensor_size,
|
325
|
+
devices=[worker_rank],
|
326
|
+
stream_name=stream_name,
|
327
|
+
)
|
328
|
+
)
|
329
|
+
if new_tensor_event and not find_or_create_event("TensorCreationEvent"):
|
330
|
+
self.data_dag.append(
|
331
|
+
TensorCreationEvent(
|
332
|
+
command_id=command_id,
|
333
|
+
DTensorRef=ref,
|
334
|
+
storage_id=storage_id,
|
335
|
+
dims=dims,
|
336
|
+
devices=[worker_rank],
|
337
|
+
stream_name=stream_name,
|
338
|
+
)
|
339
|
+
)
|
340
|
+
if not find_or_create_event("TensorAccessEvent"):
|
341
|
+
self.data_dag.append(
|
342
|
+
TensorAccessEvent(
|
343
|
+
command_id=command_id,
|
344
|
+
DTensorRef=ref,
|
345
|
+
storage_id=storage_id,
|
346
|
+
dims=dims,
|
347
|
+
devices=[worker_rank],
|
348
|
+
stream_name=stream_name,
|
349
|
+
)
|
350
|
+
)
|
351
|
+
if mutate and not find_or_create_event("TensorMutationEvent"):
|
352
|
+
self.data_dag.append(
|
353
|
+
TensorMutationEvent(
|
354
|
+
command_id=command_id,
|
355
|
+
DTensorRef=ref,
|
356
|
+
storage_id=storage_id,
|
357
|
+
dims=dims,
|
358
|
+
devices=[worker_rank],
|
359
|
+
stream_name=stream_name,
|
360
|
+
)
|
361
|
+
)
|
362
|
+
|
363
|
+
if update_storage_devices:
|
364
|
+
find_or_create_event("StorageCreationEvent")
|
365
|
+
if update_tensor_devices:
|
366
|
+
find_or_create_event("TensorCreationEvent")
|
367
|
+
|
368
|
+
def delete_tensor(
|
369
|
+
self,
|
370
|
+
ref: int,
|
371
|
+
mesh_ranks: List[int],
|
372
|
+
stream_name: str,
|
373
|
+
command_id: int,
|
374
|
+
) -> None:
|
375
|
+
storage_id = self._data.tensorref_to_storageid[ref]
|
376
|
+
|
377
|
+
self._data.data_dependency_info[storage_id].tensor_deletion_ids.add(command_id)
|
378
|
+
|
379
|
+
self.data_dag.append(
|
380
|
+
TensorDeletionEvent(
|
381
|
+
command_id=command_id,
|
382
|
+
DTensorRef=ref,
|
383
|
+
storage_id=storage_id,
|
384
|
+
dims=self._data.data_dependency_info[storage_id].dims,
|
385
|
+
devices=mesh_ranks,
|
386
|
+
stream_name=stream_name,
|
387
|
+
)
|
388
|
+
)
|
389
|
+
|
390
|
+
del self._data.tensorref_to_storageid[ref]
|
391
|
+
self._data.storageid_to_tensorref[storage_id].remove(ref)
|
392
|
+
|
393
|
+
if not self._data.storageid_to_tensorref[storage_id]:
|
394
|
+
self.data_dag.append(
|
395
|
+
StorageDeletionEvent(
|
396
|
+
command_id=command_id,
|
397
|
+
storage_id=storage_id,
|
398
|
+
dtype=self._data.data_dependency_info[storage_id].dtype,
|
399
|
+
dims=self._data.data_dependency_info[storage_id].dims,
|
400
|
+
size=self._data.data_dependency_info[storage_id].size,
|
401
|
+
devices=mesh_ranks,
|
402
|
+
stream_name=stream_name,
|
403
|
+
)
|
404
|
+
)
|
405
|
+
|
406
|
+
self._data.data_dependency_info[storage_id].storage_deletion_id = command_id
|
407
|
+
|
408
|
+
def add_sendtensor(
|
409
|
+
self,
|
410
|
+
result_tensor_id: int,
|
411
|
+
src_devices: List[int],
|
412
|
+
src_stream_name: str,
|
413
|
+
dst_devices: List[int],
|
414
|
+
dst_stream_name: str,
|
415
|
+
result_tensor_dims: Tuple[int, ...],
|
416
|
+
) -> None:
|
417
|
+
self._control.sendtensor_info[
|
418
|
+
result_tensor_id
|
419
|
+
].result_tensor_id = result_tensor_id
|
420
|
+
self._control.sendtensor_info[result_tensor_id].src_devices = src_devices
|
421
|
+
self._control.sendtensor_info[
|
422
|
+
result_tensor_id
|
423
|
+
].src_stream_name = src_stream_name
|
424
|
+
self._control.sendtensor_info[result_tensor_id].dst_devices = dst_devices
|
425
|
+
self._control.sendtensor_info[
|
426
|
+
result_tensor_id
|
427
|
+
].dst_stream_name = dst_stream_name
|
428
|
+
self._control.sendtensor_info[
|
429
|
+
result_tensor_id
|
430
|
+
].result_tensor_dims = result_tensor_dims
|
431
|
+
return
|
432
|
+
|
433
|
+
def remove_dag_item_type(
|
434
|
+
self, command_types: Union[str, List[str]], print_removed_nodes: bool = False
|
435
|
+
) -> int:
|
436
|
+
"""
|
437
|
+
Removes nodes from the DAG that match the specified command type(s).
|
438
|
+
|
439
|
+
Args:
|
440
|
+
command_types: A string or list of strings representing command types to remove.
|
441
|
+
Nodes with command_name that starts with any of these strings will be removed.
|
442
|
+
|
443
|
+
Returns:
|
444
|
+
int: The number of nodes removed from the DAG.
|
445
|
+
|
446
|
+
Example:
|
447
|
+
# Remove all 'Borrow' related commands
|
448
|
+
graph.remove_dag_item_type('Borrow')
|
449
|
+
|
450
|
+
# Remove multiple command types
|
451
|
+
graph.remove_dag_item_type(['Reduce', 'SendTensor'])
|
452
|
+
"""
|
453
|
+
if isinstance(command_types, str):
|
454
|
+
command_types = [command_types]
|
455
|
+
|
456
|
+
removed_nodes = [
|
457
|
+
node
|
458
|
+
for node in self.control_dag
|
459
|
+
if any(node.command_name.startswith(ct) for ct in command_types)
|
460
|
+
]
|
461
|
+
self.control_dag = [
|
462
|
+
node
|
463
|
+
for node in self.control_dag
|
464
|
+
if not any(node.command_name.startswith(ct) for ct in command_types)
|
465
|
+
]
|
466
|
+
|
467
|
+
num_removed = len(removed_nodes)
|
468
|
+
if num_removed > 0:
|
469
|
+
print(f"Removed {num_removed} DAG items of type(s) {command_types}:")
|
470
|
+
if print_removed_nodes:
|
471
|
+
for node in removed_nodes:
|
472
|
+
print(
|
473
|
+
f"{type(node).__name__}, Worker: {node.worker_rank}, Command ID: {node.command_id}"
|
474
|
+
)
|
475
|
+
else:
|
476
|
+
print(f"No nodes removed of type(s) {command_types}.")
|
477
|
+
return num_removed
|
478
|
+
|
479
|
+
def export_dag_json(self, output_file: str) -> None:
|
480
|
+
# Note: The default width unit is in us, so we need to use "larger" standard durations to ensure the flow events are visible.
|
481
|
+
default_event_width = 4000
|
482
|
+
default_event_spacing = 1000
|
483
|
+
stream_locs = defaultdict(int)
|
484
|
+
trace_events = []
|
485
|
+
|
486
|
+
borrows_start_stream = {}
|
487
|
+
|
488
|
+
reduce_sendtensor_max_ts = defaultdict(int)
|
489
|
+
reduce_sendtensor_events = defaultdict(list)
|
490
|
+
|
491
|
+
for dag_item in self.control_dag:
|
492
|
+
worker_rank = dag_item.worker_rank
|
493
|
+
name = dag_item.command_name
|
494
|
+
cat = dag_item.command_name.split(":")[0]
|
495
|
+
event: Dict[str, Any] = {
|
496
|
+
"name": name,
|
497
|
+
"cat": cat,
|
498
|
+
"pid": worker_rank,
|
499
|
+
"args": {
|
500
|
+
"command_id": dag_item.command_id,
|
501
|
+
"command_type": cat,
|
502
|
+
"devices": dag_item.devices,
|
503
|
+
"control dependencies": dag_item.control_dependencies,
|
504
|
+
},
|
505
|
+
}
|
506
|
+
|
507
|
+
if isinstance(dag_item, Command):
|
508
|
+
stream_name = dag_item.stream_name
|
509
|
+
event["ph"] = "X"
|
510
|
+
event["tid"] = stream_name
|
511
|
+
event["dur"] = default_event_width
|
512
|
+
|
513
|
+
if event["cat"] in ["BorrowCreate", "BorrowLastUse"]:
|
514
|
+
event["ts"] = stream_locs[f"{worker_rank}_{stream_name}"]
|
515
|
+
|
516
|
+
borrow_id = int(event["name"].split(":")[-1])
|
517
|
+
borrows_start_stream[event["name"]] = stream_name
|
518
|
+
|
519
|
+
# Create edge
|
520
|
+
event_start = event.copy()
|
521
|
+
|
522
|
+
event_start["ph"] = "s"
|
523
|
+
event_start["ts"] = event["ts"] + default_event_width
|
524
|
+
|
525
|
+
if event["cat"] == "BorrowCreate":
|
526
|
+
event_start["name"] = (
|
527
|
+
f"BorrowCreate->BorrowFirstUse: {borrow_id}"
|
528
|
+
)
|
529
|
+
event_start["cat"] = "BorrowCreate->BorrowFirstUse"
|
530
|
+
event_start["id"] = (
|
531
|
+
f"{worker_rank}:{borrow_id}:create->firstuse"
|
532
|
+
)
|
533
|
+
elif event["cat"] == "BorrowLastUse":
|
534
|
+
event_start["name"] = f"BorrowLastUse->BorrowDrop: {borrow_id}"
|
535
|
+
event_start["cat"] = "BorrowLastUse->BorrowDrop"
|
536
|
+
event_start["id"] = f"{worker_rank}:{borrow_id}:lastuse->drop"
|
537
|
+
event_start["args"] = {"devices": dag_item.devices}
|
538
|
+
del event_start["dur"]
|
539
|
+
|
540
|
+
trace_events.append(event_start)
|
541
|
+
|
542
|
+
if event["cat"] in ["BorrowFirstUse", "BorrowDrop"]:
|
543
|
+
event["ts"] = stream_locs[f"{worker_rank}_{stream_name}"]
|
544
|
+
|
545
|
+
borrow_id = int(event["name"].split(":")[-1])
|
546
|
+
start_stream_name = ""
|
547
|
+
|
548
|
+
if event["cat"] == "BorrowFirstUse":
|
549
|
+
start_stream_name = borrows_start_stream[
|
550
|
+
f"BorrowCreate: {borrow_id}"
|
551
|
+
]
|
552
|
+
elif event["cat"] == "BorrowDrop":
|
553
|
+
start_stream_name = borrows_start_stream[
|
554
|
+
f"BorrowLastUse: {borrow_id}"
|
555
|
+
]
|
556
|
+
|
557
|
+
# Create edge
|
558
|
+
event_end = event.copy()
|
559
|
+
event_end["ph"] = "f"
|
560
|
+
event_end["ts"] = max(
|
561
|
+
stream_locs[f"{worker_rank}_{start_stream_name}"],
|
562
|
+
stream_locs[f"{worker_rank}_{stream_name}"],
|
563
|
+
)
|
564
|
+
|
565
|
+
if event["cat"] == "BorrowFirstUse":
|
566
|
+
event_end["name"] = f"BorrowCreate->BorrowFirstUse: {borrow_id}"
|
567
|
+
event_end["cat"] = "BorrowCreate->BorrowFirstUse"
|
568
|
+
event_end["id"] = f"{worker_rank}:{borrow_id}:create->firstuse"
|
569
|
+
elif event["cat"] == "BorrowDrop":
|
570
|
+
event_end["name"] = f"BorrowLastUse->BorrowDrop: {borrow_id}"
|
571
|
+
event_end["cat"] = "BorrowLastUse->BorrowDrop"
|
572
|
+
event_end["id"] = f"{worker_rank}:{borrow_id}:lastuse->drop"
|
573
|
+
event_end["args"] = {"devices": dag_item.devices}
|
574
|
+
del event_end["dur"]
|
575
|
+
|
576
|
+
stream_locs[f"{worker_rank}_{stream_name}"] = max(
|
577
|
+
stream_locs[f"{worker_rank}_{start_stream_name}"],
|
578
|
+
stream_locs[f"{worker_rank}_{stream_name}"],
|
579
|
+
)
|
580
|
+
trace_events.append(event_end)
|
581
|
+
|
582
|
+
if event["cat"] in ["Reduce", "SendTensor"]:
|
583
|
+
ts = max(
|
584
|
+
stream_locs[f"{worker_rank}_{stream_name}"],
|
585
|
+
reduce_sendtensor_max_ts[name],
|
586
|
+
)
|
587
|
+
event["ts"] = ts
|
588
|
+
stream_locs[f"{worker_rank}_{stream_name}"] = ts
|
589
|
+
reduce_sendtensor_events[name].append(
|
590
|
+
event
|
591
|
+
) # save event for later in case we need to update
|
592
|
+
# update max timestamp if necessary
|
593
|
+
if ts > reduce_sendtensor_max_ts[name]:
|
594
|
+
reduce_sendtensor_max_ts[name] = ts
|
595
|
+
# update timestamps of all Reduce/SendTensor events with the same name
|
596
|
+
for e in reduce_sendtensor_events[name]:
|
597
|
+
if e["name"] == name and e["ts"] != ts:
|
598
|
+
e["ts"] = ts
|
599
|
+
stream_locs[f"{e['pid']}_{e['tid']}"] = (
|
600
|
+
reduce_sendtensor_max_ts[name]
|
601
|
+
+ default_event_width
|
602
|
+
+ default_event_spacing
|
603
|
+
)
|
604
|
+
# Extra SendTensor metadata
|
605
|
+
if event["cat"] == "SendTensor":
|
606
|
+
send_devices_threshold = len(dag_item.devices) // 2
|
607
|
+
event["args"]["send devices"] = dag_item.devices[
|
608
|
+
:send_devices_threshold
|
609
|
+
]
|
610
|
+
event["args"]["recv devices"] = dag_item.devices[
|
611
|
+
send_devices_threshold:
|
612
|
+
]
|
613
|
+
|
614
|
+
else:
|
615
|
+
event["ts"] = stream_locs[f"{worker_rank}_{stream_name}"]
|
616
|
+
|
617
|
+
stream_locs[f"{worker_rank}_{stream_name}"] += (
|
618
|
+
default_event_width + default_event_spacing
|
619
|
+
)
|
620
|
+
event["args"]["traceback"] = dag_item.traceback
|
621
|
+
trace_events.append(event)
|
622
|
+
else:
|
623
|
+
raise ValueError(f"Unknown DAG item type: {type(dag_item)}")
|
624
|
+
|
625
|
+
with open(output_file, "w") as f:
|
626
|
+
json.dump({"traceEvents": trace_events}, f)
|
627
|
+
|
628
|
+
def _export_info_to_csv(
|
629
|
+
self, info_dict: Dict[Any, Any], filename: str, info_type: str
|
630
|
+
) -> None:
|
631
|
+
def _format_value_for_display(value):
|
632
|
+
"""Format a value for CSV display, handling collections."""
|
633
|
+
if isinstance(value, (dict, List, set)):
|
634
|
+
if not value:
|
635
|
+
return "None"
|
636
|
+
return str(sorted(value))
|
637
|
+
return str(value)
|
638
|
+
|
639
|
+
if not info_dict:
|
640
|
+
print(f"No {info_type} information to export.")
|
641
|
+
return
|
642
|
+
|
643
|
+
# Get the first value to determine if it's a NamedTuple or dict
|
644
|
+
first_value = next(iter(info_dict.values()))
|
645
|
+
is_namedtuple = isinstance(first_value, tuple)
|
646
|
+
is_dataclass = hasattr(first_value, "__dataclass_fields__")
|
647
|
+
|
648
|
+
if not (is_namedtuple or is_dataclass):
|
649
|
+
raise ValueError(
|
650
|
+
f"Expected NamedTuple or dataclass, got {type(first_value)}"
|
651
|
+
)
|
652
|
+
|
653
|
+
if is_namedtuple:
|
654
|
+
# Use fixed order for NamedTuple headers
|
655
|
+
keys = [
|
656
|
+
"DataEvent",
|
657
|
+
"command_id",
|
658
|
+
"storage_id",
|
659
|
+
"DTensorRef",
|
660
|
+
"devices",
|
661
|
+
"stream_name",
|
662
|
+
"dims",
|
663
|
+
"dtype",
|
664
|
+
"size",
|
665
|
+
]
|
666
|
+
else: # is_dataclass
|
667
|
+
keys = list(first_value.__dataclass_fields__.keys())
|
668
|
+
|
669
|
+
def get_value(obj, key):
|
670
|
+
if key == "DataEvent" and is_namedtuple:
|
671
|
+
return obj.__class__.__name__[:-5] # remove "Event" suffix
|
672
|
+
try:
|
673
|
+
return getattr(obj, key)
|
674
|
+
except AttributeError:
|
675
|
+
return ""
|
676
|
+
|
677
|
+
widths = {key: len(key) for key in keys}
|
678
|
+
|
679
|
+
for info in info_dict.values():
|
680
|
+
for key in keys:
|
681
|
+
value = get_value(info, key)
|
682
|
+
if value is not None:
|
683
|
+
str_value = _format_value_for_display(value)
|
684
|
+
widths[key] = max(widths[key], len(str_value))
|
685
|
+
|
686
|
+
with open(filename, "w", newline="") as f:
|
687
|
+
writer = csv.writer(f, delimiter="\t")
|
688
|
+
# Write header with aligned fields
|
689
|
+
writer.writerow([key.ljust(widths[key]) for key in keys])
|
690
|
+
for info in info_dict.values():
|
691
|
+
row = []
|
692
|
+
for key in keys:
|
693
|
+
value = get_value(info, key)
|
694
|
+
str_value = _format_value_for_display(value)
|
695
|
+
row.append(str_value.ljust(widths[key]))
|
696
|
+
writer.writerow(row)
|
697
|
+
|
698
|
+
def export_borrows_csv(self, filename: str) -> None:
|
699
|
+
self._export_info_to_csv(self._control.borrows_info, filename, "borrows")
|
700
|
+
|
701
|
+
def export_sendtensors_csv(self, filename: str) -> None:
|
702
|
+
self._export_info_to_csv(self._control.sendtensor_info, filename, "SendTensor")
|
703
|
+
|
704
|
+
def export_data_csv(self, filename: str) -> None:
|
705
|
+
self._export_info_to_csv(self._data.data_dependency_info, filename, "tensor")
|
706
|
+
|
707
|
+
def export_data_timeline_csv(self, filename: str) -> None:
|
708
|
+
if not self.data_dag:
|
709
|
+
print("No data dependency timeline information to export.")
|
710
|
+
return
|
711
|
+
|
712
|
+
# Convert list to dict with indices as keys to use _export_info_to_csv
|
713
|
+
timeline_dict = dict(enumerate(self.data_dag))
|
714
|
+
self._export_info_to_csv(timeline_dict, filename, "data dependency timeline")
|
715
|
+
|
716
|
+
class _ControlManager:
|
717
|
+
"""
|
718
|
+
Internal manager for control flow information in the IRGraph.
|
719
|
+
|
720
|
+
Tracks metadata about borrows and tensor send operations across workers and streams.
|
721
|
+
|
722
|
+
Attributes:
|
723
|
+
borrows_info: Maps borrow IDs to their metadata (devices, streams, command IDs)
|
724
|
+
sendtensor_info: Maps tensor IDs to send operation metadata (source/destination devices and streams)
|
725
|
+
"""
|
726
|
+
|
727
|
+
def __init__(self):
|
728
|
+
self.borrows_info: DefaultDict[int, BorrowInfo] = defaultdict(BorrowInfo)
|
729
|
+
|
730
|
+
self.sendtensor_info: DefaultDict[int, SendTensorInfo] = defaultdict(
|
731
|
+
SendTensorInfo
|
732
|
+
)
|
733
|
+
|
734
|
+
class _DataManager:
|
735
|
+
"""
|
736
|
+
Internal manager for data flow information in the IRGraph.
|
737
|
+
|
738
|
+
Tracks tensor and storage lifecycle events including creation, access, mutation, and deletion.
|
739
|
+
Maintains mappings between tensor references, storage IDs, and their associated metadata.
|
740
|
+
|
741
|
+
Attributes:
|
742
|
+
data_dependency_info: Maps storage IDs to their complete lifecycle metadata
|
743
|
+
tensorref_to_stream: Maps tensor references to their associated stream names
|
744
|
+
tensorref_to_storageid: Maps tensor references to their underlying storage IDs
|
745
|
+
tensorref_to_mesh: Maps tensor references to the set of mesh device IDs
|
746
|
+
id_to_storageid: Maps Python object IDs to storage IDs
|
747
|
+
storageid_to_tensorref: Maps storage IDs to their associated tensor references
|
748
|
+
storageid_counter: Counter for generating unique storage IDs
|
749
|
+
"""
|
750
|
+
|
751
|
+
def __init__(self):
|
752
|
+
self.data_dependency_info: DefaultDict[int, TensorInfo] = defaultdict(
|
753
|
+
TensorInfo
|
754
|
+
)
|
755
|
+
self.tensorref_to_stream: Dict[
|
756
|
+
int, str
|
757
|
+
] = {} # key = DTensorRef.ref (int); value = stream name (str)
|
758
|
+
self.tensorref_to_storageid: Dict[
|
759
|
+
int, int
|
760
|
+
] = {} # key = DTensorRef.ref (int); value = storage id (int)
|
761
|
+
self.tensorref_to_mesh: DefaultDict[int, Set[int]] = defaultdict(
|
762
|
+
set
|
763
|
+
) # key = DTensorRef.ref (int); value = mesh device ids (Set[int])
|
764
|
+
self.id_to_storageid: Dict[
|
765
|
+
int, int
|
766
|
+
] = {} # key = id(UntypedStorage) (int); value = storage id (int)
|
767
|
+
self.storageid_to_tensorref: DefaultDict[int, Set[int]] = defaultdict(
|
768
|
+
set
|
769
|
+
) # key = storage_id (int); value = List[DTensorRef] (List[int])
|
770
|
+
self.storageid_counter: Iterator[int] = count()
|