torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,770 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import csv
9
+ import json
10
+ from collections import defaultdict
11
+ from dataclasses import dataclass, field
12
+ from itertools import count
13
+ from typing import (
14
+ Any,
15
+ DefaultDict,
16
+ Dict,
17
+ Iterator,
18
+ List,
19
+ NamedTuple,
20
+ Optional,
21
+ Set,
22
+ Tuple,
23
+ Union,
24
+ )
25
+
26
+ import torch
27
+
28
+
29
+ class Command(NamedTuple):
30
+ """
31
+ Represents a node in the control flow DAG that tracks command execution on workers.
32
+
33
+ Each Command node captures an operation executed on a specific worker and stream,
34
+ including its control dependencies and associated devices.
35
+
36
+ Attributes:
37
+ worker_rank (int): Worker that executed the command
38
+ stream_name (str): Stream on which the command was executed
39
+ command_id (int): Unique identifier for the command
40
+ command_name (str): Name of command (CallFunction: aten:mm, SendTensor: 7, etc.)
41
+ devices (List[int]): Device IDs associated with this command
42
+ control_dependencies (List[int]): Command IDs this command depends on
43
+ traceback (List[str]): Python traceback at command execution
44
+ duration (int): Command execution duration in milliseconds
45
+ """
46
+
47
+ worker_rank: int
48
+ stream_name: str
49
+ command_id: int
50
+ command_name: str
51
+ devices: List[int]
52
+ control_dependencies: List[int]
53
+ traceback: List[str]
54
+ duration: int = 0 # ms
55
+
56
+
57
+ class StorageCreationEvent(NamedTuple):
58
+ command_id: int
59
+ storage_id: int
60
+ dtype: Optional[torch.dtype]
61
+ dims: Optional[tuple]
62
+ size: Optional[int]
63
+ devices: List[int]
64
+ stream_name: str
65
+
66
+
67
+ class StorageDeletionEvent(NamedTuple):
68
+ command_id: int
69
+ storage_id: int
70
+ dtype: Optional[torch.dtype]
71
+ dims: Optional[tuple]
72
+ size: Optional[int]
73
+ devices: List[int]
74
+ stream_name: str
75
+
76
+
77
+ class TensorCreationEvent(NamedTuple):
78
+ command_id: int
79
+ DTensorRef: int
80
+ storage_id: int
81
+ dims: Optional[
82
+ tuple
83
+ ] # TODO: make sure dims here reflect tensor's and not storages'
84
+ devices: List[int]
85
+ stream_name: str
86
+
87
+
88
+ class TensorAccessEvent(NamedTuple):
89
+ command_id: int
90
+ DTensorRef: int
91
+ storage_id: int
92
+ dims: Optional[tuple]
93
+ devices: List[int]
94
+ stream_name: str
95
+
96
+
97
+ class TensorMutationEvent(NamedTuple):
98
+ command_id: int
99
+ DTensorRef: int
100
+ storage_id: int
101
+ dims: Optional[tuple]
102
+ devices: List[int]
103
+ stream_name: str
104
+
105
+
106
+ class TensorDeletionEvent(NamedTuple):
107
+ command_id: int
108
+ DTensorRef: int
109
+ storage_id: int
110
+ dims: Optional[tuple]
111
+ devices: List[int]
112
+ stream_name: str
113
+
114
+
115
+ """
116
+ Represents a node in the data flow DAG that tracks tensor and storage lifecycle events.
117
+
118
+ Each DataEvent captures a specific event in the lifecycle of tensors and storage objects,
119
+ including creation, access, mutation, and deletion operations across workers and devices.
120
+ """
121
+ DataEvent = Union[
122
+ StorageCreationEvent,
123
+ StorageDeletionEvent,
124
+ TensorCreationEvent,
125
+ TensorAccessEvent,
126
+ TensorMutationEvent,
127
+ TensorDeletionEvent,
128
+ ]
129
+
130
+
131
+ @dataclass
132
+ class BorrowInfo:
133
+ borrow_id: Optional[int] = None
134
+ devices: Set[int] = field(default_factory=set)
135
+ src_stream_name: Optional[str] = None
136
+ dst_stream_name: Optional[str] = None
137
+ create_id: Optional[int] = None
138
+ firstuse_id: Optional[int] = None
139
+ lastuse_id: Optional[int] = None
140
+ drop_id: Optional[int] = None
141
+
142
+
143
+ @dataclass
144
+ class SendTensorInfo:
145
+ result_tensor_id: Optional[int] = None
146
+ src_devices: Optional[List[int]] = None
147
+ src_stream_name: Optional[str] = None
148
+ dst_devices: Optional[List[int]] = None
149
+ dst_stream_name: Optional[str] = None
150
+ result_tensor_dims: Optional[Tuple[int, ...]] = None
151
+
152
+
153
+ @dataclass
154
+ class TensorInfo:
155
+ storage_id: Optional[int] = None
156
+ DTensorRefs: Set[int] = field(default_factory=set)
157
+ dtype: Optional[torch.dtype] = None
158
+ dims: Tuple[int, ...] = field(default_factory=tuple)
159
+ size: Optional[int] = None
160
+ devices: Set[int] = field(default_factory=set)
161
+ stream_name: Optional[str] = None
162
+ storage_create_id: Optional[int] = None
163
+ tensor_create_ids: Set[int] = field(default_factory=set)
164
+ access_ids: Set[int] = field(default_factory=set)
165
+ mutation_ids: Set[int] = field(default_factory=set)
166
+ lastuse_id: Optional[int] = None
167
+ tensor_deletion_ids: Set[int] = field(default_factory=set)
168
+ storage_deletion_id: Optional[int] = None
169
+
170
+
171
+ class IRGraph:
172
+ """
173
+ Represents an intermediate representation (IR) graph for distributed tensor operations.
174
+
175
+ The IRGraph tracks both control flow (commands executed on workers) and data flow
176
+ (tensor/storage lifecycle events) in distributed tensor computations. It consists of:
177
+
178
+ 1. Control DAG: Tracks command execution across workers and streams
179
+ 2. Data DAG: Tracks tensor and storage lifecycle events (creation, access, mutation, deletion)
180
+
181
+ The graph can be exported to Chrome Trace format for visualization, and additional CSV
182
+ exports provide detailed information about borrows, tensor sends, and data dependencies.
183
+
184
+ Attributes:
185
+ control_dag (List[Command]): Command nodes representing operations executed on workers
186
+ data_dag (List[DataEvent]): Data events tracking tensor/storage lifecycle
187
+ _control: Internal manager for control flow information (borrows, sendtensor)
188
+ _data: Internal manager for data flow information (tensors, storage)
189
+ """
190
+
191
+ def __init__(self) -> None:
192
+ self.control_dag: List[Command] = []
193
+ self.data_dag: List[DataEvent] = []
194
+ self._control: IRGraph._ControlManager = self._ControlManager()
195
+ self._data: IRGraph._DataManager = self._DataManager()
196
+
197
+ def insert_node(
198
+ self,
199
+ worker_rank: int,
200
+ stream_name: str,
201
+ command_id: int,
202
+ command_name: str,
203
+ devices: List[int],
204
+ control_dependencies: List[int],
205
+ traceback: List[str],
206
+ ) -> None:
207
+ new_dag_node = Command(
208
+ worker_rank=worker_rank,
209
+ stream_name=stream_name,
210
+ command_id=command_id,
211
+ command_name=command_name,
212
+ devices=devices,
213
+ control_dependencies=control_dependencies,
214
+ traceback=traceback,
215
+ )
216
+ self.control_dag.append(new_dag_node)
217
+
218
+ def add_borrow(
219
+ self,
220
+ borrow_id: int,
221
+ device: int,
222
+ src_stream_name: str,
223
+ dst_stream_name: str,
224
+ create_id: int,
225
+ ) -> None:
226
+ self._control.borrows_info[borrow_id].borrow_id = borrow_id
227
+ self._control.borrows_info[borrow_id].devices.add(device)
228
+ self._control.borrows_info[borrow_id].src_stream_name = src_stream_name
229
+ self._control.borrows_info[borrow_id].dst_stream_name = dst_stream_name
230
+ self._control.borrows_info[borrow_id].create_id = create_id
231
+
232
+ def update_tensor(
233
+ self,
234
+ temp_id: int,
235
+ ref: int,
236
+ dtype: torch.dtype,
237
+ dims: Tuple[int, ...],
238
+ worker_rank: int,
239
+ stream_name: str,
240
+ command_id: int,
241
+ mutate=False,
242
+ borrow_src_tensor_ref: Optional[int] = None,
243
+ tensor_size: Optional[int] = None,
244
+ ) -> None:
245
+ new_tensor_event = new_storage_event = False
246
+ update_tensor_devices = update_storage_devices = False
247
+
248
+ if temp_id not in self._data.id_to_storageid:
249
+ if borrow_src_tensor_ref is None:
250
+ new_storage_event = True
251
+ storage_id = next(self._data.storageid_counter)
252
+ self._data.id_to_storageid[temp_id] = storage_id
253
+ self._data.data_dependency_info[storage_id].storage_id = storage_id
254
+ self._data.data_dependency_info[storage_id].dtype = dtype
255
+ self._data.data_dependency_info[storage_id].dims = dims
256
+ self._data.data_dependency_info[storage_id].size = tensor_size
257
+ self._data.data_dependency_info[storage_id].stream_name = stream_name
258
+ self._data.data_dependency_info[
259
+ storage_id
260
+ ].storage_create_id = command_id
261
+ # borrow aliasing
262
+ else:
263
+ storage_id = self._data.tensorref_to_storageid[borrow_src_tensor_ref]
264
+ self._data.id_to_storageid[temp_id] = storage_id
265
+ else:
266
+ storage_id = self._data.id_to_storageid[temp_id]
267
+ if worker_rank not in self._data.data_dependency_info[storage_id].devices:
268
+ update_storage_devices = True
269
+ self._data.data_dependency_info[storage_id].devices.add(worker_rank)
270
+
271
+ if ref not in self._data.tensorref_to_stream:
272
+ new_tensor_event = True
273
+ self._data.tensorref_to_storageid[ref] = storage_id
274
+ self._data.tensorref_to_mesh[ref].add(worker_rank)
275
+ self._data.tensorref_to_stream[ref] = stream_name
276
+ self._data.storageid_to_tensorref[storage_id].add(ref)
277
+
278
+ self._data.data_dependency_info[storage_id].DTensorRefs.add(ref)
279
+ self._data.data_dependency_info[storage_id].tensor_create_ids.add(
280
+ command_id
281
+ )
282
+ else:
283
+ if worker_rank not in self._data.tensorref_to_mesh[ref]:
284
+ update_tensor_devices = True
285
+ self._data.tensorref_to_mesh[ref].add(worker_rank)
286
+
287
+ self._data.data_dependency_info[storage_id].access_ids.add(command_id)
288
+ self._data.data_dependency_info[
289
+ storage_id
290
+ ].lastuse_id = command_id # commands are processed in increasing command_id
291
+ if mutate:
292
+ self._data.data_dependency_info[storage_id].mutation_ids.add(command_id)
293
+
294
+ # Helper function to find or create events
295
+ def find_or_create_event(event_type):
296
+ # Look for existing event with same command_id and event_type
297
+ # Look backwards since events are processed in increasing command_id
298
+ for i in range(len(self.data_dag) - 1, -1, -1):
299
+ event = self.data_dag[i]
300
+ event_class_name = event.__class__.__name__
301
+ if (
302
+ event.command_id == command_id
303
+ and event_class_name == event_type
304
+ and (not hasattr(event, "DTensorRef") or event.DTensorRef == ref)
305
+ ):
306
+ # If worker_rank already exists, just return True
307
+ if worker_rank in event.devices:
308
+ return True
309
+
310
+ # Update devices list
311
+ updated_devices = event.devices + [worker_rank]
312
+ updated_event = event._replace(devices=updated_devices)
313
+ self.data_dag[i] = updated_event
314
+ return True
315
+ return False
316
+
317
+ if new_storage_event and not find_or_create_event("StorageCreationEvent"):
318
+ self.data_dag.append(
319
+ StorageCreationEvent(
320
+ command_id=command_id,
321
+ storage_id=storage_id,
322
+ dtype=dtype,
323
+ dims=dims,
324
+ size=tensor_size,
325
+ devices=[worker_rank],
326
+ stream_name=stream_name,
327
+ )
328
+ )
329
+ if new_tensor_event and not find_or_create_event("TensorCreationEvent"):
330
+ self.data_dag.append(
331
+ TensorCreationEvent(
332
+ command_id=command_id,
333
+ DTensorRef=ref,
334
+ storage_id=storage_id,
335
+ dims=dims,
336
+ devices=[worker_rank],
337
+ stream_name=stream_name,
338
+ )
339
+ )
340
+ if not find_or_create_event("TensorAccessEvent"):
341
+ self.data_dag.append(
342
+ TensorAccessEvent(
343
+ command_id=command_id,
344
+ DTensorRef=ref,
345
+ storage_id=storage_id,
346
+ dims=dims,
347
+ devices=[worker_rank],
348
+ stream_name=stream_name,
349
+ )
350
+ )
351
+ if mutate and not find_or_create_event("TensorMutationEvent"):
352
+ self.data_dag.append(
353
+ TensorMutationEvent(
354
+ command_id=command_id,
355
+ DTensorRef=ref,
356
+ storage_id=storage_id,
357
+ dims=dims,
358
+ devices=[worker_rank],
359
+ stream_name=stream_name,
360
+ )
361
+ )
362
+
363
+ if update_storage_devices:
364
+ find_or_create_event("StorageCreationEvent")
365
+ if update_tensor_devices:
366
+ find_or_create_event("TensorCreationEvent")
367
+
368
+ def delete_tensor(
369
+ self,
370
+ ref: int,
371
+ mesh_ranks: List[int],
372
+ stream_name: str,
373
+ command_id: int,
374
+ ) -> None:
375
+ storage_id = self._data.tensorref_to_storageid[ref]
376
+
377
+ self._data.data_dependency_info[storage_id].tensor_deletion_ids.add(command_id)
378
+
379
+ self.data_dag.append(
380
+ TensorDeletionEvent(
381
+ command_id=command_id,
382
+ DTensorRef=ref,
383
+ storage_id=storage_id,
384
+ dims=self._data.data_dependency_info[storage_id].dims,
385
+ devices=mesh_ranks,
386
+ stream_name=stream_name,
387
+ )
388
+ )
389
+
390
+ del self._data.tensorref_to_storageid[ref]
391
+ self._data.storageid_to_tensorref[storage_id].remove(ref)
392
+
393
+ if not self._data.storageid_to_tensorref[storage_id]:
394
+ self.data_dag.append(
395
+ StorageDeletionEvent(
396
+ command_id=command_id,
397
+ storage_id=storage_id,
398
+ dtype=self._data.data_dependency_info[storage_id].dtype,
399
+ dims=self._data.data_dependency_info[storage_id].dims,
400
+ size=self._data.data_dependency_info[storage_id].size,
401
+ devices=mesh_ranks,
402
+ stream_name=stream_name,
403
+ )
404
+ )
405
+
406
+ self._data.data_dependency_info[storage_id].storage_deletion_id = command_id
407
+
408
+ def add_sendtensor(
409
+ self,
410
+ result_tensor_id: int,
411
+ src_devices: List[int],
412
+ src_stream_name: str,
413
+ dst_devices: List[int],
414
+ dst_stream_name: str,
415
+ result_tensor_dims: Tuple[int, ...],
416
+ ) -> None:
417
+ self._control.sendtensor_info[
418
+ result_tensor_id
419
+ ].result_tensor_id = result_tensor_id
420
+ self._control.sendtensor_info[result_tensor_id].src_devices = src_devices
421
+ self._control.sendtensor_info[
422
+ result_tensor_id
423
+ ].src_stream_name = src_stream_name
424
+ self._control.sendtensor_info[result_tensor_id].dst_devices = dst_devices
425
+ self._control.sendtensor_info[
426
+ result_tensor_id
427
+ ].dst_stream_name = dst_stream_name
428
+ self._control.sendtensor_info[
429
+ result_tensor_id
430
+ ].result_tensor_dims = result_tensor_dims
431
+ return
432
+
433
+ def remove_dag_item_type(
434
+ self, command_types: Union[str, List[str]], print_removed_nodes: bool = False
435
+ ) -> int:
436
+ """
437
+ Removes nodes from the DAG that match the specified command type(s).
438
+
439
+ Args:
440
+ command_types: A string or list of strings representing command types to remove.
441
+ Nodes with command_name that starts with any of these strings will be removed.
442
+
443
+ Returns:
444
+ int: The number of nodes removed from the DAG.
445
+
446
+ Example:
447
+ # Remove all 'Borrow' related commands
448
+ graph.remove_dag_item_type('Borrow')
449
+
450
+ # Remove multiple command types
451
+ graph.remove_dag_item_type(['Reduce', 'SendTensor'])
452
+ """
453
+ if isinstance(command_types, str):
454
+ command_types = [command_types]
455
+
456
+ removed_nodes = [
457
+ node
458
+ for node in self.control_dag
459
+ if any(node.command_name.startswith(ct) for ct in command_types)
460
+ ]
461
+ self.control_dag = [
462
+ node
463
+ for node in self.control_dag
464
+ if not any(node.command_name.startswith(ct) for ct in command_types)
465
+ ]
466
+
467
+ num_removed = len(removed_nodes)
468
+ if num_removed > 0:
469
+ print(f"Removed {num_removed} DAG items of type(s) {command_types}:")
470
+ if print_removed_nodes:
471
+ for node in removed_nodes:
472
+ print(
473
+ f"{type(node).__name__}, Worker: {node.worker_rank}, Command ID: {node.command_id}"
474
+ )
475
+ else:
476
+ print(f"No nodes removed of type(s) {command_types}.")
477
+ return num_removed
478
+
479
+ def export_dag_json(self, output_file: str) -> None:
480
+ # Note: The default width unit is in us, so we need to use "larger" standard durations to ensure the flow events are visible.
481
+ default_event_width = 4000
482
+ default_event_spacing = 1000
483
+ stream_locs = defaultdict(int)
484
+ trace_events = []
485
+
486
+ borrows_start_stream = {}
487
+
488
+ reduce_sendtensor_max_ts = defaultdict(int)
489
+ reduce_sendtensor_events = defaultdict(list)
490
+
491
+ for dag_item in self.control_dag:
492
+ worker_rank = dag_item.worker_rank
493
+ name = dag_item.command_name
494
+ cat = dag_item.command_name.split(":")[0]
495
+ event: Dict[str, Any] = {
496
+ "name": name,
497
+ "cat": cat,
498
+ "pid": worker_rank,
499
+ "args": {
500
+ "command_id": dag_item.command_id,
501
+ "command_type": cat,
502
+ "devices": dag_item.devices,
503
+ "control dependencies": dag_item.control_dependencies,
504
+ },
505
+ }
506
+
507
+ if isinstance(dag_item, Command):
508
+ stream_name = dag_item.stream_name
509
+ event["ph"] = "X"
510
+ event["tid"] = stream_name
511
+ event["dur"] = default_event_width
512
+
513
+ if event["cat"] in ["BorrowCreate", "BorrowLastUse"]:
514
+ event["ts"] = stream_locs[f"{worker_rank}_{stream_name}"]
515
+
516
+ borrow_id = int(event["name"].split(":")[-1])
517
+ borrows_start_stream[event["name"]] = stream_name
518
+
519
+ # Create edge
520
+ event_start = event.copy()
521
+
522
+ event_start["ph"] = "s"
523
+ event_start["ts"] = event["ts"] + default_event_width
524
+
525
+ if event["cat"] == "BorrowCreate":
526
+ event_start["name"] = (
527
+ f"BorrowCreate->BorrowFirstUse: {borrow_id}"
528
+ )
529
+ event_start["cat"] = "BorrowCreate->BorrowFirstUse"
530
+ event_start["id"] = (
531
+ f"{worker_rank}:{borrow_id}:create->firstuse"
532
+ )
533
+ elif event["cat"] == "BorrowLastUse":
534
+ event_start["name"] = f"BorrowLastUse->BorrowDrop: {borrow_id}"
535
+ event_start["cat"] = "BorrowLastUse->BorrowDrop"
536
+ event_start["id"] = f"{worker_rank}:{borrow_id}:lastuse->drop"
537
+ event_start["args"] = {"devices": dag_item.devices}
538
+ del event_start["dur"]
539
+
540
+ trace_events.append(event_start)
541
+
542
+ if event["cat"] in ["BorrowFirstUse", "BorrowDrop"]:
543
+ event["ts"] = stream_locs[f"{worker_rank}_{stream_name}"]
544
+
545
+ borrow_id = int(event["name"].split(":")[-1])
546
+ start_stream_name = ""
547
+
548
+ if event["cat"] == "BorrowFirstUse":
549
+ start_stream_name = borrows_start_stream[
550
+ f"BorrowCreate: {borrow_id}"
551
+ ]
552
+ elif event["cat"] == "BorrowDrop":
553
+ start_stream_name = borrows_start_stream[
554
+ f"BorrowLastUse: {borrow_id}"
555
+ ]
556
+
557
+ # Create edge
558
+ event_end = event.copy()
559
+ event_end["ph"] = "f"
560
+ event_end["ts"] = max(
561
+ stream_locs[f"{worker_rank}_{start_stream_name}"],
562
+ stream_locs[f"{worker_rank}_{stream_name}"],
563
+ )
564
+
565
+ if event["cat"] == "BorrowFirstUse":
566
+ event_end["name"] = f"BorrowCreate->BorrowFirstUse: {borrow_id}"
567
+ event_end["cat"] = "BorrowCreate->BorrowFirstUse"
568
+ event_end["id"] = f"{worker_rank}:{borrow_id}:create->firstuse"
569
+ elif event["cat"] == "BorrowDrop":
570
+ event_end["name"] = f"BorrowLastUse->BorrowDrop: {borrow_id}"
571
+ event_end["cat"] = "BorrowLastUse->BorrowDrop"
572
+ event_end["id"] = f"{worker_rank}:{borrow_id}:lastuse->drop"
573
+ event_end["args"] = {"devices": dag_item.devices}
574
+ del event_end["dur"]
575
+
576
+ stream_locs[f"{worker_rank}_{stream_name}"] = max(
577
+ stream_locs[f"{worker_rank}_{start_stream_name}"],
578
+ stream_locs[f"{worker_rank}_{stream_name}"],
579
+ )
580
+ trace_events.append(event_end)
581
+
582
+ if event["cat"] in ["Reduce", "SendTensor"]:
583
+ ts = max(
584
+ stream_locs[f"{worker_rank}_{stream_name}"],
585
+ reduce_sendtensor_max_ts[name],
586
+ )
587
+ event["ts"] = ts
588
+ stream_locs[f"{worker_rank}_{stream_name}"] = ts
589
+ reduce_sendtensor_events[name].append(
590
+ event
591
+ ) # save event for later in case we need to update
592
+ # update max timestamp if necessary
593
+ if ts > reduce_sendtensor_max_ts[name]:
594
+ reduce_sendtensor_max_ts[name] = ts
595
+ # update timestamps of all Reduce/SendTensor events with the same name
596
+ for e in reduce_sendtensor_events[name]:
597
+ if e["name"] == name and e["ts"] != ts:
598
+ e["ts"] = ts
599
+ stream_locs[f"{e['pid']}_{e['tid']}"] = (
600
+ reduce_sendtensor_max_ts[name]
601
+ + default_event_width
602
+ + default_event_spacing
603
+ )
604
+ # Extra SendTensor metadata
605
+ if event["cat"] == "SendTensor":
606
+ send_devices_threshold = len(dag_item.devices) // 2
607
+ event["args"]["send devices"] = dag_item.devices[
608
+ :send_devices_threshold
609
+ ]
610
+ event["args"]["recv devices"] = dag_item.devices[
611
+ send_devices_threshold:
612
+ ]
613
+
614
+ else:
615
+ event["ts"] = stream_locs[f"{worker_rank}_{stream_name}"]
616
+
617
+ stream_locs[f"{worker_rank}_{stream_name}"] += (
618
+ default_event_width + default_event_spacing
619
+ )
620
+ event["args"]["traceback"] = dag_item.traceback
621
+ trace_events.append(event)
622
+ else:
623
+ raise ValueError(f"Unknown DAG item type: {type(dag_item)}")
624
+
625
+ with open(output_file, "w") as f:
626
+ json.dump({"traceEvents": trace_events}, f)
627
+
628
+ def _export_info_to_csv(
629
+ self, info_dict: Dict[Any, Any], filename: str, info_type: str
630
+ ) -> None:
631
+ def _format_value_for_display(value):
632
+ """Format a value for CSV display, handling collections."""
633
+ if isinstance(value, (dict, List, set)):
634
+ if not value:
635
+ return "None"
636
+ return str(sorted(value))
637
+ return str(value)
638
+
639
+ if not info_dict:
640
+ print(f"No {info_type} information to export.")
641
+ return
642
+
643
+ # Get the first value to determine if it's a NamedTuple or dict
644
+ first_value = next(iter(info_dict.values()))
645
+ is_namedtuple = isinstance(first_value, tuple)
646
+ is_dataclass = hasattr(first_value, "__dataclass_fields__")
647
+
648
+ if not (is_namedtuple or is_dataclass):
649
+ raise ValueError(
650
+ f"Expected NamedTuple or dataclass, got {type(first_value)}"
651
+ )
652
+
653
+ if is_namedtuple:
654
+ # Use fixed order for NamedTuple headers
655
+ keys = [
656
+ "DataEvent",
657
+ "command_id",
658
+ "storage_id",
659
+ "DTensorRef",
660
+ "devices",
661
+ "stream_name",
662
+ "dims",
663
+ "dtype",
664
+ "size",
665
+ ]
666
+ else: # is_dataclass
667
+ keys = list(first_value.__dataclass_fields__.keys())
668
+
669
+ def get_value(obj, key):
670
+ if key == "DataEvent" and is_namedtuple:
671
+ return obj.__class__.__name__[:-5] # remove "Event" suffix
672
+ try:
673
+ return getattr(obj, key)
674
+ except AttributeError:
675
+ return ""
676
+
677
+ widths = {key: len(key) for key in keys}
678
+
679
+ for info in info_dict.values():
680
+ for key in keys:
681
+ value = get_value(info, key)
682
+ if value is not None:
683
+ str_value = _format_value_for_display(value)
684
+ widths[key] = max(widths[key], len(str_value))
685
+
686
+ with open(filename, "w", newline="") as f:
687
+ writer = csv.writer(f, delimiter="\t")
688
+ # Write header with aligned fields
689
+ writer.writerow([key.ljust(widths[key]) for key in keys])
690
+ for info in info_dict.values():
691
+ row = []
692
+ for key in keys:
693
+ value = get_value(info, key)
694
+ str_value = _format_value_for_display(value)
695
+ row.append(str_value.ljust(widths[key]))
696
+ writer.writerow(row)
697
+
698
+ def export_borrows_csv(self, filename: str) -> None:
699
+ self._export_info_to_csv(self._control.borrows_info, filename, "borrows")
700
+
701
+ def export_sendtensors_csv(self, filename: str) -> None:
702
+ self._export_info_to_csv(self._control.sendtensor_info, filename, "SendTensor")
703
+
704
+ def export_data_csv(self, filename: str) -> None:
705
+ self._export_info_to_csv(self._data.data_dependency_info, filename, "tensor")
706
+
707
+ def export_data_timeline_csv(self, filename: str) -> None:
708
+ if not self.data_dag:
709
+ print("No data dependency timeline information to export.")
710
+ return
711
+
712
+ # Convert list to dict with indices as keys to use _export_info_to_csv
713
+ timeline_dict = dict(enumerate(self.data_dag))
714
+ self._export_info_to_csv(timeline_dict, filename, "data dependency timeline")
715
+
716
+ class _ControlManager:
717
+ """
718
+ Internal manager for control flow information in the IRGraph.
719
+
720
+ Tracks metadata about borrows and tensor send operations across workers and streams.
721
+
722
+ Attributes:
723
+ borrows_info: Maps borrow IDs to their metadata (devices, streams, command IDs)
724
+ sendtensor_info: Maps tensor IDs to send operation metadata (source/destination devices and streams)
725
+ """
726
+
727
+ def __init__(self):
728
+ self.borrows_info: DefaultDict[int, BorrowInfo] = defaultdict(BorrowInfo)
729
+
730
+ self.sendtensor_info: DefaultDict[int, SendTensorInfo] = defaultdict(
731
+ SendTensorInfo
732
+ )
733
+
734
+ class _DataManager:
735
+ """
736
+ Internal manager for data flow information in the IRGraph.
737
+
738
+ Tracks tensor and storage lifecycle events including creation, access, mutation, and deletion.
739
+ Maintains mappings between tensor references, storage IDs, and their associated metadata.
740
+
741
+ Attributes:
742
+ data_dependency_info: Maps storage IDs to their complete lifecycle metadata
743
+ tensorref_to_stream: Maps tensor references to their associated stream names
744
+ tensorref_to_storageid: Maps tensor references to their underlying storage IDs
745
+ tensorref_to_mesh: Maps tensor references to the set of mesh device IDs
746
+ id_to_storageid: Maps Python object IDs to storage IDs
747
+ storageid_to_tensorref: Maps storage IDs to their associated tensor references
748
+ storageid_counter: Counter for generating unique storage IDs
749
+ """
750
+
751
+ def __init__(self):
752
+ self.data_dependency_info: DefaultDict[int, TensorInfo] = defaultdict(
753
+ TensorInfo
754
+ )
755
+ self.tensorref_to_stream: Dict[
756
+ int, str
757
+ ] = {} # key = DTensorRef.ref (int); value = stream name (str)
758
+ self.tensorref_to_storageid: Dict[
759
+ int, int
760
+ ] = {} # key = DTensorRef.ref (int); value = storage id (int)
761
+ self.tensorref_to_mesh: DefaultDict[int, Set[int]] = defaultdict(
762
+ set
763
+ ) # key = DTensorRef.ref (int); value = mesh device ids (Set[int])
764
+ self.id_to_storageid: Dict[
765
+ int, int
766
+ ] = {} # key = id(UntypedStorage) (int); value = storage id (int)
767
+ self.storageid_to_tensorref: DefaultDict[int, Set[int]] = defaultdict(
768
+ set
769
+ ) # key = storage_id (int); value = List[DTensorRef] (List[int])
770
+ self.storageid_counter: Iterator[int] = count()