torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,690 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import atexit
9
+ import difflib
10
+ import itertools
11
+ import logging
12
+ import math
13
+ import time
14
+ import traceback
15
+ import weakref
16
+ from collections import defaultdict
17
+ from typing import (
18
+ Callable,
19
+ cast,
20
+ Dict,
21
+ List,
22
+ NamedTuple,
23
+ Optional,
24
+ Sequence,
25
+ Set,
26
+ Tuple,
27
+ TYPE_CHECKING,
28
+ Union,
29
+ )
30
+
31
+ from weakref import WeakKeyDictionary
32
+
33
+ import torch
34
+ import torch.distributed
35
+ from monarch._rust_bindings.monarch_extension import tensor_worker
36
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
37
+ LogLevel,
38
+ WorldState,
39
+ )
40
+ from monarch.common import messages
41
+ from monarch.common.borrows import Borrow, StorageAliases
42
+ from monarch.common.controller_api import LogMessage, MessageResult, TController
43
+ from monarch.common.device_mesh import DeviceMesh
44
+
45
+ from monarch.common.future import Future
46
+ from monarch.common.invocation import DeviceException, RemoteException, Seq
47
+ from monarch.common.recording import flatten_messages, Recording
48
+
49
+ from monarch.common.reference import Ref, Referenceable
50
+ from monarch.common.shape import NDSlice
51
+ from monarch.common.stream import StreamRef
52
+ from monarch.common.tensor import Tensor
53
+ from monarch.common.tree import tree_map
54
+
55
+ from . import _coalescing
56
+
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+ _CONTROLLER_STATUS_INTERVAL = 2
61
+
62
+
63
+ def TTL(timeout: Optional[float]) -> Callable[[], float]:
64
+ if timeout is None:
65
+ return lambda: math.inf
66
+ expiry = time.time() + timeout
67
+ return lambda: max(expiry - time.time(), 0)
68
+
69
+
70
+ class Client:
71
+ def __init__(
72
+ self,
73
+ controller: TController,
74
+ world_size: int,
75
+ gpu_per_host: int,
76
+ ):
77
+ self.inner = controller
78
+ self._world_size = world_size
79
+ self._gpu_per_host = gpu_per_host
80
+ self.next_ref = itertools.count()
81
+ self.failures: Dict[int, Dict[int, RemoteException]] = defaultdict(dict)
82
+ self._pending_del: Dict[DeviceMesh, List[int]] = defaultdict(list)
83
+ self._shutdown = False
84
+ self.controller_status_ttl = TTL(_CONTROLLER_STATUS_INTERVAL)
85
+ self._aliases: WeakKeyDictionary[torch.UntypedStorage, StorageAliases] = (
86
+ WeakKeyDictionary()
87
+ )
88
+
89
+ # stream._active = Stream("main2", _default=True)
90
+
91
+ self._backend_network_init = False
92
+ self._backend_network_init_point_to_point: Set[
93
+ Tuple["StreamRef", "StreamRef"]
94
+ ] = set()
95
+
96
+ self.seq_gen = itertools.count()
97
+ # seq of the most recent message that was sent to controller
98
+ self.last_assigned_seq = -1
99
+ # seq of the last acked message from controller, ack message is initiated
100
+ # by the _request_status() call. By comparing last_processed_seq and
101
+ # last_assigned_seq, we can tell if all messages are processed by all
102
+ # workers.
103
+ self.last_processed_seq = -1
104
+
105
+ # an error that we have received but know for certain has not
106
+ # been propagated to a future. This will be reported on shutdown
107
+ # to avoid hiding the error. This is best effort: we only keep
108
+ # the error until the point the a future is dependent on
109
+ # _any_ error, not particularly the tracked one.
110
+ self._pending_shutdown_error = None
111
+
112
+ self.recorder = Recorder()
113
+
114
+ self.pending_results: Dict[
115
+ Seq, # seq of an invocation
116
+ Tuple[
117
+ Optional["Future"], # future to set
118
+ List[List[traceback.FrameSummary]], # local call stacks
119
+ ],
120
+ ] = {}
121
+ atexit.register(self._atexit)
122
+ self.created_communicators = set()
123
+
124
+ def send(
125
+ self,
126
+ ranks: Union[NDSlice, List[NDSlice]],
127
+ msg: NamedTuple,
128
+ ) -> None:
129
+ if not _coalescing.is_active(self):
130
+ return self.send_nocoalesce(ranks, msg)
131
+ if _coalescing.is_recording(self):
132
+ match msg:
133
+ case messages.BorrowFirstUse() if msg.borrow not in self.recorder.borrow_entries_created:
134
+ return self.send_nocoalesce(ranks, msg)
135
+ case messages.BorrowLastUse() if msg.borrow not in self.recorder.borrow_entries_created:
136
+ raise ValueError(
137
+ "cannot explicitly drop a tensor inside a compiled block that was borrowed outside of it."
138
+ )
139
+ self.recorder.add_message(ranks, msg)
140
+
141
+ def send_nocoalesce(
142
+ self,
143
+ ranks: Union[NDSlice, List[NDSlice]],
144
+ msg: NamedTuple,
145
+ ) -> None:
146
+ self.inner.send(ranks, msg)
147
+
148
+ def reset_recorder(self) -> "Recorder":
149
+ old, self.recorder = self.recorder, Recorder()
150
+ return old
151
+
152
+ def drop_borrow(self, borrow: "Borrow") -> None:
153
+ if not _coalescing.is_active(self):
154
+ return
155
+ if borrow._id not in self.recorder.borrow_entries_created:
156
+ tb = borrow.traceback_string
157
+ raise RuntimeError(
158
+ f"Borrow Traceback:\n{tb}Cannot drop a borrow while repeating a coalesced block because it would cause the borrow to drop multiple times. "
159
+ )
160
+ del self.recorder.borrow_entries_created[borrow._id]
161
+
162
+ def new_borrow(self, borrow_entry: "Borrow") -> None:
163
+ if not _coalescing.is_active(self):
164
+ return
165
+ self.recorder.borrow_entries_created[borrow_entry._id] = borrow_entry
166
+
167
+ @property
168
+ def all_ranks(self) -> NDSlice:
169
+ return NDSlice(offset=0, sizes=[self._world_size], strides=[1])
170
+
171
+ @property
172
+ def gpu_per_host(self) -> int:
173
+ return self._gpu_per_host
174
+
175
+ # shut down everything, including client/system/controller/workers.
176
+ # the shutdown procedure will wait for all messages to be processed
177
+ # by the worker, then stop the system.
178
+ def shutdown(
179
+ self,
180
+ destroy_pg: bool = True,
181
+ error_reason: Optional[RemoteException | DeviceException | Exception] = None,
182
+ ) -> None:
183
+ if self.has_shutdown:
184
+ return
185
+ logger.info("shutting down the client gracefully")
186
+
187
+ atexit.unregister(self._atexit)
188
+ self._shutdown = True
189
+
190
+ # request status for the last sent seq, and wait for the result to make sure all
191
+ # seqs are processed.
192
+ if self.last_assigned_seq > self.last_processed_seq:
193
+ self._request_status()
194
+
195
+ # send Exit message to stop the workers, wait for a bit for the workers to Exit
196
+ # with the correct exit code before we stop the system.
197
+ self.send(self.all_ranks, messages.Exit(destroy_pg, error_reason))
198
+ time.sleep(2)
199
+
200
+ # put a overall timeout on the shutdown waiting for now, better shutdown for
201
+ # multi-mesh setup will be implemented later.
202
+ timeout = 60
203
+ start_time = time.time()
204
+
205
+ try:
206
+ while (
207
+ time.time() - start_time < timeout
208
+ and self.last_assigned_seq > self.last_processed_seq
209
+ ):
210
+ # TODO(T216336422): retire client::drain_and_stop() as it doesn't
211
+ # really drain all messages
212
+ output = self.inner.next_message(1.0)
213
+ if output is not None:
214
+ if isinstance(output, MessageResult):
215
+ # restart the timer as we got new result back
216
+ start_time = time.time()
217
+ self._handle_pending_result(output)
218
+ elif isinstance(output, LogMessage):
219
+ self._log_message(output)
220
+
221
+ # Drain any remaining message in client queue (if any)
222
+ for output in self.inner.drain_and_stop():
223
+ if isinstance(output, MessageResult):
224
+ self._handle_pending_result(output)
225
+ elif isinstance(output, LogMessage):
226
+ self._log_message(output)
227
+ except DeviceException:
228
+ # exception in message draining should be ignored during shutdown, as
229
+ # we are shutting down the system anyway
230
+ logger.warning(
231
+ "exception in message draining during shutdown, "
232
+ "ignoring and continue to stop the system"
233
+ )
234
+ pass
235
+
236
+ # all messages are processed, we can now stop the system
237
+ if time.time() - start_time >= timeout:
238
+ logger.warning(
239
+ "timeout waiting for all messages to be processed, "
240
+ "stop the mesh anyway"
241
+ )
242
+ else:
243
+ logger.info("all messages are processed, stop the mesh")
244
+ self.inner.stop_mesh()
245
+
246
+ @property
247
+ def has_shutdown(self) -> bool:
248
+ return self._shutdown
249
+
250
+ def new_ref(self) -> int:
251
+ r = next(self.next_ref)
252
+ if _coalescing.is_active(self):
253
+ self.recorder.first_ref = min(self.recorder.first_ref, r)
254
+ return r
255
+
256
+ def handle_deletes(
257
+ self,
258
+ ranks: Union[NDSlice, List[NDSlice]],
259
+ refs: List[int],
260
+ coalesce: bool = True,
261
+ ):
262
+ if coalesce:
263
+ self.send(ranks, messages.DeleteRefs(refs))
264
+ else:
265
+ self.send_nocoalesce(ranks, messages.DeleteRefs(refs))
266
+ self.inner.drop_refs([tensor_worker.Ref(id=ref) for ref in refs])
267
+
268
+ def flush_deletes(self, coalesce: bool = True):
269
+ for mesh, refs in self._pending_del.items():
270
+ self.handle_deletes(mesh.processes, refs, coalesce)
271
+ self._pending_del.clear()
272
+
273
+ def delete_ref(self, device_mesh: DeviceMesh, ref: int) -> None:
274
+ self._pending_del[device_mesh].append(ref)
275
+
276
+ @property
277
+ def aliases(self) -> WeakKeyDictionary[torch.UntypedStorage, StorageAliases]:
278
+ return self._aliases
279
+
280
+ def _request_status(self):
281
+ self.send(
282
+ self.all_ranks,
283
+ messages.RequestStatus(self.last_assigned_seq, False),
284
+ )
285
+
286
+ def handle_next_message(self, timeout: Optional[float]) -> bool:
287
+ output = self.inner.next_message(timeout)
288
+ if output is not None:
289
+ if isinstance(output, MessageResult):
290
+ self._handle_pending_result(output)
291
+ elif isinstance(output, LogMessage):
292
+ self._log_message(output)
293
+ return True
294
+ return False
295
+
296
+ def _log_message(self, msg: LogMessage) -> None:
297
+ match msg.level:
298
+ case LogLevel.INFO:
299
+ logger.info(msg.message)
300
+ case LogLevel.WARNING:
301
+ logger.warning(msg.message)
302
+ case LogLevel.ERROR:
303
+ logger.error(msg.message)
304
+
305
+ def _handle_pending_result(self, output: MessageResult) -> None:
306
+ result = output.result
307
+ seq = output.seq
308
+ error = output.error
309
+
310
+ self.last_processed_seq = max(self.last_processed_seq, seq)
311
+
312
+ if error is not None:
313
+ logging.info("Received error for seq %s: %s", seq, error)
314
+ self._pending_shutdown_error = error
315
+ # We should not have set result if we have an error.
316
+ assert result is None
317
+ if not isinstance(error, RemoteException):
318
+ raise error
319
+
320
+ # Populate controller tracebacks for the remote failure
321
+ original_frame_seq = error.seq
322
+ index = error.controller_frame_index
323
+ assert index is not None
324
+ # TODO: Populate tracebacks for dependent invocations
325
+ if original_frame_seq == seq:
326
+ # The current invocation is the one causing the remote failure.
327
+ # We should have not populated the tracebacks yet.
328
+ assert error.controller_frames is None
329
+ _, tracebacks = self.pending_results[original_frame_seq]
330
+ assert tracebacks is not None
331
+ assert (
332
+ len(tracebacks) > index
333
+ ), f"tracebacks contains {len(tracebacks)} frames, but index is {index}"
334
+ error.controller_frames = tracebacks[index]
335
+
336
+ fut, _ = self.pending_results[seq]
337
+ if fut is not None:
338
+ if error is None:
339
+ fut._set_result(result)
340
+ else:
341
+ fut._set_result(error)
342
+ self._pending_shutdown_error = None
343
+ elif result is not None:
344
+ logger.debug(f"{seq}: unused result {result}")
345
+ elif error is not None:
346
+ # errors get reported as results even if they
347
+ # do not have futures attached.
348
+ pass
349
+
350
+ # We can safely delete the seq as tracebacks have been saved to the remote failure itself.
351
+ del self.pending_results[seq]
352
+
353
+ def split_comm(self, dims, device_mesh, stream_ref) -> None:
354
+ """Create a split communicator group with the specified ranks, and
355
+ associate it with a specific device mesh and stream.
356
+ """
357
+ # For simplicity, just send this message to all ranks and split from the
358
+ # global communicator. As an optimization, the client could remember
359
+ # which comms have already been created and issue a message to a smaller
360
+ # set of ranks.
361
+ if not self._backend_network_init:
362
+ raise AssertionError(
363
+ "split_comm called before backend network initialization"
364
+ )
365
+
366
+ msg = messages.SplitComm(tuple(sorted(dims)), device_mesh, stream_ref)
367
+ if msg in self.created_communicators:
368
+ return
369
+
370
+ self.send_nocoalesce(self.all_ranks, msg)
371
+ self.created_communicators.add(msg)
372
+
373
+ def backend_network_init(self) -> None:
374
+ if self._backend_network_init:
375
+ return
376
+ self._backend_network_init = True
377
+ logger.info("Initializing backend network")
378
+ self.send_nocoalesce(self.all_ranks, messages.BackendNetworkInit())
379
+
380
+ def backend_network_point_to_point_init(
381
+ self, from_stream_ref: "StreamRef", to_stream_ref: "StreamRef"
382
+ ) -> None:
383
+ key = (from_stream_ref, to_stream_ref)
384
+
385
+ if key in self._backend_network_init_point_to_point:
386
+ return
387
+ self._backend_network_init_point_to_point.add(key)
388
+ self.send_nocoalesce(
389
+ self.all_ranks,
390
+ messages.BackendNetworkPointToPointInit(from_stream_ref, to_stream_ref),
391
+ )
392
+
393
+ def new_node(
394
+ self,
395
+ defs: Sequence["Tensor"],
396
+ uses: Sequence["Tensor"],
397
+ future: Optional["Future"] = None,
398
+ tracebacks: Optional[List[List[traceback.FrameSummary]]] = None,
399
+ ) -> Seq:
400
+ for t in uses:
401
+ t._use()
402
+
403
+ if tracebacks is None:
404
+ tracebacks = [traceback.extract_stack()[:-2]]
405
+ if _coalescing.is_recording(self):
406
+ assert future is None, "this should have been checked in fetch shard"
407
+ return self.recorder.add(defs, uses, tracebacks[0])
408
+ else:
409
+ return self.new_node_nocoalesce(defs, uses, future, tracebacks)
410
+
411
+ def new_node_nocoalesce(
412
+ self,
413
+ defs: Sequence["Tensor"],
414
+ uses: Sequence["Tensor"],
415
+ future: Optional["Future"],
416
+ tracebacks: List[List[traceback.FrameSummary]],
417
+ ) -> Seq:
418
+ seq = self._next_seq()
419
+ self.pending_results[seq] = (future, tracebacks)
420
+ for d in defs:
421
+ d._seq = seq
422
+ self.inner.node(seq, defs, uses)
423
+ return seq
424
+
425
+ def _next_seq(self) -> Seq:
426
+ self.last_assigned_seq = next(self.seq_gen)
427
+ return self.last_assigned_seq
428
+
429
+ def _atexit(self) -> None:
430
+ logger.warning(
431
+ "Client is not shutting down properly before atexit. "
432
+ "This may be due to an exception or because device_mesh.exit() "
433
+ "was not called."
434
+ )
435
+ # Calling self.shutdown may cause a deadlock if something is wrong with
436
+ # the networking. Or should we make shutdown() not wait indefinitely?
437
+ self._shutdown = True
438
+
439
+ # send shutdown message to stop other processes.
440
+ self.inner.stop_mesh()
441
+
442
+ def no_coalescing(self, reason):
443
+ if _coalescing.is_active(self):
444
+ raise NotImplementedError(f"NYI: {reason} during a coalescing block")
445
+
446
+ def mesh_state(self) -> WorldState:
447
+ return self.inner.worker_world_state()
448
+
449
+ def fetch(
450
+ self,
451
+ mesh: "DeviceMesh",
452
+ stream: "StreamRef",
453
+ shard,
454
+ preprocess_message,
455
+ args,
456
+ kwargs,
457
+ defs: Tuple["Tensor", ...],
458
+ uses: Tuple["Tensor", ...],
459
+ ) -> "Future":
460
+ fut = Future(self)
461
+ ident = self.new_node(defs, uses, fut)
462
+ process = mesh._process(shard)
463
+ self.send(
464
+ process,
465
+ messages.SendValue(
466
+ ident,
467
+ None,
468
+ defs,
469
+ preprocess_message,
470
+ args,
471
+ kwargs,
472
+ stream,
473
+ ),
474
+ )
475
+ # we have to ask for status updates
476
+ # from workers to be sure they have finished
477
+ # enough work to count this future as finished,
478
+ # and all potential errors have been reported
479
+ self._request_status()
480
+ return fut
481
+
482
+
483
+ def tree_map_refs(first_ref: int, tree):
484
+ def translate_id(ref: int) -> int:
485
+ diff = ref - first_ref
486
+ if diff >= 0:
487
+ return -1 - diff
488
+ return ref
489
+
490
+ def translate_ref(obj):
491
+ match obj:
492
+ case Ref():
493
+ return translate_id(obj.id)
494
+ case Referenceable():
495
+ return None if obj.ref is None else translate_id(obj.ref)
496
+ case messages.DeleteRefs():
497
+ # Python destructors may not run in a deterministic order across
498
+ # traces of a recorded function, so we need to sort the refs to ensure
499
+ # a fair comparison during validation.
500
+ return messages.DeleteRefs(sorted([translate_id(r) for r in obj.refs]))
501
+ case messages.BorrowCreate():
502
+ result, borrow, *rest = [translate_ref(x) for x in obj]
503
+ return messages.BorrowCreate(result, translate_id(borrow), *rest)
504
+ case messages.BorrowDrop():
505
+ return messages.BorrowDrop(translate_id(obj.borrow))
506
+ case messages.BorrowFirstUse():
507
+ return messages.BorrowFirstUse(translate_id(obj.borrow))
508
+ case messages.BorrowLastUse():
509
+ return messages.BorrowLastUse(translate_id(obj.borrow))
510
+ case _:
511
+ return obj
512
+
513
+ return tree_map(
514
+ translate_ref,
515
+ tree,
516
+ is_leaf=lambda x: isinstance(
517
+ x,
518
+ (
519
+ Ref,
520
+ Referenceable,
521
+ messages.DeleteRefs,
522
+ messages.BorrowCreate,
523
+ messages.BorrowDrop,
524
+ messages.BorrowFirstUse,
525
+ messages.BorrowLastUse,
526
+ ),
527
+ ),
528
+ )
529
+
530
+
531
+ class Recorder:
532
+ def __init__(self):
533
+ self.borrow_entries_created: Dict[int, Borrow] = {}
534
+ self.messages: List[Union[NDSlice, List[NDSlice]], NamedTuple] = []
535
+ # these tables track the externally captured tensors that we
536
+ # use and mutate whenever this recording is run.
537
+ self.uses = {} # ordered set
538
+ self.mutates = {} # ordered set
539
+ self.creates: List[weakref.ref] = []
540
+ self.tracebacks = []
541
+ self.first_ref: int = math.inf
542
+ self.reference_recording: Optional["Recording"] = None
543
+ # Map from formal tensor storage to its corresponding argument indices
544
+ # in the recording input (there may be multiple aliases of the same
545
+ # tensor in the recording input).
546
+ self.formal_storages_to_indices: defaultdict[
547
+ torch.UntypedStorage, List[int]
548
+ ] = defaultdict(list)
549
+ # Set of tensor storages for formals that are mutated during the recording.
550
+ self.mutated_formal_storages: Set[torch.UntypedStorage] = set()
551
+
552
+ def add_formal(self, formal: Tensor, argument_index: int) -> None:
553
+ self.formal_storages_to_indices[formal._fake.untyped_storage()].append(
554
+ argument_index
555
+ )
556
+
557
+ def add(
558
+ self,
559
+ defs: Sequence["Tensor"],
560
+ uses: Sequence["Tensor"],
561
+ traceback: List[traceback.FrameSummary],
562
+ ):
563
+ for u in uses:
564
+ if u._seq is None:
565
+ # a lack of sequence num on a tensor means it was created within
566
+ # the recording and does not have to be tracked as a use
567
+ continue
568
+ self.uses[u] = None
569
+ for d in defs:
570
+ # a lack of sequence num means the tensor doesn't need to be tracked
571
+ # as a mutates, unless that tensor is an alias of a formal tensor
572
+ if d._seq is None:
573
+ self.creates.append(weakref.ref(d))
574
+ storage = d._fake.untyped_storage()
575
+ if storage in self.formal_storages_to_indices:
576
+ self.mutated_formal_storages.add(storage)
577
+ else:
578
+ self.mutates[d] = None
579
+ self.tracebacks.append(traceback)
580
+ return len(self.tracebacks) - 1
581
+
582
+ def _check(self):
583
+ if self.borrow_entries_created:
584
+ tbs = "------------\n".join(
585
+ b.traceback_string for b in self.borrow_entries_created.values()
586
+ )
587
+ raise RuntimeError(
588
+ f"Borrows created during recorded coalesced block need to be dropped before the block ends. Tracebacks of where the blocks were created: {tbs}"
589
+ )
590
+
591
+ @property
592
+ def flat_messages(self):
593
+ return flatten_messages(self.messages)
594
+
595
+ def run_once(self, client: "Client"):
596
+ self._check()
597
+ for rank, msgs in self.flat_messages.items():
598
+ client.send_nocoalesce(
599
+ NDSlice(offset=rank, sizes=[], strides=[]), messages.CommandGroup(msgs)
600
+ )
601
+
602
+ def abandon(self):
603
+ # an error happened and we will not use this recording. Every tensor created
604
+ # as part of this recording has never been defined, so we blank out the
605
+ # .ref to disarm the deletions.
606
+ for w in self.creates:
607
+ v = w()
608
+ if v is not None:
609
+ v.ref = None
610
+
611
+ def add_message(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple):
612
+ if isinstance(msg, messages.RecordingFormal):
613
+ self.add_formal(cast(Tensor, msg.result), msg.argument_index)
614
+
615
+ # this is pretty expensive, but we can't hold tensor references without
616
+ # extending their lifetime unnecessarily, so they must be converted to
617
+ # references here. It also prevents a bug when a tensor is dropped,
618
+ # after a message is recorded and will no longer have a ref field.
619
+ msg = tree_map(
620
+ lambda x: (
621
+ Ref(x.ref) if isinstance(x, Tensor) and x.ref is not None else x
622
+ ),
623
+ msg,
624
+ )
625
+ self.messages.append((ranks, msg))
626
+ reference_recording = self.reference_recording
627
+ if reference_recording is not None:
628
+ last_index = len(self.messages) - 1
629
+ reference_messages = reference_recording.buffered_messages
630
+ mine = self.messages[last_index]
631
+ theirs = (
632
+ reference_messages[last_index]
633
+ if len(reference_messages) > last_index
634
+ else None
635
+ )
636
+ mine = tree_map_refs(self.first_ref, mine)
637
+ theirs = tree_map_refs(reference_recording.first_ref, theirs)
638
+ if mine != theirs:
639
+ traceback_index = len(self.tracebacks) - 1
640
+
641
+ tb_mine = traceback.format_list(self.tracebacks[traceback_index])
642
+ while tb_mine and "in _record_and_define" not in tb_mine[0]:
643
+ tb_mine.pop(0)
644
+
645
+ tb_theirs = traceback.format_list(
646
+ reference_recording.tracebacks[traceback_index]
647
+ )
648
+ while tb_theirs and "in _record_and_define" not in tb_theirs[0]:
649
+ tb_theirs.pop(0)
650
+
651
+ the_diff = "\n".join(difflib.ndiff([str(theirs)], [str(mine)]))
652
+ raise RuntimeError(
653
+ f"monarch.compiled failed to verify recording. Recording diverges at operation {last_index}.\n{the_diff}\n\nTraceback of original recording\n{''.join(tb_theirs)}\n\nTraceback of second recording\n{''.join(tb_mine)}\n"
654
+ )
655
+
656
+ def verify_against(self, reference: Recording):
657
+ self.reference_recording = reference
658
+
659
+ def define_recording(
660
+ self,
661
+ client: "Client",
662
+ nresults: int,
663
+ nformals: int,
664
+ ) -> Recording:
665
+ self._check()
666
+ # any remaining references to tensors we defined in the recording are
667
+ # not valid for future use outside the recording, so drop them
668
+ # such that we report an error if they are used.
669
+ for w in self.creates:
670
+ v = w()
671
+ if v is not None:
672
+ v._drop_ref()
673
+ # It should be safe to use a list instead of a set here, since
674
+ # no entry in formal_storages_to_indices should have any overlap
675
+ # with any other entry. So mutated_formal_indices should automatically
676
+ # have unique elements.
677
+ mutated_formal_indices = []
678
+ for storage in self.mutated_formal_storages:
679
+ mutated_formal_indices.extend(self.formal_storages_to_indices[storage])
680
+ return Recording(
681
+ client,
682
+ list(self.uses.keys()),
683
+ list(self.mutates.keys()),
684
+ sorted(mutated_formal_indices),
685
+ self.tracebacks,
686
+ self.messages,
687
+ nresults,
688
+ nformals,
689
+ self.first_ref,
690
+ )
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ SIM_MESH_CLIENT_TIMEOUT = 5
10
+ SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL = 5