torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,573 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
from __future__ import annotations
|
10
|
+
|
11
|
+
from traceback import FrameSummary
|
12
|
+
from typing import (
|
13
|
+
cast,
|
14
|
+
Dict,
|
15
|
+
List,
|
16
|
+
Literal,
|
17
|
+
NamedTuple,
|
18
|
+
Optional,
|
19
|
+
Protocol,
|
20
|
+
Tuple,
|
21
|
+
TYPE_CHECKING,
|
22
|
+
)
|
23
|
+
|
24
|
+
from monarch._rust_bindings.monarch_extension import tensor_worker
|
25
|
+
from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction
|
26
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
27
|
+
from monarch.common.reference import Referenceable
|
28
|
+
from monarch.common.tree import flattener
|
29
|
+
from pyre_extensions import none_throws
|
30
|
+
|
31
|
+
from .shape import NDSlice
|
32
|
+
from .tensor_factory import TensorFactory
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from monarch.common.stream import StreamRef
|
36
|
+
|
37
|
+
from .device_mesh import DeviceMesh, RemoteProcessGroup
|
38
|
+
from .pipe import Pipe
|
39
|
+
from .recording import Recording
|
40
|
+
from .tensor import Tensor
|
41
|
+
|
42
|
+
|
43
|
+
Dims = Tuple[str, ...]
|
44
|
+
|
45
|
+
|
46
|
+
def _to_rust_function(
|
47
|
+
x: ResolvableFunction,
|
48
|
+
) -> tensor_worker.ResolvableFunction:
|
49
|
+
if isinstance(x, ResolvableFromCloudpickle):
|
50
|
+
return tensor_worker.Cloudpickle(bytes=x.data)
|
51
|
+
return tensor_worker.FunctionPath(path=str(x))
|
52
|
+
|
53
|
+
|
54
|
+
def _result_to_references(result: object) -> List[tensor_worker.Ref | None]:
|
55
|
+
"""
|
56
|
+
Flatten the result pytree.
|
57
|
+
Only keep the referenceables and leave the rest as None.
|
58
|
+
The workers will generate the full result list so we know
|
59
|
+
what referenceables to be assigned to.
|
60
|
+
"""
|
61
|
+
leaves = flattener(result, lambda x: True)(result)
|
62
|
+
return [
|
63
|
+
_ref(leaf)
|
64
|
+
if isinstance(leaf, Referenceable) or isinstance(leaf, tensor_worker.Ref)
|
65
|
+
else None
|
66
|
+
for leaf in leaves
|
67
|
+
]
|
68
|
+
|
69
|
+
|
70
|
+
def _ref(r: Referenceable | tensor_worker.Ref) -> tensor_worker.Ref:
|
71
|
+
if isinstance(r, Referenceable):
|
72
|
+
return tensor_worker.Ref(id=none_throws(r.ref))
|
73
|
+
return r
|
74
|
+
|
75
|
+
|
76
|
+
# We cant do inheritance with NamedTuple so we can use this protocol for
|
77
|
+
# type casting for now until we can move to rust messages entirely.
|
78
|
+
# Preferring this over a massive if else to keep everything co-located and
|
79
|
+
# easier to identify drift.
|
80
|
+
class SupportsToRustMessage(Protocol):
|
81
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage: ...
|
82
|
+
|
83
|
+
|
84
|
+
class CreateDeviceMesh(NamedTuple):
|
85
|
+
result: DeviceMesh
|
86
|
+
names: Dims
|
87
|
+
ranks: NDSlice
|
88
|
+
|
89
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
90
|
+
return tensor_worker.CreateDeviceMesh(
|
91
|
+
result=tensor_worker.Ref(id=self.result.ref),
|
92
|
+
names=self.names,
|
93
|
+
ranks=NDSlice(
|
94
|
+
offset=self.ranks.offset,
|
95
|
+
sizes=self.ranks.sizes,
|
96
|
+
strides=self.ranks.strides,
|
97
|
+
),
|
98
|
+
)
|
99
|
+
|
100
|
+
|
101
|
+
class CreateStream(NamedTuple):
|
102
|
+
result: "StreamRef"
|
103
|
+
default: bool
|
104
|
+
|
105
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
106
|
+
return tensor_worker.CreateStream(
|
107
|
+
id=tensor_worker.StreamRef(id=self.result.ref),
|
108
|
+
stream_creation=(
|
109
|
+
tensor_worker.StreamCreationMode.UseDefaultStream
|
110
|
+
if self.default
|
111
|
+
else tensor_worker.StreamCreationMode.CreateNewStream
|
112
|
+
),
|
113
|
+
)
|
114
|
+
|
115
|
+
|
116
|
+
class CreateRemoteProcessGroup(NamedTuple):
|
117
|
+
result: Referenceable
|
118
|
+
device_mesh: DeviceMesh
|
119
|
+
dims: Dims
|
120
|
+
|
121
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
122
|
+
return tensor_worker.CreateRemoteProcessGroup(
|
123
|
+
result=tensor_worker.Ref(id=none_throws(self.result.ref)),
|
124
|
+
device_mesh=tensor_worker.Ref(id=self.device_mesh.ref),
|
125
|
+
dims=self.dims,
|
126
|
+
)
|
127
|
+
|
128
|
+
|
129
|
+
class CallFunction(NamedTuple):
|
130
|
+
ident: int
|
131
|
+
result: object # pytree with tensors in it
|
132
|
+
mutates: Tuple[Tensor | tensor_worker.Ref, ...]
|
133
|
+
function: ResolvableFunction
|
134
|
+
args: Tuple[object, ...]
|
135
|
+
kwargs: Dict[str, object]
|
136
|
+
stream: "StreamRef"
|
137
|
+
device_mesh: DeviceMesh
|
138
|
+
remote_process_groups: List[RemoteProcessGroup]
|
139
|
+
|
140
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
141
|
+
return tensor_worker.CallFunction(
|
142
|
+
seq=self.ident,
|
143
|
+
results=_result_to_references(self.result),
|
144
|
+
mutates=[_ref(r) for r in self.mutates],
|
145
|
+
function=_to_rust_function(self.function),
|
146
|
+
args=self.args,
|
147
|
+
kwargs=self.kwargs,
|
148
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
149
|
+
remote_process_groups=[
|
150
|
+
tensor_worker.Ref(id=none_throws(remote_process_group.ref))
|
151
|
+
for remote_process_group in self.remote_process_groups
|
152
|
+
],
|
153
|
+
)
|
154
|
+
|
155
|
+
|
156
|
+
class Exit(NamedTuple):
|
157
|
+
destroy_pg: bool
|
158
|
+
error: Optional[RemoteException | DeviceException | Exception]
|
159
|
+
|
160
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
161
|
+
actor_id = None
|
162
|
+
error_message = None
|
163
|
+
if isinstance(self.error, (RemoteException, DeviceException)):
|
164
|
+
actor_id = self.error.source_actor_id
|
165
|
+
error_message = self.error.message
|
166
|
+
elif self.error is not None:
|
167
|
+
error_message = str(self.error)
|
168
|
+
|
169
|
+
error_reason = None if error_message is None else (actor_id, error_message)
|
170
|
+
return tensor_worker.Exit(error_reason=error_reason)
|
171
|
+
|
172
|
+
|
173
|
+
class CommandGroup(NamedTuple):
|
174
|
+
commands: List[NamedTuple]
|
175
|
+
|
176
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
177
|
+
rust_commands = []
|
178
|
+
for c in self.commands:
|
179
|
+
if hasattr(c, "to_rust_message"):
|
180
|
+
c = cast(SupportsToRustMessage, c)
|
181
|
+
rust_commands.append(c.to_rust_message())
|
182
|
+
else:
|
183
|
+
raise NotImplementedError(f"Unsupported command {c}")
|
184
|
+
return tensor_worker.CommandGroup(commands=rust_commands)
|
185
|
+
|
186
|
+
|
187
|
+
class RecordingFormal(NamedTuple):
|
188
|
+
result: Tensor | tensor_worker.Ref
|
189
|
+
argument_index: int
|
190
|
+
stream: "StreamRef"
|
191
|
+
|
192
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
193
|
+
return tensor_worker.RecordingFormal(
|
194
|
+
result=_ref(self.result),
|
195
|
+
argument_index=self.argument_index,
|
196
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
197
|
+
)
|
198
|
+
|
199
|
+
|
200
|
+
class RecordingResult(NamedTuple):
|
201
|
+
input: Tensor | tensor_worker.Ref
|
202
|
+
output_index: int
|
203
|
+
stream: "StreamRef"
|
204
|
+
|
205
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
206
|
+
return tensor_worker.RecordingResult(
|
207
|
+
result=_ref(self.input),
|
208
|
+
output_index=self.output_index,
|
209
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
210
|
+
)
|
211
|
+
|
212
|
+
|
213
|
+
class DefineRecording(NamedTuple):
|
214
|
+
result: Recording
|
215
|
+
nresults: int
|
216
|
+
nformals: int
|
217
|
+
commands: List[NamedTuple]
|
218
|
+
ntotal_messages: int
|
219
|
+
message_index: int
|
220
|
+
|
221
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
222
|
+
define_recording = tensor_worker.DefineRecording(
|
223
|
+
result=tensor_worker.Ref(id=none_throws(self.result.ref)),
|
224
|
+
nresults=self.nresults,
|
225
|
+
nformals=self.nformals,
|
226
|
+
commands=[],
|
227
|
+
ntotal_messages=self.ntotal_messages,
|
228
|
+
index=self.message_index,
|
229
|
+
)
|
230
|
+
for c in self.commands:
|
231
|
+
if hasattr(c, "to_rust_message"):
|
232
|
+
c = cast(SupportsToRustMessage, c)
|
233
|
+
if isinstance(c, CallFunction):
|
234
|
+
define_recording.append_call_function(
|
235
|
+
seq=c.ident,
|
236
|
+
results=_result_to_references(c.result),
|
237
|
+
mutates=[_ref(r) for r in c.mutates],
|
238
|
+
function=_to_rust_function(c.function),
|
239
|
+
args=c.args,
|
240
|
+
kwargs=c.kwargs,
|
241
|
+
stream=tensor_worker.StreamRef(id=c.stream.ref),
|
242
|
+
remote_process_groups=[
|
243
|
+
tensor_worker.Ref(id=none_throws(remote_process_group.ref))
|
244
|
+
for remote_process_group in c.remote_process_groups
|
245
|
+
],
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
define_recording.append(c.to_rust_message())
|
249
|
+
else:
|
250
|
+
raise NotImplementedError(f"Unsupported command {c}")
|
251
|
+
return define_recording
|
252
|
+
|
253
|
+
|
254
|
+
class CallRecording(NamedTuple):
|
255
|
+
ident: int
|
256
|
+
recording: Recording
|
257
|
+
results: List[Tensor | tensor_worker.Ref]
|
258
|
+
actuals: List[Tensor | tensor_worker.Ref]
|
259
|
+
|
260
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
261
|
+
return tensor_worker.CallRecording(
|
262
|
+
seq=self.ident,
|
263
|
+
recording=tensor_worker.Ref(id=none_throws(self.recording.ref)),
|
264
|
+
results=[_ref(r) for r in self.results],
|
265
|
+
actuals=[_ref(r) for r in self.actuals],
|
266
|
+
)
|
267
|
+
|
268
|
+
|
269
|
+
class DeleteRefs(NamedTuple):
|
270
|
+
refs: List[int]
|
271
|
+
|
272
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
273
|
+
return tensor_worker.DeleteRefs(
|
274
|
+
refs=[tensor_worker.Ref(id=r) for r in self.refs]
|
275
|
+
)
|
276
|
+
|
277
|
+
|
278
|
+
# This is worker <> controller/backend comms only will be supported differently
|
279
|
+
class Restarted(NamedTuple):
|
280
|
+
result: int
|
281
|
+
|
282
|
+
|
283
|
+
class SendValue(NamedTuple):
|
284
|
+
ident: int
|
285
|
+
destination: Pipe | None # if present the pipe along which to send the result,
|
286
|
+
# otherwise send FetchResult to controller
|
287
|
+
mutates: Tuple[Tensor | tensor_worker.Ref, ...]
|
288
|
+
function: ResolvableFunction | None # None is equivalent to lambda x: x
|
289
|
+
args: Tuple[object, ...]
|
290
|
+
kwargs: Dict[str, object]
|
291
|
+
stream: StreamRef
|
292
|
+
|
293
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
294
|
+
return tensor_worker.SendValue(
|
295
|
+
seq=self.ident,
|
296
|
+
destination=(
|
297
|
+
tensor_worker.Ref(id=self.destination.ref) if self.destination else None
|
298
|
+
),
|
299
|
+
mutates=[_ref(r) for r in self.mutates],
|
300
|
+
function=_to_rust_function(self.function) if self.function else None,
|
301
|
+
args=self.args,
|
302
|
+
kwargs=self.kwargs,
|
303
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
304
|
+
)
|
305
|
+
|
306
|
+
|
307
|
+
# Worker -> Controller comm only handled differently
|
308
|
+
class FetchResult(NamedTuple):
|
309
|
+
ident: int
|
310
|
+
value: object
|
311
|
+
|
312
|
+
|
313
|
+
# Worker -> Controller comm only handled differently
|
314
|
+
class RemoteFunctionFailed(NamedTuple):
|
315
|
+
failing_ident: int
|
316
|
+
stack_offset: int
|
317
|
+
exception: Exception
|
318
|
+
worker_frames: List[FrameSummary]
|
319
|
+
|
320
|
+
|
321
|
+
# Worker -> Controller comm only handled differently
|
322
|
+
class InternalException(NamedTuple):
|
323
|
+
exception: Exception
|
324
|
+
frames: List[FrameSummary]
|
325
|
+
|
326
|
+
|
327
|
+
# Worker -> Controller comm only handled differently
|
328
|
+
class RemoteGeneratorFailed(NamedTuple):
|
329
|
+
exception: Exception
|
330
|
+
frames: List[FrameSummary]
|
331
|
+
|
332
|
+
|
333
|
+
# Worker -> Controller comm only handled differently
|
334
|
+
class Status(NamedTuple):
|
335
|
+
first_uncompleted_ident: int
|
336
|
+
|
337
|
+
|
338
|
+
# When the controller is waiting on a status update,
|
339
|
+
# it will request one even if it is before the
|
340
|
+
# periodic one.
|
341
|
+
class RequestStatus(NamedTuple):
|
342
|
+
ident: int
|
343
|
+
controller: bool
|
344
|
+
|
345
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
346
|
+
return tensor_worker.RequestStatus(seq=self.ident, controller=self.controller)
|
347
|
+
|
348
|
+
|
349
|
+
class BorrowCreate(NamedTuple):
|
350
|
+
result: Tensor | tensor_worker.Ref
|
351
|
+
borrow: int
|
352
|
+
tensor: Tensor | tensor_worker.Ref
|
353
|
+
from_stream: StreamRef
|
354
|
+
to_stream: StreamRef
|
355
|
+
|
356
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
357
|
+
return tensor_worker.BorrowCreate(
|
358
|
+
result=_ref(self.result),
|
359
|
+
borrow=self.borrow,
|
360
|
+
tensor=_ref(self.tensor),
|
361
|
+
from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
|
362
|
+
to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
|
363
|
+
)
|
364
|
+
|
365
|
+
|
366
|
+
class BorrowDrop(NamedTuple):
|
367
|
+
borrow: int # id of borrowed tensor
|
368
|
+
|
369
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
370
|
+
return tensor_worker.BorrowDrop(
|
371
|
+
borrow=self.borrow,
|
372
|
+
)
|
373
|
+
|
374
|
+
|
375
|
+
class BorrowFirstUse(NamedTuple):
|
376
|
+
borrow: int # id of borrowed tensor
|
377
|
+
|
378
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
379
|
+
return tensor_worker.BorrowFirstUse(
|
380
|
+
borrow=self.borrow,
|
381
|
+
)
|
382
|
+
|
383
|
+
|
384
|
+
class BorrowLastUse(NamedTuple):
|
385
|
+
borrow: int # id of borrowed tensor
|
386
|
+
|
387
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
388
|
+
return tensor_worker.BorrowLastUse(
|
389
|
+
borrow=self.borrow,
|
390
|
+
)
|
391
|
+
|
392
|
+
|
393
|
+
class SendTensor(NamedTuple):
|
394
|
+
result: Tensor | tensor_worker.Ref
|
395
|
+
from_ranks: NDSlice
|
396
|
+
to_ranks: NDSlice
|
397
|
+
tensor: Tensor | tensor_worker.Ref
|
398
|
+
factory: TensorFactory
|
399
|
+
from_stream: StreamRef
|
400
|
+
to_stream: StreamRef
|
401
|
+
|
402
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
403
|
+
return tensor_worker.SendTensor(
|
404
|
+
result=_ref(self.result),
|
405
|
+
from_ranks=NDSlice(
|
406
|
+
offset=self.from_ranks.offset,
|
407
|
+
sizes=self.from_ranks.sizes,
|
408
|
+
strides=self.from_ranks.strides,
|
409
|
+
),
|
410
|
+
to_ranks=NDSlice(
|
411
|
+
offset=self.to_ranks.offset,
|
412
|
+
sizes=self.to_ranks.sizes,
|
413
|
+
strides=self.to_ranks.strides,
|
414
|
+
),
|
415
|
+
tensor=_ref(self.tensor),
|
416
|
+
factory=tensor_worker.TensorFactory(
|
417
|
+
size=self.factory.size,
|
418
|
+
dtype=self.factory.dtype,
|
419
|
+
device=self.factory.device,
|
420
|
+
layout=self.factory.layout,
|
421
|
+
),
|
422
|
+
from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
|
423
|
+
to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
|
424
|
+
)
|
425
|
+
|
426
|
+
|
427
|
+
class SplitComm(NamedTuple):
|
428
|
+
dims: Dims
|
429
|
+
device_mesh: DeviceMesh
|
430
|
+
stream: StreamRef
|
431
|
+
|
432
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
433
|
+
return tensor_worker.SplitComm(
|
434
|
+
dims=self.dims,
|
435
|
+
device_mesh=tensor_worker.Ref(id=self.device_mesh.ref),
|
436
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
437
|
+
)
|
438
|
+
|
439
|
+
|
440
|
+
class SplitCommForProcessGroup(NamedTuple):
|
441
|
+
remote_process_group: DeviceMesh
|
442
|
+
stream: StreamRef
|
443
|
+
|
444
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
445
|
+
return tensor_worker.SplitCommForProcessGroup(
|
446
|
+
remote_process_group=tensor_worker.Ref(id=self.remote_process_group.ref),
|
447
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
448
|
+
)
|
449
|
+
|
450
|
+
|
451
|
+
class Reduce(NamedTuple):
|
452
|
+
result: Tensor | tensor_worker.Ref
|
453
|
+
local_tensor: Tensor | tensor_worker.Ref
|
454
|
+
factory: TensorFactory
|
455
|
+
source_mesh: DeviceMesh
|
456
|
+
stream: StreamRef
|
457
|
+
dims: Dims
|
458
|
+
reduction: str
|
459
|
+
scatter: bool
|
460
|
+
inplace: bool
|
461
|
+
out: Tensor | tensor_worker.Ref | None
|
462
|
+
|
463
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
464
|
+
match self.reduction:
|
465
|
+
case "sum":
|
466
|
+
reduction = tensor_worker.ReductionType.Sum
|
467
|
+
case "prod":
|
468
|
+
reduction = tensor_worker.ReductionType.Prod
|
469
|
+
case "stack":
|
470
|
+
reduction = tensor_worker.ReductionType.Stack
|
471
|
+
case "avg":
|
472
|
+
reduction = tensor_worker.ReductionType.Avg
|
473
|
+
case "min":
|
474
|
+
reduction = tensor_worker.ReductionType.Min
|
475
|
+
case "max":
|
476
|
+
reduction = tensor_worker.ReductionType.Max
|
477
|
+
case _:
|
478
|
+
raise ValueError(f"Unsupported reduction {self.reduction}")
|
479
|
+
|
480
|
+
return tensor_worker.Reduce(
|
481
|
+
result=_ref(self.result),
|
482
|
+
tensor=_ref(self.local_tensor),
|
483
|
+
factory=tensor_worker.TensorFactory(
|
484
|
+
size=self.factory.size,
|
485
|
+
dtype=self.factory.dtype,
|
486
|
+
device=self.factory.device,
|
487
|
+
layout=self.factory.layout,
|
488
|
+
),
|
489
|
+
mesh=tensor_worker.Ref(id=self.source_mesh.ref),
|
490
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
491
|
+
dims=self.dims,
|
492
|
+
reduction=reduction,
|
493
|
+
scatter=self.scatter,
|
494
|
+
in_place=self.inplace,
|
495
|
+
out=_ref(self.out) if self.out is not None else None,
|
496
|
+
)
|
497
|
+
|
498
|
+
|
499
|
+
class CreatePipe(NamedTuple):
|
500
|
+
result: Pipe
|
501
|
+
key: str
|
502
|
+
function: ResolvableFunction
|
503
|
+
max_messages: int
|
504
|
+
device_mesh: DeviceMesh
|
505
|
+
args: Tuple[object, ...]
|
506
|
+
kwargs: Dict[str, object]
|
507
|
+
|
508
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
509
|
+
return tensor_worker.CreatePipe(
|
510
|
+
result=tensor_worker.Ref(id=self.result.ref),
|
511
|
+
key=self.key,
|
512
|
+
function=_to_rust_function(self.function),
|
513
|
+
max_messages=self.max_messages,
|
514
|
+
mesh=tensor_worker.Ref(id=self.device_mesh.ref),
|
515
|
+
args=self.args,
|
516
|
+
kwargs=self.kwargs,
|
517
|
+
)
|
518
|
+
|
519
|
+
|
520
|
+
class PipeRecv(NamedTuple):
|
521
|
+
ident: int
|
522
|
+
result: object # pytree with tensors in it
|
523
|
+
pipe: Pipe
|
524
|
+
stream: StreamRef
|
525
|
+
|
526
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
527
|
+
return tensor_worker.PipeRecv(
|
528
|
+
seq=self.ident,
|
529
|
+
results=_result_to_references(self.result),
|
530
|
+
pipe=tensor_worker.Ref(id=self.pipe.ref),
|
531
|
+
stream=tensor_worker.StreamRef(id=self.stream.ref),
|
532
|
+
)
|
533
|
+
|
534
|
+
|
535
|
+
class BackendNetworkInit(NamedTuple):
|
536
|
+
hostname: str | None = None
|
537
|
+
port: int | None = None
|
538
|
+
|
539
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
540
|
+
return tensor_worker.BackendNetworkInit()
|
541
|
+
|
542
|
+
|
543
|
+
class BackendNetworkPointToPointInit(NamedTuple):
|
544
|
+
from_stream: StreamRef
|
545
|
+
to_stream: StreamRef
|
546
|
+
|
547
|
+
def to_rust_message(self) -> tensor_worker.WorkerMessage:
|
548
|
+
return tensor_worker.BackendNetworkPointToPointInit(
|
549
|
+
from_stream=tensor_worker.StreamRef(id=self.from_stream.ref),
|
550
|
+
to_stream=tensor_worker.StreamRef(id=self.to_stream.ref),
|
551
|
+
)
|
552
|
+
|
553
|
+
|
554
|
+
# TODO: This is not supported on the rust side and might be only needed for remote funcs
|
555
|
+
class DebuggerRead(NamedTuple):
|
556
|
+
requested: int
|
557
|
+
|
558
|
+
|
559
|
+
# TODO: This is not supported on the rust side and might be only needed for remote funcs
|
560
|
+
class DebuggerWrite(NamedTuple):
|
561
|
+
payload: bytes
|
562
|
+
|
563
|
+
|
564
|
+
# TODO: This is not supported on the rust side and might be only needed for remote funcs
|
565
|
+
class DebuggerMessage(NamedTuple):
|
566
|
+
stream_id: int
|
567
|
+
action: Literal["paused", "attach", "detach"] | DebuggerRead | DebuggerWrite
|
568
|
+
|
569
|
+
|
570
|
+
# TODO: Might need to be supported differently through typed worker exceptions
|
571
|
+
class DependentOnError(Exception):
|
572
|
+
def __init__(self, ident: int) -> None:
|
573
|
+
self.ident = ident
|
@@ -0,0 +1,41 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
from contextlib import contextmanager
|
9
|
+
from typing import Generator, Optional
|
10
|
+
|
11
|
+
import monarch.common._C # @manual=//monarch/python/monarch/common:_C
|
12
|
+
import torch
|
13
|
+
|
14
|
+
monarch.common._C.patch_cuda()
|
15
|
+
|
16
|
+
_mock_cuda_stream: Optional[torch.cuda.Stream] = None
|
17
|
+
|
18
|
+
|
19
|
+
def get_mock_cuda_stream() -> torch.cuda.Stream:
|
20
|
+
global _mock_cuda_stream
|
21
|
+
if _mock_cuda_stream is None:
|
22
|
+
_mock_cuda_stream = torch.cuda.Stream()
|
23
|
+
return _mock_cuda_stream
|
24
|
+
|
25
|
+
|
26
|
+
@contextmanager
|
27
|
+
def mock_cuda_guard() -> Generator[None, None, None]:
|
28
|
+
try:
|
29
|
+
with torch.cuda.stream(get_mock_cuda_stream()):
|
30
|
+
monarch.common._C.mock_cuda()
|
31
|
+
yield
|
32
|
+
finally:
|
33
|
+
monarch.common._C.unmock_cuda()
|
34
|
+
|
35
|
+
|
36
|
+
def mock_cuda() -> None:
|
37
|
+
monarch.common._C.mock_cuda()
|
38
|
+
|
39
|
+
|
40
|
+
def unmock_cuda() -> None:
|
41
|
+
monarch.common._C.unmock_cuda()
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import functools
|
9
|
+
import itertools
|
10
|
+
import os
|
11
|
+
from typing import Any, Iterator
|
12
|
+
|
13
|
+
import torch
|
14
|
+
from torch._subclasses.fake_tensor import FakeTensor
|
15
|
+
from torch.utils._pytree import register_pytree_node
|
16
|
+
from torch.utils.weak import WeakTensorKeyDictionary
|
17
|
+
|
18
|
+
_key_table: WeakTensorKeyDictionary = WeakTensorKeyDictionary()
|
19
|
+
_key_counter: Iterator[int] = itertools.count(1)
|
20
|
+
|
21
|
+
# check that we are for sure running on the worker process
|
22
|
+
_on_worker = os.environ.get("LOCAL_RANK") is not None
|
23
|
+
|
24
|
+
|
25
|
+
def wrap_create(create, xs):
|
26
|
+
return create(xs[0])
|
27
|
+
|
28
|
+
|
29
|
+
class OpaqueRef:
|
30
|
+
"""
|
31
|
+
OpaqueRef is a reference to an object that is only resolvable on the worker
|
32
|
+
This is used to pass objects from the controller to the worker across User Defined Functions
|
33
|
+
|
34
|
+
Example::
|
35
|
+
def init_udf_worker():
|
36
|
+
model = nn.Linear(3, 4)
|
37
|
+
model_ref = OpaqueRef(model)
|
38
|
+
return model_ref
|
39
|
+
|
40
|
+
def run_step_worker(model_ref: OpaqueRef):
|
41
|
+
model = model_ref.value
|
42
|
+
# do something with model (e.g. forward pass
|
43
|
+
|
44
|
+
# on Controller
|
45
|
+
model_ref = init_udf()
|
46
|
+
run_step(model_ref)
|
47
|
+
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(self, value=None):
|
51
|
+
self._key = torch.tensor(next(_key_counter), dtype=torch.int64)
|
52
|
+
self.check_worker("create")
|
53
|
+
_key_table[self._key] = value
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def _create(cls, key: torch.Tensor):
|
57
|
+
c = cls.__new__(cls)
|
58
|
+
c._key = key
|
59
|
+
return c
|
60
|
+
|
61
|
+
# like NamedTuple, just pass the call to reconstruct this
|
62
|
+
# rather than the dict. This also ensures the OpaqueObject
|
63
|
+
# subclass degrades into this class when sent to the worker
|
64
|
+
def __reduce_ex__(self, protocol):
|
65
|
+
return OpaqueRef._create, (self._key,)
|
66
|
+
|
67
|
+
def __repr__(self):
|
68
|
+
return f"OpaqueRef({repr(self._key)})"
|
69
|
+
|
70
|
+
@property
|
71
|
+
def value(self) -> Any:
|
72
|
+
self.check_worker("access")
|
73
|
+
return _key_table[self._key]
|
74
|
+
|
75
|
+
@value.setter
|
76
|
+
def value(self, v: Any) -> None:
|
77
|
+
self.check_worker("set")
|
78
|
+
_key_table[self._key] = v
|
79
|
+
|
80
|
+
def check_worker(self, what):
|
81
|
+
# both checks are needed for the case where OpaqueRef() is
|
82
|
+
# called on the client with no mesh active.
|
83
|
+
in_worker_or_propagate = _on_worker or isinstance(self._key, FakeTensor)
|
84
|
+
if not in_worker_or_propagate:
|
85
|
+
raise RuntimeError(
|
86
|
+
f"Client is attempting to {what} an OpaqueRef. This can only be done in a remote function."
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def _flatten(x: OpaqueRef):
|
91
|
+
return (x._key,), functools.partial(wrap_create, x._create)
|
92
|
+
|
93
|
+
|
94
|
+
def _unflatten(xs, ctx):
|
95
|
+
return ctx(xs)
|
96
|
+
|
97
|
+
|
98
|
+
register_pytree_node(OpaqueRef, _flatten, _unflatten)
|