torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,646 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import atexit
9
+ import difflib
10
+ import itertools
11
+ import logging
12
+ import math
13
+ import time
14
+ import traceback
15
+ import weakref
16
+ from collections import defaultdict
17
+ from typing import (
18
+ Callable,
19
+ cast,
20
+ Dict,
21
+ List,
22
+ NamedTuple,
23
+ Optional,
24
+ Sequence,
25
+ Set,
26
+ Tuple,
27
+ TYPE_CHECKING,
28
+ Union,
29
+ )
30
+
31
+ from weakref import WeakKeyDictionary
32
+
33
+ import torch
34
+ import torch.distributed
35
+ from monarch._rust_bindings.monarch_extension import tensor_worker
36
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
37
+ LogLevel,
38
+ WorldState,
39
+ )
40
+ from monarch.common import messages
41
+ from monarch.common.borrows import Borrow, StorageAliases
42
+ from monarch.common.controller_api import LogMessage, MessageResult, TController
43
+ from monarch.common.device_mesh import DeviceMesh
44
+ from monarch.common.invocation import DeviceException, RemoteException, Seq
45
+ from monarch.common.recording import flatten_messages, Recording
46
+
47
+ from monarch.common.reference import Ref, Referenceable
48
+ from monarch.common.shape import NDSlice
49
+ from monarch.common.stream import StreamRef
50
+ from monarch.common.tensor import Tensor
51
+ from monarch.common.tree import tree_map
52
+
53
+ from . import _coalescing
54
+
55
+ if TYPE_CHECKING:
56
+ from monarch.common.future import Future
57
+
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+ _CONTROLLER_STATUS_INTERVAL = 2
62
+
63
+
64
+ def TTL(timeout: Optional[float]) -> Callable[[], float]:
65
+ if timeout is None:
66
+ return lambda: math.inf
67
+ expiry = time.time() + timeout
68
+ return lambda: max(expiry - time.time(), 0)
69
+
70
+
71
+ class Client:
72
+ def __init__(
73
+ self,
74
+ controller: TController,
75
+ world_size: int,
76
+ gpu_per_host: int,
77
+ ):
78
+ self.inner = controller
79
+ self._world_size = world_size
80
+ self._gpu_per_host = gpu_per_host
81
+ self.next_ref = itertools.count()
82
+ self.failures: Dict[int, Dict[int, RemoteException]] = defaultdict(dict)
83
+ self._pending_del: Dict[DeviceMesh, List[int]] = defaultdict(list)
84
+ self._shutdown = False
85
+ self.controller_status_ttl = TTL(_CONTROLLER_STATUS_INTERVAL)
86
+ self._aliases: WeakKeyDictionary[torch.UntypedStorage, StorageAliases] = (
87
+ WeakKeyDictionary()
88
+ )
89
+
90
+ # stream._active = Stream("main2", _default=True)
91
+
92
+ self._backend_network_init = False
93
+ self._backend_network_init_point_to_point: Set[
94
+ Tuple["StreamRef", "StreamRef"]
95
+ ] = set()
96
+
97
+ self.seq_gen = itertools.count()
98
+ # seq of the most recent message that was sent to controller
99
+ self.last_assigned_seq = -1
100
+ # seq of the last acked message from controller, ack message is initiated
101
+ # by the _request_status() call. By comparing last_processed_seq and
102
+ # last_assigned_seq, we can tell if all messages are processed by all
103
+ # workers.
104
+ self.last_processed_seq = -1
105
+
106
+ self.recorder = Recorder()
107
+
108
+ self.pending_results: Dict[
109
+ Seq, # seq of an invocation
110
+ Tuple[
111
+ Optional["Future"], # future to set
112
+ List[List[traceback.FrameSummary]], # local call stacks
113
+ ],
114
+ ] = {}
115
+ atexit.register(self._atexit)
116
+ self.created_communicators = set()
117
+
118
+ def send(
119
+ self,
120
+ ranks: Union[NDSlice, List[NDSlice]],
121
+ msg: NamedTuple,
122
+ ) -> None:
123
+ if not _coalescing.is_active(self):
124
+ return self.send_nocoalesce(ranks, msg)
125
+ if _coalescing.is_recording(self):
126
+ match msg:
127
+ case messages.BorrowFirstUse() if msg.borrow not in self.recorder.borrow_entries_created:
128
+ return self.send_nocoalesce(ranks, msg)
129
+ case messages.BorrowLastUse() if msg.borrow not in self.recorder.borrow_entries_created:
130
+ raise ValueError(
131
+ "cannot explicitly drop a tensor inside a compiled block that was borrowed outside of it."
132
+ )
133
+ self.recorder.add_message(ranks, msg)
134
+
135
+ def send_nocoalesce(
136
+ self,
137
+ ranks: Union[NDSlice, List[NDSlice]],
138
+ msg: NamedTuple,
139
+ ) -> None:
140
+ self.inner.send(ranks, msg)
141
+
142
+ def reset_recorder(self) -> "Recorder":
143
+ old, self.recorder = self.recorder, Recorder()
144
+ return old
145
+
146
+ def drop_borrow(self, borrow: "Borrow") -> None:
147
+ if not _coalescing.is_active(self):
148
+ return
149
+ if borrow._id not in self.recorder.borrow_entries_created:
150
+ tb = borrow.traceback_string
151
+ raise RuntimeError(
152
+ f"Borrow Traceback:\n{tb}Cannot drop a borrow while repeating a coalesced block because it would cause the borrow to drop multiple times. "
153
+ )
154
+ del self.recorder.borrow_entries_created[borrow._id]
155
+
156
+ def new_borrow(self, borrow_entry: "Borrow") -> None:
157
+ if not _coalescing.is_active(self):
158
+ return
159
+ self.recorder.borrow_entries_created[borrow_entry._id] = borrow_entry
160
+
161
+ @property
162
+ def all_ranks(self) -> NDSlice:
163
+ return NDSlice(offset=0, sizes=[self._world_size], strides=[1])
164
+
165
+ @property
166
+ def gpu_per_host(self) -> int:
167
+ return self._gpu_per_host
168
+
169
+ # shut down everything, including client/system/controller/workers.
170
+ # the shutdown procedure will wait for all messages to be processed
171
+ # by the worker, then stop the system.
172
+ def shutdown(
173
+ self,
174
+ destroy_pg: bool = True,
175
+ error_reason: Optional[RemoteException | DeviceException | Exception] = None,
176
+ ) -> None:
177
+ logger.info("shutting down the client gracefully")
178
+
179
+ atexit.unregister(self._atexit)
180
+ self._shutdown = True
181
+
182
+ # request status for the last sent seq, and wait for the result to make sure all
183
+ # seqs are processed.
184
+ if self.last_assigned_seq > self.last_processed_seq:
185
+ self._request_status()
186
+
187
+ # send Exit message to stop the workers, wait for a bit for the workers to Exit
188
+ # with the correct exit code before we stop the system.
189
+ self.send(self.all_ranks, messages.Exit(destroy_pg, error_reason))
190
+ time.sleep(2)
191
+
192
+ # put a overall timeout on the shutdown waiting for now, better shutdown for
193
+ # multi-mesh setup will be implemented later.
194
+ timeout = 60
195
+ start_time = time.time()
196
+
197
+ try:
198
+ while (
199
+ time.time() - start_time < timeout
200
+ and self.last_assigned_seq > self.last_processed_seq
201
+ ):
202
+ # TODO(T216336422): retire client::drain_and_stop() as it doesn't
203
+ # really drain all messages
204
+ output = self.inner.next_message(1.0)
205
+ if output is not None:
206
+ if isinstance(output, MessageResult):
207
+ # restart the timer as we got new result back
208
+ start_time = time.time()
209
+ self._handle_pending_result(output)
210
+ elif isinstance(output, LogMessage):
211
+ self._log_message(output)
212
+
213
+ # Drain any remaining message in client queue (if any)
214
+ for output in self.inner.drain_and_stop():
215
+ if isinstance(output, MessageResult):
216
+ self._handle_pending_result(output)
217
+ elif isinstance(output, LogMessage):
218
+ self._log_message(output)
219
+ except DeviceException:
220
+ # exception in message draining should be ignored during shutdown, as
221
+ # we are shutting down the system anyway
222
+ logger.warning(
223
+ "exception in message draining during shutdown, "
224
+ "ignoring and continue to stop the system"
225
+ )
226
+ pass
227
+
228
+ # all messages are processed, we can now stop the system
229
+ if time.time() - start_time >= timeout:
230
+ logger.warning(
231
+ "timeout waiting for all messages to be processed, "
232
+ "stop the mesh anyway"
233
+ )
234
+ else:
235
+ logger.info("all messages are processed, stop the mesh")
236
+ self.inner.stop_mesh()
237
+
238
+ @property
239
+ def has_shutdown(self) -> bool:
240
+ return self._shutdown
241
+
242
+ def new_ref(self) -> int:
243
+ r = next(self.next_ref)
244
+ if _coalescing.is_active(self):
245
+ self.recorder.first_ref = min(self.recorder.first_ref, r)
246
+ return r
247
+
248
+ def handle_deletes(
249
+ self,
250
+ ranks: Union[NDSlice, List[NDSlice]],
251
+ refs: List[int],
252
+ coalesce: bool = True,
253
+ ):
254
+ if coalesce:
255
+ self.send(ranks, messages.DeleteRefs(refs))
256
+ else:
257
+ self.send_nocoalesce(ranks, messages.DeleteRefs(refs))
258
+ self.inner.drop_refs([tensor_worker.Ref(id=ref) for ref in refs])
259
+
260
+ def flush_deletes(self, coalesce: bool = True):
261
+ for mesh, refs in self._pending_del.items():
262
+ self.handle_deletes(mesh.processes, refs, coalesce)
263
+ self._pending_del.clear()
264
+
265
+ def delete_ref(self, device_mesh: DeviceMesh, ref: int) -> None:
266
+ self._pending_del[device_mesh].append(ref)
267
+
268
+ @property
269
+ def aliases(self) -> WeakKeyDictionary[torch.UntypedStorage, StorageAliases]:
270
+ return self._aliases
271
+
272
+ def _request_status(self):
273
+ self.send(
274
+ self.all_ranks,
275
+ messages.RequestStatus(self.last_assigned_seq, False),
276
+ )
277
+
278
+ def handle_next_message(self, timeout: Optional[float]) -> bool:
279
+ output = self.inner.next_message(timeout)
280
+ if output is not None:
281
+ if isinstance(output, MessageResult):
282
+ self._handle_pending_result(output)
283
+ elif isinstance(output, LogMessage):
284
+ self._log_message(output)
285
+ return True
286
+ return False
287
+
288
+ def _log_message(self, msg: LogMessage) -> None:
289
+ match msg.level:
290
+ case LogLevel.INFO:
291
+ logger.info(msg.message)
292
+ case LogLevel.WARNING:
293
+ logger.warning(msg.message)
294
+ case LogLevel.ERROR:
295
+ logger.error(msg.message)
296
+
297
+ def _handle_pending_result(self, output: MessageResult) -> None:
298
+ result = output.result
299
+ seq = output.seq
300
+ error = output.error
301
+
302
+ self.last_processed_seq = max(self.last_processed_seq, seq)
303
+
304
+ if error is not None:
305
+ logging.error("Received error for seq %s: %s", seq, error)
306
+ # We should not have set result if we have an error.
307
+ assert result is None
308
+ if not isinstance(error, RemoteException):
309
+ raise error
310
+
311
+ # Populate controller tracebacks for the remote failure
312
+ original_frame_seq = error.seq
313
+ index = error.controller_frame_index
314
+ assert index is not None
315
+ # TODO: Populate tracebacks for dependent invocations
316
+ if original_frame_seq == seq:
317
+ # The current invocation is the one causing the remote failure.
318
+ # We should have not populated the tracebacks yet.
319
+ assert error.controller_frames is None
320
+ _, tracebacks = self.pending_results[original_frame_seq]
321
+ assert tracebacks is not None
322
+ assert (
323
+ len(tracebacks) > index
324
+ ), f"tracebacks contains {len(tracebacks)} frames, but index is {index}"
325
+ error.controller_frames = tracebacks[index]
326
+
327
+ fut, _ = self.pending_results[seq]
328
+ if fut is not None:
329
+ fut._set_result(result if error is None else error)
330
+ elif result is not None:
331
+ logger.debug(f"{seq}: unused result {result}")
332
+ elif error is not None:
333
+ # errors get reported as results even if they
334
+ # do not have futures attached.
335
+ logger.warning(
336
+ f"Error encountered for this instruction {seq}. Proceeding forward because error is unused and unhandled. Error details:\n{error}."
337
+ )
338
+
339
+ # We can safely delete the seq as tracebacks have been saved to the remote failure itself.
340
+ del self.pending_results[seq]
341
+
342
+ def split_comm(self, dims, device_mesh, stream_ref) -> None:
343
+ """Create a split communicator group with the specified ranks, and
344
+ associate it with a specific device mesh and stream.
345
+ """
346
+ # For simplicity, just send this message to all ranks and split from the
347
+ # global communicator. As an optimization, the client could remember
348
+ # which comms have already been created and issue a message to a smaller
349
+ # set of ranks.
350
+ if not self._backend_network_init:
351
+ raise AssertionError(
352
+ "split_comm called before backend network initialization"
353
+ )
354
+
355
+ msg = messages.SplitComm(tuple(sorted(dims)), device_mesh, stream_ref)
356
+ if msg in self.created_communicators:
357
+ return
358
+
359
+ self.send_nocoalesce(self.all_ranks, msg)
360
+ self.created_communicators.add(msg)
361
+
362
+ def backend_network_init(self) -> None:
363
+ if self._backend_network_init:
364
+ return
365
+ self._backend_network_init = True
366
+ logger.info("Initializing backend network")
367
+ self.send_nocoalesce(self.all_ranks, messages.BackendNetworkInit())
368
+
369
+ def backend_network_point_to_point_init(
370
+ self, from_stream_ref: "StreamRef", to_stream_ref: "StreamRef"
371
+ ) -> None:
372
+ key = (from_stream_ref, to_stream_ref)
373
+
374
+ if key in self._backend_network_init_point_to_point:
375
+ return
376
+ self._backend_network_init_point_to_point.add(key)
377
+ self.send_nocoalesce(
378
+ self.all_ranks,
379
+ messages.BackendNetworkPointToPointInit(from_stream_ref, to_stream_ref),
380
+ )
381
+
382
+ def new_node(
383
+ self,
384
+ defs: Sequence["Tensor"],
385
+ uses: Sequence["Tensor"],
386
+ future: Optional["Future"] = None,
387
+ tracebacks: Optional[List[List[traceback.FrameSummary]]] = None,
388
+ ) -> Seq:
389
+ for t in uses:
390
+ t._use()
391
+
392
+ if tracebacks is None:
393
+ tracebacks = [traceback.extract_stack()[:-2]]
394
+ if _coalescing.is_recording(self):
395
+ assert future is None, "this should have been checked in fetch shard"
396
+ return self.recorder.add(defs, uses, tracebacks[0])
397
+ else:
398
+ return self.new_node_nocoalesce(defs, uses, future, tracebacks)
399
+
400
+ def new_node_nocoalesce(
401
+ self,
402
+ defs: Sequence["Tensor"],
403
+ uses: Sequence["Tensor"],
404
+ future: Optional["Future"],
405
+ tracebacks: List[List[traceback.FrameSummary]],
406
+ ) -> Seq:
407
+ seq = self._next_seq()
408
+ self.pending_results[seq] = (future, tracebacks)
409
+ for d in defs:
410
+ d._seq = seq
411
+ self.inner.node(seq, defs, uses)
412
+ return seq
413
+
414
+ def _next_seq(self) -> Seq:
415
+ self.last_assigned_seq = next(self.seq_gen)
416
+ return self.last_assigned_seq
417
+
418
+ def _atexit(self) -> None:
419
+ logger.warning(
420
+ "Client is not shutting down properly before atexit. "
421
+ "This may be due to an exception or because device_mesh.exit() "
422
+ "was not called."
423
+ )
424
+ # Calling self.shutdown may cause a deadlock if something is wrong with
425
+ # the networking. Or should we make shutdown() not wait indefinitely?
426
+ self._shutdown = True
427
+
428
+ # send shutdown message to stop other processes.
429
+ self.inner.stop_mesh()
430
+
431
+ def no_coalescing(self, reason):
432
+ if _coalescing.is_active(self):
433
+ raise NotImplementedError(f"NYI: {reason} during a coalescing block")
434
+
435
+ def mesh_state(self) -> WorldState:
436
+ return self.inner.worker_world_state()
437
+
438
+
439
+ def tree_map_refs(first_ref: int, tree):
440
+ def translate_id(ref: int) -> int:
441
+ diff = ref - first_ref
442
+ if diff >= 0:
443
+ return -1 - diff
444
+ return ref
445
+
446
+ def translate_ref(obj):
447
+ match obj:
448
+ case Ref():
449
+ return translate_id(obj.id)
450
+ case Referenceable():
451
+ return None if obj.ref is None else translate_id(obj.ref)
452
+ case messages.DeleteRefs():
453
+ # Python destructors may not run in a deterministic order across
454
+ # traces of a recorded function, so we need to sort the refs to ensure
455
+ # a fair comparison during validation.
456
+ return messages.DeleteRefs(sorted([translate_id(r) for r in obj.refs]))
457
+ case messages.BorrowCreate():
458
+ result, borrow, *rest = [translate_ref(x) for x in obj]
459
+ return messages.BorrowCreate(result, translate_id(borrow), *rest)
460
+ case messages.BorrowDrop():
461
+ return messages.BorrowDrop(translate_id(obj.borrow))
462
+ case messages.BorrowFirstUse():
463
+ return messages.BorrowFirstUse(translate_id(obj.borrow))
464
+ case messages.BorrowLastUse():
465
+ return messages.BorrowLastUse(translate_id(obj.borrow))
466
+ case _:
467
+ return obj
468
+
469
+ return tree_map(
470
+ translate_ref,
471
+ tree,
472
+ is_leaf=lambda x: isinstance(
473
+ x,
474
+ (
475
+ Ref,
476
+ Referenceable,
477
+ messages.DeleteRefs,
478
+ messages.BorrowCreate,
479
+ messages.BorrowDrop,
480
+ messages.BorrowFirstUse,
481
+ messages.BorrowLastUse,
482
+ ),
483
+ ),
484
+ )
485
+
486
+
487
+ class Recorder:
488
+ def __init__(self):
489
+ self.borrow_entries_created: Dict[int, Borrow] = {}
490
+ self.messages: List[Union[NDSlice, List[NDSlice]], NamedTuple] = []
491
+ # these tables track the externally captured tensors that we
492
+ # use and mutate whenever this recording is run.
493
+ self.uses = {} # ordered set
494
+ self.mutates = {} # ordered set
495
+ self.creates: List[weakref.ref] = []
496
+ self.tracebacks = []
497
+ self.first_ref: int = math.inf
498
+ self.reference_recording: Optional["Recording"] = None
499
+ # Map from formal tensor storage to its corresponding argument indices
500
+ # in the recording input (there may be multiple aliases of the same
501
+ # tensor in the recording input).
502
+ self.formal_storages_to_indices: defaultdict[
503
+ torch.UntypedStorage, List[int]
504
+ ] = defaultdict(list)
505
+ # Set of tensor storages for formals that are mutated during the recording.
506
+ self.mutated_formal_storages: Set[torch.UntypedStorage] = set()
507
+
508
+ def add_formal(self, formal: Tensor, argument_index: int) -> None:
509
+ self.formal_storages_to_indices[formal._fake.untyped_storage()].append(
510
+ argument_index
511
+ )
512
+
513
+ def add(
514
+ self,
515
+ defs: Sequence["Tensor"],
516
+ uses: Sequence["Tensor"],
517
+ traceback: List[traceback.FrameSummary],
518
+ ):
519
+ for u in uses:
520
+ if u._seq is None:
521
+ # a lack of sequence num on a tensor means it was created within
522
+ # the recording and does not have to be tracked as a use
523
+ continue
524
+ self.uses[u] = None
525
+ for d in defs:
526
+ # a lack of sequence num means the tensor doesn't need to be tracked
527
+ # as a mutates, unless that tensor is an alias of a formal tensor
528
+ if d._seq is None:
529
+ self.creates.append(weakref.ref(d))
530
+ storage = d._fake.untyped_storage()
531
+ if storage in self.formal_storages_to_indices:
532
+ self.mutated_formal_storages.add(storage)
533
+ else:
534
+ self.mutates[d] = None
535
+ self.tracebacks.append(traceback)
536
+ return len(self.tracebacks) - 1
537
+
538
+ def _check(self):
539
+ if self.borrow_entries_created:
540
+ tbs = "------------\n".join(
541
+ b.traceback_string for b in self.borrow_entries_created.values()
542
+ )
543
+ raise RuntimeError(
544
+ f"Borrows created during recorded coalesced block need to be dropped before the block ends. Tracebacks of where the blocks were created: {tbs}"
545
+ )
546
+
547
+ @property
548
+ def flat_messages(self):
549
+ return flatten_messages(self.messages)
550
+
551
+ def run_once(self, client: "Client"):
552
+ self._check()
553
+ for rank, msgs in self.flat_messages.items():
554
+ client.send_nocoalesce(
555
+ NDSlice(offset=rank, sizes=[], strides=[]), messages.CommandGroup(msgs)
556
+ )
557
+
558
+ def abandon(self):
559
+ # an error happened and we will not use this recording. Every tensor created
560
+ # as part of this recording has never been defined, so we blank out the
561
+ # .ref to disarm the deletions.
562
+ for w in self.creates:
563
+ v = w()
564
+ if v is not None:
565
+ v.ref = None
566
+
567
+ def add_message(self, ranks: Union[NDSlice, List[NDSlice]], msg: NamedTuple):
568
+ if isinstance(msg, messages.RecordingFormal):
569
+ self.add_formal(cast(Tensor, msg.result), msg.argument_index)
570
+
571
+ # this is pretty expensive, but we can't hold tensor references without
572
+ # extending their lifetime unnecessarily, so they must be converted to
573
+ # references here. It also prevents a bug when a tensor is dropped,
574
+ # after a message is recorded and will no longer have a ref field.
575
+ msg = tree_map(
576
+ lambda x: (
577
+ Ref(x.ref) if isinstance(x, Tensor) and x.ref is not None else x
578
+ ),
579
+ msg,
580
+ )
581
+ self.messages.append((ranks, msg))
582
+ reference_recording = self.reference_recording
583
+ if reference_recording is not None:
584
+ last_index = len(self.messages) - 1
585
+ reference_messages = reference_recording.buffered_messages
586
+ mine = self.messages[last_index]
587
+ theirs = (
588
+ reference_messages[last_index]
589
+ if len(reference_messages) > last_index
590
+ else None
591
+ )
592
+ mine = tree_map_refs(self.first_ref, mine)
593
+ theirs = tree_map_refs(reference_recording.first_ref, theirs)
594
+ if mine != theirs:
595
+ traceback_index = len(self.tracebacks) - 1
596
+
597
+ tb_mine = traceback.format_list(self.tracebacks[traceback_index])
598
+ while tb_mine and "in _record_and_define" not in tb_mine[0]:
599
+ tb_mine.pop(0)
600
+
601
+ tb_theirs = traceback.format_list(
602
+ reference_recording.tracebacks[traceback_index]
603
+ )
604
+ while tb_theirs and "in _record_and_define" not in tb_theirs[0]:
605
+ tb_theirs.pop(0)
606
+
607
+ the_diff = "\n".join(difflib.ndiff([str(theirs)], [str(mine)]))
608
+ raise RuntimeError(
609
+ f"monarch.compiled failed to verify recording. Recording diverges at operation {last_index}.\n{the_diff}\n\nTraceback of original recording\n{''.join(tb_theirs)}\n\nTraceback of second recording\n{''.join(tb_mine)}\n"
610
+ )
611
+
612
+ def verify_against(self, reference: Recording):
613
+ self.reference_recording = reference
614
+
615
+ def define_recording(
616
+ self,
617
+ client: "Client",
618
+ nresults: int,
619
+ nformals: int,
620
+ ) -> Recording:
621
+ self._check()
622
+ # any remaining references to tensors we defined in the recording are
623
+ # not valid for future use outside the recording, so drop them
624
+ # such that we report an error if they are used.
625
+ for w in self.creates:
626
+ v = w()
627
+ if v is not None:
628
+ v._drop_ref()
629
+ # It should be safe to use a list instead of a set here, since
630
+ # no entry in formal_storages_to_indices should have any overlap
631
+ # with any other entry. So mutated_formal_indices should automatically
632
+ # have unique elements.
633
+ mutated_formal_indices = []
634
+ for storage in self.mutated_formal_storages:
635
+ mutated_formal_indices.extend(self.formal_storages_to_indices[storage])
636
+ return Recording(
637
+ client,
638
+ list(self.uses.keys()),
639
+ list(self.mutates.keys()),
640
+ sorted(mutated_formal_indices),
641
+ self.tracebacks,
642
+ self.messages,
643
+ nresults,
644
+ nformals,
645
+ self.first_ref,
646
+ )
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ SIM_MESH_CLIENT_TIMEOUT = 5
10
+ SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL = 5
@@ -0,0 +1,40 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from functools import wraps
9
+
10
+
11
+ class _ContextManager:
12
+ def __init__(self, generator):
13
+ self.generator = generator
14
+ self.generator.send(None)
15
+
16
+ def __enter__(self):
17
+ return
18
+
19
+ def __exit__(self, *args):
20
+ try:
21
+ self.generator.send(None)
22
+ except StopIteration:
23
+ pass
24
+ else:
25
+ raise RuntimeError("context manager generator did not exit")
26
+
27
+
28
+ def activate_first_context_manager(func):
29
+ """
30
+ Similar to contextlib.contextmanager but it
31
+ starts the context when the function is called rather than
32
+ than at the start of the with statement. Useful for things where
33
+ you want to optionally activate the context without a guard.
34
+ """
35
+
36
+ @wraps(func)
37
+ def helper(*args, **kwargs):
38
+ return _ContextManager(func(*args, **kwargs))
39
+
40
+ return helper