torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,572 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from __future__ import annotations
10
+
11
+ from traceback import FrameSummary
12
+ from typing import (
13
+ cast,
14
+ Dict,
15
+ List,
16
+ Literal,
17
+ NamedTuple,
18
+ Optional,
19
+ Protocol,
20
+ Tuple,
21
+ TYPE_CHECKING,
22
+ )
23
+
24
+ from monarch._rust_bindings.monarch_extension import tensor_worker
25
+ from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction
26
+ from monarch.common.invocation import DeviceException, RemoteException
27
+ from monarch.common.reference import Referenceable
28
+ from monarch.common.stream import StreamRef
29
+ from monarch.common.tree import flattener
30
+ from pyre_extensions import none_throws
31
+
32
+ from .shape import NDSlice
33
+ from .tensor_factory import TensorFactory
34
+
35
+ if TYPE_CHECKING:
36
+ from .device_mesh import DeviceMesh, RemoteProcessGroup
37
+ from .pipe import Pipe
38
+ from .recording import Recording
39
+ from .tensor import Tensor
40
+
41
+
42
+ Dims = Tuple[str, ...]
43
+
44
+
45
+ def _to_rust_function(
46
+ x: ResolvableFunction,
47
+ ) -> tensor_worker.ResolvableFunction:
48
+ if isinstance(x, ResolvableFromCloudpickle):
49
+ return tensor_worker.Cloudpickle(bytes=x.data)
50
+ return tensor_worker.FunctionPath(path=str(x))
51
+
52
+
53
+ def _result_to_references(result: object) -> List[tensor_worker.Ref | None]:
54
+ """
55
+ Flatten the result pytree.
56
+ Only keep the referenceables and leave the rest as None.
57
+ The workers will generate the full result list so we know
58
+ what referenceables to be assigned to.
59
+ """
60
+ leaves = flattener(result, lambda x: True)(result)
61
+ return [
62
+ _ref(leaf)
63
+ if isinstance(leaf, Referenceable) or isinstance(leaf, tensor_worker.Ref)
64
+ else None
65
+ for leaf in leaves
66
+ ]
67
+
68
+
69
+ def _ref(r: Referenceable | tensor_worker.Ref) -> tensor_worker.Ref:
70
+ if isinstance(r, Referenceable):
71
+ return tensor_worker.Ref(id=none_throws(r.ref))
72
+ return r
73
+
74
+
75
+ # We cant do inheritance with NamedTuple so we can use this protocol for
76
+ # type casting for now until we can move to rust messages entirely.
77
+ # Preferring this over a massive if else to keep everything co-located and
78
+ # easier to identify drift.
79
+ class SupportsToRustMessage(Protocol):
80
+ def to_rust_message(self) -> tensor_worker.WorkerMessage: ...
81
+
82
+
83
+ class CreateDeviceMesh(NamedTuple):
84
+ result: DeviceMesh
85
+ names: Dims
86
+ ranks: NDSlice
87
+
88
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
89
+ return tensor_worker.CreateDeviceMesh(
90
+ result=tensor_worker.Ref(id=self.result.ref),
91
+ names=self.names,
92
+ ranks=NDSlice(
93
+ offset=self.ranks.offset,
94
+ sizes=self.ranks.sizes,
95
+ strides=self.ranks.strides,
96
+ ),
97
+ )
98
+
99
+
100
+ class CreateStream(NamedTuple):
101
+ result: StreamRef
102
+ default: bool
103
+
104
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
105
+ return tensor_worker.CreateStream(
106
+ id=tensor_worker.StreamRef(id=self.result.ref),
107
+ stream_creation=(
108
+ tensor_worker.StreamCreationMode.UseDefaultStream
109
+ if self.default
110
+ else tensor_worker.StreamCreationMode.CreateNewStream
111
+ ),
112
+ )
113
+
114
+
115
+ class CreateRemoteProcessGroup(NamedTuple):
116
+ result: Referenceable
117
+ device_mesh: DeviceMesh
118
+ dims: Dims
119
+
120
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
121
+ return tensor_worker.CreateRemoteProcessGroup(
122
+ result=tensor_worker.Ref(id=none_throws(self.result.ref)),
123
+ device_mesh=tensor_worker.Ref(id=self.device_mesh.ref),
124
+ dims=self.dims,
125
+ )
126
+
127
+
128
+ class CallFunction(NamedTuple):
129
+ ident: int
130
+ result: object # pytree with tensors in it
131
+ mutates: Tuple[Tensor | tensor_worker.Ref, ...]
132
+ function: ResolvableFunction
133
+ args: Tuple[object, ...]
134
+ kwargs: Dict[str, object]
135
+ stream: StreamRef
136
+ device_mesh: DeviceMesh
137
+ remote_process_groups: List[RemoteProcessGroup]
138
+
139
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
140
+ return tensor_worker.CallFunction(
141
+ seq=self.ident,
142
+ results=_result_to_references(self.result),
143
+ mutates=[_ref(r) for r in self.mutates],
144
+ function=_to_rust_function(self.function),
145
+ args=self.args,
146
+ kwargs=self.kwargs,
147
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
148
+ remote_process_groups=[
149
+ tensor_worker.Ref(id=none_throws(remote_process_group.ref))
150
+ for remote_process_group in self.remote_process_groups
151
+ ],
152
+ )
153
+
154
+
155
+ class Exit(NamedTuple):
156
+ destroy_pg: bool
157
+ error: Optional[RemoteException | DeviceException | Exception]
158
+
159
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
160
+ actor_id = None
161
+ error_message = None
162
+ if isinstance(self.error, (RemoteException, DeviceException)):
163
+ actor_id = self.error.source_actor_id
164
+ error_message = self.error.message
165
+ elif self.error is not None:
166
+ error_message = str(self.error)
167
+
168
+ error_reason = None if error_message is None else (actor_id, error_message)
169
+ return tensor_worker.Exit(error_reason=error_reason)
170
+
171
+
172
+ class CommandGroup(NamedTuple):
173
+ commands: List[NamedTuple]
174
+
175
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
176
+ rust_commands = []
177
+ for c in self.commands:
178
+ if hasattr(c, "to_rust_message"):
179
+ c = cast(SupportsToRustMessage, c)
180
+ rust_commands.append(c.to_rust_message())
181
+ else:
182
+ raise NotImplementedError(f"Unsupported command {c}")
183
+ return tensor_worker.CommandGroup(commands=rust_commands)
184
+
185
+
186
+ class RecordingFormal(NamedTuple):
187
+ result: Tensor | tensor_worker.Ref
188
+ argument_index: int
189
+ stream: "StreamRef"
190
+
191
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
192
+ return tensor_worker.RecordingFormal(
193
+ result=_ref(self.result),
194
+ argument_index=self.argument_index,
195
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
196
+ )
197
+
198
+
199
+ class RecordingResult(NamedTuple):
200
+ input: Tensor | tensor_worker.Ref
201
+ output_index: int
202
+ stream: StreamRef
203
+
204
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
205
+ return tensor_worker.RecordingResult(
206
+ result=_ref(self.input),
207
+ output_index=self.output_index,
208
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
209
+ )
210
+
211
+
212
+ class DefineRecording(NamedTuple):
213
+ result: Recording
214
+ nresults: int
215
+ nformals: int
216
+ commands: List[NamedTuple]
217
+ ntotal_messages: int
218
+ message_index: int
219
+
220
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
221
+ define_recording = tensor_worker.DefineRecording(
222
+ result=tensor_worker.Ref(id=none_throws(self.result.ref)),
223
+ nresults=self.nresults,
224
+ nformals=self.nformals,
225
+ commands=[],
226
+ ntotal_messages=self.ntotal_messages,
227
+ index=self.message_index,
228
+ )
229
+ for c in self.commands:
230
+ if hasattr(c, "to_rust_message"):
231
+ c = cast(SupportsToRustMessage, c)
232
+ if isinstance(c, CallFunction):
233
+ define_recording.append_call_function(
234
+ seq=c.ident,
235
+ results=_result_to_references(c.result),
236
+ mutates=[_ref(r) for r in c.mutates],
237
+ function=_to_rust_function(c.function),
238
+ args=c.args,
239
+ kwargs=c.kwargs,
240
+ stream=tensor_worker.StreamRef(id=c.stream.ref),
241
+ remote_process_groups=[
242
+ tensor_worker.Ref(id=none_throws(remote_process_group.ref))
243
+ for remote_process_group in c.remote_process_groups
244
+ ],
245
+ )
246
+ else:
247
+ define_recording.append(c.to_rust_message())
248
+ else:
249
+ raise NotImplementedError(f"Unsupported command {c}")
250
+ return define_recording
251
+
252
+
253
+ class CallRecording(NamedTuple):
254
+ ident: int
255
+ recording: Recording
256
+ results: List[Tensor | tensor_worker.Ref]
257
+ actuals: List[Tensor | tensor_worker.Ref]
258
+
259
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
260
+ return tensor_worker.CallRecording(
261
+ seq=self.ident,
262
+ recording=tensor_worker.Ref(id=none_throws(self.recording.ref)),
263
+ results=[_ref(r) for r in self.results],
264
+ actuals=[_ref(r) for r in self.actuals],
265
+ )
266
+
267
+
268
+ class DeleteRefs(NamedTuple):
269
+ refs: List[int]
270
+
271
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
272
+ return tensor_worker.DeleteRefs(
273
+ refs=[tensor_worker.Ref(id=r) for r in self.refs]
274
+ )
275
+
276
+
277
+ # This is worker <> controller/backend comms only will be supported differently
278
+ class Restarted(NamedTuple):
279
+ result: int
280
+
281
+
282
+ class SendValue(NamedTuple):
283
+ ident: int
284
+ destination: Pipe | None # if present the pipe along which to send the result,
285
+ # otherwise send FetchResult to controller
286
+ mutates: Tuple[Tensor | tensor_worker.Ref, ...]
287
+ function: ResolvableFunction | None # None is equivalent to lambda x: x
288
+ args: Tuple[object, ...]
289
+ kwargs: Dict[str, object]
290
+ stream: StreamRef
291
+
292
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
293
+ return tensor_worker.SendValue(
294
+ seq=self.ident,
295
+ destination=(
296
+ tensor_worker.Ref(id=self.destination.ref) if self.destination else None
297
+ ),
298
+ mutates=[_ref(r) for r in self.mutates],
299
+ function=_to_rust_function(self.function) if self.function else None,
300
+ args=self.args,
301
+ kwargs=self.kwargs,
302
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
303
+ )
304
+
305
+
306
+ # Worker -> Controller comm only handled differently
307
+ class FetchResult(NamedTuple):
308
+ ident: int
309
+ value: object
310
+
311
+
312
+ # Worker -> Controller comm only handled differently
313
+ class RemoteFunctionFailed(NamedTuple):
314
+ failing_ident: int
315
+ stack_offset: int
316
+ exception: Exception
317
+ worker_frames: List[FrameSummary]
318
+
319
+
320
+ # Worker -> Controller comm only handled differently
321
+ class InternalException(NamedTuple):
322
+ exception: Exception
323
+ frames: List[FrameSummary]
324
+
325
+
326
+ # Worker -> Controller comm only handled differently
327
+ class RemoteGeneratorFailed(NamedTuple):
328
+ exception: Exception
329
+ frames: List[FrameSummary]
330
+
331
+
332
+ # Worker -> Controller comm only handled differently
333
+ class Status(NamedTuple):
334
+ first_uncompleted_ident: int
335
+
336
+
337
+ # When the controller is waiting on a status update,
338
+ # it will request one even if it is before the
339
+ # periodic one.
340
+ class RequestStatus(NamedTuple):
341
+ ident: int
342
+ controller: bool
343
+
344
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
345
+ return tensor_worker.RequestStatus(seq=self.ident, controller=self.controller)
346
+
347
+
348
+ class BorrowCreate(NamedTuple):
349
+ result: Tensor | tensor_worker.Ref
350
+ borrow: int
351
+ tensor: Tensor | tensor_worker.Ref
352
+ from_stream: StreamRef
353
+ to_stream: StreamRef
354
+
355
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
356
+ return tensor_worker.BorrowCreate(
357
+ result=_ref(self.result),
358
+ borrow=self.borrow,
359
+ tensor=_ref(self.tensor),
360
+ from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
361
+ to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
362
+ )
363
+
364
+
365
+ class BorrowDrop(NamedTuple):
366
+ borrow: int # id of borrowed tensor
367
+
368
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
369
+ return tensor_worker.BorrowDrop(
370
+ borrow=self.borrow,
371
+ )
372
+
373
+
374
+ class BorrowFirstUse(NamedTuple):
375
+ borrow: int # id of borrowed tensor
376
+
377
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
378
+ return tensor_worker.BorrowFirstUse(
379
+ borrow=self.borrow,
380
+ )
381
+
382
+
383
+ class BorrowLastUse(NamedTuple):
384
+ borrow: int # id of borrowed tensor
385
+
386
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
387
+ return tensor_worker.BorrowLastUse(
388
+ borrow=self.borrow,
389
+ )
390
+
391
+
392
+ class SendTensor(NamedTuple):
393
+ result: Tensor | tensor_worker.Ref
394
+ from_ranks: NDSlice
395
+ to_ranks: NDSlice
396
+ tensor: Tensor | tensor_worker.Ref
397
+ factory: TensorFactory
398
+ from_stream: StreamRef
399
+ to_stream: StreamRef
400
+
401
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
402
+ return tensor_worker.SendTensor(
403
+ result=_ref(self.result),
404
+ from_ranks=NDSlice(
405
+ offset=self.from_ranks.offset,
406
+ sizes=self.from_ranks.sizes,
407
+ strides=self.from_ranks.strides,
408
+ ),
409
+ to_ranks=NDSlice(
410
+ offset=self.to_ranks.offset,
411
+ sizes=self.to_ranks.sizes,
412
+ strides=self.to_ranks.strides,
413
+ ),
414
+ tensor=_ref(self.tensor),
415
+ factory=tensor_worker.TensorFactory(
416
+ size=self.factory.size,
417
+ dtype=self.factory.dtype,
418
+ device=self.factory.device,
419
+ layout=self.factory.layout,
420
+ ),
421
+ from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
422
+ to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
423
+ )
424
+
425
+
426
+ class SplitComm(NamedTuple):
427
+ dims: Dims
428
+ device_mesh: DeviceMesh
429
+ stream: StreamRef
430
+
431
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
432
+ return tensor_worker.SplitComm(
433
+ dims=self.dims,
434
+ device_mesh=tensor_worker.Ref(id=self.device_mesh.ref),
435
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
436
+ )
437
+
438
+
439
+ class SplitCommForProcessGroup(NamedTuple):
440
+ remote_process_group: DeviceMesh
441
+ stream: StreamRef
442
+
443
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
444
+ return tensor_worker.SplitCommForProcessGroup(
445
+ remote_process_group=tensor_worker.Ref(id=self.remote_process_group.ref),
446
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
447
+ )
448
+
449
+
450
+ class Reduce(NamedTuple):
451
+ result: Tensor | tensor_worker.Ref
452
+ local_tensor: Tensor | tensor_worker.Ref
453
+ factory: TensorFactory
454
+ source_mesh: DeviceMesh
455
+ stream: StreamRef
456
+ dims: Dims
457
+ reduction: str
458
+ scatter: bool
459
+ inplace: bool
460
+ out: Tensor | tensor_worker.Ref | None
461
+
462
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
463
+ match self.reduction:
464
+ case "sum":
465
+ reduction = tensor_worker.ReductionType.Sum
466
+ case "prod":
467
+ reduction = tensor_worker.ReductionType.Prod
468
+ case "stack":
469
+ reduction = tensor_worker.ReductionType.Stack
470
+ case "avg":
471
+ reduction = tensor_worker.ReductionType.Avg
472
+ case "min":
473
+ reduction = tensor_worker.ReductionType.Min
474
+ case "max":
475
+ reduction = tensor_worker.ReductionType.Max
476
+ case _:
477
+ raise ValueError(f"Unsupported reduction {self.reduction}")
478
+
479
+ return tensor_worker.Reduce(
480
+ result=_ref(self.result),
481
+ tensor=_ref(self.local_tensor),
482
+ factory=tensor_worker.TensorFactory(
483
+ size=self.factory.size,
484
+ dtype=self.factory.dtype,
485
+ device=self.factory.device,
486
+ layout=self.factory.layout,
487
+ ),
488
+ mesh=tensor_worker.Ref(id=self.source_mesh.ref),
489
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
490
+ dims=self.dims,
491
+ reduction=reduction,
492
+ scatter=self.scatter,
493
+ in_place=self.inplace,
494
+ out=_ref(self.out) if self.out is not None else None,
495
+ )
496
+
497
+
498
+ class CreatePipe(NamedTuple):
499
+ result: Pipe
500
+ key: str
501
+ function: ResolvableFunction
502
+ max_messages: int
503
+ device_mesh: DeviceMesh
504
+ args: Tuple[object, ...]
505
+ kwargs: Dict[str, object]
506
+
507
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
508
+ return tensor_worker.CreatePipe(
509
+ result=tensor_worker.Ref(id=self.result.ref),
510
+ key=self.key,
511
+ function=_to_rust_function(self.function),
512
+ max_messages=self.max_messages,
513
+ mesh=tensor_worker.Ref(id=self.device_mesh.ref),
514
+ args=self.args,
515
+ kwargs=self.kwargs,
516
+ )
517
+
518
+
519
+ class PipeRecv(NamedTuple):
520
+ ident: int
521
+ result: object # pytree with tensors in it
522
+ pipe: Pipe
523
+ stream: StreamRef
524
+
525
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
526
+ return tensor_worker.PipeRecv(
527
+ seq=self.ident,
528
+ results=_result_to_references(self.result),
529
+ pipe=tensor_worker.Ref(id=self.pipe.ref),
530
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
531
+ )
532
+
533
+
534
+ class BackendNetworkInit(NamedTuple):
535
+ hostname: str | None = None
536
+ port: int | None = None
537
+
538
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
539
+ return tensor_worker.BackendNetworkInit()
540
+
541
+
542
+ class BackendNetworkPointToPointInit(NamedTuple):
543
+ from_stream: StreamRef
544
+ to_stream: StreamRef
545
+
546
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
547
+ return tensor_worker.BackendNetworkPointToPointInit(
548
+ from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
549
+ to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
550
+ )
551
+
552
+
553
+ # TODO: This is not supported on the rust side and might be only needed for remote funcs
554
+ class DebuggerRead(NamedTuple):
555
+ requested: int
556
+
557
+
558
+ # TODO: This is not supported on the rust side and might be only needed for remote funcs
559
+ class DebuggerWrite(NamedTuple):
560
+ payload: bytes
561
+
562
+
563
+ # TODO: This is not supported on the rust side and might be only needed for remote funcs
564
+ class DebuggerMessage(NamedTuple):
565
+ stream_id: int
566
+ action: Literal["paused", "attach", "detach"] | DebuggerRead | DebuggerWrite
567
+
568
+
569
+ # TODO: Might need to be supported differently through typed worker exceptions
570
+ class DependentOnError(Exception):
571
+ def __init__(self, ident: int) -> None:
572
+ self.ident = ident
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ from contextlib import contextmanager
9
+ from typing import Generator, Optional
10
+
11
+ import monarch.common._C # @manual=//monarch/python/monarch/common:_C
12
+ import torch
13
+
14
+ monarch.common._C.patch_cuda()
15
+
16
+ _mock_cuda_stream: Optional[torch.cuda.Stream] = None
17
+
18
+
19
+ def get_mock_cuda_stream() -> torch.cuda.Stream:
20
+ global _mock_cuda_stream
21
+ if _mock_cuda_stream is None:
22
+ _mock_cuda_stream = torch.cuda.Stream()
23
+ return _mock_cuda_stream
24
+
25
+
26
+ @contextmanager
27
+ def mock_cuda_guard() -> Generator[None, None, None]:
28
+ try:
29
+ with torch.cuda.stream(get_mock_cuda_stream()):
30
+ monarch.common._C.mock_cuda()
31
+ yield
32
+ finally:
33
+ monarch.common._C.unmock_cuda()
34
+
35
+
36
+ def mock_cuda() -> None:
37
+ monarch.common._C.mock_cuda()
38
+
39
+
40
+ def unmock_cuda() -> None:
41
+ monarch.common._C.unmock_cuda()
@@ -0,0 +1,98 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import functools
9
+ import itertools
10
+ import os
11
+ from typing import Any, Iterator
12
+
13
+ import torch
14
+ from torch._subclasses.fake_tensor import FakeTensor
15
+ from torch.utils._pytree import register_pytree_node
16
+ from torch.utils.weak import WeakTensorKeyDictionary
17
+
18
+ _key_table: WeakTensorKeyDictionary = WeakTensorKeyDictionary()
19
+ _key_counter: Iterator[int] = itertools.count(1)
20
+
21
+ # check that we are for sure running on the worker process
22
+ _on_worker = os.environ.get("LOCAL_RANK") is not None
23
+
24
+
25
+ def wrap_create(create, xs):
26
+ return create(xs[0])
27
+
28
+
29
+ class OpaqueRef:
30
+ """
31
+ OpaqueRef is a reference to an object that is only resolvable on the worker
32
+ This is used to pass objects from the controller to the worker across User Defined Functions
33
+
34
+ Example::
35
+ def init_udf_worker():
36
+ model = nn.Linear(3, 4)
37
+ model_ref = OpaqueRef(model)
38
+ return model_ref
39
+
40
+ def run_step_worker(model_ref: OpaqueRef):
41
+ model = model_ref.value
42
+ # do something with model (e.g. forward pass
43
+
44
+ # on Controller
45
+ model_ref = init_udf()
46
+ run_step(model_ref)
47
+
48
+ """
49
+
50
+ def __init__(self, value=None):
51
+ self._key = torch.tensor(next(_key_counter), dtype=torch.int64)
52
+ self.check_worker("create")
53
+ _key_table[self._key] = value
54
+
55
+ @classmethod
56
+ def _create(cls, key: torch.Tensor):
57
+ c = cls.__new__(cls)
58
+ c._key = key
59
+ return c
60
+
61
+ # like NamedTuple, just pass the call to reconstruct this
62
+ # rather than the dict. This also ensures the OpaqueObject
63
+ # subclass degrades into this class when sent to the worker
64
+ def __reduce_ex__(self, protocol):
65
+ return OpaqueRef._create, (self._key,)
66
+
67
+ def __repr__(self):
68
+ return f"OpaqueRef({repr(self._key)})"
69
+
70
+ @property
71
+ def value(self) -> Any:
72
+ self.check_worker("access")
73
+ return _key_table[self._key]
74
+
75
+ @value.setter
76
+ def value(self, v: Any) -> None:
77
+ self.check_worker("set")
78
+ _key_table[self._key] = v
79
+
80
+ def check_worker(self, what):
81
+ # both checks are needed for the case where OpaqueRef() is
82
+ # called on the client with no mesh active.
83
+ in_worker_or_propagate = _on_worker or isinstance(self._key, FakeTensor)
84
+ if not in_worker_or_propagate:
85
+ raise RuntimeError(
86
+ f"Client is attempting to {what} an OpaqueRef. This can only be done in a remote function."
87
+ )
88
+
89
+
90
+ def _flatten(x: OpaqueRef):
91
+ return (x._key,), functools.partial(wrap_create, x._create)
92
+
93
+
94
+ def _unflatten(xs, ctx):
95
+ return ctx(xs)
96
+
97
+
98
+ register_pytree_node(OpaqueRef, _flatten, _unflatten)