torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1052 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import copy
|
9
|
+
import cProfile
|
10
|
+
import enum
|
11
|
+
import heapq
|
12
|
+
import io
|
13
|
+
import itertools
|
14
|
+
import json
|
15
|
+
import logging
|
16
|
+
import os
|
17
|
+
import pickle
|
18
|
+
import pstats
|
19
|
+
import subprocess
|
20
|
+
import tempfile
|
21
|
+
import time
|
22
|
+
import traceback
|
23
|
+
import warnings
|
24
|
+
from collections import defaultdict
|
25
|
+
from enum import auto
|
26
|
+
from functools import cache
|
27
|
+
from pathlib import Path
|
28
|
+
from typing import (
|
29
|
+
Any,
|
30
|
+
cast,
|
31
|
+
Generator,
|
32
|
+
Iterable,
|
33
|
+
List,
|
34
|
+
NamedTuple,
|
35
|
+
Optional,
|
36
|
+
Tuple,
|
37
|
+
Union,
|
38
|
+
)
|
39
|
+
|
40
|
+
import numpy as np
|
41
|
+
|
42
|
+
import torch
|
43
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
44
|
+
ActorId,
|
45
|
+
)
|
46
|
+
from monarch.common import messages
|
47
|
+
from monarch.common.controller_api import LogMessage, MessageResult
|
48
|
+
from monarch.common.device_mesh import DeviceMesh
|
49
|
+
from monarch.common.function import ResolvableFunction, ResolvableFunctionFromPath
|
50
|
+
from monarch.common.invocation import DeviceException
|
51
|
+
from monarch.common.shape import iter_ranks, NDSlice
|
52
|
+
from monarch.simulator.command_history import CommandHistory, DTensorRef
|
53
|
+
from monarch.simulator.config import META_VAL
|
54
|
+
from monarch.simulator.ir import IRGraph
|
55
|
+
from monarch.simulator.mock_controller import MockController
|
56
|
+
from monarch.simulator.profiling import RuntimeEstimator, RuntimeProfiler
|
57
|
+
from monarch.simulator.task import Borrow, EventTask, Task
|
58
|
+
from monarch.simulator.tensor import FakeTensorTracker
|
59
|
+
from monarch.simulator.trace import (
|
60
|
+
dump_memory_trace,
|
61
|
+
dump_process_name,
|
62
|
+
dump_thread_event_trace,
|
63
|
+
MemoryViewer,
|
64
|
+
TraceEvent,
|
65
|
+
upload_trace,
|
66
|
+
)
|
67
|
+
from monarch.simulator.utils import (
|
68
|
+
clean_name,
|
69
|
+
compress_workers_range,
|
70
|
+
file_path_with_iter,
|
71
|
+
)
|
72
|
+
from monarch.simulator.worker import Worker, WorkerGroup
|
73
|
+
from torch.utils._pytree import tree_leaves
|
74
|
+
|
75
|
+
logger = logging.getLogger(__name__)
|
76
|
+
|
77
|
+
|
78
|
+
class SimulatorBackendMode(enum.Enum):
|
79
|
+
"""
|
80
|
+
An enum to specify the mode of the simulator.
|
81
|
+
"""
|
82
|
+
|
83
|
+
# Simulates the commands, dumps the trace, and reports the simulated
|
84
|
+
# execution time and memory. It is the default mode.
|
85
|
+
SIMULATE = auto()
|
86
|
+
# Simulates the commands and reports the simulated execution time and memory
|
87
|
+
# without generating a trace.
|
88
|
+
SIMULATE_WITH_REPORT_ONLY = auto()
|
89
|
+
# Only records the commands without actually simulating them.
|
90
|
+
COMMAND_HISTORY = auto()
|
91
|
+
# SIMULATE + COMMAND_HISTORY
|
92
|
+
EVERYTHING = auto()
|
93
|
+
|
94
|
+
@property
|
95
|
+
def simulation_enabled(self) -> bool:
|
96
|
+
return self in (self.SIMULATE, self.SIMULATE_WITH_REPORT_ONLY, self.EVERYTHING)
|
97
|
+
|
98
|
+
@property
|
99
|
+
def command_history_enabled(self) -> bool:
|
100
|
+
return self in (self.COMMAND_HISTORY, self.EVERYTHING)
|
101
|
+
|
102
|
+
|
103
|
+
class SimulatorTraceMode(enum.Enum):
|
104
|
+
"""
|
105
|
+
An enum to specify the mode of the simulated trace.
|
106
|
+
"""
|
107
|
+
|
108
|
+
# Only traces the controller
|
109
|
+
CONTROLLER_TRACE_ONLY = auto()
|
110
|
+
# Only traces the streams of all the workers.
|
111
|
+
STREAM_ONLY = auto()
|
112
|
+
# Traces all the streams of all the workers.
|
113
|
+
EVERYTHING = auto()
|
114
|
+
|
115
|
+
@property
|
116
|
+
def stream_enabled(self) -> bool:
|
117
|
+
return self in (self.STREAM_ONLY, self.EVERYTHING)
|
118
|
+
|
119
|
+
@property
|
120
|
+
def controller_enabled(self) -> bool:
|
121
|
+
return self in (self.CONTROLLER_TRACE_ONLY, self.EVERYTHING)
|
122
|
+
|
123
|
+
|
124
|
+
def get_fake_tensor(x):
|
125
|
+
if isinstance(x, (torch.Tensor, DTensorRef)):
|
126
|
+
return x._fake
|
127
|
+
return x
|
128
|
+
|
129
|
+
|
130
|
+
def get_ids(tree):
|
131
|
+
if isinstance(tree, (torch.Tensor, DTensorRef)):
|
132
|
+
tree = [tree]
|
133
|
+
ids = {}
|
134
|
+
for arg in tree_leaves(tree):
|
135
|
+
if isinstance(arg, (torch.Tensor, DTensorRef)):
|
136
|
+
ids[arg.ref] = arg._fake
|
137
|
+
return ids
|
138
|
+
|
139
|
+
|
140
|
+
class Simulator:
|
141
|
+
"""
|
142
|
+
A class to simulate the execution of the commands from the controller.
|
143
|
+
It can be used to simulate on the fly with SimulatorBackend() or replay an
|
144
|
+
existing trace with Simulator.replay().
|
145
|
+
"""
|
146
|
+
|
147
|
+
def __init__(
|
148
|
+
self,
|
149
|
+
*,
|
150
|
+
world_size: int = 0,
|
151
|
+
profile: bool = False,
|
152
|
+
replay_file: Optional[str] = None,
|
153
|
+
trace_mode: SimulatorTraceMode = SimulatorTraceMode.EVERYTHING,
|
154
|
+
upload_trace: bool = False,
|
155
|
+
trace_path: str = "trace.json",
|
156
|
+
group_workers: bool = False,
|
157
|
+
):
|
158
|
+
self.command_history: Optional[CommandHistory] = None
|
159
|
+
if replay_file:
|
160
|
+
self.command_history = CommandHistory.load(replay_file)
|
161
|
+
world_size = self.command_history.world_size
|
162
|
+
|
163
|
+
if world_size <= 0:
|
164
|
+
raise ValueError(
|
165
|
+
f"{world_size=} is not correct. Please specify a valid "
|
166
|
+
"world_size or ensure the replay file contains the world_size."
|
167
|
+
)
|
168
|
+
|
169
|
+
self.runtime = RuntimeEstimator()
|
170
|
+
self.runtime_profiler = RuntimeProfiler(world_size=torch.cuda.device_count())
|
171
|
+
self.events: List[TraceEvent] = []
|
172
|
+
self.command_id = 0
|
173
|
+
self.fake_tensor_tracker = FakeTensorTracker()
|
174
|
+
|
175
|
+
self._worker_groups: List[WorkerGroup] = []
|
176
|
+
self._workers: List[Worker] = []
|
177
|
+
self._worker_group_mapping = np.zeros(1, dtype=np.int32)
|
178
|
+
if group_workers:
|
179
|
+
self._worker_groups = [
|
180
|
+
WorkerGroup(
|
181
|
+
np.arange(world_size), self.fake_tensor_tracker, self.runtime
|
182
|
+
)
|
183
|
+
]
|
184
|
+
self._worker_group_mapping = np.zeros(world_size, dtype=np.int32)
|
185
|
+
else:
|
186
|
+
self._workers = [
|
187
|
+
Worker(self.fake_tensor_tracker, self.runtime)
|
188
|
+
for _ in range(world_size)
|
189
|
+
]
|
190
|
+
|
191
|
+
self.worker_commands = defaultdict(list)
|
192
|
+
self.now = 0
|
193
|
+
self.profiler = cProfile.Profile() if profile else None
|
194
|
+
self.simulation_time = 0.0
|
195
|
+
self.trace_mode = trace_mode
|
196
|
+
self.upload_trace = upload_trace
|
197
|
+
self._debug = False
|
198
|
+
self.trace_path = os.path.abspath(trace_path)
|
199
|
+
self.current_traceback = []
|
200
|
+
|
201
|
+
@property
|
202
|
+
def workers(self) -> List[Worker]:
|
203
|
+
if self._worker_groups:
|
204
|
+
# why can't pyre figure out the upcasting?
|
205
|
+
return cast(List[Worker], self._worker_groups)
|
206
|
+
else:
|
207
|
+
return self._workers
|
208
|
+
|
209
|
+
def _print_worker0(self) -> None:
|
210
|
+
if not self._debug:
|
211
|
+
return
|
212
|
+
|
213
|
+
for idx, stream in self.workers[0].streams.items():
|
214
|
+
if stream.task_queue:
|
215
|
+
logger.info(
|
216
|
+
(
|
217
|
+
self.now,
|
218
|
+
idx,
|
219
|
+
stream.task_queue[0],
|
220
|
+
stream.task_queue[0].state,
|
221
|
+
stream.task_queue[0].dependencies,
|
222
|
+
stream.tensors,
|
223
|
+
)
|
224
|
+
)
|
225
|
+
|
226
|
+
def _run(self) -> None:
|
227
|
+
"""
|
228
|
+
This method simulates the execution of tasks on workers. It iteratively checks
|
229
|
+
the status of workers and executes tasks in three stages: maybe_set_ready,
|
230
|
+
maybe_execute, and maybe_finish. These stages are performed in separate loops
|
231
|
+
to simulate asynchronous execution. The method continues until no status change
|
232
|
+
occurs.
|
233
|
+
"""
|
234
|
+
|
235
|
+
task_changed_status = True
|
236
|
+
while task_changed_status:
|
237
|
+
self._print_worker0()
|
238
|
+
task_changed_status = False
|
239
|
+
for worker in self.workers:
|
240
|
+
task_changed_status = worker.maybe_set_ready() or task_changed_status
|
241
|
+
for worker in self.workers:
|
242
|
+
task_changed_status = worker.maybe_execute() or task_changed_status
|
243
|
+
for worker in self.workers:
|
244
|
+
task_changed_status = worker.maybe_finish() or task_changed_status
|
245
|
+
|
246
|
+
def _print_profiler(self):
|
247
|
+
if self.profiler is None:
|
248
|
+
return
|
249
|
+
s = io.StringIO()
|
250
|
+
ps = pstats.Stats(self.profiler, stream=s).sort_stats(pstats.SortKey.CUMULATIVE)
|
251
|
+
ps.print_stats()
|
252
|
+
print(s.getvalue())
|
253
|
+
print(
|
254
|
+
f"Simulation run time, excluding loading the file: {self.simulation_time}."
|
255
|
+
)
|
256
|
+
|
257
|
+
def _rank_to_worker(self, ranks: List[NDSlice]) -> Generator[Worker, None, None]:
|
258
|
+
for rank in ranks:
|
259
|
+
for worker in rank:
|
260
|
+
yield self._workers[worker]
|
261
|
+
|
262
|
+
def _ndslice_to_worker_group(
|
263
|
+
self, ranks: List[NDSlice]
|
264
|
+
) -> Generator[WorkerGroup, None, None]:
|
265
|
+
# TODO: While we already use numpy array, this can still be quite slow
|
266
|
+
# because iterating ranks happens in Python. We should cache the results
|
267
|
+
# since we don't have that many different ranks combinations.
|
268
|
+
|
269
|
+
workers_list = [np.array(list(iter(ranks_))) for ranks_ in ranks]
|
270
|
+
workers = np.sort(np.concatenate(workers_list))
|
271
|
+
groups = np.bincount(self._worker_group_mapping[workers])
|
272
|
+
groups_iter = cast(Iterable, groups.flat)
|
273
|
+
all_matches = all(
|
274
|
+
len(self._worker_groups[group_id].workers) == element_count
|
275
|
+
for group_id, element_count in enumerate(groups_iter)
|
276
|
+
if element_count > 0
|
277
|
+
)
|
278
|
+
if all_matches:
|
279
|
+
for group_id in np.nonzero(groups)[0].flat:
|
280
|
+
yield self._worker_groups[group_id]
|
281
|
+
else:
|
282
|
+
new_groups = []
|
283
|
+
participate_groups = []
|
284
|
+
groups_iter = cast(Iterable, groups.flat)
|
285
|
+
for group_id, element_count in enumerate(groups_iter):
|
286
|
+
group = self._worker_groups[group_id]
|
287
|
+
new_groups.append(group)
|
288
|
+
if element_count > 0:
|
289
|
+
participate_groups.append(group)
|
290
|
+
|
291
|
+
not_participate_set = np.setdiff1d(
|
292
|
+
group.workers, workers, assume_unique=True
|
293
|
+
)
|
294
|
+
not_participate_group = group.split(not_participate_set)
|
295
|
+
new_groups.append(not_participate_group)
|
296
|
+
self._worker_group_mapping[not_participate_set] = (
|
297
|
+
len(new_groups) - 1
|
298
|
+
)
|
299
|
+
self._worker_groups = new_groups
|
300
|
+
for group in participate_groups:
|
301
|
+
yield group
|
302
|
+
|
303
|
+
def iter_workers(self, ranks: List[NDSlice]) -> Generator[Worker, None, None]:
|
304
|
+
if self._worker_groups:
|
305
|
+
yield from self._ndslice_to_worker_group(ranks)
|
306
|
+
else:
|
307
|
+
yield from self._rank_to_worker(ranks)
|
308
|
+
|
309
|
+
def _report(self, trace_path: str = "", memory_view_path: str = ""):
|
310
|
+
trace = []
|
311
|
+
|
312
|
+
exec_time = 0.0
|
313
|
+
max_mem = 0.0
|
314
|
+
|
315
|
+
# perfetto treads tid and pid as part of the same namespace
|
316
|
+
# (unlike chrome://trace). If they colleide then names will
|
317
|
+
# get clobbered, so we assign unique ids to each individual
|
318
|
+
# concept.
|
319
|
+
id_iter = iter(itertools.count(1))
|
320
|
+
|
321
|
+
@cache
|
322
|
+
def to_id(key):
|
323
|
+
return next(id_iter)
|
324
|
+
|
325
|
+
dump_process_name(trace, pid=0, name="Controller")
|
326
|
+
exec_time = max(
|
327
|
+
exec_time,
|
328
|
+
dump_thread_event_trace(
|
329
|
+
trace, self.events, pid=0, tid=0, name="Controller"
|
330
|
+
),
|
331
|
+
)
|
332
|
+
|
333
|
+
if isinstance(self.workers[0], WorkerGroup):
|
334
|
+
workers = sorted(self.workers, key=lambda g: min(g.workers))
|
335
|
+
else:
|
336
|
+
workers = self.workers
|
337
|
+
|
338
|
+
memory_viewer = MemoryViewer()
|
339
|
+
for worker_id, worker in enumerate(workers):
|
340
|
+
if not worker.events:
|
341
|
+
continue
|
342
|
+
pid = to_id(("worker", worker_id))
|
343
|
+
name = f"Device {worker_id}"
|
344
|
+
if isinstance(worker, WorkerGroup):
|
345
|
+
name = f"{name} {compress_workers_range(worker.workers)}"
|
346
|
+
dump_process_name(trace=trace, pid=pid, name=name)
|
347
|
+
# TODO: find a better tid for worker trace
|
348
|
+
exec_time = max(
|
349
|
+
dump_thread_event_trace(
|
350
|
+
trace, self.events, pid=pid, tid=32000, name=name
|
351
|
+
),
|
352
|
+
exec_time,
|
353
|
+
)
|
354
|
+
|
355
|
+
for stream_id, stream in worker.streams.items():
|
356
|
+
tid = to_id(("stream", worker_id, stream_id))
|
357
|
+
exec_time = max(
|
358
|
+
dump_thread_event_trace(
|
359
|
+
trace, stream.events, pid=pid, tid=tid, name=stream.name
|
360
|
+
),
|
361
|
+
exec_time,
|
362
|
+
)
|
363
|
+
|
364
|
+
# Get the memory order
|
365
|
+
curr_mem = 0
|
366
|
+
memory_viewer.next_device()
|
367
|
+
mem_events = {
|
368
|
+
stream_id: copy.copy(stream.memory.events)
|
369
|
+
for stream_id, stream in worker.streams.items()
|
370
|
+
}
|
371
|
+
while True:
|
372
|
+
min_ts = float("inf")
|
373
|
+
min_stream_events = None
|
374
|
+
min_stream_id = 0
|
375
|
+
for stream_id, events in mem_events.items():
|
376
|
+
if events and min_ts > events[0][0]:
|
377
|
+
min_ts = events[0][0]
|
378
|
+
min_stream_id, min_stream_events = stream_id, events
|
379
|
+
|
380
|
+
if min_stream_events is None:
|
381
|
+
break
|
382
|
+
|
383
|
+
mem_ts, mem_addr, mem_delta, traceback = heapq.heappop(
|
384
|
+
min_stream_events
|
385
|
+
)
|
386
|
+
curr_mem += mem_delta
|
387
|
+
max_mem = max(curr_mem, max_mem)
|
388
|
+
dump_memory_trace(
|
389
|
+
trace,
|
390
|
+
pid=pid,
|
391
|
+
memory=curr_mem,
|
392
|
+
ts=mem_ts,
|
393
|
+
name="memory",
|
394
|
+
)
|
395
|
+
memory_viewer.add_trace(mem_addr, mem_delta, min_stream_id, traceback)
|
396
|
+
|
397
|
+
if trace_path:
|
398
|
+
with open(trace_path, "w") as f:
|
399
|
+
json.dump({"traceEvents": trace}, f, indent=4)
|
400
|
+
|
401
|
+
memory_viewer.dump(memory_view_path)
|
402
|
+
|
403
|
+
if self.upload_trace:
|
404
|
+
upload_trace(os.path.abspath(f.name))
|
405
|
+
|
406
|
+
return exec_time / 10**6, max_mem / 10**6
|
407
|
+
|
408
|
+
def step(self, iter_count: int, dump_trace: bool = False) -> Tuple[float, float]:
|
409
|
+
"""
|
410
|
+
Step to the next iteration simulation and return the execution time in second
|
411
|
+
and peak memory usage in MB of this iteration.
|
412
|
+
"""
|
413
|
+
path = file_path_with_iter(self.trace_path, iter_count) if dump_trace else ""
|
414
|
+
directory = os.path.dirname(path)
|
415
|
+
memory_view_path = os.path.join(directory, "memory_view.pt")
|
416
|
+
memory_view_path = file_path_with_iter(memory_view_path, iter_count)
|
417
|
+
return self._report(path, memory_view_path)
|
418
|
+
|
419
|
+
def exit(self, iter_count: int, dump_trace: bool = False) -> Tuple[float, float]:
|
420
|
+
return self.step(iter_count, dump_trace)
|
421
|
+
|
422
|
+
@classmethod
|
423
|
+
def replay(cls, replay_file: str, profile: bool = False) -> None:
|
424
|
+
self = cls(replay_file=replay_file, profile=profile)
|
425
|
+
for command in cast(CommandHistory, self.command_history).commands:
|
426
|
+
if command.backend_command != "send":
|
427
|
+
continue
|
428
|
+
assert command.ranks is not None
|
429
|
+
self.send(command.timestamp, command.ranks, command.msg)
|
430
|
+
self._report()
|
431
|
+
self._print_profiler()
|
432
|
+
|
433
|
+
# Methods below simulate the methods of a real backend.
|
434
|
+
def send(self, now: int, ranks: List[NDSlice], msg) -> None:
|
435
|
+
logger.debug(f"Sending {msg} at {now}.")
|
436
|
+
self.current_traceback = traceback.extract_stack()[:-3]
|
437
|
+
command_name = type(msg).__name__
|
438
|
+
self.command_id += 1
|
439
|
+
# These two commands typically take a long time to execute on the
|
440
|
+
# controller side. Ignoring them will make the simulation trace easier
|
441
|
+
# to read.
|
442
|
+
if self.trace_mode.controller_enabled and command_name not in (
|
443
|
+
"CreateDeviceMesh",
|
444
|
+
"CreateStream",
|
445
|
+
):
|
446
|
+
if command_name != "CallFunction":
|
447
|
+
meta = [command_name] + META_VAL
|
448
|
+
else:
|
449
|
+
meta = [clean_name(msg.function.path)] + META_VAL
|
450
|
+
self.events.append(
|
451
|
+
TraceEvent(
|
452
|
+
self.now,
|
453
|
+
now - self.now,
|
454
|
+
meta,
|
455
|
+
self.command_id,
|
456
|
+
self.current_traceback,
|
457
|
+
)
|
458
|
+
)
|
459
|
+
|
460
|
+
if self.trace_mode.controller_enabled:
|
461
|
+
self.now = now
|
462
|
+
|
463
|
+
if not self.trace_mode.stream_enabled and command_name != "CommandGroup":
|
464
|
+
return
|
465
|
+
|
466
|
+
begin = time.monotonic()
|
467
|
+
if self.profiler:
|
468
|
+
self.profiler.enable()
|
469
|
+
|
470
|
+
attr = getattr(self, command_name, None)
|
471
|
+
if attr is None:
|
472
|
+
# Instead of silently ignoring the unimplemented method, a warning
|
473
|
+
# gives us the signal to review any newly implemented messages.
|
474
|
+
warnings.warn(
|
475
|
+
f"Simulator doesn't implement {type(msg).__name__} {msg}."
|
476
|
+
"This can cause incorrect simulation.",
|
477
|
+
stacklevel=2,
|
478
|
+
)
|
479
|
+
return
|
480
|
+
|
481
|
+
attr(ranks, msg)
|
482
|
+
self._run()
|
483
|
+
|
484
|
+
if self.profiler:
|
485
|
+
self.profiler.disable()
|
486
|
+
self.simulation_time += time.monotonic() - begin
|
487
|
+
|
488
|
+
def recvready(self):
|
489
|
+
raise NotImplementedError()
|
490
|
+
|
491
|
+
def propagate(self, msg: messages.SendValue) -> Any:
|
492
|
+
assert isinstance(msg.function, ResolvableFunction)
|
493
|
+
call_msg = messages.CallFunction(
|
494
|
+
ident=0,
|
495
|
+
result=None,
|
496
|
+
mutates=(),
|
497
|
+
function=msg.function,
|
498
|
+
args=msg.args,
|
499
|
+
kwargs=msg.kwargs,
|
500
|
+
stream=None, # pyre-ignore[6]
|
501
|
+
device_mesh=None, # pyre-ignore[6]
|
502
|
+
remote_process_groups=[],
|
503
|
+
)
|
504
|
+
ret = self.runtime_profiler.profile_cmd(call_msg, [0])
|
505
|
+
return ret[0][0]
|
506
|
+
|
507
|
+
def Exit(self, ranks: List[NDSlice], msg: messages.Exit):
|
508
|
+
return
|
509
|
+
|
510
|
+
def CallFunction(self, ranks: List[NDSlice], msg: messages.CallFunction):
|
511
|
+
inputs = get_ids(msg.args)
|
512
|
+
outputs = get_ids(msg.result)
|
513
|
+
if msg.mutates:
|
514
|
+
outputs.update(get_ids(msg.mutates))
|
515
|
+
self.fake_tensor_tracker.add(inputs)
|
516
|
+
self.fake_tensor_tracker.add(outputs)
|
517
|
+
stream = msg.stream.ref
|
518
|
+
for worker in self.iter_workers(ranks):
|
519
|
+
name = clean_name(str(msg.function))
|
520
|
+
worker.add_task(
|
521
|
+
Task(
|
522
|
+
inputs=list(inputs.keys()),
|
523
|
+
outputs=list(outputs.keys()),
|
524
|
+
command_id=self.command_id,
|
525
|
+
start_time=self.now,
|
526
|
+
runtime=self.runtime.get_runtime(msg),
|
527
|
+
meta=[name],
|
528
|
+
traceback=self.current_traceback,
|
529
|
+
),
|
530
|
+
self.now,
|
531
|
+
stream=stream,
|
532
|
+
)
|
533
|
+
|
534
|
+
def SendTensor(self, ranks: List[NDSlice], msg: messages.SendTensor):
|
535
|
+
# NOTE: The memory usage calculation for SendTensor may not be accurate when
|
536
|
+
# the source and destination ranks are the same. In such cases, memory usage
|
537
|
+
# should increase if the result tensor is modified. However, this depends on
|
538
|
+
# the specific implementation by the worker.
|
539
|
+
|
540
|
+
inputs = get_ids(msg.tensor)
|
541
|
+
outputs = get_ids(msg.result)
|
542
|
+
self.fake_tensor_tracker.add(inputs)
|
543
|
+
self.fake_tensor_tracker.add(outputs)
|
544
|
+
if msg.from_stream is not msg.to_stream:
|
545
|
+
raise NotImplementedError(
|
546
|
+
"simulator using to_mesh between different streams"
|
547
|
+
)
|
548
|
+
stream = msg.from_stream.ref
|
549
|
+
|
550
|
+
if msg.from_ranks == msg.to_ranks:
|
551
|
+
for worker in self.iter_workers([msg.from_ranks]):
|
552
|
+
worker.add_task(
|
553
|
+
Task(
|
554
|
+
inputs=list(inputs.keys()),
|
555
|
+
outputs=list(outputs.keys()),
|
556
|
+
command_id=self.command_id,
|
557
|
+
start_time=self.now,
|
558
|
+
runtime=self.runtime.get_runtime(msg),
|
559
|
+
meta=["SendTensor"],
|
560
|
+
traceback=self.current_traceback,
|
561
|
+
),
|
562
|
+
self.now,
|
563
|
+
stream=stream,
|
564
|
+
)
|
565
|
+
else:
|
566
|
+
collectives_pair = []
|
567
|
+
for worker in self.iter_workers([msg.from_ranks]):
|
568
|
+
collectives_pair.append([])
|
569
|
+
worker.add_task(
|
570
|
+
Task(
|
571
|
+
inputs=list(inputs.keys()),
|
572
|
+
outputs=[],
|
573
|
+
command_id=self.command_id,
|
574
|
+
start_time=self.now,
|
575
|
+
runtime=self.runtime.get_runtime(msg),
|
576
|
+
meta=["SendTensor"],
|
577
|
+
collectives=collectives_pair[-1],
|
578
|
+
traceback=self.current_traceback,
|
579
|
+
),
|
580
|
+
self.now,
|
581
|
+
stream=stream,
|
582
|
+
)
|
583
|
+
|
584
|
+
for worker, collectives in zip(
|
585
|
+
self.iter_workers([msg.to_ranks]), collectives_pair, strict=True
|
586
|
+
):
|
587
|
+
worker.add_task(
|
588
|
+
Task(
|
589
|
+
inputs=[],
|
590
|
+
outputs=list(outputs.keys()),
|
591
|
+
command_id=self.command_id,
|
592
|
+
start_time=self.now,
|
593
|
+
runtime=self.runtime.get_runtime(msg),
|
594
|
+
meta=["RecvTensor"],
|
595
|
+
collectives=collectives,
|
596
|
+
traceback=self.current_traceback,
|
597
|
+
),
|
598
|
+
self.now,
|
599
|
+
stream=stream,
|
600
|
+
)
|
601
|
+
|
602
|
+
def CommandGroup(self, ranks: List[NDSlice], msg: messages.CommandGroup):
|
603
|
+
for command in msg.commands:
|
604
|
+
self.send(self.now, ranks, command)
|
605
|
+
|
606
|
+
def CreateStream(self, ranks: List[NDSlice], msg: messages.CreateStream):
|
607
|
+
for worker in self.iter_workers(ranks):
|
608
|
+
assert msg.result.ref is not None
|
609
|
+
worker.create_stream(msg.result.ref, msg.result.name, default=msg.default)
|
610
|
+
|
611
|
+
def Reduce(self, ranks: List[NDSlice], msg: messages.Reduce):
|
612
|
+
inputs = get_ids(msg.local_tensor)
|
613
|
+
outputs = get_ids(msg.result)
|
614
|
+
self.fake_tensor_tracker.add(inputs)
|
615
|
+
self.fake_tensor_tracker.add(outputs)
|
616
|
+
|
617
|
+
# TODO: controller doesn't implement reduce and scatter yet so it is
|
618
|
+
# not possible to get such a request.
|
619
|
+
if msg.reduction == "stack":
|
620
|
+
if msg.scatter:
|
621
|
+
meta_str = "all_to_all"
|
622
|
+
else:
|
623
|
+
meta_str = "all_gather"
|
624
|
+
else:
|
625
|
+
if msg.scatter:
|
626
|
+
meta_str = "all_reduce"
|
627
|
+
else:
|
628
|
+
meta_str = "reduce_scatter"
|
629
|
+
|
630
|
+
meta = [meta_str]
|
631
|
+
stream = msg.stream.ref
|
632
|
+
collectives = []
|
633
|
+
for worker in self.iter_workers(ranks):
|
634
|
+
worker.add_task(
|
635
|
+
Task(
|
636
|
+
inputs=list(inputs.keys()),
|
637
|
+
outputs=list(outputs.keys()),
|
638
|
+
start_time=self.now,
|
639
|
+
runtime=self.runtime.get_runtime(msg),
|
640
|
+
meta=meta,
|
641
|
+
command_id=self.command_id,
|
642
|
+
collectives=collectives,
|
643
|
+
traceback=self.current_traceback,
|
644
|
+
),
|
645
|
+
self.now,
|
646
|
+
stream=stream,
|
647
|
+
)
|
648
|
+
|
649
|
+
def BorrowCreate(self, ranks: List[NDSlice], msg: messages.BorrowCreate):
|
650
|
+
inputs = get_ids(msg.tensor)
|
651
|
+
outputs = get_ids(msg.result)
|
652
|
+
self.fake_tensor_tracker.add(inputs)
|
653
|
+
self.fake_tensor_tracker.add(outputs, is_borrowed=True)
|
654
|
+
from_stream = msg.from_stream.ref
|
655
|
+
to_stream = msg.to_stream.ref
|
656
|
+
assert from_stream is not None
|
657
|
+
assert to_stream is not None
|
658
|
+
borrow = Borrow(
|
659
|
+
ident=msg.borrow,
|
660
|
+
tensor_src_id=cast(int, cast(DTensorRef, msg.tensor).ref),
|
661
|
+
tensor_dst_id=cast(int, cast(DTensorRef, msg.result).ref),
|
662
|
+
from_stream=from_stream,
|
663
|
+
to_stream=to_stream,
|
664
|
+
)
|
665
|
+
for worker in self.iter_workers(ranks):
|
666
|
+
recorded_task = worker.streams[from_stream].record_event()
|
667
|
+
# Note: there is no perfect way to set the start_time when the
|
668
|
+
# controller timing is disabled -- the wait event's start time
|
669
|
+
# may be very early like 0. This is because only the GPU events
|
670
|
+
# are tracked and there are no other GPU events except for
|
671
|
+
# communications and wait events on the communication stream.
|
672
|
+
# However, if we let the event's start_time to be based on the
|
673
|
+
# main stream's timing, we may lose other information.
|
674
|
+
start_time = self.now
|
675
|
+
wait_event = EventTask(
|
676
|
+
recorded_task=recorded_task,
|
677
|
+
event_stream=from_stream,
|
678
|
+
event_stream_name=worker.streams[from_stream].name,
|
679
|
+
wait_stream=to_stream,
|
680
|
+
wait_stream_name=worker.streams[to_stream].name,
|
681
|
+
command_id=self.command_id,
|
682
|
+
start_time=start_time,
|
683
|
+
borrow=borrow,
|
684
|
+
runtime=self.runtime.get_runtime("wait_event"),
|
685
|
+
traceback=self.current_traceback,
|
686
|
+
)
|
687
|
+
worker.borrow(wait_event, borrow)
|
688
|
+
|
689
|
+
def BorrowFirstUse(self, ranks: List[NDSlice], msg: messages.BorrowFirstUse):
|
690
|
+
for worker in self.iter_workers(ranks):
|
691
|
+
worker.borrow_first_use(msg.borrow, self.now)
|
692
|
+
|
693
|
+
def BorrowLastUse(self, ranks: List[NDSlice], msg: messages.BorrowLastUse):
|
694
|
+
for worker in self.iter_workers(ranks):
|
695
|
+
borrow_wait_event = worker.wait_events[msg.borrow]
|
696
|
+
recorded_task = worker.streams[borrow_wait_event.wait_stream].record_event()
|
697
|
+
last_use_event = EventTask(
|
698
|
+
recorded_task=recorded_task,
|
699
|
+
event_stream=borrow_wait_event.wait_stream,
|
700
|
+
event_stream_name=worker.streams[borrow_wait_event.wait_stream].name,
|
701
|
+
wait_stream=borrow_wait_event.event_stream,
|
702
|
+
wait_stream_name=worker.streams[borrow_wait_event.event_stream].name,
|
703
|
+
command_id=self.command_id,
|
704
|
+
start_time=self.now,
|
705
|
+
runtime=self.runtime.get_runtime("wait_event"),
|
706
|
+
traceback=self.current_traceback,
|
707
|
+
)
|
708
|
+
worker.borrow_last_use(last_use_event, msg.borrow)
|
709
|
+
|
710
|
+
def BorrowDrop(self, ranks: List[NDSlice], msg: messages.BorrowDrop):
|
711
|
+
for worker in self.iter_workers(ranks):
|
712
|
+
worker.borrow_drop(msg.borrow, self.now)
|
713
|
+
|
714
|
+
def DeleteRefs(self, ranks: List[NDSlice], msg: messages.DeleteRefs):
|
715
|
+
for worker in self.iter_workers(ranks):
|
716
|
+
worker.delete_refs(msg.refs, self.now)
|
717
|
+
|
718
|
+
def BackendNetworkInit(
|
719
|
+
self, ranks: List[NDSlice], msg: messages.BackendNetworkInit
|
720
|
+
):
|
721
|
+
return
|
722
|
+
|
723
|
+
def CreatePipe(self, ranks: List[NDSlice], msg: messages.CreatePipe):
|
724
|
+
# We don't have to track Pipe creation (yet).
|
725
|
+
return
|
726
|
+
|
727
|
+
def PipeRecv(self, ranks: List[NDSlice], msg: messages.PipeRecv):
|
728
|
+
outputs = get_ids(msg.result)
|
729
|
+
cpu_device = torch.device("cpu")
|
730
|
+
self.fake_tensor_tracker.add(outputs)
|
731
|
+
for fake in outputs.values():
|
732
|
+
if fake.device != cpu_device:
|
733
|
+
raise NotImplementedError("PipeRecv only support CPU device now.")
|
734
|
+
|
735
|
+
for worker in self.iter_workers(ranks):
|
736
|
+
for tensor_id in outputs.keys():
|
737
|
+
worker.add_cpu_tensor(tensor_id, self.now)
|
738
|
+
|
739
|
+
# Not doing anything for the following messages (yet).
|
740
|
+
def SendValue(self, ranks: List[NDSlice], msg: messages.SendValue):
|
741
|
+
return
|
742
|
+
|
743
|
+
def CreateDeviceMesh(self, ranks: List[NDSlice], msg: messages.CreateDeviceMesh):
|
744
|
+
return
|
745
|
+
|
746
|
+
def RequestStatus(self, ranks: List[NDSlice], msg: messages.RequestStatus):
|
747
|
+
return
|
748
|
+
|
749
|
+
def SplitComm(self, ranks: List[NDSlice], msg: messages.SplitComm):
|
750
|
+
return
|
751
|
+
|
752
|
+
def BackendNetworkPointToPointInit(
|
753
|
+
self, ranks: List[NDSlice], msg: messages.BackendNetworkPointToPointInit
|
754
|
+
):
|
755
|
+
return
|
756
|
+
|
757
|
+
|
758
|
+
class SimulatorController(MockController):
|
759
|
+
"""
|
760
|
+
A backend that simulates the behavior of the ProcessBackend. It can also be
|
761
|
+
used to only record the commands sent to it, and then replay them later using
|
762
|
+
the `Simulator` class.
|
763
|
+
|
764
|
+
Args:
|
765
|
+
world_size (int): The number of workers in the simulation.
|
766
|
+
grph_per_host (int): The number of GPUs per machine.
|
767
|
+
"""
|
768
|
+
|
769
|
+
def __init__(
|
770
|
+
self,
|
771
|
+
world_size: int,
|
772
|
+
gpu_per_host: int,
|
773
|
+
*,
|
774
|
+
simulate_mode: SimulatorBackendMode = SimulatorBackendMode.SIMULATE,
|
775
|
+
trace_mode: SimulatorTraceMode = SimulatorTraceMode.EVERYTHING,
|
776
|
+
upload_trace: bool = False,
|
777
|
+
trace_path: str = "trace.json",
|
778
|
+
command_history_path: str = "command_history.pkl",
|
779
|
+
group_workers: bool = False,
|
780
|
+
ir: Optional[IRGraph] = None,
|
781
|
+
):
|
782
|
+
if len(DTensorRef.created) != 0:
|
783
|
+
DTensorRef.created.clear()
|
784
|
+
warnings.warn(
|
785
|
+
"clearing old DTensorRef information. TODO: support multiple simulator backends in the same process.",
|
786
|
+
stacklevel=1,
|
787
|
+
)
|
788
|
+
super().__init__(world_size, verbose=False)
|
789
|
+
|
790
|
+
self._gpu_per_host = gpu_per_host
|
791
|
+
self.timestamp_base = time.monotonic_ns()
|
792
|
+
self.worker_commands = defaultdict(list)
|
793
|
+
self.simulator: Optional[Simulator] = None
|
794
|
+
self.command_history: Optional[CommandHistory] = None
|
795
|
+
self.iter = 0
|
796
|
+
self.mode = simulate_mode
|
797
|
+
self.exception = False
|
798
|
+
self.ir = ir
|
799
|
+
|
800
|
+
if self.mode.command_history_enabled:
|
801
|
+
self.command_history = CommandHistory(
|
802
|
+
world_size, file_path=os.path.abspath(command_history_path)
|
803
|
+
)
|
804
|
+
|
805
|
+
if self.mode.simulation_enabled:
|
806
|
+
self.simulator = Simulator(
|
807
|
+
world_size=world_size,
|
808
|
+
trace_mode=trace_mode,
|
809
|
+
upload_trace=upload_trace,
|
810
|
+
trace_path=trace_path,
|
811
|
+
group_workers=group_workers,
|
812
|
+
)
|
813
|
+
|
814
|
+
@property
|
815
|
+
def gpu_per_host(self) -> int:
|
816
|
+
return self._gpu_per_host
|
817
|
+
|
818
|
+
def cleanup_simulation(self):
|
819
|
+
DTensorRef.created.clear()
|
820
|
+
|
821
|
+
def __del__(self):
|
822
|
+
self.cleanup_simulation()
|
823
|
+
|
824
|
+
def step(self) -> Tuple[float, float]:
|
825
|
+
"""
|
826
|
+
Step to the next iteration simulation and return the execution time in second
|
827
|
+
and peak memory usage in MB of this iteration. If the simulation mode is
|
828
|
+
COMMAND_HISTORY, then the return time and memory will be 0.0 as the backend
|
829
|
+
only records the commands.
|
830
|
+
"""
|
831
|
+
if self.command_history:
|
832
|
+
self.command_history.step(self.iter)
|
833
|
+
|
834
|
+
if self.simulator:
|
835
|
+
exec_time, max_mem = self.simulator.step(
|
836
|
+
self.iter,
|
837
|
+
dump_trace=(
|
838
|
+
self.mode != SimulatorBackendMode.SIMULATE_WITH_REPORT_ONLY
|
839
|
+
),
|
840
|
+
)
|
841
|
+
else:
|
842
|
+
exec_time = max_mem = 0.0
|
843
|
+
|
844
|
+
self.iter += 1
|
845
|
+
|
846
|
+
return exec_time, max_mem
|
847
|
+
|
848
|
+
def _send(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple) -> None:
|
849
|
+
now = time.monotonic_ns() - self.timestamp_base
|
850
|
+
|
851
|
+
if isinstance(ranks, NDSlice):
|
852
|
+
ranks = [ranks]
|
853
|
+
|
854
|
+
if self.command_history:
|
855
|
+
command = self.command_history.record(
|
856
|
+
now,
|
857
|
+
"send",
|
858
|
+
self.simulator.command_id if self.simulator else 0,
|
859
|
+
self.simulator.current_traceback if self.simulator else (),
|
860
|
+
ranks,
|
861
|
+
msg,
|
862
|
+
None,
|
863
|
+
self.ir,
|
864
|
+
)
|
865
|
+
else:
|
866
|
+
command = CommandHistory.convert_command(
|
867
|
+
now,
|
868
|
+
"send",
|
869
|
+
self.simulator.command_id if self.simulator else 0,
|
870
|
+
self.simulator.current_traceback if self.simulator else (),
|
871
|
+
ranks,
|
872
|
+
msg,
|
873
|
+
None,
|
874
|
+
self.ir,
|
875
|
+
)
|
876
|
+
|
877
|
+
if self.simulator:
|
878
|
+
self.simulator.send(now, cast(List[NDSlice], command.ranks), command.msg)
|
879
|
+
|
880
|
+
if type(msg).__name__ == "SendValue":
|
881
|
+
msg = cast(messages.SendValue, msg)
|
882
|
+
if (
|
883
|
+
isinstance(msg.function, ResolvableFunctionFromPath)
|
884
|
+
and msg.function.path == "monarch.cached_remote_function._propagate"
|
885
|
+
):
|
886
|
+
assert self.simulator is not None
|
887
|
+
assert msg.destination is None
|
888
|
+
ret = self.simulator.propagate(msg)
|
889
|
+
for _ in iter_ranks(ranks):
|
890
|
+
self.history.future_completed(msg.ident, ret)
|
891
|
+
return
|
892
|
+
|
893
|
+
if type(msg).__name__ not in ("CommandGroup",):
|
894
|
+
return super().send(ranks, msg)
|
895
|
+
|
896
|
+
def send(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple) -> None:
|
897
|
+
if self.exception:
|
898
|
+
return
|
899
|
+
|
900
|
+
try:
|
901
|
+
self._send(ranks, msg)
|
902
|
+
except Exception as e:
|
903
|
+
self.exception = True
|
904
|
+
# TODO: Should we also call simulator.exit() and cleanup?
|
905
|
+
self.responses.append(
|
906
|
+
MessageResult(
|
907
|
+
seq=0, # will not be used
|
908
|
+
result=None,
|
909
|
+
error=DeviceException(
|
910
|
+
e,
|
911
|
+
traceback.extract_tb(e.__traceback__),
|
912
|
+
ActorId.from_string("unknown[0].unknown[0]"),
|
913
|
+
message="Simulator has an internal error.",
|
914
|
+
),
|
915
|
+
)
|
916
|
+
)
|
917
|
+
|
918
|
+
def next_message(
|
919
|
+
self, timeout: Optional[float]
|
920
|
+
) -> Optional[MessageResult | LogMessage]:
|
921
|
+
now = time.monotonic_ns() - self.timestamp_base
|
922
|
+
|
923
|
+
if self.command_history:
|
924
|
+
self.command_history.record(
|
925
|
+
now,
|
926
|
+
"next_message",
|
927
|
+
self.simulator.command_id if self.simulator else 0,
|
928
|
+
self.simulator.current_traceback if self.simulator else (),
|
929
|
+
None,
|
930
|
+
None,
|
931
|
+
timeout,
|
932
|
+
self.ir,
|
933
|
+
)
|
934
|
+
|
935
|
+
return super().next_message(timeout)
|
936
|
+
|
937
|
+
def Exit(self, ranks: Union[NDSlice, List[NDSlice]], msg: messages.Exit):
|
938
|
+
if self.command_history:
|
939
|
+
self.command_history.dump(self.command_history.file_path)
|
940
|
+
if self.simulator:
|
941
|
+
self.simulator.exit(
|
942
|
+
self.iter,
|
943
|
+
dump_trace=(
|
944
|
+
self.mode != SimulatorBackendMode.SIMULATE_WITH_REPORT_ONLY
|
945
|
+
),
|
946
|
+
)
|
947
|
+
self.cleanup_simulation()
|
948
|
+
|
949
|
+
return super().Exit(ranks, msg)
|
950
|
+
|
951
|
+
|
952
|
+
class SimulatorInterface:
|
953
|
+
"""
|
954
|
+
API for interactive with simulator.
|
955
|
+
sim.mesh retrieves the simulator mesh.
|
956
|
+
"""
|
957
|
+
|
958
|
+
def __init__(
|
959
|
+
self, mesh: "DeviceMesh", ctrl: "SimulatorController", ir: Optional["IRGraph"]
|
960
|
+
):
|
961
|
+
self.mesh = mesh
|
962
|
+
self._ctrl = ctrl
|
963
|
+
self._ir = ir
|
964
|
+
|
965
|
+
def upload(self):
|
966
|
+
sim = self._ctrl.simulator
|
967
|
+
old, sim.upload_trace = sim.upload_trace, True
|
968
|
+
try:
|
969
|
+
self._ctrl.step()
|
970
|
+
finally:
|
971
|
+
sim.upload_trace = old
|
972
|
+
|
973
|
+
def _display_html(self, html_code):
|
974
|
+
import base64
|
975
|
+
|
976
|
+
from IPython.display import display, Javascript
|
977
|
+
|
978
|
+
# Encode the HTML code in base64 to be passed to JavaScript, then
|
979
|
+
# decode from base64 inside JavaScript. This is a hack to get this to
|
980
|
+
# work properly in Bento.
|
981
|
+
b64_html = base64.b64encode(html_code.encode("utf-8")).decode("utf-8")
|
982
|
+
|
983
|
+
# JavaScript to open a new window and write the HTML
|
984
|
+
js_code = f"""
|
985
|
+
var newWindow = window.open("", "_blank");
|
986
|
+
newWindow.document.write(atob("{b64_html}"));
|
987
|
+
newWindow.document.close();
|
988
|
+
window.open("").close()
|
989
|
+
"""
|
990
|
+
|
991
|
+
# Display the JavaScript
|
992
|
+
display(Javascript(js_code))
|
993
|
+
|
994
|
+
def _run_trace2html(self, json_filename, html_filename):
|
995
|
+
# Call the trace2html script to convert JSON to HTML
|
996
|
+
for trace2html in [
|
997
|
+
"trace2html",
|
998
|
+
Path.home() / "fbsource/third-party/catapult/tracing/bin/trace2html",
|
999
|
+
]:
|
1000
|
+
try:
|
1001
|
+
subprocess.run(
|
1002
|
+
[trace2html, json_filename, "--output", html_filename], check=True
|
1003
|
+
)
|
1004
|
+
return
|
1005
|
+
except FileNotFoundError:
|
1006
|
+
pass
|
1007
|
+
raise RuntimeError(
|
1008
|
+
"trace2html not found. `git clone https://chromium.googlesource.com/catapult` and add catapult/tracing/bin to PATH"
|
1009
|
+
)
|
1010
|
+
|
1011
|
+
def _display_trace(self, json_filename, pkl_filename):
|
1012
|
+
# Create temporary files for JSON and HTML
|
1013
|
+
with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as html_file:
|
1014
|
+
html_filename = html_file.name
|
1015
|
+
|
1016
|
+
self._run_trace2html(json_filename, html_filename)
|
1017
|
+
|
1018
|
+
with open(pkl_filename, "rb") as pfile:
|
1019
|
+
# @lint-ignore PYTHONPICKLEISBAD
|
1020
|
+
memory_data = pickle.load(pfile)
|
1021
|
+
import torch.cuda._memory_viz as viz
|
1022
|
+
|
1023
|
+
self._display_html(viz.trace_plot(memory_data))
|
1024
|
+
|
1025
|
+
# Read the HTML content from the temporary HTML file
|
1026
|
+
with open(html_filename, "r") as file:
|
1027
|
+
html_code = file.read()
|
1028
|
+
self._display_html(html_code)
|
1029
|
+
|
1030
|
+
def display(self):
|
1031
|
+
"""
|
1032
|
+
From a jupyter notebook, open the trace report as a new window in your browser.
|
1033
|
+
Watch for popup blockers.
|
1034
|
+
"""
|
1035
|
+
sim = self._ctrl.simulator
|
1036
|
+
with tempfile.NamedTemporaryFile(
|
1037
|
+
suffix=".json", delete=False
|
1038
|
+
) as json_file, tempfile.NamedTemporaryFile(
|
1039
|
+
suffix=".pkl", delete=False
|
1040
|
+
) as memory_pkl:
|
1041
|
+
sim._report(trace_path=json_file.name, memory_view_path=memory_pkl.name)
|
1042
|
+
self._display_trace(json_file.name, memory_pkl.name)
|
1043
|
+
|
1044
|
+
def export_ir(self, ir_path: str) -> None:
|
1045
|
+
"""
|
1046
|
+
Exports the simulator internal representation (IR) to a file.
|
1047
|
+
Args:
|
1048
|
+
ir_path (str): The path to the file where the IR will be exported.
|
1049
|
+
"""
|
1050
|
+
assert self._ir is not None, "Simulator IR does not exist!"
|
1051
|
+
with open(ir_path, "wb") as f:
|
1052
|
+
torch.save(self._ir, f)
|