torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/allocator.py
ADDED
@@ -0,0 +1,220 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
import abc
|
10
|
+
import logging
|
11
|
+
from typing import final, Optional
|
12
|
+
|
13
|
+
from monarch import ActorFuture as Future
|
14
|
+
from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
15
|
+
Alloc,
|
16
|
+
AllocSpec,
|
17
|
+
)
|
18
|
+
|
19
|
+
from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
20
|
+
LocalAllocatorBase,
|
21
|
+
ProcessAllocatorBase,
|
22
|
+
RemoteAllocatorBase,
|
23
|
+
)
|
24
|
+
|
25
|
+
ALLOC_LABEL_PROC_MESH_NAME = "procmesh.monarch.meta.com/name"
|
26
|
+
|
27
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
@final
|
31
|
+
class ProcessAllocator(ProcessAllocatorBase):
|
32
|
+
"""
|
33
|
+
An allocator that allocates by spawning local processes.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def allocate(self, spec: AllocSpec) -> Future[Alloc]:
|
37
|
+
"""
|
38
|
+
Allocate a process according to the provided spec.
|
39
|
+
|
40
|
+
Arguments:
|
41
|
+
- `spec`: The spec to allocate according to.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
- A future that will be fulfilled when the requested allocation is fulfilled.
|
45
|
+
"""
|
46
|
+
return Future(
|
47
|
+
lambda: self.allocate_nonblocking(spec),
|
48
|
+
lambda: self.allocate_blocking(spec),
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
@final
|
53
|
+
class LocalAllocator(LocalAllocatorBase):
|
54
|
+
"""
|
55
|
+
An allocator that allocates by spawning actors into the current process.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def allocate(self, spec: AllocSpec) -> Future[Alloc]:
|
59
|
+
"""
|
60
|
+
Allocate a process according to the provided spec.
|
61
|
+
|
62
|
+
Arguments:
|
63
|
+
- `spec`: The spec to allocate according to.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
- A future that will be fulfilled when the requested allocation is fulfilled.
|
67
|
+
"""
|
68
|
+
return Future(
|
69
|
+
lambda: self.allocate_nonblocking(spec),
|
70
|
+
lambda: self.allocate_blocking(spec),
|
71
|
+
)
|
72
|
+
|
73
|
+
|
74
|
+
class RemoteAllocInitializer(abc.ABC):
|
75
|
+
"""Subclass-able Python interface for `hyperactor_mesh::alloc::remoteprocess:RemoteProcessAllocInitializer`.
|
76
|
+
|
77
|
+
NOTE: changes to method signatures of this class must be made to the call-site at
|
78
|
+
`PyRemoteProcessAllocInitializer.py_initialize_alloc()` in `monarch/monarch_hyperactor/src/alloc.rs`
|
79
|
+
"""
|
80
|
+
|
81
|
+
@abc.abstractmethod
|
82
|
+
async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
|
83
|
+
"""
|
84
|
+
Return the addresses of the servers that should be used to allocate processes
|
85
|
+
for the proc mesh. The addresses should be running hyperactor's RemoteProcessAllocator.
|
86
|
+
|
87
|
+
Each address is of the form `{transport}!{addr}(:{port})`.
|
88
|
+
This is the string form of `hyperactor::channel::ChannelAddr` (Rust).
|
89
|
+
For example, `tcp!127.0.0.1:1234`.
|
90
|
+
|
91
|
+
NOTE: Currently, all the addresses must have the same transport type and port
|
92
|
+
NOTE: Although this method is currently called once at the initialization of the Allocator,
|
93
|
+
in the future this method can be called multiple times and should return the current set of
|
94
|
+
addresses that are eligible to handle allocation requests.
|
95
|
+
|
96
|
+
Arguments:
|
97
|
+
- `match_labels`: The match labels specified in `AllocSpec.AllocConstraints`. Initializer implementations
|
98
|
+
can read specific labels for matching a set of hosts that will service `allocate()` requests.
|
99
|
+
|
100
|
+
"""
|
101
|
+
...
|
102
|
+
|
103
|
+
|
104
|
+
class StaticRemoteAllocInitializer(RemoteAllocInitializer):
|
105
|
+
"""
|
106
|
+
Returns the static list of server addresses that this initializer
|
107
|
+
was constructed with on each `initialize_alloc()` call.
|
108
|
+
"""
|
109
|
+
|
110
|
+
def __init__(self, *addrs: str) -> None:
|
111
|
+
super().__init__()
|
112
|
+
self.addrs: list[str] = list(addrs)
|
113
|
+
|
114
|
+
async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
|
115
|
+
_ = match_labels # Suppress unused variable warning
|
116
|
+
return list(self.addrs)
|
117
|
+
|
118
|
+
|
119
|
+
class TorchXRemoteAllocInitializer(RemoteAllocInitializer):
|
120
|
+
"""
|
121
|
+
For monarch runtimes running as a job on a supported scheduler.
|
122
|
+
Such runtimes are typically launched using the monarch CLI (e.g `monarch create --scheduler slurm ...`).
|
123
|
+
|
124
|
+
Returns the server addresses of a specific monarch runtime by using TorchX's status API
|
125
|
+
to get the hostnames of the nodes.
|
126
|
+
"""
|
127
|
+
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
server_handle: str,
|
131
|
+
/,
|
132
|
+
transport: Optional[str] = None,
|
133
|
+
port: Optional[int] = None,
|
134
|
+
) -> None:
|
135
|
+
"""
|
136
|
+
NOTE: If `transport` and `port` specified, they are used over the `transport` and `port`
|
137
|
+
information that is tagged as metadata on the server's job. This is useful in two specific
|
138
|
+
situations:
|
139
|
+
1) The job was NOT created wit monarch CLI (hence no metadata tags exist)
|
140
|
+
2) The scheduler does not support job metadata tagging
|
141
|
+
|
142
|
+
Arguments:
|
143
|
+
- `server_handle`: points to a monarch runtime. Of the form `{scheduler}://{namespace}/{job_id}`.
|
144
|
+
the `{namespace}` can be empty if not configured (e.g. `slurm:///1234` - notice the triple slashes).
|
145
|
+
- `transport`: the channel transport that should be used to connect to the remote process allocator address
|
146
|
+
- `port`: the port that the remote process allocator is running on
|
147
|
+
|
148
|
+
"""
|
149
|
+
self.server_handle = server_handle
|
150
|
+
self.transport = transport
|
151
|
+
self.port = port
|
152
|
+
|
153
|
+
async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
|
154
|
+
# lazy import since torchx-fb is not included in `fbcode//monarch/python/monarch:monarch.whl`
|
155
|
+
# nor any of the base conda environments
|
156
|
+
from monarch.tools.commands import server_ready
|
157
|
+
|
158
|
+
mesh_name = match_labels.get(ALLOC_LABEL_PROC_MESH_NAME)
|
159
|
+
|
160
|
+
server = await server_ready(self.server_handle)
|
161
|
+
|
162
|
+
# job does not exist or it is in a terminal state (SUCCEEDED, FAILED, CANCELLED)
|
163
|
+
if not (server and server.is_running):
|
164
|
+
raise ValueError(
|
165
|
+
f"{self.server_handle} does not exist or is in a terminal state"
|
166
|
+
)
|
167
|
+
|
168
|
+
if not mesh_name:
|
169
|
+
logger.info(
|
170
|
+
"no match label `%s` specified in alloc constraints",
|
171
|
+
ALLOC_LABEL_PROC_MESH_NAME,
|
172
|
+
)
|
173
|
+
|
174
|
+
num_meshes = len(server.meshes)
|
175
|
+
|
176
|
+
if num_meshes == 1:
|
177
|
+
logger.info(
|
178
|
+
"found a single proc mesh `%s` in %s, will allocate on it",
|
179
|
+
server.meshes[0].name,
|
180
|
+
self.server_handle,
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
raise RuntimeError(
|
184
|
+
f"{num_meshes} proc meshes in {self.server_handle},"
|
185
|
+
f" please specify the mesh name as a match label `{ALLOC_LABEL_PROC_MESH_NAME}`"
|
186
|
+
f" in allocation constraints of the alloc spec"
|
187
|
+
)
|
188
|
+
mesh = server.meshes[0]
|
189
|
+
else:
|
190
|
+
mesh = server.get_mesh_spec(mesh_name)
|
191
|
+
|
192
|
+
server_addrs = mesh.server_addrs(self.transport, self.port)
|
193
|
+
|
194
|
+
logger.info(
|
195
|
+
"initializing alloc on remote allocator addresses: %s", server_addrs
|
196
|
+
)
|
197
|
+
return server_addrs
|
198
|
+
|
199
|
+
|
200
|
+
@final
|
201
|
+
class RemoteAllocator(RemoteAllocatorBase):
|
202
|
+
"""
|
203
|
+
An allocator that allocates by spawning actors on a remote host.
|
204
|
+
The remote host must be running hyperactor's remote-process-allocator.
|
205
|
+
"""
|
206
|
+
|
207
|
+
def allocate(self, spec: AllocSpec) -> Future[Alloc]:
|
208
|
+
"""
|
209
|
+
Allocate a process according to the provided spec.
|
210
|
+
|
211
|
+
Arguments:
|
212
|
+
- `spec`: The spec to allocate according to.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
- A future that will be fulfilled when the requested allocation is fulfilled.
|
216
|
+
"""
|
217
|
+
return Future(
|
218
|
+
lambda: self.allocate_nonblocking(spec),
|
219
|
+
lambda: self.allocate_blocking(spec),
|
220
|
+
)
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
"""
|
8
|
+
This is the main function for the boostrapping a new process using a ProcessAllocator.
|
9
|
+
"""
|
10
|
+
|
11
|
+
import asyncio
|
12
|
+
import importlib.resources
|
13
|
+
import logging
|
14
|
+
import os
|
15
|
+
import sys
|
16
|
+
|
17
|
+
# Import torch to avoid import-time races if a spawned actor tries to import torch.
|
18
|
+
import torch # noqa[F401]
|
19
|
+
|
20
|
+
|
21
|
+
async def main():
|
22
|
+
from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main
|
23
|
+
|
24
|
+
await bootstrap_main()
|
25
|
+
|
26
|
+
|
27
|
+
def invoke_main():
|
28
|
+
# if this is invoked with the stdout piped somewhere, then print
|
29
|
+
# changes its buffering behavior. So we default to the standard
|
30
|
+
# behavior of std out as if it were a terminal.
|
31
|
+
sys.stdout.reconfigure(line_buffering=True)
|
32
|
+
global bootstrap_main
|
33
|
+
|
34
|
+
# TODO: figure out what from worker_main.py we should reproduce here.
|
35
|
+
from monarch.telemetry import TracingForwarder
|
36
|
+
|
37
|
+
if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
|
38
|
+
raise RuntimeError("Error during bootstrap for testing")
|
39
|
+
|
40
|
+
# forward logs to rust tracing. Defaults to on.
|
41
|
+
if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
|
42
|
+
logging.root.addHandler(TracingForwarder(level=logging.DEBUG))
|
43
|
+
|
44
|
+
try:
|
45
|
+
with (
|
46
|
+
importlib.resources.path("monarch", "py-spy") as pyspy,
|
47
|
+
):
|
48
|
+
if pyspy.exists():
|
49
|
+
os.environ["PYSPY_BIN"] = str(pyspy)
|
50
|
+
# fallback to using local py-spy
|
51
|
+
except Exception as e:
|
52
|
+
logging.warning(f"Failed to set up py-spy: {e}")
|
53
|
+
|
54
|
+
# Start an event loop for PythonActors to use.
|
55
|
+
asyncio.run(main())
|
56
|
+
|
57
|
+
|
58
|
+
if __name__ == "__main__":
|
59
|
+
invoke_main() # pragma: no cover
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
"""
|
9
|
+
Builtins for Monarch is a set of remote function defintions for PyTorch functions and other utilities.
|
10
|
+
"""
|
11
|
+
|
12
|
+
from .log import log_remote, set_logging_level_remote
|
13
|
+
|
14
|
+
__all__ = ["log_remote", "set_logging_level_remote"]
|
monarch/builtins/log.py
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import logging
|
8
|
+
|
9
|
+
from monarch.common.remote import remote
|
10
|
+
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
@remote(propagate="inspect")
|
16
|
+
def log_remote(*args, level: int = logging.WARNING, **kwargs) -> None:
|
17
|
+
logger.log(level, *args, **kwargs)
|
18
|
+
|
19
|
+
|
20
|
+
@remote(propagate="inspect")
|
21
|
+
def set_logging_level_remote(level: int) -> None:
|
22
|
+
logger.setLevel(level)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre strict
|
8
|
+
from typing import Callable
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from monarch.common.remote import remote
|
12
|
+
|
13
|
+
|
14
|
+
@remote(propagate="inspect")
|
15
|
+
def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
|
16
|
+
torch.manual_seed(seed ^ process_idx)
|
17
|
+
|
18
|
+
|
19
|
+
@remote(propagate=lambda: torch.zeros(1))
|
20
|
+
def get_rng_state_remote() -> torch.Tensor:
|
21
|
+
return torch.get_rng_state()
|
22
|
+
|
23
|
+
|
24
|
+
@remote(propagate="inspect")
|
25
|
+
def set_rng_state_remote(new_state: torch.Tensor) -> None:
|
26
|
+
torch.set_rng_state(new_state)
|
27
|
+
|
28
|
+
|
29
|
+
def _run_no_return(f: Callable) -> None:
|
30
|
+
f()
|
31
|
+
return None
|
32
|
+
|
33
|
+
|
34
|
+
# TODO: return result when uint64 is supported from remote function
|
35
|
+
@remote(propagate=lambda: _run_no_return(torch.seed))
|
36
|
+
def seed_remote() -> None:
|
37
|
+
torch.seed()
|
38
|
+
|
39
|
+
|
40
|
+
# same underlying implementation as seed_remote (torch.seed)
|
41
|
+
# TODO: return result when uint64 is supported from remote function
|
42
|
+
@remote(propagate=lambda: _run_no_return(torch.random.seed))
|
43
|
+
def random_seed_remote() -> None:
|
44
|
+
torch.random.seed()
|
45
|
+
|
46
|
+
|
47
|
+
@remote(propagate="inspect")
|
48
|
+
def manual_seed_cuda_remote(seed: int) -> None:
|
49
|
+
torch.cuda.manual_seed(seed)
|
50
|
+
|
51
|
+
|
52
|
+
@remote(propagate="inspect")
|
53
|
+
def manual_seed_all_cuda_remote(seed: int) -> None:
|
54
|
+
torch.cuda.manual_seed_all(seed)
|
55
|
+
|
56
|
+
|
57
|
+
@remote(propagate=lambda: [torch.zeros(1)])
|
58
|
+
def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
|
59
|
+
return torch.cuda.get_rng_state_all()
|
60
|
+
|
61
|
+
|
62
|
+
@remote(propagate="inspect")
|
63
|
+
def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
|
64
|
+
torch.cuda.set_rng_state_all(states)
|
65
|
+
|
66
|
+
|
67
|
+
# initial_seed may sometimes return a uint64 which currenly can't be unwrapped by the framework
|
68
|
+
# def initial_seed_remote() -> int: ...
|
@@ -0,0 +1,257 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import importlib
|
9
|
+
import logging
|
10
|
+
|
11
|
+
from contextlib import contextmanager
|
12
|
+
from typing import Dict, List, Optional, Type, Union
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from monarch.common.process_group import SingleControllerProcessGroupWrapper
|
16
|
+
|
17
|
+
from monarch.common.remote import DummyProcessGroup, remote, RemoteProcessGroup
|
18
|
+
|
19
|
+
from torch import autograd
|
20
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
def _controller_autograd_function_forward(
|
26
|
+
autograd_function_class: Type[autograd.Function],
|
27
|
+
):
|
28
|
+
"""
|
29
|
+
Decorator for authoring a controller remote function wrapper that wraps an autograd.Function forward.
|
30
|
+
Sets up the autograd.function.FunctionCtx() to send over the wire and sets up the original ctx
|
31
|
+
with the ctx_tensors and ctx attributes.
|
32
|
+
"""
|
33
|
+
|
34
|
+
def decorator(func):
|
35
|
+
def wrapper(ctx, *args):
|
36
|
+
# Need dummy context because cannot pickle autograd.FunctionBackward
|
37
|
+
wire_ctx = autograd.function.FunctionCtx()
|
38
|
+
# Track arg tensors that have requires_grad
|
39
|
+
arg_tensors, _ = tree_flatten(args)
|
40
|
+
wire_ctx.args_requires_grads = []
|
41
|
+
for i, arg in enumerate(arg_tensors):
|
42
|
+
if isinstance(arg, torch.Tensor) and arg.requires_grad:
|
43
|
+
wire_ctx.args_requires_grads.append(i)
|
44
|
+
out, ctx_attrs, ctx_tensors = func(
|
45
|
+
autograd_function_class.__module__,
|
46
|
+
autograd_function_class.__name__,
|
47
|
+
wire_ctx,
|
48
|
+
*args,
|
49
|
+
)
|
50
|
+
if ctx is None:
|
51
|
+
return out
|
52
|
+
ctx.save_for_backward(*ctx_tensors)
|
53
|
+
ctx.attr_names = ctx_attrs.keys()
|
54
|
+
ctx.pg_names = []
|
55
|
+
dim_to_remote_group = {}
|
56
|
+
for arg in args:
|
57
|
+
if isinstance(arg, RemoteProcessGroup):
|
58
|
+
dim_to_remote_group[arg.dims] = arg
|
59
|
+
for name, v in ctx_attrs.items():
|
60
|
+
if isinstance(v, DummyProcessGroup):
|
61
|
+
setattr(ctx, name, dim_to_remote_group[v.dims])
|
62
|
+
ctx.pg_names.append(name)
|
63
|
+
else:
|
64
|
+
setattr(ctx, name, v)
|
65
|
+
|
66
|
+
return out
|
67
|
+
|
68
|
+
return wrapper
|
69
|
+
|
70
|
+
return decorator
|
71
|
+
|
72
|
+
|
73
|
+
def _controller_autograd_function_backward(
|
74
|
+
autograd_function_class: Type[autograd.Function],
|
75
|
+
):
|
76
|
+
"""
|
77
|
+
Decorator for authoring a controller remote function wrapper that wraps an autograd.Function backward.
|
78
|
+
Manually sets up wire_ctx with ctx tensors and attributes.
|
79
|
+
"""
|
80
|
+
|
81
|
+
def decorator(func):
|
82
|
+
def wrapper(ctx, *grad_outputs):
|
83
|
+
# Manually set up wire_ctx with ctx tensors and attributes
|
84
|
+
wire_ctx = autograd.function.FunctionCtx()
|
85
|
+
# send over tensor references with ctx_tensors
|
86
|
+
ctx_tensors = ctx.saved_tensors
|
87
|
+
wire_ctx.save_for_backward(ctx_tensors)
|
88
|
+
for name in ctx.attr_names:
|
89
|
+
setattr(wire_ctx, name, getattr(ctx, name))
|
90
|
+
process_groups = {name: getattr(ctx, name) for name in ctx.pg_names}
|
91
|
+
|
92
|
+
return func(
|
93
|
+
autograd_function_class.__module__,
|
94
|
+
autograd_function_class.__name__,
|
95
|
+
wire_ctx,
|
96
|
+
ctx_tensors,
|
97
|
+
# explicitly pass pg to worker
|
98
|
+
process_groups,
|
99
|
+
*grad_outputs,
|
100
|
+
)
|
101
|
+
|
102
|
+
return wrapper
|
103
|
+
|
104
|
+
return decorator
|
105
|
+
|
106
|
+
|
107
|
+
@contextmanager
|
108
|
+
def manage_grads(list_of_tensors, indices):
|
109
|
+
try:
|
110
|
+
for i in indices:
|
111
|
+
assert list_of_tensors[i].is_leaf, "can't have non-leaf tensors on worker"
|
112
|
+
list_of_tensors[i].requires_grad = True
|
113
|
+
yield list_of_tensors
|
114
|
+
finally:
|
115
|
+
for i in indices:
|
116
|
+
list_of_tensors[i].requires_grad = False
|
117
|
+
|
118
|
+
|
119
|
+
def worker_autograd_function_forward(
|
120
|
+
module_name: str,
|
121
|
+
class_name: str,
|
122
|
+
ctx: autograd.function.FunctionCtx,
|
123
|
+
*args,
|
124
|
+
**kwargs,
|
125
|
+
):
|
126
|
+
# Capture initial state of ctx attributes
|
127
|
+
before = set()
|
128
|
+
before.add("to_save")
|
129
|
+
for attr in dir(ctx):
|
130
|
+
if not attr.startswith("_"):
|
131
|
+
before.add(attr)
|
132
|
+
|
133
|
+
# Set tensors that require grad from additional arg
|
134
|
+
flatten_args, spec = tree_flatten(args)
|
135
|
+
# pyre-ignore
|
136
|
+
with manage_grads(flatten_args, ctx.args_requires_grads) as args_with_grad:
|
137
|
+
args = tree_unflatten(args_with_grad, spec)
|
138
|
+
|
139
|
+
# Call the original forward function
|
140
|
+
module = importlib.import_module(module_name)
|
141
|
+
class_ = getattr(module, class_name)
|
142
|
+
with torch.no_grad():
|
143
|
+
out = class_.forward(ctx, *args, **kwargs)
|
144
|
+
|
145
|
+
# Capture state of ctx attributes after the function call
|
146
|
+
after = set()
|
147
|
+
for attr in dir(ctx):
|
148
|
+
if not attr.startswith("_"):
|
149
|
+
after.add(attr)
|
150
|
+
ctx_attrs = {attr: getattr(ctx, attr) for attr in after - before}
|
151
|
+
ctx_attrs["ctx_requires_grads"] = []
|
152
|
+
|
153
|
+
if not hasattr(ctx, "to_save"):
|
154
|
+
to_save = []
|
155
|
+
else:
|
156
|
+
# pyre-ignore
|
157
|
+
for idx, t in enumerate(ctx.to_save):
|
158
|
+
# generally, workers should not have requires_grad set. Set to correct state after
|
159
|
+
# but record requires_grad for next forward
|
160
|
+
if isinstance(t, torch.Tensor) and t.requires_grad and t.is_leaf:
|
161
|
+
t.requires_grad = False
|
162
|
+
ctx_attrs["ctx_requires_grads"].append(idx)
|
163
|
+
to_save = ctx.to_save
|
164
|
+
return out, ctx_attrs, to_save
|
165
|
+
|
166
|
+
|
167
|
+
def worker_autograd_function_backward(
|
168
|
+
module_name: str,
|
169
|
+
class_name: str,
|
170
|
+
ctx: autograd.function.FunctionCtx,
|
171
|
+
ctx_tensors: List[torch.Tensor],
|
172
|
+
process_groups: Dict[
|
173
|
+
str, Union[SingleControllerProcessGroupWrapper, DummyProcessGroup]
|
174
|
+
],
|
175
|
+
*grad_outputs: torch.Tensor,
|
176
|
+
):
|
177
|
+
# set correct requires_grad state pre backward
|
178
|
+
# pyre-ignore
|
179
|
+
with manage_grads(ctx_tensors, ctx.ctx_requires_grads) as ctx_grad_tensors:
|
180
|
+
# for i in ctx.ctx_requires_grads:
|
181
|
+
# ctx_tensors[i].requires_grad = True
|
182
|
+
if ctx_grad_tensors:
|
183
|
+
# pyre-ignore
|
184
|
+
ctx.saved_tensors = ctx_grad_tensors
|
185
|
+
for name, v in process_groups.items():
|
186
|
+
setattr(ctx, name, v)
|
187
|
+
# Call the original backward function
|
188
|
+
module = importlib.import_module(module_name)
|
189
|
+
class_ = getattr(module, class_name)
|
190
|
+
with torch.no_grad():
|
191
|
+
out = class_.backward(ctx, *grad_outputs)
|
192
|
+
return out
|
193
|
+
|
194
|
+
|
195
|
+
forward_remote_fn = remote(
|
196
|
+
"monarch.cached_remote_function.worker_autograd_function_forward"
|
197
|
+
)
|
198
|
+
|
199
|
+
backward_remote_fn = remote(
|
200
|
+
"monarch.cached_remote_function.worker_autograd_function_backward"
|
201
|
+
)
|
202
|
+
|
203
|
+
|
204
|
+
class RemoteAutogradFunction(autograd.Function):
|
205
|
+
"""
|
206
|
+
New autograd.Function (custom forward/backward) that will run on the worker as a UDF RemoteFunction
|
207
|
+
|
208
|
+
|
209
|
+
Example::
|
210
|
+
my_remote_autograd_function = remote_autograd_function(my_custom_autograd_function)
|
211
|
+
"""
|
212
|
+
|
213
|
+
@staticmethod
|
214
|
+
def forward(ctx, *args):
|
215
|
+
raise NotImplementedError()
|
216
|
+
|
217
|
+
@staticmethod
|
218
|
+
def backward(ctx, *grads):
|
219
|
+
raise NotImplementedError()
|
220
|
+
|
221
|
+
|
222
|
+
def remote_autograd_function(
|
223
|
+
target_class: Type[autograd.Function], name: Optional[str] = None
|
224
|
+
) -> Type[RemoteAutogradFunction]:
|
225
|
+
"""
|
226
|
+
Returns a new autograd.Function (custom forward/backward) that will run on the worker as a UDF RemoteFunction
|
227
|
+
Logic is done on the controller (e.g., Dtensors set up and saved for backward).
|
228
|
+
The autograd.function.FunctionCtx() is sent over the wire to the worker.
|
229
|
+
Special handling is done for ctx_tensors, requires_grad fo tensors and process groups.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
target_class: autograd.Function class to be run remotely
|
233
|
+
name: name of the new autograd.Function to be called on the worker
|
234
|
+
"""
|
235
|
+
if issubclass(target_class, RemoteAutogradFunction):
|
236
|
+
logging.warning(
|
237
|
+
f"{target_class} is already a autograd.Function UDF! You are likely monkey-patching too many times"
|
238
|
+
)
|
239
|
+
return target_class
|
240
|
+
assert issubclass(
|
241
|
+
target_class, autograd.Function
|
242
|
+
), f"{target_class} is not a torch.autograd.Function!"
|
243
|
+
if name is None:
|
244
|
+
name = f"Remote_{target_class.__name__}"
|
245
|
+
|
246
|
+
return type(
|
247
|
+
name,
|
248
|
+
(RemoteAutogradFunction,),
|
249
|
+
{
|
250
|
+
"forward": staticmethod(
|
251
|
+
_controller_autograd_function_forward(target_class)(forward_remote_fn)
|
252
|
+
),
|
253
|
+
"backward": staticmethod(
|
254
|
+
_controller_autograd_function_backward(target_class)(backward_remote_fn)
|
255
|
+
),
|
256
|
+
},
|
257
|
+
)
|
monarch/code_sync.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
from monarch._rust_bindings.monarch_extension.code_sync import ( # noqa: F401
|
8
|
+
RemoteWorkspace,
|
9
|
+
RsyncMeshClient,
|
10
|
+
)
|
monarch/common/_C.pyi
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
def patch_cuda() -> None: ...
|
10
|
+
def mock_cuda() -> None: ...
|
11
|
+
def unmock_cuda() -> None: ...
|
monarch/common/_C.so
ADDED
Binary file
|
File without changes
|