torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
monarch/allocator.py ADDED
@@ -0,0 +1,62 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import final
8
+
9
+ from monarch import ActorFuture as Future
10
+ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
11
+ Alloc,
12
+ AllocSpec,
13
+ )
14
+
15
+ from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension
16
+ LocalAllocatorBase,
17
+ ProcessAllocatorBase,
18
+ )
19
+
20
+
21
+ @final
22
+ class ProcessAllocator(ProcessAllocatorBase):
23
+ """
24
+ An allocator that allocates by spawning local processes.
25
+ """
26
+
27
+ def allocate(self, spec: AllocSpec) -> Future[Alloc]:
28
+ """
29
+ Allocate a process according to the provided spec.
30
+
31
+ Arguments:
32
+ - `spec`: The spec to allocate according to.
33
+
34
+ Returns:
35
+ - A future that will be fulfilled when the requested allocation is fulfilled.
36
+ """
37
+ return Future(
38
+ lambda: self.allocate_nonblocking(spec),
39
+ lambda: self.allocate_blocking(spec),
40
+ )
41
+
42
+
43
+ @final
44
+ class LocalAllocator(LocalAllocatorBase):
45
+ """
46
+ An allocator that allocates by spawning actors into the current process.
47
+ """
48
+
49
+ def allocate(self, spec: AllocSpec) -> Future[Alloc]:
50
+ """
51
+ Allocate a process according to the provided spec.
52
+
53
+ Arguments:
54
+ - `spec`: The spec to allocate according to.
55
+
56
+ Returns:
57
+ - A future that will be fulfilled when the requested allocation is fulfilled.
58
+ """
59
+ return Future(
60
+ lambda: self.allocate_nonblocking(spec),
61
+ lambda: self.allocate_blocking(spec),
62
+ )
@@ -0,0 +1,75 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ This is the main function for the boostrapping a new process using a ProcessAllocator.
9
+ """
10
+
11
+ import asyncio
12
+ import importlib.resources
13
+ import logging
14
+ import os
15
+ import sys
16
+
17
+ # Import torch to avoid import-time races if a spawned actor tries to import torch.
18
+ import torch # noqa[F401]
19
+
20
+
21
+ async def main():
22
+ from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main
23
+
24
+ await bootstrap_main()
25
+
26
+
27
+ def invoke_main():
28
+ # if this is invoked with the stdout piped somewhere, then print
29
+ # changes its buffering behavior. So we default to the standard
30
+ # behavior of std out as if it were a terminal.
31
+ sys.stdout.reconfigure(line_buffering=True)
32
+ global bootstrap_main
33
+ from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
34
+ forward_to_tracing,
35
+ )
36
+
37
+ # TODO: figure out what from worker_main.py we should reproduce here.
38
+
39
+ class TracingForwarder(logging.Handler):
40
+ def emit(self, record: logging.LogRecord) -> None:
41
+ try:
42
+ forward_to_tracing(
43
+ record.getMessage(),
44
+ record.filename or "",
45
+ record.lineno or 0,
46
+ record.levelno,
47
+ )
48
+ except AttributeError:
49
+ forward_to_tracing(
50
+ record.__str__(),
51
+ record.filename or "",
52
+ record.lineno or 0,
53
+ record.levelno,
54
+ )
55
+
56
+ # forward logs to rust tracing. Defaults to on.
57
+ if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
58
+ logging.root.addHandler(TracingForwarder())
59
+
60
+ try:
61
+ with (
62
+ importlib.resources.path("monarch", "py-spy") as pyspy,
63
+ ):
64
+ if pyspy.exists():
65
+ os.environ["PYSPY_BIN"] = str(pyspy)
66
+ # fallback to using local py-spy
67
+ except Exception as e:
68
+ logging.warning(f"Failed to set up py-spy: {e}")
69
+
70
+ # Start an event loop for PythonActors to use.
71
+ asyncio.run(main())
72
+
73
+
74
+ if __name__ == "__main__":
75
+ invoke_main() # pragma: no cover
@@ -0,0 +1,14 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ """
9
+ Builtins for Monarch is a set of remote function defintions for PyTorch functions and other utilities.
10
+ """
11
+
12
+ from .log import log_remote, set_logging_level_remote
13
+
14
+ __all__ = ["log_remote", "set_logging_level_remote"]
@@ -0,0 +1,22 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from monarch.common.remote import remote
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @remote(propagate="inspect")
16
+ def log_remote(*args, level: int = logging.WARNING, **kwargs) -> None:
17
+ logger.log(level, *args, **kwargs)
18
+
19
+
20
+ @remote(propagate="inspect")
21
+ def set_logging_level_remote(level: int) -> None:
22
+ logger.setLevel(level)
@@ -0,0 +1,69 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre strict
8
+ from typing import Callable
9
+
10
+ import torch
11
+ from monarch.common.remote import remote
12
+
13
+
14
+ @remote(propagate="inspect")
15
+ def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
16
+ torch.manual_seed(seed ^ process_idx)
17
+
18
+
19
+ @remote(propagate=lambda: 0)
20
+ def initial_seed_remote() -> int:
21
+ return torch.initial_seed()
22
+
23
+
24
+ @remote(propagate=lambda: torch.zeros(1))
25
+ def get_rng_state_remote() -> torch.Tensor:
26
+ return torch.get_rng_state()
27
+
28
+
29
+ @remote(propagate="inspect")
30
+ def set_rng_state_remote(new_state: torch.Tensor) -> None:
31
+ torch.set_rng_state(new_state)
32
+
33
+
34
+ def _run_no_return(f: Callable) -> None:
35
+ f()
36
+ return None
37
+
38
+
39
+ # TODO: return result when uint64 is supported from remote function
40
+ @remote(propagate=lambda: _run_no_return(torch.seed))
41
+ def seed_remote() -> None:
42
+ torch.seed()
43
+
44
+
45
+ # same underlying implementation as seed_remote (torch.seed)
46
+ # TODO: return result when uint64 is supported from remote function
47
+ @remote(propagate=lambda: _run_no_return(torch.random.seed))
48
+ def random_seed_remote() -> None:
49
+ torch.random.seed()
50
+
51
+
52
+ @remote(propagate="inspect")
53
+ def manual_seed_cuda_remote(seed: int) -> None:
54
+ torch.cuda.manual_seed(seed)
55
+
56
+
57
+ @remote(propagate="inspect")
58
+ def manual_seed_all_cuda_remote(seed: int) -> None:
59
+ torch.cuda.manual_seed_all(seed)
60
+
61
+
62
+ @remote(propagate=lambda: [torch.zeros(1)])
63
+ def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
64
+ return torch.cuda.get_rng_state_all()
65
+
66
+
67
+ @remote(propagate="inspect")
68
+ def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
69
+ torch.cuda.set_rng_state_all(states)
@@ -0,0 +1,257 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import importlib
9
+ import logging
10
+
11
+ from contextlib import contextmanager
12
+ from typing import Dict, List, Optional, Type, Union
13
+
14
+ import torch
15
+ from monarch.common.process_group import SingleControllerProcessGroupWrapper
16
+
17
+ from monarch.common.remote import DummyProcessGroup, remote, RemoteProcessGroup
18
+
19
+ from torch import autograd
20
+ from torch.utils._pytree import tree_flatten, tree_unflatten
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _controller_autograd_function_forward(
26
+ autograd_function_class: Type[autograd.Function],
27
+ ):
28
+ """
29
+ Decorator for authoring a controller remote function wrapper that wraps an autograd.Function forward.
30
+ Sets up the autograd.function.FunctionCtx() to send over the wire and sets up the original ctx
31
+ with the ctx_tensors and ctx attributes.
32
+ """
33
+
34
+ def decorator(func):
35
+ def wrapper(ctx, *args):
36
+ # Need dummy context because cannot pickle autograd.FunctionBackward
37
+ wire_ctx = autograd.function.FunctionCtx()
38
+ # Track arg tensors that have requires_grad
39
+ arg_tensors, _ = tree_flatten(args)
40
+ wire_ctx.args_requires_grads = []
41
+ for i, arg in enumerate(arg_tensors):
42
+ if isinstance(arg, torch.Tensor) and arg.requires_grad:
43
+ wire_ctx.args_requires_grads.append(i)
44
+ out, ctx_attrs, ctx_tensors = func(
45
+ autograd_function_class.__module__,
46
+ autograd_function_class.__name__,
47
+ wire_ctx,
48
+ *args,
49
+ )
50
+ if ctx is None:
51
+ return out
52
+ ctx.save_for_backward(*ctx_tensors)
53
+ ctx.attr_names = ctx_attrs.keys()
54
+ ctx.pg_names = []
55
+ dim_to_remote_group = {}
56
+ for arg in args:
57
+ if isinstance(arg, RemoteProcessGroup):
58
+ dim_to_remote_group[arg.dims] = arg
59
+ for name, v in ctx_attrs.items():
60
+ if isinstance(v, DummyProcessGroup):
61
+ setattr(ctx, name, dim_to_remote_group[v.dims])
62
+ ctx.pg_names.append(name)
63
+ else:
64
+ setattr(ctx, name, v)
65
+
66
+ return out
67
+
68
+ return wrapper
69
+
70
+ return decorator
71
+
72
+
73
+ def _controller_autograd_function_backward(
74
+ autograd_function_class: Type[autograd.Function],
75
+ ):
76
+ """
77
+ Decorator for authoring a controller remote function wrapper that wraps an autograd.Function backward.
78
+ Manually sets up wire_ctx with ctx tensors and attributes.
79
+ """
80
+
81
+ def decorator(func):
82
+ def wrapper(ctx, *grad_outputs):
83
+ # Manually set up wire_ctx with ctx tensors and attributes
84
+ wire_ctx = autograd.function.FunctionCtx()
85
+ # send over tensor references with ctx_tensors
86
+ ctx_tensors = ctx.saved_tensors
87
+ wire_ctx.save_for_backward(ctx_tensors)
88
+ for name in ctx.attr_names:
89
+ setattr(wire_ctx, name, getattr(ctx, name))
90
+ process_groups = {name: getattr(ctx, name) for name in ctx.pg_names}
91
+
92
+ return func(
93
+ autograd_function_class.__module__,
94
+ autograd_function_class.__name__,
95
+ wire_ctx,
96
+ ctx_tensors,
97
+ # explicitly pass pg to worker
98
+ process_groups,
99
+ *grad_outputs,
100
+ )
101
+
102
+ return wrapper
103
+
104
+ return decorator
105
+
106
+
107
+ @contextmanager
108
+ def manage_grads(list_of_tensors, indices):
109
+ try:
110
+ for i in indices:
111
+ assert list_of_tensors[i].is_leaf, "can't have non-leaf tensors on worker"
112
+ list_of_tensors[i].requires_grad = True
113
+ yield list_of_tensors
114
+ finally:
115
+ for i in indices:
116
+ list_of_tensors[i].requires_grad = False
117
+
118
+
119
+ def worker_autograd_function_forward(
120
+ module_name: str,
121
+ class_name: str,
122
+ ctx: autograd.function.FunctionCtx,
123
+ *args,
124
+ **kwargs,
125
+ ):
126
+ # Capture initial state of ctx attributes
127
+ before = set()
128
+ before.add("to_save")
129
+ for attr in dir(ctx):
130
+ if not attr.startswith("_"):
131
+ before.add(attr)
132
+
133
+ # Set tensors that require grad from additional arg
134
+ flatten_args, spec = tree_flatten(args)
135
+ # pyre-ignore
136
+ with manage_grads(flatten_args, ctx.args_requires_grads) as args_with_grad:
137
+ args = tree_unflatten(args_with_grad, spec)
138
+
139
+ # Call the original forward function
140
+ module = importlib.import_module(module_name)
141
+ class_ = getattr(module, class_name)
142
+ with torch.no_grad():
143
+ out = class_.forward(ctx, *args, **kwargs)
144
+
145
+ # Capture state of ctx attributes after the function call
146
+ after = set()
147
+ for attr in dir(ctx):
148
+ if not attr.startswith("_"):
149
+ after.add(attr)
150
+ ctx_attrs = {attr: getattr(ctx, attr) for attr in after - before}
151
+ ctx_attrs["ctx_requires_grads"] = []
152
+
153
+ if not hasattr(ctx, "to_save"):
154
+ to_save = []
155
+ else:
156
+ # pyre-ignore
157
+ for idx, t in enumerate(ctx.to_save):
158
+ # generally, workers should not have requires_grad set. Set to correct state after
159
+ # but record requires_grad for next forward
160
+ if isinstance(t, torch.Tensor) and t.requires_grad and t.is_leaf:
161
+ t.requires_grad = False
162
+ ctx_attrs["ctx_requires_grads"].append(idx)
163
+ to_save = ctx.to_save
164
+ return out, ctx_attrs, to_save
165
+
166
+
167
+ def worker_autograd_function_backward(
168
+ module_name: str,
169
+ class_name: str,
170
+ ctx: autograd.function.FunctionCtx,
171
+ ctx_tensors: List[torch.Tensor],
172
+ process_groups: Dict[
173
+ str, Union[SingleControllerProcessGroupWrapper, DummyProcessGroup]
174
+ ],
175
+ *grad_outputs: torch.Tensor,
176
+ ):
177
+ # set correct requires_grad state pre backward
178
+ # pyre-ignore
179
+ with manage_grads(ctx_tensors, ctx.ctx_requires_grads) as ctx_grad_tensors:
180
+ # for i in ctx.ctx_requires_grads:
181
+ # ctx_tensors[i].requires_grad = True
182
+ if ctx_grad_tensors:
183
+ # pyre-ignore
184
+ ctx.saved_tensors = ctx_grad_tensors
185
+ for name, v in process_groups.items():
186
+ setattr(ctx, name, v)
187
+ # Call the original backward function
188
+ module = importlib.import_module(module_name)
189
+ class_ = getattr(module, class_name)
190
+ with torch.no_grad():
191
+ out = class_.backward(ctx, *grad_outputs)
192
+ return out
193
+
194
+
195
+ forward_remote_fn = remote(
196
+ "monarch.cached_remote_function.worker_autograd_function_forward"
197
+ )
198
+
199
+ backward_remote_fn = remote(
200
+ "monarch.cached_remote_function.worker_autograd_function_backward"
201
+ )
202
+
203
+
204
+ class RemoteAutogradFunction(autograd.Function):
205
+ """
206
+ New autograd.Function (custom forward/backward) that will run on the worker as a UDF RemoteFunction
207
+
208
+
209
+ Example::
210
+ my_remote_autograd_function = remote_autograd_function(my_custom_autograd_function)
211
+ """
212
+
213
+ @staticmethod
214
+ def forward(ctx, *args):
215
+ raise NotImplementedError()
216
+
217
+ @staticmethod
218
+ def backward(ctx, *grads):
219
+ raise NotImplementedError()
220
+
221
+
222
+ def remote_autograd_function(
223
+ target_class: Type[autograd.Function], name: Optional[str] = None
224
+ ) -> Type[RemoteAutogradFunction]:
225
+ """
226
+ Returns a new autograd.Function (custom forward/backward) that will run on the worker as a UDF RemoteFunction
227
+ Logic is done on the controller (e.g., Dtensors set up and saved for backward).
228
+ The autograd.function.FunctionCtx() is sent over the wire to the worker.
229
+ Special handling is done for ctx_tensors, requires_grad fo tensors and process groups.
230
+
231
+ Args:
232
+ target_class: autograd.Function class to be run remotely
233
+ name: name of the new autograd.Function to be called on the worker
234
+ """
235
+ if issubclass(target_class, RemoteAutogradFunction):
236
+ logging.warning(
237
+ f"{target_class} is already a autograd.Function UDF! You are likely monkey-patching too many times"
238
+ )
239
+ return target_class
240
+ assert issubclass(
241
+ target_class, autograd.Function
242
+ ), f"{target_class} is not a torch.autograd.Function!"
243
+ if name is None:
244
+ name = f"Remote_{target_class.__name__}"
245
+
246
+ return type(
247
+ name,
248
+ (RemoteAutogradFunction,),
249
+ {
250
+ "forward": staticmethod(
251
+ _controller_autograd_function_forward(target_class)(forward_remote_fn)
252
+ ),
253
+ "backward": staticmethod(
254
+ _controller_autograd_function_backward(target_class)(backward_remote_fn)
255
+ ),
256
+ },
257
+ )
monarch/common/_C.pyi ADDED
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ def patch_cuda() -> None: ...
10
+ def mock_cuda() -> None: ...
11
+ def unmock_cuda() -> None: ...
monarch/common/_C.so ADDED
Binary file
File without changes