torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,424 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import itertools
|
9
|
+
import traceback
|
10
|
+
import warnings
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from typing import List, NamedTuple, Optional, Sequence
|
13
|
+
|
14
|
+
import torch
|
15
|
+
|
16
|
+
from monarch.common import messages
|
17
|
+
from monarch.common.shape import NDSlice
|
18
|
+
from monarch.simulator.ir import IRGraph
|
19
|
+
from monarch.simulator.tensor import DTensorRef
|
20
|
+
from monarch.simulator.utils import clean_name, file_path_with_iter
|
21
|
+
|
22
|
+
from torch.utils._pytree import tree_map
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class Command:
|
27
|
+
timestamp: int
|
28
|
+
# Either "send" or "recvready" now.
|
29
|
+
backend_command: str
|
30
|
+
# "send" arguments
|
31
|
+
ranks: Optional[List[NDSlice]] = None
|
32
|
+
msg: Optional[NamedTuple] = None
|
33
|
+
# "recvready" arguments
|
34
|
+
timeout: Optional[float] = None
|
35
|
+
|
36
|
+
|
37
|
+
class CommandHistory:
|
38
|
+
"""
|
39
|
+
A class to record commands sent to the SimulatorBackend. The class can be
|
40
|
+
later be used for replaying the recorded commands.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
maxlen (int): The maximum number of commands to record. Defaults to 10_000_000.
|
44
|
+
"""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
world_size: int,
|
49
|
+
*,
|
50
|
+
maxlen: int = 10_000_000,
|
51
|
+
file_path: str = "command_history.pt",
|
52
|
+
) -> None:
|
53
|
+
self.world_size = world_size
|
54
|
+
self.maxlen = maxlen
|
55
|
+
self.commands: List[Command] = []
|
56
|
+
self.warn_once: bool = False
|
57
|
+
self.file_path = file_path
|
58
|
+
|
59
|
+
def __del__(self):
|
60
|
+
DTensorRef.created.clear()
|
61
|
+
|
62
|
+
def record(
|
63
|
+
self,
|
64
|
+
now: int,
|
65
|
+
backend_command: str,
|
66
|
+
command_id: int,
|
67
|
+
traceback: Sequence[traceback.FrameSummary] = (),
|
68
|
+
ranks: Optional[List[NDSlice]] = None,
|
69
|
+
msg: Optional[NamedTuple] = None,
|
70
|
+
timeout: Optional[float] = None,
|
71
|
+
ir: Optional[IRGraph] = None,
|
72
|
+
) -> Command:
|
73
|
+
command = self.convert_command(
|
74
|
+
now, backend_command, command_id, traceback, ranks, msg, timeout, ir
|
75
|
+
)
|
76
|
+
if len(self.commands) < self.maxlen:
|
77
|
+
self.commands.append(command)
|
78
|
+
elif not self.warn_once:
|
79
|
+
warnings.warn(
|
80
|
+
(
|
81
|
+
f"CommandHistory's maxlen is {self.maxlen}, and we already "
|
82
|
+
" execeed the limit. The rest commands will not be recorded."
|
83
|
+
),
|
84
|
+
stacklevel=2,
|
85
|
+
)
|
86
|
+
self.warn_once = True
|
87
|
+
return command
|
88
|
+
|
89
|
+
@staticmethod
|
90
|
+
def convert_command(
|
91
|
+
now: int,
|
92
|
+
backend_command: str,
|
93
|
+
command_id: int,
|
94
|
+
traceback: Sequence[traceback.FrameSummary] = (),
|
95
|
+
ranks: Optional[List[NDSlice]] = None,
|
96
|
+
msg: Optional[NamedTuple] = None,
|
97
|
+
timeout: Optional[float] = None,
|
98
|
+
ir: Optional[IRGraph] = None,
|
99
|
+
) -> Command:
|
100
|
+
msg = CommandHistory._convert_command(msg)
|
101
|
+
|
102
|
+
if ir:
|
103
|
+
if isinstance(msg, messages.CommandGroup):
|
104
|
+
for i, command in enumerate(msg.commands):
|
105
|
+
CommandHistory._maybe_insert_ir(
|
106
|
+
ir, command_id + i + 1, traceback, ranks, command
|
107
|
+
) # i starts from 0, so command_id + i + 1
|
108
|
+
else:
|
109
|
+
CommandHistory._maybe_insert_ir(ir, command_id, traceback, ranks, msg)
|
110
|
+
return Command(
|
111
|
+
timestamp=now,
|
112
|
+
backend_command=backend_command,
|
113
|
+
ranks=ranks,
|
114
|
+
msg=msg,
|
115
|
+
timeout=timeout,
|
116
|
+
)
|
117
|
+
|
118
|
+
@staticmethod
|
119
|
+
def convert_msg(msg):
|
120
|
+
def _convert_arg(v):
|
121
|
+
if isinstance(v, torch.Tensor):
|
122
|
+
return DTensorRef.from_ref(v)
|
123
|
+
return v
|
124
|
+
|
125
|
+
name = type(msg).__name__
|
126
|
+
match name:
|
127
|
+
case "CallFunction":
|
128
|
+
args, kwargs, mutates, result = tree_map(
|
129
|
+
_convert_arg, (msg.args, msg.kwargs, msg.mutates, msg.result)
|
130
|
+
)
|
131
|
+
msg = msg._replace(
|
132
|
+
args=args, kwargs=kwargs, mutates=mutates, result=result
|
133
|
+
)
|
134
|
+
case "SendTensor":
|
135
|
+
msg = msg._replace(
|
136
|
+
tensor=DTensorRef.from_ref(msg.tensor),
|
137
|
+
result=DTensorRef.from_ref(msg.result),
|
138
|
+
)
|
139
|
+
case "Reduce":
|
140
|
+
msg = msg._replace(
|
141
|
+
local_tensor=DTensorRef.from_ref(msg.local_tensor),
|
142
|
+
result=DTensorRef.from_ref(msg.result),
|
143
|
+
)
|
144
|
+
case "BorrowCreate":
|
145
|
+
msg = msg._replace(
|
146
|
+
result=DTensorRef.from_ref(msg.result),
|
147
|
+
tensor=DTensorRef.from_ref(msg.tensor),
|
148
|
+
)
|
149
|
+
|
150
|
+
return msg
|
151
|
+
|
152
|
+
@staticmethod
|
153
|
+
def _convert_command(msg):
|
154
|
+
if isinstance(msg, messages.CommandGroup):
|
155
|
+
for idx, command in enumerate(msg.commands):
|
156
|
+
msg.commands[idx] = CommandHistory.convert_msg(command)
|
157
|
+
return msg
|
158
|
+
else:
|
159
|
+
return CommandHistory.convert_msg(msg)
|
160
|
+
|
161
|
+
# TODO: Add function to simplify repeated modifications to ir
|
162
|
+
@staticmethod
|
163
|
+
def _maybe_insert_ir(
|
164
|
+
ir: IRGraph,
|
165
|
+
command_id: int,
|
166
|
+
tb: Sequence[traceback.FrameSummary] = (),
|
167
|
+
ranks: Optional[List[NDSlice]] = None,
|
168
|
+
msg: Optional[NamedTuple] = None,
|
169
|
+
) -> None:
|
170
|
+
# Process tensor results and update IR
|
171
|
+
def _process_tensor_results(
|
172
|
+
result,
|
173
|
+
worker_rank,
|
174
|
+
stream_name,
|
175
|
+
command_id,
|
176
|
+
mutate=False,
|
177
|
+
borrow_src_tensor_ref=None,
|
178
|
+
):
|
179
|
+
if result is not None:
|
180
|
+
results_list = result if isinstance(result, list) else [result]
|
181
|
+
for tensor_ref in results_list:
|
182
|
+
fake = tensor_ref._fake
|
183
|
+
ir.update_tensor(
|
184
|
+
tensor_ref._storage_id,
|
185
|
+
tensor_ref.ref,
|
186
|
+
fake.dtype,
|
187
|
+
tuple(fake.shape),
|
188
|
+
worker_rank,
|
189
|
+
stream_name,
|
190
|
+
command_id,
|
191
|
+
mutate=mutate,
|
192
|
+
borrow_src_tensor_ref=borrow_src_tensor_ref,
|
193
|
+
tensor_size=tensor_ref._size,
|
194
|
+
)
|
195
|
+
|
196
|
+
assert msg is not None
|
197
|
+
stream_name = src_stream_name = dst_stream_name = ""
|
198
|
+
flattened_ranks = list(itertools.chain.from_iterable(ranks or []))
|
199
|
+
command_type = ""
|
200
|
+
devices = []
|
201
|
+
control_dependencies = []
|
202
|
+
dag_item_type = type(msg).__name__
|
203
|
+
result = getattr(msg, "result", None)
|
204
|
+
for worker_rank in flattened_ranks:
|
205
|
+
match dag_item_type:
|
206
|
+
case "CallFunction":
|
207
|
+
stream_name = getattr(msg, "stream", None).name
|
208
|
+
command_type = (
|
209
|
+
f"CallFunction: {clean_name(str(getattr(msg, 'function', '')))}"
|
210
|
+
)
|
211
|
+
devices = [worker_rank]
|
212
|
+
msg_args = getattr(msg, "args", None)
|
213
|
+
if msg_args is not None:
|
214
|
+
for arg in msg_args:
|
215
|
+
if isinstance(arg, DTensorRef):
|
216
|
+
_process_tensor_results(
|
217
|
+
arg, worker_rank, stream_name, command_id
|
218
|
+
)
|
219
|
+
msg_mutates = getattr(msg, "mutates", None)
|
220
|
+
if msg_mutates is not None:
|
221
|
+
for mutate_src in msg_mutates:
|
222
|
+
if isinstance(mutate_src, DTensorRef) or (
|
223
|
+
isinstance(mutate_src, list)
|
224
|
+
and all(isinstance(m, DTensorRef) for m in mutate_src)
|
225
|
+
):
|
226
|
+
mutates_list = (
|
227
|
+
mutate_src
|
228
|
+
if isinstance(mutate_src, list)
|
229
|
+
else [mutate_src]
|
230
|
+
)
|
231
|
+
_process_tensor_results(
|
232
|
+
mutates_list,
|
233
|
+
worker_rank,
|
234
|
+
stream_name,
|
235
|
+
command_id,
|
236
|
+
mutate=True,
|
237
|
+
)
|
238
|
+
_process_tensor_results(
|
239
|
+
result,
|
240
|
+
worker_rank,
|
241
|
+
stream_name,
|
242
|
+
command_id,
|
243
|
+
)
|
244
|
+
|
245
|
+
case "Reduce":
|
246
|
+
stream_name = getattr(msg, "stream", None).name
|
247
|
+
reduction = getattr(msg, "reduction", None)
|
248
|
+
scatter = getattr(msg, "scatter", False)
|
249
|
+
if reduction == "stack":
|
250
|
+
if scatter:
|
251
|
+
reduce_type = "all_to_all"
|
252
|
+
else:
|
253
|
+
reduce_type = "all_gather"
|
254
|
+
else:
|
255
|
+
if scatter:
|
256
|
+
reduce_type = "all_reduce"
|
257
|
+
else:
|
258
|
+
reduce_type = "reduce_scatter"
|
259
|
+
command_type = f"Reduce: {reduce_type}: {result.ref}" # use result.ref as unique Reduce id
|
260
|
+
devices = flattened_ranks
|
261
|
+
_process_tensor_results(
|
262
|
+
result, worker_rank, stream_name, command_id
|
263
|
+
)
|
264
|
+
case "BorrowCreate":
|
265
|
+
borrow_id = getattr(msg, "borrow", None)
|
266
|
+
borrow_src_tensor_ref = getattr(msg, "tensor", None).ref
|
267
|
+
stream_name = src_stream_name = getattr(
|
268
|
+
msg, "from_stream", None
|
269
|
+
).name
|
270
|
+
dst_stream_name = getattr(msg, "to_stream", None).name
|
271
|
+
|
272
|
+
command_type = f"BorrowCreate: {borrow_id}"
|
273
|
+
devices = [worker_rank]
|
274
|
+
ir.add_borrow(
|
275
|
+
borrow_id,
|
276
|
+
worker_rank,
|
277
|
+
src_stream_name,
|
278
|
+
dst_stream_name,
|
279
|
+
command_id,
|
280
|
+
)
|
281
|
+
_process_tensor_results(
|
282
|
+
result,
|
283
|
+
worker_rank,
|
284
|
+
dst_stream_name,
|
285
|
+
command_id,
|
286
|
+
borrow_src_tensor_ref=borrow_src_tensor_ref,
|
287
|
+
)
|
288
|
+
case "BorrowFirstUse":
|
289
|
+
borrow_id = getattr(msg, "borrow", None)
|
290
|
+
stream_name = ir._control.borrows_info[borrow_id].dst_stream_name
|
291
|
+
command_type = f"BorrowFirstUse: {borrow_id}"
|
292
|
+
devices = [worker_rank]
|
293
|
+
control_dependencies = [
|
294
|
+
ir._control.borrows_info[borrow_id].create_id
|
295
|
+
]
|
296
|
+
ir._control.borrows_info[borrow_id].firstuse_id = command_id
|
297
|
+
case "BorrowLastUse":
|
298
|
+
borrow_id = getattr(msg, "borrow", None)
|
299
|
+
stream_name = src_stream_name = ir._control.borrows_info[
|
300
|
+
borrow_id
|
301
|
+
].dst_stream_name
|
302
|
+
dst_stream_name = ir._control.borrows_info[
|
303
|
+
borrow_id
|
304
|
+
].src_stream_name
|
305
|
+
command_type = f"BorrowLastUse: {borrow_id}"
|
306
|
+
devices = [worker_rank]
|
307
|
+
ir._control.borrows_info[borrow_id].lastuse_id = command_id
|
308
|
+
case "BorrowDrop":
|
309
|
+
borrow_id = getattr(msg, "borrow", None)
|
310
|
+
stream_name = ir._control.borrows_info[borrow_id].src_stream_name
|
311
|
+
command_type = f"BorrowDrop: {borrow_id}"
|
312
|
+
devices = [worker_rank]
|
313
|
+
control_dependencies = [
|
314
|
+
ir._control.borrows_info[borrow_id].lastuse_id
|
315
|
+
]
|
316
|
+
ir._control.borrows_info[borrow_id].drop_id = command_id
|
317
|
+
|
318
|
+
if dag_item_type in [
|
319
|
+
"CallFunction",
|
320
|
+
"Reduce",
|
321
|
+
"BorrowCreate",
|
322
|
+
"BorrowFirstUse",
|
323
|
+
"BorrowLastUse",
|
324
|
+
"BorrowDrop",
|
325
|
+
]:
|
326
|
+
ir.insert_node(
|
327
|
+
worker_rank,
|
328
|
+
stream_name,
|
329
|
+
command_id,
|
330
|
+
command_type,
|
331
|
+
devices,
|
332
|
+
control_dependencies,
|
333
|
+
traceback.format_list(tb),
|
334
|
+
)
|
335
|
+
|
336
|
+
assert ranks is not None
|
337
|
+
if dag_item_type == "SendTensor" and len(ranks) == 2:
|
338
|
+
src_flattened_ranks = list(
|
339
|
+
itertools.chain.from_iterable([ranks[0]])
|
340
|
+
) # for SendTensor, ranks[0] == source ranks
|
341
|
+
dst_flattened_ranks = list(
|
342
|
+
itertools.chain.from_iterable([ranks[1]])
|
343
|
+
) # for SendTensor, ranks[1] == destination ranks
|
344
|
+
|
345
|
+
src_stream_name = getattr(msg, "from_stream", None).name
|
346
|
+
dst_stream_name = getattr(msg, "to_stream", None).name
|
347
|
+
|
348
|
+
# Create sets of (rank, stream) pairs for source and destination ranks
|
349
|
+
src_rank_stream_pairs = {
|
350
|
+
(rank, src_stream_name) for rank in src_flattened_ranks
|
351
|
+
}
|
352
|
+
dst_rank_stream_pairs = {
|
353
|
+
(rank, dst_stream_name) for rank in dst_flattened_ranks
|
354
|
+
}
|
355
|
+
rank_stream_pairs = (
|
356
|
+
src_rank_stream_pairs | dst_rank_stream_pairs
|
357
|
+
) # find the union of the two sets
|
358
|
+
command_type = f"SendTensor: {result.ref if result else None}"
|
359
|
+
devices = flattened_ranks
|
360
|
+
control_dependencies = flattened_ranks
|
361
|
+
for rank, stream_name in rank_stream_pairs:
|
362
|
+
ir.insert_node(
|
363
|
+
rank,
|
364
|
+
stream_name,
|
365
|
+
command_id,
|
366
|
+
command_type,
|
367
|
+
devices,
|
368
|
+
control_dependencies,
|
369
|
+
traceback.format_list(tb),
|
370
|
+
)
|
371
|
+
src_tensor = getattr(msg, "tensor", None)
|
372
|
+
if src_tensor is not None:
|
373
|
+
src_tensors_list = (
|
374
|
+
src_tensor if isinstance(src_tensor, list) else [src_tensor]
|
375
|
+
)
|
376
|
+
for src_t in src_tensors_list:
|
377
|
+
for rank, src_stream_name in src_rank_stream_pairs:
|
378
|
+
_process_tensor_results(
|
379
|
+
src_t, rank, src_stream_name, command_id
|
380
|
+
)
|
381
|
+
if result is not None:
|
382
|
+
results_list = result if isinstance(result, list) else [result]
|
383
|
+
for res in results_list:
|
384
|
+
ir.add_sendtensor(
|
385
|
+
res.ref,
|
386
|
+
src_flattened_ranks,
|
387
|
+
src_stream_name,
|
388
|
+
dst_flattened_ranks,
|
389
|
+
dst_stream_name,
|
390
|
+
tuple(res._fake.size()),
|
391
|
+
)
|
392
|
+
for rank, dst_stream_name in dst_rank_stream_pairs:
|
393
|
+
_process_tensor_results(res, rank, dst_stream_name, command_id)
|
394
|
+
|
395
|
+
if dag_item_type == "DeleteRefs":
|
396
|
+
refs = getattr(msg, "refs", None)
|
397
|
+
for ref in refs:
|
398
|
+
stream_name = ir._data.tensorref_to_stream[ref]
|
399
|
+
# Do not call _insert_node() since we do not need DeleteRefs for the control DAG
|
400
|
+
ir.delete_tensor(
|
401
|
+
ref,
|
402
|
+
flattened_ranks,
|
403
|
+
stream_name,
|
404
|
+
command_id,
|
405
|
+
)
|
406
|
+
|
407
|
+
def step(self, iter_count: int, dump: bool = False) -> None:
|
408
|
+
if dump:
|
409
|
+
self.dump(file_path_with_iter(self.file_path, iter_count))
|
410
|
+
|
411
|
+
self.commands.clear()
|
412
|
+
|
413
|
+
def dump(self, file_path: str) -> None:
|
414
|
+
with open(file_path, "wb") as f:
|
415
|
+
torch.save({"world_size": self.world_size, "commands": self.commands}, f)
|
416
|
+
|
417
|
+
@classmethod
|
418
|
+
def load(cls, filename: str) -> "CommandHistory":
|
419
|
+
with open(filename, "rb") as f:
|
420
|
+
states = torch.load(f, weights_only=False)
|
421
|
+
self = cls(states["world_size"])
|
422
|
+
self.commands = states["commands"]
|
423
|
+
|
424
|
+
return self
|
@@ -0,0 +1,21 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import contextlib
|
9
|
+
|
10
|
+
META_VAL = []
|
11
|
+
|
12
|
+
|
13
|
+
@contextlib.contextmanager
|
14
|
+
def set_meta(new_value):
|
15
|
+
# Sets the metadata for any tasks created under this
|
16
|
+
global META_VAL
|
17
|
+
META_VAL.append(new_value)
|
18
|
+
try:
|
19
|
+
yield
|
20
|
+
finally:
|
21
|
+
META_VAL.pop()
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
from typing import Union
|
8
|
+
|
9
|
+
from monarch.common.client import Client as _Client
|
10
|
+
from monarch.common.device_mesh import DeviceMesh
|
11
|
+
from monarch.common.shape import NDSlice
|
12
|
+
|
13
|
+
from monarch.simulator.ir import IRGraph
|
14
|
+
from monarch.simulator.simulator import (
|
15
|
+
SimulatorBackendMode,
|
16
|
+
SimulatorController as _SimulatorController,
|
17
|
+
SimulatorInterface,
|
18
|
+
SimulatorTraceMode,
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
def Simulator(
|
23
|
+
hosts: int,
|
24
|
+
gpus: int,
|
25
|
+
*,
|
26
|
+
simulate_mode: Union["str", SimulatorBackendMode] = SimulatorBackendMode.SIMULATE,
|
27
|
+
trace_mode: Union["str", SimulatorTraceMode] = SimulatorTraceMode.STREAM_ONLY,
|
28
|
+
upload_trace: bool = False,
|
29
|
+
trace_path: str = "trace.json",
|
30
|
+
command_history_path: str = "command_history.pkl",
|
31
|
+
group_workers: bool = False,
|
32
|
+
build_ir: bool = False,
|
33
|
+
) -> "SimulatorInterface":
|
34
|
+
if isinstance(simulate_mode, str):
|
35
|
+
simulate_mode = getattr(SimulatorBackendMode, simulate_mode.upper())
|
36
|
+
if isinstance(trace_mode, str):
|
37
|
+
trace_mode = getattr(SimulatorTraceMode, trace_mode.upper())
|
38
|
+
|
39
|
+
ir = IRGraph() if build_ir else None
|
40
|
+
ctrl = _SimulatorController(
|
41
|
+
hosts * gpus,
|
42
|
+
gpu_per_host=gpus,
|
43
|
+
simulate_mode=simulate_mode,
|
44
|
+
trace_mode=trace_mode,
|
45
|
+
upload_trace=upload_trace,
|
46
|
+
trace_path=trace_path,
|
47
|
+
command_history_path=command_history_path,
|
48
|
+
group_workers=group_workers,
|
49
|
+
ir=ir,
|
50
|
+
)
|
51
|
+
client = _Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
|
52
|
+
dm = DeviceMesh(
|
53
|
+
client,
|
54
|
+
NDSlice(offset=0, sizes=[hosts, gpus], strides=[gpus, 1]),
|
55
|
+
("host", "gpu"),
|
56
|
+
)
|
57
|
+
|
58
|
+
dm.exit = lambda: client.shutdown()
|
59
|
+
return SimulatorInterface(dm, ctrl, ir)
|