torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,573 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ from __future__ import annotations
10
+
11
+ from traceback import FrameSummary
12
+ from typing import (
13
+ cast,
14
+ Dict,
15
+ List,
16
+ Literal,
17
+ NamedTuple,
18
+ Optional,
19
+ Protocol,
20
+ Tuple,
21
+ TYPE_CHECKING,
22
+ )
23
+
24
+ from monarch._rust_bindings.monarch_extension import tensor_worker
25
+ from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction
26
+ from monarch.common.invocation import DeviceException, RemoteException
27
+ from monarch.common.reference import Referenceable
28
+ from monarch.common.tree import flattener
29
+ from pyre_extensions import none_throws
30
+
31
+ from .shape import NDSlice
32
+ from .tensor_factory import TensorFactory
33
+
34
+ if TYPE_CHECKING:
35
+ from monarch.common.stream import StreamRef
36
+
37
+ from .device_mesh import DeviceMesh, RemoteProcessGroup
38
+ from .pipe import Pipe
39
+ from .recording import Recording
40
+ from .tensor import Tensor
41
+
42
+
43
+ Dims = Tuple[str, ...]
44
+
45
+
46
+ def _to_rust_function(
47
+ x: ResolvableFunction,
48
+ ) -> tensor_worker.ResolvableFunction:
49
+ if isinstance(x, ResolvableFromCloudpickle):
50
+ return tensor_worker.Cloudpickle(bytes=x.data)
51
+ return tensor_worker.FunctionPath(path=str(x))
52
+
53
+
54
+ def _result_to_references(result: object) -> List[tensor_worker.Ref | None]:
55
+ """
56
+ Flatten the result pytree.
57
+ Only keep the referenceables and leave the rest as None.
58
+ The workers will generate the full result list so we know
59
+ what referenceables to be assigned to.
60
+ """
61
+ leaves = flattener(result, lambda x: True)(result)
62
+ return [
63
+ _ref(leaf)
64
+ if isinstance(leaf, Referenceable) or isinstance(leaf, tensor_worker.Ref)
65
+ else None
66
+ for leaf in leaves
67
+ ]
68
+
69
+
70
+ def _ref(r: Referenceable | tensor_worker.Ref) -> tensor_worker.Ref:
71
+ if isinstance(r, Referenceable):
72
+ return tensor_worker.Ref(id=none_throws(r.ref))
73
+ return r
74
+
75
+
76
+ # We cant do inheritance with NamedTuple so we can use this protocol for
77
+ # type casting for now until we can move to rust messages entirely.
78
+ # Preferring this over a massive if else to keep everything co-located and
79
+ # easier to identify drift.
80
+ class SupportsToRustMessage(Protocol):
81
+ def to_rust_message(self) -> tensor_worker.WorkerMessage: ...
82
+
83
+
84
+ class CreateDeviceMesh(NamedTuple):
85
+ result: DeviceMesh
86
+ names: Dims
87
+ ranks: NDSlice
88
+
89
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
90
+ return tensor_worker.CreateDeviceMesh(
91
+ result=tensor_worker.Ref(id=self.result.ref),
92
+ names=self.names,
93
+ ranks=NDSlice(
94
+ offset=self.ranks.offset,
95
+ sizes=self.ranks.sizes,
96
+ strides=self.ranks.strides,
97
+ ),
98
+ )
99
+
100
+
101
+ class CreateStream(NamedTuple):
102
+ result: "StreamRef"
103
+ default: bool
104
+
105
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
106
+ return tensor_worker.CreateStream(
107
+ id=tensor_worker.StreamRef(id=self.result.ref),
108
+ stream_creation=(
109
+ tensor_worker.StreamCreationMode.UseDefaultStream
110
+ if self.default
111
+ else tensor_worker.StreamCreationMode.CreateNewStream
112
+ ),
113
+ )
114
+
115
+
116
+ class CreateRemoteProcessGroup(NamedTuple):
117
+ result: Referenceable
118
+ device_mesh: DeviceMesh
119
+ dims: Dims
120
+
121
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
122
+ return tensor_worker.CreateRemoteProcessGroup(
123
+ result=tensor_worker.Ref(id=none_throws(self.result.ref)),
124
+ device_mesh=tensor_worker.Ref(id=self.device_mesh.ref),
125
+ dims=self.dims,
126
+ )
127
+
128
+
129
+ class CallFunction(NamedTuple):
130
+ ident: int
131
+ result: object # pytree with tensors in it
132
+ mutates: Tuple[Tensor | tensor_worker.Ref, ...]
133
+ function: ResolvableFunction
134
+ args: Tuple[object, ...]
135
+ kwargs: Dict[str, object]
136
+ stream: "StreamRef"
137
+ device_mesh: DeviceMesh
138
+ remote_process_groups: List[RemoteProcessGroup]
139
+
140
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
141
+ return tensor_worker.CallFunction(
142
+ seq=self.ident,
143
+ results=_result_to_references(self.result),
144
+ mutates=[_ref(r) for r in self.mutates],
145
+ function=_to_rust_function(self.function),
146
+ args=self.args,
147
+ kwargs=self.kwargs,
148
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
149
+ remote_process_groups=[
150
+ tensor_worker.Ref(id=none_throws(remote_process_group.ref))
151
+ for remote_process_group in self.remote_process_groups
152
+ ],
153
+ )
154
+
155
+
156
+ class Exit(NamedTuple):
157
+ destroy_pg: bool
158
+ error: Optional[RemoteException | DeviceException | Exception]
159
+
160
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
161
+ actor_id = None
162
+ error_message = None
163
+ if isinstance(self.error, (RemoteException, DeviceException)):
164
+ actor_id = self.error.source_actor_id
165
+ error_message = self.error.message
166
+ elif self.error is not None:
167
+ error_message = str(self.error)
168
+
169
+ error_reason = None if error_message is None else (actor_id, error_message)
170
+ return tensor_worker.Exit(error_reason=error_reason)
171
+
172
+
173
+ class CommandGroup(NamedTuple):
174
+ commands: List[NamedTuple]
175
+
176
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
177
+ rust_commands = []
178
+ for c in self.commands:
179
+ if hasattr(c, "to_rust_message"):
180
+ c = cast(SupportsToRustMessage, c)
181
+ rust_commands.append(c.to_rust_message())
182
+ else:
183
+ raise NotImplementedError(f"Unsupported command {c}")
184
+ return tensor_worker.CommandGroup(commands=rust_commands)
185
+
186
+
187
+ class RecordingFormal(NamedTuple):
188
+ result: Tensor | tensor_worker.Ref
189
+ argument_index: int
190
+ stream: "StreamRef"
191
+
192
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
193
+ return tensor_worker.RecordingFormal(
194
+ result=_ref(self.result),
195
+ argument_index=self.argument_index,
196
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
197
+ )
198
+
199
+
200
+ class RecordingResult(NamedTuple):
201
+ input: Tensor | tensor_worker.Ref
202
+ output_index: int
203
+ stream: "StreamRef"
204
+
205
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
206
+ return tensor_worker.RecordingResult(
207
+ result=_ref(self.input),
208
+ output_index=self.output_index,
209
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
210
+ )
211
+
212
+
213
+ class DefineRecording(NamedTuple):
214
+ result: Recording
215
+ nresults: int
216
+ nformals: int
217
+ commands: List[NamedTuple]
218
+ ntotal_messages: int
219
+ message_index: int
220
+
221
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
222
+ define_recording = tensor_worker.DefineRecording(
223
+ result=tensor_worker.Ref(id=none_throws(self.result.ref)),
224
+ nresults=self.nresults,
225
+ nformals=self.nformals,
226
+ commands=[],
227
+ ntotal_messages=self.ntotal_messages,
228
+ index=self.message_index,
229
+ )
230
+ for c in self.commands:
231
+ if hasattr(c, "to_rust_message"):
232
+ c = cast(SupportsToRustMessage, c)
233
+ if isinstance(c, CallFunction):
234
+ define_recording.append_call_function(
235
+ seq=c.ident,
236
+ results=_result_to_references(c.result),
237
+ mutates=[_ref(r) for r in c.mutates],
238
+ function=_to_rust_function(c.function),
239
+ args=c.args,
240
+ kwargs=c.kwargs,
241
+ stream=tensor_worker.StreamRef(id=c.stream.ref),
242
+ remote_process_groups=[
243
+ tensor_worker.Ref(id=none_throws(remote_process_group.ref))
244
+ for remote_process_group in c.remote_process_groups
245
+ ],
246
+ )
247
+ else:
248
+ define_recording.append(c.to_rust_message())
249
+ else:
250
+ raise NotImplementedError(f"Unsupported command {c}")
251
+ return define_recording
252
+
253
+
254
+ class CallRecording(NamedTuple):
255
+ ident: int
256
+ recording: Recording
257
+ results: List[Tensor | tensor_worker.Ref]
258
+ actuals: List[Tensor | tensor_worker.Ref]
259
+
260
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
261
+ return tensor_worker.CallRecording(
262
+ seq=self.ident,
263
+ recording=tensor_worker.Ref(id=none_throws(self.recording.ref)),
264
+ results=[_ref(r) for r in self.results],
265
+ actuals=[_ref(r) for r in self.actuals],
266
+ )
267
+
268
+
269
+ class DeleteRefs(NamedTuple):
270
+ refs: List[int]
271
+
272
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
273
+ return tensor_worker.DeleteRefs(
274
+ refs=[tensor_worker.Ref(id=r) for r in self.refs]
275
+ )
276
+
277
+
278
+ # This is worker <> controller/backend comms only will be supported differently
279
+ class Restarted(NamedTuple):
280
+ result: int
281
+
282
+
283
+ class SendValue(NamedTuple):
284
+ ident: int
285
+ destination: Pipe | None # if present the pipe along which to send the result,
286
+ # otherwise send FetchResult to controller
287
+ mutates: Tuple[Tensor | tensor_worker.Ref, ...]
288
+ function: ResolvableFunction | None # None is equivalent to lambda x: x
289
+ args: Tuple[object, ...]
290
+ kwargs: Dict[str, object]
291
+ stream: StreamRef
292
+
293
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
294
+ return tensor_worker.SendValue(
295
+ seq=self.ident,
296
+ destination=(
297
+ tensor_worker.Ref(id=self.destination.ref) if self.destination else None
298
+ ),
299
+ mutates=[_ref(r) for r in self.mutates],
300
+ function=_to_rust_function(self.function) if self.function else None,
301
+ args=self.args,
302
+ kwargs=self.kwargs,
303
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
304
+ )
305
+
306
+
307
+ # Worker -> Controller comm only handled differently
308
+ class FetchResult(NamedTuple):
309
+ ident: int
310
+ value: object
311
+
312
+
313
+ # Worker -> Controller comm only handled differently
314
+ class RemoteFunctionFailed(NamedTuple):
315
+ failing_ident: int
316
+ stack_offset: int
317
+ exception: Exception
318
+ worker_frames: List[FrameSummary]
319
+
320
+
321
+ # Worker -> Controller comm only handled differently
322
+ class InternalException(NamedTuple):
323
+ exception: Exception
324
+ frames: List[FrameSummary]
325
+
326
+
327
+ # Worker -> Controller comm only handled differently
328
+ class RemoteGeneratorFailed(NamedTuple):
329
+ exception: Exception
330
+ frames: List[FrameSummary]
331
+
332
+
333
+ # Worker -> Controller comm only handled differently
334
+ class Status(NamedTuple):
335
+ first_uncompleted_ident: int
336
+
337
+
338
+ # When the controller is waiting on a status update,
339
+ # it will request one even if it is before the
340
+ # periodic one.
341
+ class RequestStatus(NamedTuple):
342
+ ident: int
343
+ controller: bool
344
+
345
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
346
+ return tensor_worker.RequestStatus(seq=self.ident, controller=self.controller)
347
+
348
+
349
+ class BorrowCreate(NamedTuple):
350
+ result: Tensor | tensor_worker.Ref
351
+ borrow: int
352
+ tensor: Tensor | tensor_worker.Ref
353
+ from_stream: StreamRef
354
+ to_stream: StreamRef
355
+
356
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
357
+ return tensor_worker.BorrowCreate(
358
+ result=_ref(self.result),
359
+ borrow=self.borrow,
360
+ tensor=_ref(self.tensor),
361
+ from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
362
+ to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
363
+ )
364
+
365
+
366
+ class BorrowDrop(NamedTuple):
367
+ borrow: int # id of borrowed tensor
368
+
369
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
370
+ return tensor_worker.BorrowDrop(
371
+ borrow=self.borrow,
372
+ )
373
+
374
+
375
+ class BorrowFirstUse(NamedTuple):
376
+ borrow: int # id of borrowed tensor
377
+
378
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
379
+ return tensor_worker.BorrowFirstUse(
380
+ borrow=self.borrow,
381
+ )
382
+
383
+
384
+ class BorrowLastUse(NamedTuple):
385
+ borrow: int # id of borrowed tensor
386
+
387
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
388
+ return tensor_worker.BorrowLastUse(
389
+ borrow=self.borrow,
390
+ )
391
+
392
+
393
+ class SendTensor(NamedTuple):
394
+ result: Tensor | tensor_worker.Ref
395
+ from_ranks: NDSlice
396
+ to_ranks: NDSlice
397
+ tensor: Tensor | tensor_worker.Ref
398
+ factory: TensorFactory
399
+ from_stream: StreamRef
400
+ to_stream: StreamRef
401
+
402
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
403
+ return tensor_worker.SendTensor(
404
+ result=_ref(self.result),
405
+ from_ranks=NDSlice(
406
+ offset=self.from_ranks.offset,
407
+ sizes=self.from_ranks.sizes,
408
+ strides=self.from_ranks.strides,
409
+ ),
410
+ to_ranks=NDSlice(
411
+ offset=self.to_ranks.offset,
412
+ sizes=self.to_ranks.sizes,
413
+ strides=self.to_ranks.strides,
414
+ ),
415
+ tensor=_ref(self.tensor),
416
+ factory=tensor_worker.TensorFactory(
417
+ size=self.factory.size,
418
+ dtype=self.factory.dtype,
419
+ device=self.factory.device,
420
+ layout=self.factory.layout,
421
+ ),
422
+ from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
423
+ to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
424
+ )
425
+
426
+
427
+ class SplitComm(NamedTuple):
428
+ dims: Dims
429
+ device_mesh: DeviceMesh
430
+ stream: StreamRef
431
+
432
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
433
+ return tensor_worker.SplitComm(
434
+ dims=self.dims,
435
+ device_mesh=tensor_worker.Ref(id=self.device_mesh.ref),
436
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
437
+ )
438
+
439
+
440
+ class SplitCommForProcessGroup(NamedTuple):
441
+ remote_process_group: DeviceMesh
442
+ stream: StreamRef
443
+
444
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
445
+ return tensor_worker.SplitCommForProcessGroup(
446
+ remote_process_group=tensor_worker.Ref(id=self.remote_process_group.ref),
447
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
448
+ )
449
+
450
+
451
+ class Reduce(NamedTuple):
452
+ result: Tensor | tensor_worker.Ref
453
+ local_tensor: Tensor | tensor_worker.Ref
454
+ factory: TensorFactory
455
+ source_mesh: DeviceMesh
456
+ stream: StreamRef
457
+ dims: Dims
458
+ reduction: str
459
+ scatter: bool
460
+ inplace: bool
461
+ out: Tensor | tensor_worker.Ref | None
462
+
463
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
464
+ match self.reduction:
465
+ case "sum":
466
+ reduction = tensor_worker.ReductionType.Sum
467
+ case "prod":
468
+ reduction = tensor_worker.ReductionType.Prod
469
+ case "stack":
470
+ reduction = tensor_worker.ReductionType.Stack
471
+ case "avg":
472
+ reduction = tensor_worker.ReductionType.Avg
473
+ case "min":
474
+ reduction = tensor_worker.ReductionType.Min
475
+ case "max":
476
+ reduction = tensor_worker.ReductionType.Max
477
+ case _:
478
+ raise ValueError(f"Unsupported reduction {self.reduction}")
479
+
480
+ return tensor_worker.Reduce(
481
+ result=_ref(self.result),
482
+ tensor=_ref(self.local_tensor),
483
+ factory=tensor_worker.TensorFactory(
484
+ size=self.factory.size,
485
+ dtype=self.factory.dtype,
486
+ device=self.factory.device,
487
+ layout=self.factory.layout,
488
+ ),
489
+ mesh=tensor_worker.Ref(id=self.source_mesh.ref),
490
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
491
+ dims=self.dims,
492
+ reduction=reduction,
493
+ scatter=self.scatter,
494
+ in_place=self.inplace,
495
+ out=_ref(self.out) if self.out is not None else None,
496
+ )
497
+
498
+
499
+ class CreatePipe(NamedTuple):
500
+ result: Pipe
501
+ key: str
502
+ function: ResolvableFunction
503
+ max_messages: int
504
+ device_mesh: DeviceMesh
505
+ args: Tuple[object, ...]
506
+ kwargs: Dict[str, object]
507
+
508
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
509
+ return tensor_worker.CreatePipe(
510
+ result=tensor_worker.Ref(id=self.result.ref),
511
+ key=self.key,
512
+ function=_to_rust_function(self.function),
513
+ max_messages=self.max_messages,
514
+ mesh=tensor_worker.Ref(id=self.device_mesh.ref),
515
+ args=self.args,
516
+ kwargs=self.kwargs,
517
+ )
518
+
519
+
520
+ class PipeRecv(NamedTuple):
521
+ ident: int
522
+ result: object # pytree with tensors in it
523
+ pipe: Pipe
524
+ stream: StreamRef
525
+
526
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
527
+ return tensor_worker.PipeRecv(
528
+ seq=self.ident,
529
+ results=_result_to_references(self.result),
530
+ pipe=tensor_worker.Ref(id=self.pipe.ref),
531
+ stream=tensor_worker.StreamRef(id=self.stream.ref),
532
+ )
533
+
534
+
535
+ class BackendNetworkInit(NamedTuple):
536
+ hostname: str | None = None
537
+ port: int | None = None
538
+
539
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
540
+ return tensor_worker.BackendNetworkInit()
541
+
542
+
543
+ class BackendNetworkPointToPointInit(NamedTuple):
544
+ from_stream: StreamRef
545
+ to_stream: StreamRef
546
+
547
+ def to_rust_message(self) -> tensor_worker.WorkerMessage:
548
+ return tensor_worker.BackendNetworkPointToPointInit(
549
+ from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
550
+ to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
551
+ )
552
+
553
+
554
+ # TODO: This is not supported on the rust side and might be only needed for remote funcs
555
+ class DebuggerRead(NamedTuple):
556
+ requested: int
557
+
558
+
559
+ # TODO: This is not supported on the rust side and might be only needed for remote funcs
560
+ class DebuggerWrite(NamedTuple):
561
+ payload: bytes
562
+
563
+
564
+ # TODO: This is not supported on the rust side and might be only needed for remote funcs
565
+ class DebuggerMessage(NamedTuple):
566
+ stream_id: int
567
+ action: Literal["paused", "attach", "detach"] | DebuggerRead | DebuggerWrite
568
+
569
+
570
+ # TODO: Might need to be supported differently through typed worker exceptions
571
+ class DependentOnError(Exception):
572
+ def __init__(self, ident: int) -> None:
573
+ self.ident = ident
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ from contextlib import contextmanager
9
+ from typing import Generator, Optional
10
+
11
+ import monarch.common._C # @manual=//monarch/python/monarch/common:_C
12
+ import torch
13
+
14
+ monarch.common._C.patch_cuda()
15
+
16
+ _mock_cuda_stream: Optional[torch.cuda.Stream] = None
17
+
18
+
19
+ def get_mock_cuda_stream() -> torch.cuda.Stream:
20
+ global _mock_cuda_stream
21
+ if _mock_cuda_stream is None:
22
+ _mock_cuda_stream = torch.cuda.Stream()
23
+ return _mock_cuda_stream
24
+
25
+
26
+ @contextmanager
27
+ def mock_cuda_guard() -> Generator[None, None, None]:
28
+ try:
29
+ with torch.cuda.stream(get_mock_cuda_stream()):
30
+ monarch.common._C.mock_cuda()
31
+ yield
32
+ finally:
33
+ monarch.common._C.unmock_cuda()
34
+
35
+
36
+ def mock_cuda() -> None:
37
+ monarch.common._C.mock_cuda()
38
+
39
+
40
+ def unmock_cuda() -> None:
41
+ monarch.common._C.unmock_cuda()
@@ -0,0 +1,98 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import functools
9
+ import itertools
10
+ import os
11
+ from typing import Any, Iterator
12
+
13
+ import torch
14
+ from torch._subclasses.fake_tensor import FakeTensor
15
+ from torch.utils._pytree import register_pytree_node
16
+ from torch.utils.weak import WeakTensorKeyDictionary
17
+
18
+ _key_table: WeakTensorKeyDictionary = WeakTensorKeyDictionary()
19
+ _key_counter: Iterator[int] = itertools.count(1)
20
+
21
+ # check that we are for sure running on the worker process
22
+ _on_worker = os.environ.get("LOCAL_RANK") is not None
23
+
24
+
25
+ def wrap_create(create, xs):
26
+ return create(xs[0])
27
+
28
+
29
+ class OpaqueRef:
30
+ """
31
+ OpaqueRef is a reference to an object that is only resolvable on the worker
32
+ This is used to pass objects from the controller to the worker across User Defined Functions
33
+
34
+ Example::
35
+ def init_udf_worker():
36
+ model = nn.Linear(3, 4)
37
+ model_ref = OpaqueRef(model)
38
+ return model_ref
39
+
40
+ def run_step_worker(model_ref: OpaqueRef):
41
+ model = model_ref.value
42
+ # do something with model (e.g. forward pass
43
+
44
+ # on Controller
45
+ model_ref = init_udf()
46
+ run_step(model_ref)
47
+
48
+ """
49
+
50
+ def __init__(self, value=None):
51
+ self._key = torch.tensor(next(_key_counter), dtype=torch.int64)
52
+ self.check_worker("create")
53
+ _key_table[self._key] = value
54
+
55
+ @classmethod
56
+ def _create(cls, key: torch.Tensor):
57
+ c = cls.__new__(cls)
58
+ c._key = key
59
+ return c
60
+
61
+ # like NamedTuple, just pass the call to reconstruct this
62
+ # rather than the dict. This also ensures the OpaqueObject
63
+ # subclass degrades into this class when sent to the worker
64
+ def __reduce_ex__(self, protocol):
65
+ return OpaqueRef._create, (self._key,)
66
+
67
+ def __repr__(self):
68
+ return f"OpaqueRef({repr(self._key)})"
69
+
70
+ @property
71
+ def value(self) -> Any:
72
+ self.check_worker("access")
73
+ return _key_table[self._key]
74
+
75
+ @value.setter
76
+ def value(self, v: Any) -> None:
77
+ self.check_worker("set")
78
+ _key_table[self._key] = v
79
+
80
+ def check_worker(self, what):
81
+ # both checks are needed for the case where OpaqueRef() is
82
+ # called on the client with no mesh active.
83
+ in_worker_or_propagate = _on_worker or isinstance(self._key, FakeTensor)
84
+ if not in_worker_or_propagate:
85
+ raise RuntimeError(
86
+ f"Client is attempting to {what} an OpaqueRef. This can only be done in a remote function."
87
+ )
88
+
89
+
90
+ def _flatten(x: OpaqueRef):
91
+ return (x._key,), functools.partial(wrap_create, x._create)
92
+
93
+
94
+ def _unflatten(xs, ctx):
95
+ return ctx(xs)
96
+
97
+
98
+ register_pytree_node(OpaqueRef, _flatten, _unflatten)