torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/common/client.py
ADDED
@@ -0,0 +1,690 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import atexit
|
9
|
+
import difflib
|
10
|
+
import itertools
|
11
|
+
import logging
|
12
|
+
import math
|
13
|
+
import time
|
14
|
+
import traceback
|
15
|
+
import weakref
|
16
|
+
from collections import defaultdict
|
17
|
+
from typing import (
|
18
|
+
Callable,
|
19
|
+
cast,
|
20
|
+
Dict,
|
21
|
+
List,
|
22
|
+
NamedTuple,
|
23
|
+
Optional,
|
24
|
+
Sequence,
|
25
|
+
Set,
|
26
|
+
Tuple,
|
27
|
+
TYPE_CHECKING,
|
28
|
+
Union,
|
29
|
+
)
|
30
|
+
|
31
|
+
from weakref import WeakKeyDictionary
|
32
|
+
|
33
|
+
import torch
|
34
|
+
import torch.distributed
|
35
|
+
from monarch._rust_bindings.monarch_extension import tensor_worker
|
36
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
37
|
+
LogLevel,
|
38
|
+
WorldState,
|
39
|
+
)
|
40
|
+
from monarch.common import messages
|
41
|
+
from monarch.common.borrows import Borrow, StorageAliases
|
42
|
+
from monarch.common.controller_api import LogMessage, MessageResult, TController
|
43
|
+
from monarch.common.device_mesh import DeviceMesh
|
44
|
+
|
45
|
+
from monarch.common.future import Future
|
46
|
+
from monarch.common.invocation import DeviceException, RemoteException, Seq
|
47
|
+
from monarch.common.recording import flatten_messages, Recording
|
48
|
+
|
49
|
+
from monarch.common.reference import Ref, Referenceable
|
50
|
+
from monarch.common.shape import NDSlice
|
51
|
+
from monarch.common.stream import StreamRef
|
52
|
+
from monarch.common.tensor import Tensor
|
53
|
+
from monarch.common.tree import tree_map
|
54
|
+
|
55
|
+
from . import _coalescing
|
56
|
+
|
57
|
+
|
58
|
+
logger = logging.getLogger(__name__)
|
59
|
+
|
60
|
+
_CONTROLLER_STATUS_INTERVAL = 2
|
61
|
+
|
62
|
+
|
63
|
+
def TTL(timeout: Optional[float]) -> Callable[[], float]:
|
64
|
+
if timeout is None:
|
65
|
+
return lambda: math.inf
|
66
|
+
expiry = time.time() + timeout
|
67
|
+
return lambda: max(expiry - time.time(), 0)
|
68
|
+
|
69
|
+
|
70
|
+
class Client:
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
controller: TController,
|
74
|
+
world_size: int,
|
75
|
+
gpu_per_host: int,
|
76
|
+
):
|
77
|
+
self.inner = controller
|
78
|
+
self._world_size = world_size
|
79
|
+
self._gpu_per_host = gpu_per_host
|
80
|
+
self.next_ref = itertools.count()
|
81
|
+
self.failures: Dict[int, Dict[int, RemoteException]] = defaultdict(dict)
|
82
|
+
self._pending_del: Dict[DeviceMesh, List[int]] = defaultdict(list)
|
83
|
+
self._shutdown = False
|
84
|
+
self.controller_status_ttl = TTL(_CONTROLLER_STATUS_INTERVAL)
|
85
|
+
self._aliases: WeakKeyDictionary[torch.UntypedStorage, StorageAliases] = (
|
86
|
+
WeakKeyDictionary()
|
87
|
+
)
|
88
|
+
|
89
|
+
# stream._active = Stream("main2", _default=True)
|
90
|
+
|
91
|
+
self._backend_network_init = False
|
92
|
+
self._backend_network_init_point_to_point: Set[
|
93
|
+
Tuple["StreamRef", "StreamRef"]
|
94
|
+
] = set()
|
95
|
+
|
96
|
+
self.seq_gen = itertools.count()
|
97
|
+
# seq of the most recent message that was sent to controller
|
98
|
+
self.last_assigned_seq = -1
|
99
|
+
# seq of the last acked message from controller, ack message is initiated
|
100
|
+
# by the _request_status() call. By comparing last_processed_seq and
|
101
|
+
# last_assigned_seq, we can tell if all messages are processed by all
|
102
|
+
# workers.
|
103
|
+
self.last_processed_seq = -1
|
104
|
+
|
105
|
+
# an error that we have received but know for certain has not
|
106
|
+
# been propagated to a future. This will be reported on shutdown
|
107
|
+
# to avoid hiding the error. This is best effort: we only keep
|
108
|
+
# the error until the point the a future is dependent on
|
109
|
+
# _any_ error, not particularly the tracked one.
|
110
|
+
self._pending_shutdown_error = None
|
111
|
+
|
112
|
+
self.recorder = Recorder()
|
113
|
+
|
114
|
+
self.pending_results: Dict[
|
115
|
+
Seq, # seq of an invocation
|
116
|
+
Tuple[
|
117
|
+
Optional["Future"], # future to set
|
118
|
+
List[List[traceback.FrameSummary]], # local call stacks
|
119
|
+
],
|
120
|
+
] = {}
|
121
|
+
atexit.register(self._atexit)
|
122
|
+
self.created_communicators = set()
|
123
|
+
|
124
|
+
def send(
|
125
|
+
self,
|
126
|
+
ranks: Union[NDSlice, List[NDSlice]],
|
127
|
+
msg: NamedTuple,
|
128
|
+
) -> None:
|
129
|
+
if not _coalescing.is_active(self):
|
130
|
+
return self.send_nocoalesce(ranks, msg)
|
131
|
+
if _coalescing.is_recording(self):
|
132
|
+
match msg:
|
133
|
+
case messages.BorrowFirstUse() if msg.borrow not in self.recorder.borrow_entries_created:
|
134
|
+
return self.send_nocoalesce(ranks, msg)
|
135
|
+
case messages.BorrowLastUse() if msg.borrow not in self.recorder.borrow_entries_created:
|
136
|
+
raise ValueError(
|
137
|
+
"cannot explicitly drop a tensor inside a compiled block that was borrowed outside of it."
|
138
|
+
)
|
139
|
+
self.recorder.add_message(ranks, msg)
|
140
|
+
|
141
|
+
def send_nocoalesce(
|
142
|
+
self,
|
143
|
+
ranks: Union[NDSlice, List[NDSlice]],
|
144
|
+
msg: NamedTuple,
|
145
|
+
) -> None:
|
146
|
+
self.inner.send(ranks, msg)
|
147
|
+
|
148
|
+
def reset_recorder(self) -> "Recorder":
|
149
|
+
old, self.recorder = self.recorder, Recorder()
|
150
|
+
return old
|
151
|
+
|
152
|
+
def drop_borrow(self, borrow: "Borrow") -> None:
|
153
|
+
if not _coalescing.is_active(self):
|
154
|
+
return
|
155
|
+
if borrow._id not in self.recorder.borrow_entries_created:
|
156
|
+
tb = borrow.traceback_string
|
157
|
+
raise RuntimeError(
|
158
|
+
f"Borrow Traceback:\n{tb}Cannot drop a borrow while repeating a coalesced block because it would cause the borrow to drop multiple times. "
|
159
|
+
)
|
160
|
+
del self.recorder.borrow_entries_created[borrow._id]
|
161
|
+
|
162
|
+
def new_borrow(self, borrow_entry: "Borrow") -> None:
|
163
|
+
if not _coalescing.is_active(self):
|
164
|
+
return
|
165
|
+
self.recorder.borrow_entries_created[borrow_entry._id] = borrow_entry
|
166
|
+
|
167
|
+
@property
|
168
|
+
def all_ranks(self) -> NDSlice:
|
169
|
+
return NDSlice(offset=0, sizes=[self._world_size], strides=[1])
|
170
|
+
|
171
|
+
@property
|
172
|
+
def gpu_per_host(self) -> int:
|
173
|
+
return self._gpu_per_host
|
174
|
+
|
175
|
+
# shut down everything, including client/system/controller/workers.
|
176
|
+
# the shutdown procedure will wait for all messages to be processed
|
177
|
+
# by the worker, then stop the system.
|
178
|
+
def shutdown(
|
179
|
+
self,
|
180
|
+
destroy_pg: bool = True,
|
181
|
+
error_reason: Optional[RemoteException | DeviceException | Exception] = None,
|
182
|
+
) -> None:
|
183
|
+
if self.has_shutdown:
|
184
|
+
return
|
185
|
+
logger.info("shutting down the client gracefully")
|
186
|
+
|
187
|
+
atexit.unregister(self._atexit)
|
188
|
+
self._shutdown = True
|
189
|
+
|
190
|
+
# request status for the last sent seq, and wait for the result to make sure all
|
191
|
+
# seqs are processed.
|
192
|
+
if self.last_assigned_seq > self.last_processed_seq:
|
193
|
+
self._request_status()
|
194
|
+
|
195
|
+
# send Exit message to stop the workers, wait for a bit for the workers to Exit
|
196
|
+
# with the correct exit code before we stop the system.
|
197
|
+
self.send(self.all_ranks, messages.Exit(destroy_pg, error_reason))
|
198
|
+
time.sleep(2)
|
199
|
+
|
200
|
+
# put a overall timeout on the shutdown waiting for now, better shutdown for
|
201
|
+
# multi-mesh setup will be implemented later.
|
202
|
+
timeout = 60
|
203
|
+
start_time = time.time()
|
204
|
+
|
205
|
+
try:
|
206
|
+
while (
|
207
|
+
time.time() - start_time < timeout
|
208
|
+
and self.last_assigned_seq > self.last_processed_seq
|
209
|
+
):
|
210
|
+
# TODO(T216336422): retire client::drain_and_stop() as it doesn't
|
211
|
+
# really drain all messages
|
212
|
+
output = self.inner.next_message(1.0)
|
213
|
+
if output is not None:
|
214
|
+
if isinstance(output, MessageResult):
|
215
|
+
# restart the timer as we got new result back
|
216
|
+
start_time = time.time()
|
217
|
+
self._handle_pending_result(output)
|
218
|
+
elif isinstance(output, LogMessage):
|
219
|
+
self._log_message(output)
|
220
|
+
|
221
|
+
# Drain any remaining message in client queue (if any)
|
222
|
+
for output in self.inner.drain_and_stop():
|
223
|
+
if isinstance(output, MessageResult):
|
224
|
+
self._handle_pending_result(output)
|
225
|
+
elif isinstance(output, LogMessage):
|
226
|
+
self._log_message(output)
|
227
|
+
except DeviceException:
|
228
|
+
# exception in message draining should be ignored during shutdown, as
|
229
|
+
# we are shutting down the system anyway
|
230
|
+
logger.warning(
|
231
|
+
"exception in message draining during shutdown, "
|
232
|
+
"ignoring and continue to stop the system"
|
233
|
+
)
|
234
|
+
pass
|
235
|
+
|
236
|
+
# all messages are processed, we can now stop the system
|
237
|
+
if time.time() - start_time >= timeout:
|
238
|
+
logger.warning(
|
239
|
+
"timeout waiting for all messages to be processed, "
|
240
|
+
"stop the mesh anyway"
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
logger.info("all messages are processed, stop the mesh")
|
244
|
+
self.inner.stop_mesh()
|
245
|
+
|
246
|
+
@property
|
247
|
+
def has_shutdown(self) -> bool:
|
248
|
+
return self._shutdown
|
249
|
+
|
250
|
+
def new_ref(self) -> int:
|
251
|
+
r = next(self.next_ref)
|
252
|
+
if _coalescing.is_active(self):
|
253
|
+
self.recorder.first_ref = min(self.recorder.first_ref, r)
|
254
|
+
return r
|
255
|
+
|
256
|
+
def handle_deletes(
|
257
|
+
self,
|
258
|
+
ranks: Union[NDSlice, List[NDSlice]],
|
259
|
+
refs: List[int],
|
260
|
+
coalesce: bool = True,
|
261
|
+
):
|
262
|
+
if coalesce:
|
263
|
+
self.send(ranks, messages.DeleteRefs(refs))
|
264
|
+
else:
|
265
|
+
self.send_nocoalesce(ranks, messages.DeleteRefs(refs))
|
266
|
+
self.inner.drop_refs([tensor_worker.Ref(id=ref) for ref in refs])
|
267
|
+
|
268
|
+
def flush_deletes(self, coalesce: bool = True):
|
269
|
+
for mesh, refs in self._pending_del.items():
|
270
|
+
self.handle_deletes(mesh.processes, refs, coalesce)
|
271
|
+
self._pending_del.clear()
|
272
|
+
|
273
|
+
def delete_ref(self, device_mesh: DeviceMesh, ref: int) -> None:
|
274
|
+
self._pending_del[device_mesh].append(ref)
|
275
|
+
|
276
|
+
@property
|
277
|
+
def aliases(self) -> WeakKeyDictionary[torch.UntypedStorage, StorageAliases]:
|
278
|
+
return self._aliases
|
279
|
+
|
280
|
+
def _request_status(self):
|
281
|
+
self.send(
|
282
|
+
self.all_ranks,
|
283
|
+
messages.RequestStatus(self.last_assigned_seq, False),
|
284
|
+
)
|
285
|
+
|
286
|
+
def handle_next_message(self, timeout: Optional[float]) -> bool:
|
287
|
+
output = self.inner.next_message(timeout)
|
288
|
+
if output is not None:
|
289
|
+
if isinstance(output, MessageResult):
|
290
|
+
self._handle_pending_result(output)
|
291
|
+
elif isinstance(output, LogMessage):
|
292
|
+
self._log_message(output)
|
293
|
+
return True
|
294
|
+
return False
|
295
|
+
|
296
|
+
def _log_message(self, msg: LogMessage) -> None:
|
297
|
+
match msg.level:
|
298
|
+
case LogLevel.INFO:
|
299
|
+
logger.info(msg.message)
|
300
|
+
case LogLevel.WARNING:
|
301
|
+
logger.warning(msg.message)
|
302
|
+
case LogLevel.ERROR:
|
303
|
+
logger.error(msg.message)
|
304
|
+
|
305
|
+
def _handle_pending_result(self, output: MessageResult) -> None:
|
306
|
+
result = output.result
|
307
|
+
seq = output.seq
|
308
|
+
error = output.error
|
309
|
+
|
310
|
+
self.last_processed_seq = max(self.last_processed_seq, seq)
|
311
|
+
|
312
|
+
if error is not None:
|
313
|
+
logging.info("Received error for seq %s: %s", seq, error)
|
314
|
+
self._pending_shutdown_error = error
|
315
|
+
# We should not have set result if we have an error.
|
316
|
+
assert result is None
|
317
|
+
if not isinstance(error, RemoteException):
|
318
|
+
raise error
|
319
|
+
|
320
|
+
# Populate controller tracebacks for the remote failure
|
321
|
+
original_frame_seq = error.seq
|
322
|
+
index = error.controller_frame_index
|
323
|
+
assert index is not None
|
324
|
+
# TODO: Populate tracebacks for dependent invocations
|
325
|
+
if original_frame_seq == seq:
|
326
|
+
# The current invocation is the one causing the remote failure.
|
327
|
+
# We should have not populated the tracebacks yet.
|
328
|
+
assert error.controller_frames is None
|
329
|
+
_, tracebacks = self.pending_results[original_frame_seq]
|
330
|
+
assert tracebacks is not None
|
331
|
+
assert (
|
332
|
+
len(tracebacks) > index
|
333
|
+
), f"tracebacks contains {len(tracebacks)} frames, but index is {index}"
|
334
|
+
error.controller_frames = tracebacks[index]
|
335
|
+
|
336
|
+
fut, _ = self.pending_results[seq]
|
337
|
+
if fut is not None:
|
338
|
+
if error is None:
|
339
|
+
fut._set_result(result)
|
340
|
+
else:
|
341
|
+
fut._set_result(error)
|
342
|
+
self._pending_shutdown_error = None
|
343
|
+
elif result is not None:
|
344
|
+
logger.debug(f"{seq}: unused result {result}")
|
345
|
+
elif error is not None:
|
346
|
+
# errors get reported as results even if they
|
347
|
+
# do not have futures attached.
|
348
|
+
pass
|
349
|
+
|
350
|
+
# We can safely delete the seq as tracebacks have been saved to the remote failure itself.
|
351
|
+
del self.pending_results[seq]
|
352
|
+
|
353
|
+
def split_comm(self, dims, device_mesh, stream_ref) -> None:
|
354
|
+
"""Create a split communicator group with the specified ranks, and
|
355
|
+
associate it with a specific device mesh and stream.
|
356
|
+
"""
|
357
|
+
# For simplicity, just send this message to all ranks and split from the
|
358
|
+
# global communicator. As an optimization, the client could remember
|
359
|
+
# which comms have already been created and issue a message to a smaller
|
360
|
+
# set of ranks.
|
361
|
+
if not self._backend_network_init:
|
362
|
+
raise AssertionError(
|
363
|
+
"split_comm called before backend network initialization"
|
364
|
+
)
|
365
|
+
|
366
|
+
msg = messages.SplitComm(tuple(sorted(dims)), device_mesh, stream_ref)
|
367
|
+
if msg in self.created_communicators:
|
368
|
+
return
|
369
|
+
|
370
|
+
self.send_nocoalesce(self.all_ranks, msg)
|
371
|
+
self.created_communicators.add(msg)
|
372
|
+
|
373
|
+
def backend_network_init(self) -> None:
|
374
|
+
if self._backend_network_init:
|
375
|
+
return
|
376
|
+
self._backend_network_init = True
|
377
|
+
logger.info("Initializing backend network")
|
378
|
+
self.send_nocoalesce(self.all_ranks, messages.BackendNetworkInit())
|
379
|
+
|
380
|
+
def backend_network_point_to_point_init(
|
381
|
+
self, from_stream_ref: "StreamRef", to_stream_ref: "StreamRef"
|
382
|
+
) -> None:
|
383
|
+
key = (from_stream_ref, to_stream_ref)
|
384
|
+
|
385
|
+
if key in self._backend_network_init_point_to_point:
|
386
|
+
return
|
387
|
+
self._backend_network_init_point_to_point.add(key)
|
388
|
+
self.send_nocoalesce(
|
389
|
+
self.all_ranks,
|
390
|
+
messages.BackendNetworkPointToPointInit(from_stream_ref, to_stream_ref),
|
391
|
+
)
|
392
|
+
|
393
|
+
def new_node(
|
394
|
+
self,
|
395
|
+
defs: Sequence["Tensor"],
|
396
|
+
uses: Sequence["Tensor"],
|
397
|
+
future: Optional["Future"] = None,
|
398
|
+
tracebacks: Optional[List[List[traceback.FrameSummary]]] = None,
|
399
|
+
) -> Seq:
|
400
|
+
for t in uses:
|
401
|
+
t._use()
|
402
|
+
|
403
|
+
if tracebacks is None:
|
404
|
+
tracebacks = [traceback.extract_stack()[:-2]]
|
405
|
+
if _coalescing.is_recording(self):
|
406
|
+
assert future is None, "this should have been checked in fetch shard"
|
407
|
+
return self.recorder.add(defs, uses, tracebacks[0])
|
408
|
+
else:
|
409
|
+
return self.new_node_nocoalesce(defs, uses, future, tracebacks)
|
410
|
+
|
411
|
+
def new_node_nocoalesce(
|
412
|
+
self,
|
413
|
+
defs: Sequence["Tensor"],
|
414
|
+
uses: Sequence["Tensor"],
|
415
|
+
future: Optional["Future"],
|
416
|
+
tracebacks: List[List[traceback.FrameSummary]],
|
417
|
+
) -> Seq:
|
418
|
+
seq = self._next_seq()
|
419
|
+
self.pending_results[seq] = (future, tracebacks)
|
420
|
+
for d in defs:
|
421
|
+
d._seq = seq
|
422
|
+
self.inner.node(seq, defs, uses)
|
423
|
+
return seq
|
424
|
+
|
425
|
+
def _next_seq(self) -> Seq:
|
426
|
+
self.last_assigned_seq = next(self.seq_gen)
|
427
|
+
return self.last_assigned_seq
|
428
|
+
|
429
|
+
def _atexit(self) -> None:
|
430
|
+
logger.warning(
|
431
|
+
"Client is not shutting down properly before atexit. "
|
432
|
+
"This may be due to an exception or because device_mesh.exit() "
|
433
|
+
"was not called."
|
434
|
+
)
|
435
|
+
# Calling self.shutdown may cause a deadlock if something is wrong with
|
436
|
+
# the networking. Or should we make shutdown() not wait indefinitely?
|
437
|
+
self._shutdown = True
|
438
|
+
|
439
|
+
# send shutdown message to stop other processes.
|
440
|
+
self.inner.stop_mesh()
|
441
|
+
|
442
|
+
def no_coalescing(self, reason):
|
443
|
+
if _coalescing.is_active(self):
|
444
|
+
raise NotImplementedError(f"NYI: {reason} during a coalescing block")
|
445
|
+
|
446
|
+
def mesh_state(self) -> WorldState:
|
447
|
+
return self.inner.worker_world_state()
|
448
|
+
|
449
|
+
def fetch(
|
450
|
+
self,
|
451
|
+
mesh: "DeviceMesh",
|
452
|
+
stream: "StreamRef",
|
453
|
+
shard,
|
454
|
+
preprocess_message,
|
455
|
+
args,
|
456
|
+
kwargs,
|
457
|
+
defs: Tuple["Tensor", ...],
|
458
|
+
uses: Tuple["Tensor", ...],
|
459
|
+
) -> "Future":
|
460
|
+
fut = Future(self)
|
461
|
+
ident = self.new_node(defs, uses, fut)
|
462
|
+
process = mesh._process(shard)
|
463
|
+
self.send(
|
464
|
+
process,
|
465
|
+
messages.SendValue(
|
466
|
+
ident,
|
467
|
+
None,
|
468
|
+
defs,
|
469
|
+
preprocess_message,
|
470
|
+
args,
|
471
|
+
kwargs,
|
472
|
+
stream,
|
473
|
+
),
|
474
|
+
)
|
475
|
+
# we have to ask for status updates
|
476
|
+
# from workers to be sure they have finished
|
477
|
+
# enough work to count this future as finished,
|
478
|
+
# and all potential errors have been reported
|
479
|
+
self._request_status()
|
480
|
+
return fut
|
481
|
+
|
482
|
+
|
483
|
+
def tree_map_refs(first_ref: int, tree):
|
484
|
+
def translate_id(ref: int) -> int:
|
485
|
+
diff = ref - first_ref
|
486
|
+
if diff >= 0:
|
487
|
+
return -1 - diff
|
488
|
+
return ref
|
489
|
+
|
490
|
+
def translate_ref(obj):
|
491
|
+
match obj:
|
492
|
+
case Ref():
|
493
|
+
return translate_id(obj.id)
|
494
|
+
case Referenceable():
|
495
|
+
return None if obj.ref is None else translate_id(obj.ref)
|
496
|
+
case messages.DeleteRefs():
|
497
|
+
# Python destructors may not run in a deterministic order across
|
498
|
+
# traces of a recorded function, so we need to sort the refs to ensure
|
499
|
+
# a fair comparison during validation.
|
500
|
+
return messages.DeleteRefs(sorted([translate_id(r) for r in obj.refs]))
|
501
|
+
case messages.BorrowCreate():
|
502
|
+
result, borrow, *rest = [translate_ref(x) for x in obj]
|
503
|
+
return messages.BorrowCreate(result, translate_id(borrow), *rest)
|
504
|
+
case messages.BorrowDrop():
|
505
|
+
return messages.BorrowDrop(translate_id(obj.borrow))
|
506
|
+
case messages.BorrowFirstUse():
|
507
|
+
return messages.BorrowFirstUse(translate_id(obj.borrow))
|
508
|
+
case messages.BorrowLastUse():
|
509
|
+
return messages.BorrowLastUse(translate_id(obj.borrow))
|
510
|
+
case _:
|
511
|
+
return obj
|
512
|
+
|
513
|
+
return tree_map(
|
514
|
+
translate_ref,
|
515
|
+
tree,
|
516
|
+
is_leaf=lambda x: isinstance(
|
517
|
+
x,
|
518
|
+
(
|
519
|
+
Ref,
|
520
|
+
Referenceable,
|
521
|
+
messages.DeleteRefs,
|
522
|
+
messages.BorrowCreate,
|
523
|
+
messages.BorrowDrop,
|
524
|
+
messages.BorrowFirstUse,
|
525
|
+
messages.BorrowLastUse,
|
526
|
+
),
|
527
|
+
),
|
528
|
+
)
|
529
|
+
|
530
|
+
|
531
|
+
class Recorder:
|
532
|
+
def __init__(self):
|
533
|
+
self.borrow_entries_created: Dict[int, Borrow] = {}
|
534
|
+
self.messages: List[Union[NDSlice, List[NDSlice]], NamedTuple] = []
|
535
|
+
# these tables track the externally captured tensors that we
|
536
|
+
# use and mutate whenever this recording is run.
|
537
|
+
self.uses = {} # ordered set
|
538
|
+
self.mutates = {} # ordered set
|
539
|
+
self.creates: List[weakref.ref] = []
|
540
|
+
self.tracebacks = []
|
541
|
+
self.first_ref: int = math.inf
|
542
|
+
self.reference_recording: Optional["Recording"] = None
|
543
|
+
# Map from formal tensor storage to its corresponding argument indices
|
544
|
+
# in the recording input (there may be multiple aliases of the same
|
545
|
+
# tensor in the recording input).
|
546
|
+
self.formal_storages_to_indices: defaultdict[
|
547
|
+
torch.UntypedStorage, List[int]
|
548
|
+
] = defaultdict(list)
|
549
|
+
# Set of tensor storages for formals that are mutated during the recording.
|
550
|
+
self.mutated_formal_storages: Set[torch.UntypedStorage] = set()
|
551
|
+
|
552
|
+
def add_formal(self, formal: Tensor, argument_index: int) -> None:
|
553
|
+
self.formal_storages_to_indices[formal._fake.untyped_storage()].append(
|
554
|
+
argument_index
|
555
|
+
)
|
556
|
+
|
557
|
+
def add(
|
558
|
+
self,
|
559
|
+
defs: Sequence["Tensor"],
|
560
|
+
uses: Sequence["Tensor"],
|
561
|
+
traceback: List[traceback.FrameSummary],
|
562
|
+
):
|
563
|
+
for u in uses:
|
564
|
+
if u._seq is None:
|
565
|
+
# a lack of sequence num on a tensor means it was created within
|
566
|
+
# the recording and does not have to be tracked as a use
|
567
|
+
continue
|
568
|
+
self.uses[u] = None
|
569
|
+
for d in defs:
|
570
|
+
# a lack of sequence num means the tensor doesn't need to be tracked
|
571
|
+
# as a mutates, unless that tensor is an alias of a formal tensor
|
572
|
+
if d._seq is None:
|
573
|
+
self.creates.append(weakref.ref(d))
|
574
|
+
storage = d._fake.untyped_storage()
|
575
|
+
if storage in self.formal_storages_to_indices:
|
576
|
+
self.mutated_formal_storages.add(storage)
|
577
|
+
else:
|
578
|
+
self.mutates[d] = None
|
579
|
+
self.tracebacks.append(traceback)
|
580
|
+
return len(self.tracebacks) - 1
|
581
|
+
|
582
|
+
def _check(self):
|
583
|
+
if self.borrow_entries_created:
|
584
|
+
tbs = "------------\n".join(
|
585
|
+
b.traceback_string for b in self.borrow_entries_created.values()
|
586
|
+
)
|
587
|
+
raise RuntimeError(
|
588
|
+
f"Borrows created during recorded coalesced block need to be dropped before the block ends. Tracebacks of where the blocks were created: {tbs}"
|
589
|
+
)
|
590
|
+
|
591
|
+
@property
|
592
|
+
def flat_messages(self):
|
593
|
+
return flatten_messages(self.messages)
|
594
|
+
|
595
|
+
def run_once(self, client: "Client"):
|
596
|
+
self._check()
|
597
|
+
for rank, msgs in self.flat_messages.items():
|
598
|
+
client.send_nocoalesce(
|
599
|
+
NDSlice(offset=rank, sizes=[], strides=[]), messages.CommandGroup(msgs)
|
600
|
+
)
|
601
|
+
|
602
|
+
def abandon(self):
|
603
|
+
# an error happened and we will not use this recording. Every tensor created
|
604
|
+
# as part of this recording has never been defined, so we blank out the
|
605
|
+
# .ref to disarm the deletions.
|
606
|
+
for w in self.creates:
|
607
|
+
v = w()
|
608
|
+
if v is not None:
|
609
|
+
v.ref = None
|
610
|
+
|
611
|
+
def add_message(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple):
|
612
|
+
if isinstance(msg, messages.RecordingFormal):
|
613
|
+
self.add_formal(cast(Tensor, msg.result), msg.argument_index)
|
614
|
+
|
615
|
+
# this is pretty expensive, but we can't hold tensor references without
|
616
|
+
# extending their lifetime unnecessarily, so they must be converted to
|
617
|
+
# references here. It also prevents a bug when a tensor is dropped,
|
618
|
+
# after a message is recorded and will no longer have a ref field.
|
619
|
+
msg = tree_map(
|
620
|
+
lambda x: (
|
621
|
+
Ref(x.ref) if isinstance(x, Tensor) and x.ref is not None else x
|
622
|
+
),
|
623
|
+
msg,
|
624
|
+
)
|
625
|
+
self.messages.append((ranks, msg))
|
626
|
+
reference_recording = self.reference_recording
|
627
|
+
if reference_recording is not None:
|
628
|
+
last_index = len(self.messages) - 1
|
629
|
+
reference_messages = reference_recording.buffered_messages
|
630
|
+
mine = self.messages[last_index]
|
631
|
+
theirs = (
|
632
|
+
reference_messages[last_index]
|
633
|
+
if len(reference_messages) > last_index
|
634
|
+
else None
|
635
|
+
)
|
636
|
+
mine = tree_map_refs(self.first_ref, mine)
|
637
|
+
theirs = tree_map_refs(reference_recording.first_ref, theirs)
|
638
|
+
if mine != theirs:
|
639
|
+
traceback_index = len(self.tracebacks) - 1
|
640
|
+
|
641
|
+
tb_mine = traceback.format_list(self.tracebacks[traceback_index])
|
642
|
+
while tb_mine and "in _record_and_define" not in tb_mine[0]:
|
643
|
+
tb_mine.pop(0)
|
644
|
+
|
645
|
+
tb_theirs = traceback.format_list(
|
646
|
+
reference_recording.tracebacks[traceback_index]
|
647
|
+
)
|
648
|
+
while tb_theirs and "in _record_and_define" not in tb_theirs[0]:
|
649
|
+
tb_theirs.pop(0)
|
650
|
+
|
651
|
+
the_diff = "\n".join(difflib.ndiff([str(theirs)], [str(mine)]))
|
652
|
+
raise RuntimeError(
|
653
|
+
f"monarch.compiled failed to verify recording. Recording diverges at operation {last_index}.\n{the_diff}\n\nTraceback of original recording\n{''.join(tb_theirs)}\n\nTraceback of second recording\n{''.join(tb_mine)}\n"
|
654
|
+
)
|
655
|
+
|
656
|
+
def verify_against(self, reference: Recording):
|
657
|
+
self.reference_recording = reference
|
658
|
+
|
659
|
+
def define_recording(
|
660
|
+
self,
|
661
|
+
client: "Client",
|
662
|
+
nresults: int,
|
663
|
+
nformals: int,
|
664
|
+
) -> Recording:
|
665
|
+
self._check()
|
666
|
+
# any remaining references to tensors we defined in the recording are
|
667
|
+
# not valid for future use outside the recording, so drop them
|
668
|
+
# such that we report an error if they are used.
|
669
|
+
for w in self.creates:
|
670
|
+
v = w()
|
671
|
+
if v is not None:
|
672
|
+
v._drop_ref()
|
673
|
+
# It should be safe to use a list instead of a set here, since
|
674
|
+
# no entry in formal_storages_to_indices should have any overlap
|
675
|
+
# with any other entry. So mutated_formal_indices should automatically
|
676
|
+
# have unique elements.
|
677
|
+
mutated_formal_indices = []
|
678
|
+
for storage in self.mutated_formal_storages:
|
679
|
+
mutated_formal_indices.extend(self.formal_storages_to_indices[storage])
|
680
|
+
return Recording(
|
681
|
+
client,
|
682
|
+
list(self.uses.keys()),
|
683
|
+
list(self.mutates.keys()),
|
684
|
+
sorted(mutated_formal_indices),
|
685
|
+
self.tracebacks,
|
686
|
+
self.messages,
|
687
|
+
nresults,
|
688
|
+
nformals,
|
689
|
+
self.first_ref,
|
690
|
+
)
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
SIM_MESH_CLIENT_TIMEOUT = 5
|
10
|
+
SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL = 5
|