torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1052 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import copy
9
+ import cProfile
10
+ import enum
11
+ import heapq
12
+ import io
13
+ import itertools
14
+ import json
15
+ import logging
16
+ import os
17
+ import pickle
18
+ import pstats
19
+ import subprocess
20
+ import tempfile
21
+ import time
22
+ import traceback
23
+ import warnings
24
+ from collections import defaultdict
25
+ from enum import auto
26
+ from functools import cache
27
+ from pathlib import Path
28
+ from typing import (
29
+ Any,
30
+ cast,
31
+ Generator,
32
+ Iterable,
33
+ List,
34
+ NamedTuple,
35
+ Optional,
36
+ Tuple,
37
+ Union,
38
+ )
39
+
40
+ import numpy as np
41
+
42
+ import torch
43
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
44
+ ActorId,
45
+ )
46
+ from monarch.common import messages
47
+ from monarch.common.controller_api import LogMessage, MessageResult
48
+ from monarch.common.device_mesh import DeviceMesh
49
+ from monarch.common.function import ResolvableFunction, ResolvableFunctionFromPath
50
+ from monarch.common.invocation import DeviceException
51
+ from monarch.common.shape import iter_ranks, NDSlice
52
+ from monarch.simulator.command_history import CommandHistory, DTensorRef
53
+ from monarch.simulator.config import META_VAL
54
+ from monarch.simulator.ir import IRGraph
55
+ from monarch.simulator.mock_controller import MockController
56
+ from monarch.simulator.profiling import RuntimeEstimator, RuntimeProfiler
57
+ from monarch.simulator.task import Borrow, EventTask, Task
58
+ from monarch.simulator.tensor import FakeTensorTracker
59
+ from monarch.simulator.trace import (
60
+ dump_memory_trace,
61
+ dump_process_name,
62
+ dump_thread_event_trace,
63
+ MemoryViewer,
64
+ TraceEvent,
65
+ upload_trace,
66
+ )
67
+ from monarch.simulator.utils import (
68
+ clean_name,
69
+ compress_workers_range,
70
+ file_path_with_iter,
71
+ )
72
+ from monarch.simulator.worker import Worker, WorkerGroup
73
+ from torch.utils._pytree import tree_leaves
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+
78
+ class SimulatorBackendMode(enum.Enum):
79
+ """
80
+ An enum to specify the mode of the simulator.
81
+ """
82
+
83
+ # Simulates the commands, dumps the trace, and reports the simulated
84
+ # execution time and memory. It is the default mode.
85
+ SIMULATE = auto()
86
+ # Simulates the commands and reports the simulated execution time and memory
87
+ # without generating a trace.
88
+ SIMULATE_WITH_REPORT_ONLY = auto()
89
+ # Only records the commands without actually simulating them.
90
+ COMMAND_HISTORY = auto()
91
+ # SIMULATE + COMMAND_HISTORY
92
+ EVERYTHING = auto()
93
+
94
+ @property
95
+ def simulation_enabled(self) -> bool:
96
+ return self in (self.SIMULATE, self.SIMULATE_WITH_REPORT_ONLY, self.EVERYTHING)
97
+
98
+ @property
99
+ def command_history_enabled(self) -> bool:
100
+ return self in (self.COMMAND_HISTORY, self.EVERYTHING)
101
+
102
+
103
+ class SimulatorTraceMode(enum.Enum):
104
+ """
105
+ An enum to specify the mode of the simulated trace.
106
+ """
107
+
108
+ # Only traces the controller
109
+ CONTROLLER_TRACE_ONLY = auto()
110
+ # Only traces the streams of all the workers.
111
+ STREAM_ONLY = auto()
112
+ # Traces all the streams of all the workers.
113
+ EVERYTHING = auto()
114
+
115
+ @property
116
+ def stream_enabled(self) -> bool:
117
+ return self in (self.STREAM_ONLY, self.EVERYTHING)
118
+
119
+ @property
120
+ def controller_enabled(self) -> bool:
121
+ return self in (self.CONTROLLER_TRACE_ONLY, self.EVERYTHING)
122
+
123
+
124
+ def get_fake_tensor(x):
125
+ if isinstance(x, (torch.Tensor, DTensorRef)):
126
+ return x._fake
127
+ return x
128
+
129
+
130
+ def get_ids(tree):
131
+ if isinstance(tree, (torch.Tensor, DTensorRef)):
132
+ tree = [tree]
133
+ ids = {}
134
+ for arg in tree_leaves(tree):
135
+ if isinstance(arg, (torch.Tensor, DTensorRef)):
136
+ ids[arg.ref] = arg._fake
137
+ return ids
138
+
139
+
140
+ class Simulator:
141
+ """
142
+ A class to simulate the execution of the commands from the controller.
143
+ It can be used to simulate on the fly with SimulatorBackend() or replay an
144
+ existing trace with Simulator.replay().
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ *,
150
+ world_size: int = 0,
151
+ profile: bool = False,
152
+ replay_file: Optional[str] = None,
153
+ trace_mode: SimulatorTraceMode = SimulatorTraceMode.EVERYTHING,
154
+ upload_trace: bool = False,
155
+ trace_path: str = "trace.json",
156
+ group_workers: bool = False,
157
+ ):
158
+ self.command_history: Optional[CommandHistory] = None
159
+ if replay_file:
160
+ self.command_history = CommandHistory.load(replay_file)
161
+ world_size = self.command_history.world_size
162
+
163
+ if world_size <= 0:
164
+ raise ValueError(
165
+ f"{world_size=} is not correct. Please specify a valid "
166
+ "world_size or ensure the replay file contains the world_size."
167
+ )
168
+
169
+ self.runtime = RuntimeEstimator()
170
+ self.runtime_profiler = RuntimeProfiler(world_size=torch.cuda.device_count())
171
+ self.events: List[TraceEvent] = []
172
+ self.command_id = 0
173
+ self.fake_tensor_tracker = FakeTensorTracker()
174
+
175
+ self._worker_groups: List[WorkerGroup] = []
176
+ self._workers: List[Worker] = []
177
+ self._worker_group_mapping = np.zeros(1, dtype=np.int32)
178
+ if group_workers:
179
+ self._worker_groups = [
180
+ WorkerGroup(
181
+ np.arange(world_size), self.fake_tensor_tracker, self.runtime
182
+ )
183
+ ]
184
+ self._worker_group_mapping = np.zeros(world_size, dtype=np.int32)
185
+ else:
186
+ self._workers = [
187
+ Worker(self.fake_tensor_tracker, self.runtime)
188
+ for _ in range(world_size)
189
+ ]
190
+
191
+ self.worker_commands = defaultdict(list)
192
+ self.now = 0
193
+ self.profiler = cProfile.Profile() if profile else None
194
+ self.simulation_time = 0.0
195
+ self.trace_mode = trace_mode
196
+ self.upload_trace = upload_trace
197
+ self._debug = False
198
+ self.trace_path = os.path.abspath(trace_path)
199
+ self.current_traceback = []
200
+
201
+ @property
202
+ def workers(self) -> List[Worker]:
203
+ if self._worker_groups:
204
+ # why can't pyre figure out the upcasting?
205
+ return cast(List[Worker], self._worker_groups)
206
+ else:
207
+ return self._workers
208
+
209
+ def _print_worker0(self) -> None:
210
+ if not self._debug:
211
+ return
212
+
213
+ for idx, stream in self.workers[0].streams.items():
214
+ if stream.task_queue:
215
+ logger.info(
216
+ (
217
+ self.now,
218
+ idx,
219
+ stream.task_queue[0],
220
+ stream.task_queue[0].state,
221
+ stream.task_queue[0].dependencies,
222
+ stream.tensors,
223
+ )
224
+ )
225
+
226
+ def _run(self) -> None:
227
+ """
228
+ This method simulates the execution of tasks on workers. It iteratively checks
229
+ the status of workers and executes tasks in three stages: maybe_set_ready,
230
+ maybe_execute, and maybe_finish. These stages are performed in separate loops
231
+ to simulate asynchronous execution. The method continues until no status change
232
+ occurs.
233
+ """
234
+
235
+ task_changed_status = True
236
+ while task_changed_status:
237
+ self._print_worker0()
238
+ task_changed_status = False
239
+ for worker in self.workers:
240
+ task_changed_status = worker.maybe_set_ready() or task_changed_status
241
+ for worker in self.workers:
242
+ task_changed_status = worker.maybe_execute() or task_changed_status
243
+ for worker in self.workers:
244
+ task_changed_status = worker.maybe_finish() or task_changed_status
245
+
246
+ def _print_profiler(self):
247
+ if self.profiler is None:
248
+ return
249
+ s = io.StringIO()
250
+ ps = pstats.Stats(self.profiler, stream=s).sort_stats(pstats.SortKey.CUMULATIVE)
251
+ ps.print_stats()
252
+ print(s.getvalue())
253
+ print(
254
+ f"Simulation run time, excluding loading the file: {self.simulation_time}."
255
+ )
256
+
257
+ def _rank_to_worker(self, ranks: List[NDSlice]) -> Generator[Worker, None, None]:
258
+ for rank in ranks:
259
+ for worker in rank:
260
+ yield self._workers[worker]
261
+
262
+ def _ndslice_to_worker_group(
263
+ self, ranks: List[NDSlice]
264
+ ) -> Generator[WorkerGroup, None, None]:
265
+ # TODO: While we already use numpy array, this can still be quite slow
266
+ # because iterating ranks happens in Python. We should cache the results
267
+ # since we don't have that many different ranks combinations.
268
+
269
+ workers_list = [np.array(list(iter(ranks_))) for ranks_ in ranks]
270
+ workers = np.sort(np.concatenate(workers_list))
271
+ groups = np.bincount(self._worker_group_mapping[workers])
272
+ groups_iter = cast(Iterable, groups.flat)
273
+ all_matches = all(
274
+ len(self._worker_groups[group_id].workers) == element_count
275
+ for group_id, element_count in enumerate(groups_iter)
276
+ if element_count > 0
277
+ )
278
+ if all_matches:
279
+ for group_id in np.nonzero(groups)[0].flat:
280
+ yield self._worker_groups[group_id]
281
+ else:
282
+ new_groups = []
283
+ participate_groups = []
284
+ groups_iter = cast(Iterable, groups.flat)
285
+ for group_id, element_count in enumerate(groups_iter):
286
+ group = self._worker_groups[group_id]
287
+ new_groups.append(group)
288
+ if element_count > 0:
289
+ participate_groups.append(group)
290
+
291
+ not_participate_set = np.setdiff1d(
292
+ group.workers, workers, assume_unique=True
293
+ )
294
+ not_participate_group = group.split(not_participate_set)
295
+ new_groups.append(not_participate_group)
296
+ self._worker_group_mapping[not_participate_set] = (
297
+ len(new_groups) - 1
298
+ )
299
+ self._worker_groups = new_groups
300
+ for group in participate_groups:
301
+ yield group
302
+
303
+ def iter_workers(self, ranks: List[NDSlice]) -> Generator[Worker, None, None]:
304
+ if self._worker_groups:
305
+ yield from self._ndslice_to_worker_group(ranks)
306
+ else:
307
+ yield from self._rank_to_worker(ranks)
308
+
309
+ def _report(self, trace_path: str = "", memory_view_path: str = ""):
310
+ trace = []
311
+
312
+ exec_time = 0.0
313
+ max_mem = 0.0
314
+
315
+ # perfetto treads tid and pid as part of the same namespace
316
+ # (unlike chrome://trace). If they colleide then names will
317
+ # get clobbered, so we assign unique ids to each individual
318
+ # concept.
319
+ id_iter = iter(itertools.count(1))
320
+
321
+ @cache
322
+ def to_id(key):
323
+ return next(id_iter)
324
+
325
+ dump_process_name(trace, pid=0, name="Controller")
326
+ exec_time = max(
327
+ exec_time,
328
+ dump_thread_event_trace(
329
+ trace, self.events, pid=0, tid=0, name="Controller"
330
+ ),
331
+ )
332
+
333
+ if isinstance(self.workers[0], WorkerGroup):
334
+ workers = sorted(self.workers, key=lambda g: min(g.workers))
335
+ else:
336
+ workers = self.workers
337
+
338
+ memory_viewer = MemoryViewer()
339
+ for worker_id, worker in enumerate(workers):
340
+ if not worker.events:
341
+ continue
342
+ pid = to_id(("worker", worker_id))
343
+ name = f"Device {worker_id}"
344
+ if isinstance(worker, WorkerGroup):
345
+ name = f"{name} {compress_workers_range(worker.workers)}"
346
+ dump_process_name(trace=trace, pid=pid, name=name)
347
+ # TODO: find a better tid for worker trace
348
+ exec_time = max(
349
+ dump_thread_event_trace(
350
+ trace, self.events, pid=pid, tid=32000, name=name
351
+ ),
352
+ exec_time,
353
+ )
354
+
355
+ for stream_id, stream in worker.streams.items():
356
+ tid = to_id(("stream", worker_id, stream_id))
357
+ exec_time = max(
358
+ dump_thread_event_trace(
359
+ trace, stream.events, pid=pid, tid=tid, name=stream.name
360
+ ),
361
+ exec_time,
362
+ )
363
+
364
+ # Get the memory order
365
+ curr_mem = 0
366
+ memory_viewer.next_device()
367
+ mem_events = {
368
+ stream_id: copy.copy(stream.memory.events)
369
+ for stream_id, stream in worker.streams.items()
370
+ }
371
+ while True:
372
+ min_ts = float("inf")
373
+ min_stream_events = None
374
+ min_stream_id = 0
375
+ for stream_id, events in mem_events.items():
376
+ if events and min_ts > events[0][0]:
377
+ min_ts = events[0][0]
378
+ min_stream_id, min_stream_events = stream_id, events
379
+
380
+ if min_stream_events is None:
381
+ break
382
+
383
+ mem_ts, mem_addr, mem_delta, traceback = heapq.heappop(
384
+ min_stream_events
385
+ )
386
+ curr_mem += mem_delta
387
+ max_mem = max(curr_mem, max_mem)
388
+ dump_memory_trace(
389
+ trace,
390
+ pid=pid,
391
+ memory=curr_mem,
392
+ ts=mem_ts,
393
+ name="memory",
394
+ )
395
+ memory_viewer.add_trace(mem_addr, mem_delta, min_stream_id, traceback)
396
+
397
+ if trace_path:
398
+ with open(trace_path, "w") as f:
399
+ json.dump({"traceEvents": trace}, f, indent=4)
400
+
401
+ memory_viewer.dump(memory_view_path)
402
+
403
+ if self.upload_trace:
404
+ upload_trace(os.path.abspath(f.name))
405
+
406
+ return exec_time / 10**6, max_mem / 10**6
407
+
408
+ def step(self, iter_count: int, dump_trace: bool = False) -> Tuple[float, float]:
409
+ """
410
+ Step to the next iteration simulation and return the execution time in second
411
+ and peak memory usage in MB of this iteration.
412
+ """
413
+ path = file_path_with_iter(self.trace_path, iter_count) if dump_trace else ""
414
+ directory = os.path.dirname(path)
415
+ memory_view_path = os.path.join(directory, "memory_view.pt")
416
+ memory_view_path = file_path_with_iter(memory_view_path, iter_count)
417
+ return self._report(path, memory_view_path)
418
+
419
+ def exit(self, iter_count: int, dump_trace: bool = False) -> Tuple[float, float]:
420
+ return self.step(iter_count, dump_trace)
421
+
422
+ @classmethod
423
+ def replay(cls, replay_file: str, profile: bool = False) -> None:
424
+ self = cls(replay_file=replay_file, profile=profile)
425
+ for command in cast(CommandHistory, self.command_history).commands:
426
+ if command.backend_command != "send":
427
+ continue
428
+ assert command.ranks is not None
429
+ self.send(command.timestamp, command.ranks, command.msg)
430
+ self._report()
431
+ self._print_profiler()
432
+
433
+ # Methods below simulate the methods of a real backend.
434
+ def send(self, now: int, ranks: List[NDSlice], msg) -> None:
435
+ logger.debug(f"Sending {msg} at {now}.")
436
+ self.current_traceback = traceback.extract_stack()[:-3]
437
+ command_name = type(msg).__name__
438
+ self.command_id += 1
439
+ # These two commands typically take a long time to execute on the
440
+ # controller side. Ignoring them will make the simulation trace easier
441
+ # to read.
442
+ if self.trace_mode.controller_enabled and command_name not in (
443
+ "CreateDeviceMesh",
444
+ "CreateStream",
445
+ ):
446
+ if command_name != "CallFunction":
447
+ meta = [command_name] + META_VAL
448
+ else:
449
+ meta = [clean_name(msg.function.path)] + META_VAL
450
+ self.events.append(
451
+ TraceEvent(
452
+ self.now,
453
+ now - self.now,
454
+ meta,
455
+ self.command_id,
456
+ self.current_traceback,
457
+ )
458
+ )
459
+
460
+ if self.trace_mode.controller_enabled:
461
+ self.now = now
462
+
463
+ if not self.trace_mode.stream_enabled and command_name != "CommandGroup":
464
+ return
465
+
466
+ begin = time.monotonic()
467
+ if self.profiler:
468
+ self.profiler.enable()
469
+
470
+ attr = getattr(self, command_name, None)
471
+ if attr is None:
472
+ # Instead of silently ignoring the unimplemented method, a warning
473
+ # gives us the signal to review any newly implemented messages.
474
+ warnings.warn(
475
+ f"Simulator doesn't implement {type(msg).__name__} {msg}."
476
+ "This can cause incorrect simulation.",
477
+ stacklevel=2,
478
+ )
479
+ return
480
+
481
+ attr(ranks, msg)
482
+ self._run()
483
+
484
+ if self.profiler:
485
+ self.profiler.disable()
486
+ self.simulation_time += time.monotonic() - begin
487
+
488
+ def recvready(self):
489
+ raise NotImplementedError()
490
+
491
+ def propagate(self, msg: messages.SendValue) -> Any:
492
+ assert isinstance(msg.function, ResolvableFunction)
493
+ call_msg = messages.CallFunction(
494
+ ident=0,
495
+ result=None,
496
+ mutates=(),
497
+ function=msg.function,
498
+ args=msg.args,
499
+ kwargs=msg.kwargs,
500
+ stream=None, # pyre-ignore[6]
501
+ device_mesh=None, # pyre-ignore[6]
502
+ remote_process_groups=[],
503
+ )
504
+ ret = self.runtime_profiler.profile_cmd(call_msg, [0])
505
+ return ret[0][0]
506
+
507
+ def Exit(self, ranks: List[NDSlice], msg: messages.Exit):
508
+ return
509
+
510
+ def CallFunction(self, ranks: List[NDSlice], msg: messages.CallFunction):
511
+ inputs = get_ids(msg.args)
512
+ outputs = get_ids(msg.result)
513
+ if msg.mutates:
514
+ outputs.update(get_ids(msg.mutates))
515
+ self.fake_tensor_tracker.add(inputs)
516
+ self.fake_tensor_tracker.add(outputs)
517
+ stream = msg.stream.ref
518
+ for worker in self.iter_workers(ranks):
519
+ name = clean_name(str(msg.function))
520
+ worker.add_task(
521
+ Task(
522
+ inputs=list(inputs.keys()),
523
+ outputs=list(outputs.keys()),
524
+ command_id=self.command_id,
525
+ start_time=self.now,
526
+ runtime=self.runtime.get_runtime(msg),
527
+ meta=[name],
528
+ traceback=self.current_traceback,
529
+ ),
530
+ self.now,
531
+ stream=stream,
532
+ )
533
+
534
+ def SendTensor(self, ranks: List[NDSlice], msg: messages.SendTensor):
535
+ # NOTE: The memory usage calculation for SendTensor may not be accurate when
536
+ # the source and destination ranks are the same. In such cases, memory usage
537
+ # should increase if the result tensor is modified. However, this depends on
538
+ # the specific implementation by the worker.
539
+
540
+ inputs = get_ids(msg.tensor)
541
+ outputs = get_ids(msg.result)
542
+ self.fake_tensor_tracker.add(inputs)
543
+ self.fake_tensor_tracker.add(outputs)
544
+ if msg.from_stream is not msg.to_stream:
545
+ raise NotImplementedError(
546
+ "simulator using to_mesh between different streams"
547
+ )
548
+ stream = msg.from_stream.ref
549
+
550
+ if msg.from_ranks == msg.to_ranks:
551
+ for worker in self.iter_workers([msg.from_ranks]):
552
+ worker.add_task(
553
+ Task(
554
+ inputs=list(inputs.keys()),
555
+ outputs=list(outputs.keys()),
556
+ command_id=self.command_id,
557
+ start_time=self.now,
558
+ runtime=self.runtime.get_runtime(msg),
559
+ meta=["SendTensor"],
560
+ traceback=self.current_traceback,
561
+ ),
562
+ self.now,
563
+ stream=stream,
564
+ )
565
+ else:
566
+ collectives_pair = []
567
+ for worker in self.iter_workers([msg.from_ranks]):
568
+ collectives_pair.append([])
569
+ worker.add_task(
570
+ Task(
571
+ inputs=list(inputs.keys()),
572
+ outputs=[],
573
+ command_id=self.command_id,
574
+ start_time=self.now,
575
+ runtime=self.runtime.get_runtime(msg),
576
+ meta=["SendTensor"],
577
+ collectives=collectives_pair[-1],
578
+ traceback=self.current_traceback,
579
+ ),
580
+ self.now,
581
+ stream=stream,
582
+ )
583
+
584
+ for worker, collectives in zip(
585
+ self.iter_workers([msg.to_ranks]), collectives_pair, strict=True
586
+ ):
587
+ worker.add_task(
588
+ Task(
589
+ inputs=[],
590
+ outputs=list(outputs.keys()),
591
+ command_id=self.command_id,
592
+ start_time=self.now,
593
+ runtime=self.runtime.get_runtime(msg),
594
+ meta=["RecvTensor"],
595
+ collectives=collectives,
596
+ traceback=self.current_traceback,
597
+ ),
598
+ self.now,
599
+ stream=stream,
600
+ )
601
+
602
+ def CommandGroup(self, ranks: List[NDSlice], msg: messages.CommandGroup):
603
+ for command in msg.commands:
604
+ self.send(self.now, ranks, command)
605
+
606
+ def CreateStream(self, ranks: List[NDSlice], msg: messages.CreateStream):
607
+ for worker in self.iter_workers(ranks):
608
+ assert msg.result.ref is not None
609
+ worker.create_stream(msg.result.ref, msg.result.name, default=msg.default)
610
+
611
+ def Reduce(self, ranks: List[NDSlice], msg: messages.Reduce):
612
+ inputs = get_ids(msg.local_tensor)
613
+ outputs = get_ids(msg.result)
614
+ self.fake_tensor_tracker.add(inputs)
615
+ self.fake_tensor_tracker.add(outputs)
616
+
617
+ # TODO: controller doesn't implement reduce and scatter yet so it is
618
+ # not possible to get such a request.
619
+ if msg.reduction == "stack":
620
+ if msg.scatter:
621
+ meta_str = "all_to_all"
622
+ else:
623
+ meta_str = "all_gather"
624
+ else:
625
+ if msg.scatter:
626
+ meta_str = "all_reduce"
627
+ else:
628
+ meta_str = "reduce_scatter"
629
+
630
+ meta = [meta_str]
631
+ stream = msg.stream.ref
632
+ collectives = []
633
+ for worker in self.iter_workers(ranks):
634
+ worker.add_task(
635
+ Task(
636
+ inputs=list(inputs.keys()),
637
+ outputs=list(outputs.keys()),
638
+ start_time=self.now,
639
+ runtime=self.runtime.get_runtime(msg),
640
+ meta=meta,
641
+ command_id=self.command_id,
642
+ collectives=collectives,
643
+ traceback=self.current_traceback,
644
+ ),
645
+ self.now,
646
+ stream=stream,
647
+ )
648
+
649
+ def BorrowCreate(self, ranks: List[NDSlice], msg: messages.BorrowCreate):
650
+ inputs = get_ids(msg.tensor)
651
+ outputs = get_ids(msg.result)
652
+ self.fake_tensor_tracker.add(inputs)
653
+ self.fake_tensor_tracker.add(outputs, is_borrowed=True)
654
+ from_stream = msg.from_stream.ref
655
+ to_stream = msg.to_stream.ref
656
+ assert from_stream is not None
657
+ assert to_stream is not None
658
+ borrow = Borrow(
659
+ ident=msg.borrow,
660
+ tensor_src_id=cast(int, cast(DTensorRef, msg.tensor).ref),
661
+ tensor_dst_id=cast(int, cast(DTensorRef, msg.result).ref),
662
+ from_stream=from_stream,
663
+ to_stream=to_stream,
664
+ )
665
+ for worker in self.iter_workers(ranks):
666
+ recorded_task = worker.streams[from_stream].record_event()
667
+ # Note: there is no perfect way to set the start_time when the
668
+ # controller timing is disabled -- the wait event's start time
669
+ # may be very early like 0. This is because only the GPU events
670
+ # are tracked and there are no other GPU events except for
671
+ # communications and wait events on the communication stream.
672
+ # However, if we let the event's start_time to be based on the
673
+ # main stream's timing, we may lose other information.
674
+ start_time = self.now
675
+ wait_event = EventTask(
676
+ recorded_task=recorded_task,
677
+ event_stream=from_stream,
678
+ event_stream_name=worker.streams[from_stream].name,
679
+ wait_stream=to_stream,
680
+ wait_stream_name=worker.streams[to_stream].name,
681
+ command_id=self.command_id,
682
+ start_time=start_time,
683
+ borrow=borrow,
684
+ runtime=self.runtime.get_runtime("wait_event"),
685
+ traceback=self.current_traceback,
686
+ )
687
+ worker.borrow(wait_event, borrow)
688
+
689
+ def BorrowFirstUse(self, ranks: List[NDSlice], msg: messages.BorrowFirstUse):
690
+ for worker in self.iter_workers(ranks):
691
+ worker.borrow_first_use(msg.borrow, self.now)
692
+
693
+ def BorrowLastUse(self, ranks: List[NDSlice], msg: messages.BorrowLastUse):
694
+ for worker in self.iter_workers(ranks):
695
+ borrow_wait_event = worker.wait_events[msg.borrow]
696
+ recorded_task = worker.streams[borrow_wait_event.wait_stream].record_event()
697
+ last_use_event = EventTask(
698
+ recorded_task=recorded_task,
699
+ event_stream=borrow_wait_event.wait_stream,
700
+ event_stream_name=worker.streams[borrow_wait_event.wait_stream].name,
701
+ wait_stream=borrow_wait_event.event_stream,
702
+ wait_stream_name=worker.streams[borrow_wait_event.event_stream].name,
703
+ command_id=self.command_id,
704
+ start_time=self.now,
705
+ runtime=self.runtime.get_runtime("wait_event"),
706
+ traceback=self.current_traceback,
707
+ )
708
+ worker.borrow_last_use(last_use_event, msg.borrow)
709
+
710
+ def BorrowDrop(self, ranks: List[NDSlice], msg: messages.BorrowDrop):
711
+ for worker in self.iter_workers(ranks):
712
+ worker.borrow_drop(msg.borrow, self.now)
713
+
714
+ def DeleteRefs(self, ranks: List[NDSlice], msg: messages.DeleteRefs):
715
+ for worker in self.iter_workers(ranks):
716
+ worker.delete_refs(msg.refs, self.now)
717
+
718
+ def BackendNetworkInit(
719
+ self, ranks: List[NDSlice], msg: messages.BackendNetworkInit
720
+ ):
721
+ return
722
+
723
+ def CreatePipe(self, ranks: List[NDSlice], msg: messages.CreatePipe):
724
+ # We don't have to track Pipe creation (yet).
725
+ return
726
+
727
+ def PipeRecv(self, ranks: List[NDSlice], msg: messages.PipeRecv):
728
+ outputs = get_ids(msg.result)
729
+ cpu_device = torch.device("cpu")
730
+ self.fake_tensor_tracker.add(outputs)
731
+ for fake in outputs.values():
732
+ if fake.device != cpu_device:
733
+ raise NotImplementedError("PipeRecv only support CPU device now.")
734
+
735
+ for worker in self.iter_workers(ranks):
736
+ for tensor_id in outputs.keys():
737
+ worker.add_cpu_tensor(tensor_id, self.now)
738
+
739
+ # Not doing anything for the following messages (yet).
740
+ def SendValue(self, ranks: List[NDSlice], msg: messages.SendValue):
741
+ return
742
+
743
+ def CreateDeviceMesh(self, ranks: List[NDSlice], msg: messages.CreateDeviceMesh):
744
+ return
745
+
746
+ def RequestStatus(self, ranks: List[NDSlice], msg: messages.RequestStatus):
747
+ return
748
+
749
+ def SplitComm(self, ranks: List[NDSlice], msg: messages.SplitComm):
750
+ return
751
+
752
+ def BackendNetworkPointToPointInit(
753
+ self, ranks: List[NDSlice], msg: messages.BackendNetworkPointToPointInit
754
+ ):
755
+ return
756
+
757
+
758
+ class SimulatorController(MockController):
759
+ """
760
+ A backend that simulates the behavior of the ProcessBackend. It can also be
761
+ used to only record the commands sent to it, and then replay them later using
762
+ the `Simulator` class.
763
+
764
+ Args:
765
+ world_size (int): The number of workers in the simulation.
766
+ grph_per_host (int): The number of GPUs per machine.
767
+ """
768
+
769
+ def __init__(
770
+ self,
771
+ world_size: int,
772
+ gpu_per_host: int,
773
+ *,
774
+ simulate_mode: SimulatorBackendMode = SimulatorBackendMode.SIMULATE,
775
+ trace_mode: SimulatorTraceMode = SimulatorTraceMode.EVERYTHING,
776
+ upload_trace: bool = False,
777
+ trace_path: str = "trace.json",
778
+ command_history_path: str = "command_history.pkl",
779
+ group_workers: bool = False,
780
+ ir: Optional[IRGraph] = None,
781
+ ):
782
+ if len(DTensorRef.created) != 0:
783
+ DTensorRef.created.clear()
784
+ warnings.warn(
785
+ "clearing old DTensorRef information. TODO: support multiple simulator backends in the same process.",
786
+ stacklevel=1,
787
+ )
788
+ super().__init__(world_size, verbose=False)
789
+
790
+ self._gpu_per_host = gpu_per_host
791
+ self.timestamp_base = time.monotonic_ns()
792
+ self.worker_commands = defaultdict(list)
793
+ self.simulator: Optional[Simulator] = None
794
+ self.command_history: Optional[CommandHistory] = None
795
+ self.iter = 0
796
+ self.mode = simulate_mode
797
+ self.exception = False
798
+ self.ir = ir
799
+
800
+ if self.mode.command_history_enabled:
801
+ self.command_history = CommandHistory(
802
+ world_size, file_path=os.path.abspath(command_history_path)
803
+ )
804
+
805
+ if self.mode.simulation_enabled:
806
+ self.simulator = Simulator(
807
+ world_size=world_size,
808
+ trace_mode=trace_mode,
809
+ upload_trace=upload_trace,
810
+ trace_path=trace_path,
811
+ group_workers=group_workers,
812
+ )
813
+
814
+ @property
815
+ def gpu_per_host(self) -> int:
816
+ return self._gpu_per_host
817
+
818
+ def cleanup_simulation(self):
819
+ DTensorRef.created.clear()
820
+
821
+ def __del__(self):
822
+ self.cleanup_simulation()
823
+
824
+ def step(self) -> Tuple[float, float]:
825
+ """
826
+ Step to the next iteration simulation and return the execution time in second
827
+ and peak memory usage in MB of this iteration. If the simulation mode is
828
+ COMMAND_HISTORY, then the return time and memory will be 0.0 as the backend
829
+ only records the commands.
830
+ """
831
+ if self.command_history:
832
+ self.command_history.step(self.iter)
833
+
834
+ if self.simulator:
835
+ exec_time, max_mem = self.simulator.step(
836
+ self.iter,
837
+ dump_trace=(
838
+ self.mode != SimulatorBackendMode.SIMULATE_WITH_REPORT_ONLY
839
+ ),
840
+ )
841
+ else:
842
+ exec_time = max_mem = 0.0
843
+
844
+ self.iter += 1
845
+
846
+ return exec_time, max_mem
847
+
848
+ def _send(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple) -> None:
849
+ now = time.monotonic_ns() - self.timestamp_base
850
+
851
+ if isinstance(ranks, NDSlice):
852
+ ranks = [ranks]
853
+
854
+ if self.command_history:
855
+ command = self.command_history.record(
856
+ now,
857
+ "send",
858
+ self.simulator.command_id if self.simulator else 0,
859
+ self.simulator.current_traceback if self.simulator else (),
860
+ ranks,
861
+ msg,
862
+ None,
863
+ self.ir,
864
+ )
865
+ else:
866
+ command = CommandHistory.convert_command(
867
+ now,
868
+ "send",
869
+ self.simulator.command_id if self.simulator else 0,
870
+ self.simulator.current_traceback if self.simulator else (),
871
+ ranks,
872
+ msg,
873
+ None,
874
+ self.ir,
875
+ )
876
+
877
+ if self.simulator:
878
+ self.simulator.send(now, cast(List[NDSlice], command.ranks), command.msg)
879
+
880
+ if type(msg).__name__ == "SendValue":
881
+ msg = cast(messages.SendValue, msg)
882
+ if (
883
+ isinstance(msg.function, ResolvableFunctionFromPath)
884
+ and msg.function.path == "monarch.cached_remote_function._propagate"
885
+ ):
886
+ assert self.simulator is not None
887
+ assert msg.destination is None
888
+ ret = self.simulator.propagate(msg)
889
+ for _ in iter_ranks(ranks):
890
+ self.history.future_completed(msg.ident, ret)
891
+ return
892
+
893
+ if type(msg).__name__ not in ("CommandGroup",):
894
+ return super().send(ranks, msg)
895
+
896
+ def send(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple) -> None:
897
+ if self.exception:
898
+ return
899
+
900
+ try:
901
+ self._send(ranks, msg)
902
+ except Exception as e:
903
+ self.exception = True
904
+ # TODO: Should we also call simulator.exit() and cleanup?
905
+ self.responses.append(
906
+ MessageResult(
907
+ seq=0, # will not be used
908
+ result=None,
909
+ error=DeviceException(
910
+ e,
911
+ traceback.extract_tb(e.__traceback__),
912
+ ActorId.from_string("unknown[0].unknown[0]"),
913
+ message="Simulator has an internal error.",
914
+ ),
915
+ )
916
+ )
917
+
918
+ def next_message(
919
+ self, timeout: Optional[float]
920
+ ) -> Optional[MessageResult | LogMessage]:
921
+ now = time.monotonic_ns() - self.timestamp_base
922
+
923
+ if self.command_history:
924
+ self.command_history.record(
925
+ now,
926
+ "next_message",
927
+ self.simulator.command_id if self.simulator else 0,
928
+ self.simulator.current_traceback if self.simulator else (),
929
+ None,
930
+ None,
931
+ timeout,
932
+ self.ir,
933
+ )
934
+
935
+ return super().next_message(timeout)
936
+
937
+ def Exit(self, ranks: Union[NDSlice, List[NDSlice]], msg: messages.Exit):
938
+ if self.command_history:
939
+ self.command_history.dump(self.command_history.file_path)
940
+ if self.simulator:
941
+ self.simulator.exit(
942
+ self.iter,
943
+ dump_trace=(
944
+ self.mode != SimulatorBackendMode.SIMULATE_WITH_REPORT_ONLY
945
+ ),
946
+ )
947
+ self.cleanup_simulation()
948
+
949
+ return super().Exit(ranks, msg)
950
+
951
+
952
+ class SimulatorInterface:
953
+ """
954
+ API for interactive with simulator.
955
+ sim.mesh retrieves the simulator mesh.
956
+ """
957
+
958
+ def __init__(
959
+ self, mesh: "DeviceMesh", ctrl: "SimulatorController", ir: Optional["IRGraph"]
960
+ ):
961
+ self.mesh = mesh
962
+ self._ctrl = ctrl
963
+ self._ir = ir
964
+
965
+ def upload(self):
966
+ sim = self._ctrl.simulator
967
+ old, sim.upload_trace = sim.upload_trace, True
968
+ try:
969
+ self._ctrl.step()
970
+ finally:
971
+ sim.upload_trace = old
972
+
973
+ def _display_html(self, html_code):
974
+ import base64
975
+
976
+ from IPython.display import display, Javascript
977
+
978
+ # Encode the HTML code in base64 to be passed to JavaScript, then
979
+ # decode from base64 inside JavaScript. This is a hack to get this to
980
+ # work properly in Bento.
981
+ b64_html = base64.b64encode(html_code.encode("utf-8")).decode("utf-8")
982
+
983
+ # JavaScript to open a new window and write the HTML
984
+ js_code = f"""
985
+ var newWindow = window.open("", "_blank");
986
+ newWindow.document.write(atob("{b64_html}"));
987
+ newWindow.document.close();
988
+ window.open("").close()
989
+ """
990
+
991
+ # Display the JavaScript
992
+ display(Javascript(js_code))
993
+
994
+ def _run_trace2html(self, json_filename, html_filename):
995
+ # Call the trace2html script to convert JSON to HTML
996
+ for trace2html in [
997
+ "trace2html",
998
+ Path.home() / "fbsource/third-party/catapult/tracing/bin/trace2html",
999
+ ]:
1000
+ try:
1001
+ subprocess.run(
1002
+ [trace2html, json_filename, "--output", html_filename], check=True
1003
+ )
1004
+ return
1005
+ except FileNotFoundError:
1006
+ pass
1007
+ raise RuntimeError(
1008
+ "trace2html not found. `git clone https://chromium.googlesource.com/catapult` and add catapult/tracing/bin to PATH"
1009
+ )
1010
+
1011
+ def _display_trace(self, json_filename, pkl_filename):
1012
+ # Create temporary files for JSON and HTML
1013
+ with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as html_file:
1014
+ html_filename = html_file.name
1015
+
1016
+ self._run_trace2html(json_filename, html_filename)
1017
+
1018
+ with open(pkl_filename, "rb") as pfile:
1019
+ # @lint-ignore PYTHONPICKLEISBAD
1020
+ memory_data = pickle.load(pfile)
1021
+ import torch.cuda._memory_viz as viz
1022
+
1023
+ self._display_html(viz.trace_plot(memory_data))
1024
+
1025
+ # Read the HTML content from the temporary HTML file
1026
+ with open(html_filename, "r") as file:
1027
+ html_code = file.read()
1028
+ self._display_html(html_code)
1029
+
1030
+ def display(self):
1031
+ """
1032
+ From a jupyter notebook, open the trace report as a new window in your browser.
1033
+ Watch for popup blockers.
1034
+ """
1035
+ sim = self._ctrl.simulator
1036
+ with tempfile.NamedTemporaryFile(
1037
+ suffix=".json", delete=False
1038
+ ) as json_file, tempfile.NamedTemporaryFile(
1039
+ suffix=".pkl", delete=False
1040
+ ) as memory_pkl:
1041
+ sim._report(trace_path=json_file.name, memory_view_path=memory_pkl.name)
1042
+ self._display_trace(json_file.name, memory_pkl.name)
1043
+
1044
+ def export_ir(self, ir_path: str) -> None:
1045
+ """
1046
+ Exports the simulator internal representation (IR) to a file.
1047
+ Args:
1048
+ ir_path (str): The path to the file where the IR will be exported.
1049
+ """
1050
+ assert self._ir is not None, "Simulator IR does not exist!"
1051
+ with open(ir_path, "wb") as f:
1052
+ torch.save(self._ir, f)