torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/allocator.py ADDED
@@ -0,0 +1,220 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import abc
10
+ import logging
11
+ from typing import final, Optional
12
+
13
+ from monarch import ActorFuture as Future
14
+ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
15
+ Alloc,
16
+ AllocSpec,
17
+ )
18
+
19
+ from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
20
+ LocalAllocatorBase,
21
+ ProcessAllocatorBase,
22
+ RemoteAllocatorBase,
23
+ )
24
+
25
+ ALLOC_LABEL_PROC_MESH_NAME = "procmesh.monarch.meta.com/name"
26
+
27
+ logger: logging.Logger = logging.getLogger(__name__)
28
+
29
+
30
+ @final
31
+ class ProcessAllocator(ProcessAllocatorBase):
32
+ """
33
+ An allocator that allocates by spawning local processes.
34
+ """
35
+
36
+ def allocate(self, spec: AllocSpec) -> Future[Alloc]:
37
+ """
38
+ Allocate a process according to the provided spec.
39
+
40
+ Arguments:
41
+ - `spec`: The spec to allocate according to.
42
+
43
+ Returns:
44
+ - A future that will be fulfilled when the requested allocation is fulfilled.
45
+ """
46
+ return Future(
47
+ lambda: self.allocate_nonblocking(spec),
48
+ lambda: self.allocate_blocking(spec),
49
+ )
50
+
51
+
52
+ @final
53
+ class LocalAllocator(LocalAllocatorBase):
54
+ """
55
+ An allocator that allocates by spawning actors into the current process.
56
+ """
57
+
58
+ def allocate(self, spec: AllocSpec) -> Future[Alloc]:
59
+ """
60
+ Allocate a process according to the provided spec.
61
+
62
+ Arguments:
63
+ - `spec`: The spec to allocate according to.
64
+
65
+ Returns:
66
+ - A future that will be fulfilled when the requested allocation is fulfilled.
67
+ """
68
+ return Future(
69
+ lambda: self.allocate_nonblocking(spec),
70
+ lambda: self.allocate_blocking(spec),
71
+ )
72
+
73
+
74
+ class RemoteAllocInitializer(abc.ABC):
75
+ """Subclass-able Python interface for `hyperactor_mesh::alloc::remoteprocess:RemoteProcessAllocInitializer`.
76
+
77
+ NOTE: changes to method signatures of this class must be made to the call-site at
78
+ `PyRemoteProcessAllocInitializer.py_initialize_alloc()` in `monarch/monarch_hyperactor/src/alloc.rs`
79
+ """
80
+
81
+ @abc.abstractmethod
82
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
83
+ """
84
+ Return the addresses of the servers that should be used to allocate processes
85
+ for the proc mesh. The addresses should be running hyperactor's RemoteProcessAllocator.
86
+
87
+ Each address is of the form `{transport}!{addr}(:{port})`.
88
+ This is the string form of `hyperactor::channel::ChannelAddr` (Rust).
89
+ For example, `tcp!127.0.0.1:1234`.
90
+
91
+ NOTE: Currently, all the addresses must have the same transport type and port
92
+ NOTE: Although this method is currently called once at the initialization of the Allocator,
93
+ in the future this method can be called multiple times and should return the current set of
94
+ addresses that are eligible to handle allocation requests.
95
+
96
+ Arguments:
97
+ - `match_labels`: The match labels specified in `AllocSpec.AllocConstraints`. Initializer implementations
98
+ can read specific labels for matching a set of hosts that will service `allocate()` requests.
99
+
100
+ """
101
+ ...
102
+
103
+
104
+ class StaticRemoteAllocInitializer(RemoteAllocInitializer):
105
+ """
106
+ Returns the static list of server addresses that this initializer
107
+ was constructed with on each `initialize_alloc()` call.
108
+ """
109
+
110
+ def __init__(self, *addrs: str) -> None:
111
+ super().__init__()
112
+ self.addrs: list[str] = list(addrs)
113
+
114
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
115
+ _ = match_labels # Suppress unused variable warning
116
+ return list(self.addrs)
117
+
118
+
119
+ class TorchXRemoteAllocInitializer(RemoteAllocInitializer):
120
+ """
121
+ For monarch runtimes running as a job on a supported scheduler.
122
+ Such runtimes are typically launched using the monarch CLI (e.g `monarch create --scheduler slurm ...`).
123
+
124
+ Returns the server addresses of a specific monarch runtime by using TorchX's status API
125
+ to get the hostnames of the nodes.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ server_handle: str,
131
+ /,
132
+ transport: Optional[str] = None,
133
+ port: Optional[int] = None,
134
+ ) -> None:
135
+ """
136
+ NOTE: If `transport` and `port` specified, they are used over the `transport` and `port`
137
+ information that is tagged as metadata on the server's job. This is useful in two specific
138
+ situations:
139
+ 1) The job was NOT created wit monarch CLI (hence no metadata tags exist)
140
+ 2) The scheduler does not support job metadata tagging
141
+
142
+ Arguments:
143
+ - `server_handle`: points to a monarch runtime. Of the form `{scheduler}://{namespace}/{job_id}`.
144
+ the `{namespace}` can be empty if not configured (e.g. `slurm:///1234` - notice the triple slashes).
145
+ - `transport`: the channel transport that should be used to connect to the remote process allocator address
146
+ - `port`: the port that the remote process allocator is running on
147
+
148
+ """
149
+ self.server_handle = server_handle
150
+ self.transport = transport
151
+ self.port = port
152
+
153
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
154
+ # lazy import since torchx-fb is not included in `fbcode//monarch/python/monarch:monarch.whl`
155
+ # nor any of the base conda environments
156
+ from monarch.tools.commands import server_ready
157
+
158
+ mesh_name = match_labels.get(ALLOC_LABEL_PROC_MESH_NAME)
159
+
160
+ server = await server_ready(self.server_handle)
161
+
162
+ # job does not exist or it is in a terminal state (SUCCEEDED, FAILED, CANCELLED)
163
+ if not (server and server.is_running):
164
+ raise ValueError(
165
+ f"{self.server_handle} does not exist or is in a terminal state"
166
+ )
167
+
168
+ if not mesh_name:
169
+ logger.info(
170
+ "no match label `%s` specified in alloc constraints",
171
+ ALLOC_LABEL_PROC_MESH_NAME,
172
+ )
173
+
174
+ num_meshes = len(server.meshes)
175
+
176
+ if num_meshes == 1:
177
+ logger.info(
178
+ "found a single proc mesh `%s` in %s, will allocate on it",
179
+ server.meshes[0].name,
180
+ self.server_handle,
181
+ )
182
+ else:
183
+ raise RuntimeError(
184
+ f"{num_meshes} proc meshes in {self.server_handle},"
185
+ f" please specify the mesh name as a match label `{ALLOC_LABEL_PROC_MESH_NAME}`"
186
+ f" in allocation constraints of the alloc spec"
187
+ )
188
+ mesh = server.meshes[0]
189
+ else:
190
+ mesh = server.get_mesh_spec(mesh_name)
191
+
192
+ server_addrs = mesh.server_addrs(self.transport, self.port)
193
+
194
+ logger.info(
195
+ "initializing alloc on remote allocator addresses: %s", server_addrs
196
+ )
197
+ return server_addrs
198
+
199
+
200
+ @final
201
+ class RemoteAllocator(RemoteAllocatorBase):
202
+ """
203
+ An allocator that allocates by spawning actors on a remote host.
204
+ The remote host must be running hyperactor's remote-process-allocator.
205
+ """
206
+
207
+ def allocate(self, spec: AllocSpec) -> Future[Alloc]:
208
+ """
209
+ Allocate a process according to the provided spec.
210
+
211
+ Arguments:
212
+ - `spec`: The spec to allocate according to.
213
+
214
+ Returns:
215
+ - A future that will be fulfilled when the requested allocation is fulfilled.
216
+ """
217
+ return Future(
218
+ lambda: self.allocate_nonblocking(spec),
219
+ lambda: self.allocate_blocking(spec),
220
+ )
@@ -0,0 +1,59 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ This is the main function for the boostrapping a new process using a ProcessAllocator.
9
+ """
10
+
11
+ import asyncio
12
+ import importlib.resources
13
+ import logging
14
+ import os
15
+ import sys
16
+
17
+ # Import torch to avoid import-time races if a spawned actor tries to import torch.
18
+ import torch # noqa[F401]
19
+
20
+
21
+ async def main():
22
+ from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main
23
+
24
+ await bootstrap_main()
25
+
26
+
27
+ def invoke_main():
28
+ # if this is invoked with the stdout piped somewhere, then print
29
+ # changes its buffering behavior. So we default to the standard
30
+ # behavior of std out as if it were a terminal.
31
+ sys.stdout.reconfigure(line_buffering=True)
32
+ global bootstrap_main
33
+
34
+ # TODO: figure out what from worker_main.py we should reproduce here.
35
+ from monarch.telemetry import TracingForwarder
36
+
37
+ if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
38
+ raise RuntimeError("Error during bootstrap for testing")
39
+
40
+ # forward logs to rust tracing. Defaults to on.
41
+ if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
42
+ logging.root.addHandler(TracingForwarder(level=logging.DEBUG))
43
+
44
+ try:
45
+ with (
46
+ importlib.resources.path("monarch", "py-spy") as pyspy,
47
+ ):
48
+ if pyspy.exists():
49
+ os.environ["PYSPY_BIN"] = str(pyspy)
50
+ # fallback to using local py-spy
51
+ except Exception as e:
52
+ logging.warning(f"Failed to set up py-spy: {e}")
53
+
54
+ # Start an event loop for PythonActors to use.
55
+ asyncio.run(main())
56
+
57
+
58
+ if __name__ == "__main__":
59
+ invoke_main() # pragma: no cover
@@ -0,0 +1,14 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ """
9
+ Builtins for Monarch is a set of remote function defintions for PyTorch functions and other utilities.
10
+ """
11
+
12
+ from .log import log_remote, set_logging_level_remote
13
+
14
+ __all__ = ["log_remote", "set_logging_level_remote"]
@@ -0,0 +1,22 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from monarch.common.remote import remote
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @remote(propagate="inspect")
16
+ def log_remote(*args, level: int = logging.WARNING, **kwargs) -> None:
17
+ logger.log(level, *args, **kwargs)
18
+
19
+
20
+ @remote(propagate="inspect")
21
+ def set_logging_level_remote(level: int) -> None:
22
+ logger.setLevel(level)
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre strict
8
+ from typing import Callable
9
+
10
+ import torch
11
+ from monarch.common.remote import remote
12
+
13
+
14
+ @remote(propagate="inspect")
15
+ def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
16
+ torch.manual_seed(seed ^ process_idx)
17
+
18
+
19
+ @remote(propagate=lambda: torch.zeros(1))
20
+ def get_rng_state_remote() -> torch.Tensor:
21
+ return torch.get_rng_state()
22
+
23
+
24
+ @remote(propagate="inspect")
25
+ def set_rng_state_remote(new_state: torch.Tensor) -> None:
26
+ torch.set_rng_state(new_state)
27
+
28
+
29
+ def _run_no_return(f: Callable) -> None:
30
+ f()
31
+ return None
32
+
33
+
34
+ # TODO: return result when uint64 is supported from remote function
35
+ @remote(propagate=lambda: _run_no_return(torch.seed))
36
+ def seed_remote() -> None:
37
+ torch.seed()
38
+
39
+
40
+ # same underlying implementation as seed_remote (torch.seed)
41
+ # TODO: return result when uint64 is supported from remote function
42
+ @remote(propagate=lambda: _run_no_return(torch.random.seed))
43
+ def random_seed_remote() -> None:
44
+ torch.random.seed()
45
+
46
+
47
+ @remote(propagate="inspect")
48
+ def manual_seed_cuda_remote(seed: int) -> None:
49
+ torch.cuda.manual_seed(seed)
50
+
51
+
52
+ @remote(propagate="inspect")
53
+ def manual_seed_all_cuda_remote(seed: int) -> None:
54
+ torch.cuda.manual_seed_all(seed)
55
+
56
+
57
+ @remote(propagate=lambda: [torch.zeros(1)])
58
+ def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
59
+ return torch.cuda.get_rng_state_all()
60
+
61
+
62
+ @remote(propagate="inspect")
63
+ def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
64
+ torch.cuda.set_rng_state_all(states)
65
+
66
+
67
+ # initial_seed may sometimes return a uint64 which currenly can't be unwrapped by the framework
68
+ # def initial_seed_remote() -> int: ...
@@ -0,0 +1,257 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import importlib
9
+ import logging
10
+
11
+ from contextlib import contextmanager
12
+ from typing import Dict, List, Optional, Type, Union
13
+
14
+ import torch
15
+ from monarch.common.process_group import SingleControllerProcessGroupWrapper
16
+
17
+ from monarch.common.remote import DummyProcessGroup, remote, RemoteProcessGroup
18
+
19
+ from torch import autograd
20
+ from torch.utils._pytree import tree_flatten, tree_unflatten
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _controller_autograd_function_forward(
26
+ autograd_function_class: Type[autograd.Function],
27
+ ):
28
+ """
29
+ Decorator for authoring a controller remote function wrapper that wraps an autograd.Function forward.
30
+ Sets up the autograd.function.FunctionCtx() to send over the wire and sets up the original ctx
31
+ with the ctx_tensors and ctx attributes.
32
+ """
33
+
34
+ def decorator(func):
35
+ def wrapper(ctx, *args):
36
+ # Need dummy context because cannot pickle autograd.FunctionBackward
37
+ wire_ctx = autograd.function.FunctionCtx()
38
+ # Track arg tensors that have requires_grad
39
+ arg_tensors, _ = tree_flatten(args)
40
+ wire_ctx.args_requires_grads = []
41
+ for i, arg in enumerate(arg_tensors):
42
+ if isinstance(arg, torch.Tensor) and arg.requires_grad:
43
+ wire_ctx.args_requires_grads.append(i)
44
+ out, ctx_attrs, ctx_tensors = func(
45
+ autograd_function_class.__module__,
46
+ autograd_function_class.__name__,
47
+ wire_ctx,
48
+ *args,
49
+ )
50
+ if ctx is None:
51
+ return out
52
+ ctx.save_for_backward(*ctx_tensors)
53
+ ctx.attr_names = ctx_attrs.keys()
54
+ ctx.pg_names = []
55
+ dim_to_remote_group = {}
56
+ for arg in args:
57
+ if isinstance(arg, RemoteProcessGroup):
58
+ dim_to_remote_group[arg.dims] = arg
59
+ for name, v in ctx_attrs.items():
60
+ if isinstance(v, DummyProcessGroup):
61
+ setattr(ctx, name, dim_to_remote_group[v.dims])
62
+ ctx.pg_names.append(name)
63
+ else:
64
+ setattr(ctx, name, v)
65
+
66
+ return out
67
+
68
+ return wrapper
69
+
70
+ return decorator
71
+
72
+
73
+ def _controller_autograd_function_backward(
74
+ autograd_function_class: Type[autograd.Function],
75
+ ):
76
+ """
77
+ Decorator for authoring a controller remote function wrapper that wraps an autograd.Function backward.
78
+ Manually sets up wire_ctx with ctx tensors and attributes.
79
+ """
80
+
81
+ def decorator(func):
82
+ def wrapper(ctx, *grad_outputs):
83
+ # Manually set up wire_ctx with ctx tensors and attributes
84
+ wire_ctx = autograd.function.FunctionCtx()
85
+ # send over tensor references with ctx_tensors
86
+ ctx_tensors = ctx.saved_tensors
87
+ wire_ctx.save_for_backward(ctx_tensors)
88
+ for name in ctx.attr_names:
89
+ setattr(wire_ctx, name, getattr(ctx, name))
90
+ process_groups = {name: getattr(ctx, name) for name in ctx.pg_names}
91
+
92
+ return func(
93
+ autograd_function_class.__module__,
94
+ autograd_function_class.__name__,
95
+ wire_ctx,
96
+ ctx_tensors,
97
+ # explicitly pass pg to worker
98
+ process_groups,
99
+ *grad_outputs,
100
+ )
101
+
102
+ return wrapper
103
+
104
+ return decorator
105
+
106
+
107
+ @contextmanager
108
+ def manage_grads(list_of_tensors, indices):
109
+ try:
110
+ for i in indices:
111
+ assert list_of_tensors[i].is_leaf, "can't have non-leaf tensors on worker"
112
+ list_of_tensors[i].requires_grad = True
113
+ yield list_of_tensors
114
+ finally:
115
+ for i in indices:
116
+ list_of_tensors[i].requires_grad = False
117
+
118
+
119
+ def worker_autograd_function_forward(
120
+ module_name: str,
121
+ class_name: str,
122
+ ctx: autograd.function.FunctionCtx,
123
+ *args,
124
+ **kwargs,
125
+ ):
126
+ # Capture initial state of ctx attributes
127
+ before = set()
128
+ before.add("to_save")
129
+ for attr in dir(ctx):
130
+ if not attr.startswith("_"):
131
+ before.add(attr)
132
+
133
+ # Set tensors that require grad from additional arg
134
+ flatten_args, spec = tree_flatten(args)
135
+ # pyre-ignore
136
+ with manage_grads(flatten_args, ctx.args_requires_grads) as args_with_grad:
137
+ args = tree_unflatten(args_with_grad, spec)
138
+
139
+ # Call the original forward function
140
+ module = importlib.import_module(module_name)
141
+ class_ = getattr(module, class_name)
142
+ with torch.no_grad():
143
+ out = class_.forward(ctx, *args, **kwargs)
144
+
145
+ # Capture state of ctx attributes after the function call
146
+ after = set()
147
+ for attr in dir(ctx):
148
+ if not attr.startswith("_"):
149
+ after.add(attr)
150
+ ctx_attrs = {attr: getattr(ctx, attr) for attr in after - before}
151
+ ctx_attrs["ctx_requires_grads"] = []
152
+
153
+ if not hasattr(ctx, "to_save"):
154
+ to_save = []
155
+ else:
156
+ # pyre-ignore
157
+ for idx, t in enumerate(ctx.to_save):
158
+ # generally, workers should not have requires_grad set. Set to correct state after
159
+ # but record requires_grad for next forward
160
+ if isinstance(t, torch.Tensor) and t.requires_grad and t.is_leaf:
161
+ t.requires_grad = False
162
+ ctx_attrs["ctx_requires_grads"].append(idx)
163
+ to_save = ctx.to_save
164
+ return out, ctx_attrs, to_save
165
+
166
+
167
+ def worker_autograd_function_backward(
168
+ module_name: str,
169
+ class_name: str,
170
+ ctx: autograd.function.FunctionCtx,
171
+ ctx_tensors: List[torch.Tensor],
172
+ process_groups: Dict[
173
+ str, Union[SingleControllerProcessGroupWrapper, DummyProcessGroup]
174
+ ],
175
+ *grad_outputs: torch.Tensor,
176
+ ):
177
+ # set correct requires_grad state pre backward
178
+ # pyre-ignore
179
+ with manage_grads(ctx_tensors, ctx.ctx_requires_grads) as ctx_grad_tensors:
180
+ # for i in ctx.ctx_requires_grads:
181
+ # ctx_tensors[i].requires_grad = True
182
+ if ctx_grad_tensors:
183
+ # pyre-ignore
184
+ ctx.saved_tensors = ctx_grad_tensors
185
+ for name, v in process_groups.items():
186
+ setattr(ctx, name, v)
187
+ # Call the original backward function
188
+ module = importlib.import_module(module_name)
189
+ class_ = getattr(module, class_name)
190
+ with torch.no_grad():
191
+ out = class_.backward(ctx, *grad_outputs)
192
+ return out
193
+
194
+
195
+ forward_remote_fn = remote(
196
+ "monarch.cached_remote_function.worker_autograd_function_forward"
197
+ )
198
+
199
+ backward_remote_fn = remote(
200
+ "monarch.cached_remote_function.worker_autograd_function_backward"
201
+ )
202
+
203
+
204
+ class RemoteAutogradFunction(autograd.Function):
205
+ """
206
+ New autograd.Function (custom forward/backward) that will run on the worker as a UDF RemoteFunction
207
+
208
+
209
+ Example::
210
+ my_remote_autograd_function = remote_autograd_function(my_custom_autograd_function)
211
+ """
212
+
213
+ @staticmethod
214
+ def forward(ctx, *args):
215
+ raise NotImplementedError()
216
+
217
+ @staticmethod
218
+ def backward(ctx, *grads):
219
+ raise NotImplementedError()
220
+
221
+
222
+ def remote_autograd_function(
223
+ target_class: Type[autograd.Function], name: Optional[str] = None
224
+ ) -> Type[RemoteAutogradFunction]:
225
+ """
226
+ Returns a new autograd.Function (custom forward/backward) that will run on the worker as a UDF RemoteFunction
227
+ Logic is done on the controller (e.g., Dtensors set up and saved for backward).
228
+ The autograd.function.FunctionCtx() is sent over the wire to the worker.
229
+ Special handling is done for ctx_tensors, requires_grad fo tensors and process groups.
230
+
231
+ Args:
232
+ target_class: autograd.Function class to be run remotely
233
+ name: name of the new autograd.Function to be called on the worker
234
+ """
235
+ if issubclass(target_class, RemoteAutogradFunction):
236
+ logging.warning(
237
+ f"{target_class} is already a autograd.Function UDF! You are likely monkey-patching too many times"
238
+ )
239
+ return target_class
240
+ assert issubclass(
241
+ target_class, autograd.Function
242
+ ), f"{target_class} is not a torch.autograd.Function!"
243
+ if name is None:
244
+ name = f"Remote_{target_class.__name__}"
245
+
246
+ return type(
247
+ name,
248
+ (RemoteAutogradFunction,),
249
+ {
250
+ "forward": staticmethod(
251
+ _controller_autograd_function_forward(target_class)(forward_remote_fn)
252
+ ),
253
+ "backward": staticmethod(
254
+ _controller_autograd_function_backward(target_class)(backward_remote_fn)
255
+ ),
256
+ },
257
+ )
monarch/code_sync.py ADDED
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from monarch._rust_bindings.monarch_extension.code_sync import ( # noqa: F401
8
+ RemoteWorkspace,
9
+ RsyncMeshClient,
10
+ )
monarch/common/_C.pyi ADDED
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ def patch_cuda() -> None: ...
10
+ def mock_cuda() -> None: ...
11
+ def unmock_cuda() -> None: ...
monarch/common/_C.so ADDED
Binary file
File without changes