torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,235 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import List
8
+
9
+ import torch
10
+ from monarch.common.function_caching import TensorGroup, TensorGroupPattern
11
+ from monarch.common.opaque_ref import OpaqueRef
12
+ from monarch.common.remote import remote
13
+ from monarch.common.tensor_factory import TensorFactory
14
+ from monarch.common.tree import flatten
15
+ from monarch.opaque_object import _fresh_opaque_ref, OpaqueObject
16
+ from torch.autograd.graph import get_gradient_edge
17
+
18
+
19
+ def _get_parameters_shape(module: OpaqueRef) -> TensorGroupPattern:
20
+ the_module: torch.nn.Module = module.value
21
+ group = TensorGroup(list(the_module.parameters()))
22
+ return group.pattern
23
+
24
+
25
+ def _get_parameters(module: OpaqueRef) -> List[torch.Tensor]:
26
+ # XXX - we do not want worker tensors refs to have requires grad on,
27
+ # because then any compute will create a backward graph
28
+ # which will never get used.
29
+ # This should be enforced at the worker level, but I think we are
30
+ # hijacking the requires_grad bit to communicate information in
31
+ # the autograd controller wrapper. We need to to use a different
32
+ # side-channel to do that.
33
+ return [p.detach() for p in module.value.parameters()]
34
+
35
+
36
+ def _remote_forward(require_grads: List[bool], module: OpaqueRef, args, kwargs):
37
+ # forward on the worker
38
+
39
+ # the parameter tensors inside the module will be require_grad_(True),
40
+ # but the input worker tensors, like all worker state, do not have
41
+ # autograd recording on. We have to turn it on inside just this function
42
+ # to do an autograd pass.
43
+
44
+ # parameters has to match what _get_parameters returns for this to work.
45
+ parameters = list(module.value.parameters())
46
+ all_inputs, unflatten_inputs = flatten(
47
+ (args, kwargs, parameters), lambda x: isinstance(x, torch.Tensor)
48
+ )
49
+ # set requires grad on inputs. We skip the parameters because they
50
+ # will already have requires grad set, and we can't detach them
51
+ # here otherwise grad won't flow to them.
52
+ for i in range(len(all_inputs) - len(parameters)):
53
+ if require_grads[i]:
54
+ all_inputs[i] = all_inputs[i].detach().requires_grad_(True)
55
+
56
+ # we have to create this just in case the module doesn't actually
57
+ # _use_ the parameter, in which case we have to create the zero.
58
+ # we can't really tell apriori if it will be used or not.
59
+ input_factories = [TensorFactory.from_tensor(t) for t in all_inputs]
60
+
61
+ # we have to be careful to save just the autograph graph edges and not
62
+ # the input/output tensors. Otherwise we might keep them longer then they
63
+ # are truly needed.
64
+ all_inputs_require_grad_edges = [
65
+ get_gradient_edge(input) for input, rg in zip(all_inputs, require_grads) if rg
66
+ ]
67
+
68
+ args, kwargs, _ = unflatten_inputs(all_inputs)
69
+
70
+ # the real module gets called here.
71
+ result = module.value(*args, **kwargs)
72
+
73
+ all_outputs_requires_grad, unflatten_outputs = flatten(
74
+ result, lambda x: isinstance(x, torch.Tensor) and x.requires_grad
75
+ )
76
+
77
+ all_output_edges = [
78
+ get_gradient_edge(output) for output in all_outputs_requires_grad
79
+ ]
80
+
81
+ # this backward closure keeps the state around to invoke backward
82
+ # and is held as the OpaqueRef we return to the controller.
83
+ def backward(all_grad_outputs: List[torch.Tensor]):
84
+ # careful, do not capture any input/output tensors.
85
+ # they might not be required for gradient, and will waste memory.
86
+ with torch.no_grad():
87
+ grad_inputs = torch.autograd.grad(
88
+ inputs=all_inputs_require_grad_edges,
89
+ outputs=all_output_edges,
90
+ grad_outputs=all_grad_outputs,
91
+ allow_unused=True,
92
+ )
93
+ grad_inputs_iter = iter(grad_inputs)
94
+ all_grad_inputs = [
95
+ next(grad_inputs_iter) if rg else None for rg in require_grads
96
+ ]
97
+ for i, rg in enumerate(require_grads):
98
+ # if the grad turned out unused we have to make a zero tensor here
99
+ # because the controller is expecting tensors not None.
100
+ if rg and all_grad_inputs[i] is None:
101
+ all_grad_inputs[i] = input_factories[i].zeros()
102
+ return all_grad_inputs
103
+
104
+ # detach outputs, because worker tensors do not keep gradient state
105
+ # the only gradient state on the worker is localized to the backward closure.
106
+ result = unflatten_outputs(t.detach() for t in all_outputs_requires_grad)
107
+ return OpaqueRef(backward), result
108
+
109
+
110
+ def _remote_backward(backward_closure: OpaqueRef, all_grad_outputs: List[torch.Tensor]):
111
+ # this is just a small trampoline that calls the closure that forward defined.
112
+ return backward_closure.value(all_grad_outputs)
113
+
114
+
115
+ class OpaqueModule:
116
+ """
117
+ Provides an _unsafe_ wrapper around a stateful module object that lives on a remote mesh.
118
+
119
+ linear = OpaqueModule("torch.nn.Linear", 3, 3, device="cuda")
120
+ output = linear(input, propagate=lambda self, x: x.clone())
121
+ r = output.sum()
122
+ with torch.no_grad():
123
+ r.backward()
124
+
125
+ It supports:
126
+
127
+ * Accessing parameters of the module on the controller via m.parameters(), which will
128
+ use remote functions to figure out the shape of parameters and get a reference to them.
129
+ * invoking the forward of module by providing inputs and a manual shape propagation function.
130
+ m(input, propagate=lambda self, x: x.clone())
131
+ Trying to do a cached function in this situation is very tricky because of the boundaries
132
+ between autograd/noautograd so it is not implemented yet.
133
+ * calcuating gradients through the module invocation as if this module was a normal controller module.
134
+
135
+ In the future we should consider whether we want this to actually be a subclass of torch.nn.Module,
136
+ such that it could have hooks, and other features. If we do this, we need to implement most of
137
+ the existing torch.nn.Module API so that it behaves in the expected way.
138
+
139
+ """
140
+
141
+ def __init__(self, *args, **kwargs):
142
+ self._object = OpaqueObject(*args, **kwargs)
143
+ self._parameters: List[torch.Tensor] = None
144
+
145
+ def parameters(self):
146
+ if self._parameters is None:
147
+ tensor_group_pattern = (
148
+ remote(_get_parameters_shape)
149
+ .call_on_shard_and_fetch(self._object)
150
+ .result()
151
+ )
152
+ self._parameters = [
153
+ p.requires_grad_(True)
154
+ for p in remote(
155
+ _get_parameters,
156
+ propagate=lambda self: tensor_group_pattern.empty([]),
157
+ )(self._object)
158
+ ]
159
+
160
+ return self._parameters
161
+
162
+ def call_method(self, *args, **kwargs):
163
+ return self._object.call_method(*args, **kwargs)
164
+
165
+ def __call__(self, *args, propagator, **kwargs):
166
+ parameters = self.parameters()
167
+ # torch.autograd.Function only supports flat lists of input/output tensors
168
+ # so we have to do a bunch of flattenting unflattening to call it
169
+ all_inputs, unflatten_inputs = flatten(
170
+ (args, kwargs, parameters), lambda x: isinstance(x, torch.Tensor)
171
+ )
172
+
173
+ # the worker will need to understand which gradients to calculate,
174
+ # which we pass in as a flag array here.
175
+ requires_grad = [t.requires_grad for t in all_inputs]
176
+ if not sum(requires_grad):
177
+ # early exit if we do not have gradients (including toward the parameters)
178
+ return self._object.call_method("__call__", propagator, *args, **kwargs)
179
+
180
+ # these will be used to describe the shape of gradients to the inputs,
181
+ # so we cannot use TensorGroup to recover alias information. Having
182
+ # gradient tensors that alias each other coming out of one of this functions
183
+ # will break things.
184
+ input_factories = [TensorFactory.from_tensor(i) for i in all_inputs]
185
+
186
+ unflatten_outputs = None
187
+ backward_ctx = None
188
+
189
+ # we use this autograd function to define how to hook up the gradient
190
+ # calculated on the worker to the gradient graph _on the client_.
191
+
192
+ # This code runs entirely on the client.
193
+ class F(torch.autograd.Function):
194
+ @staticmethod
195
+ def forward(ctx, *all_inputs):
196
+ nonlocal backward_ctx, unflatten_outputs
197
+ args, kwargs, parameters = unflatten_inputs(all_inputs)
198
+ # this remote call invokes the forward pass on the worker.
199
+ # notice it returns the (non-gradient recording result) of the
200
+ # forward pass, and a backward_ctx opaque ref that we will
201
+ # call in the backward pass to flow controller gradients
202
+ # through the worker saved autograd state. Holding
203
+ # backward_ctx alive on the worker is what keeps
204
+ # the worker autograd state alive. We should check there is
205
+ # no funny business with class lifetimes.
206
+ backward_ctx, result = remote(
207
+ _remote_forward,
208
+ propagate=lambda requires_grad, obj, args, kwargs: (
209
+ _fresh_opaque_ref(),
210
+ propagator(obj, *args, **kwargs),
211
+ ),
212
+ )(requires_grad, self._object, args, kwargs)
213
+
214
+ flat_outputs, unflatten_outputs = flatten(
215
+ result, lambda x: isinstance(x, torch.Tensor)
216
+ )
217
+ return (*flat_outputs,)
218
+
219
+ @staticmethod
220
+ def backward(ctx, *all_grad_outputs):
221
+ # this instructs the worker to propgate output grads back to our input
222
+ # grads, all_grad_inputs has to match all_inputs of forward.
223
+ all_grad_inputs = remote(
224
+ _remote_backward,
225
+ propagate=lambda _ctx, _all_grad_outputs: tuple(
226
+ f.empty() if rg else None
227
+ for f, rg in zip(input_factories, requires_grad)
228
+ ),
229
+ )(backward_ctx, all_grad_outputs)
230
+ return all_grad_inputs
231
+
232
+ # apply unwraps the gradient tensors and inserts our custom block.
233
+ flat_outputs = F.apply(*all_inputs)
234
+ result = unflatten_outputs(flat_outputs)
235
+ return result
@@ -0,0 +1,88 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import functools
8
+
9
+ import torch
10
+ from monarch.common.function import (
11
+ ConvertsToResolvable,
12
+ resolvable_function,
13
+ ResolvableFunction,
14
+ )
15
+
16
+ from monarch.common.opaque_ref import OpaqueRef
17
+ from monarch.common.remote import remote
18
+
19
+
20
+ def _invoke_method(obj: OpaqueRef, method_name: str, *args, **kwargs):
21
+ return getattr(obj.value, method_name)(*args, **kwargs)
22
+
23
+
24
+ def _fresh_opaque_ref():
25
+ return OpaqueRef(torch.zeros(0, dtype=torch.int64))
26
+
27
+
28
+ @remote(propagate=lambda *args, **kwargs: _fresh_opaque_ref())
29
+ def _construct_object(
30
+ constructor_resolver: ResolvableFunction, *args, **kwargs
31
+ ) -> OpaqueRef:
32
+ constructor = constructor_resolver.resolve()
33
+ return OpaqueRef(constructor(*args, **kwargs))
34
+
35
+
36
+ def opaque_method(fn):
37
+ method_name = fn.__name__
38
+
39
+ @functools.wraps(fn)
40
+ def impl(self, *args, **kwargs):
41
+ return self.call_method(method_name, fn, *args, **kwargs)
42
+
43
+ return impl
44
+
45
+
46
+ class OpaqueObject(OpaqueRef):
47
+ """
48
+ Provides syntax sugar for working with OpaqueObjRef objects on the controller.
49
+
50
+ class MyWrapperObject(OpaqueObject):
51
+
52
+ # Declare that the object has a_remote_add method.
53
+ # The definition provides the shape propagation rule.
54
+ @opaque_method
55
+ def a_remote_add(self, t: torch.Tensor):
56
+ return t + t
57
+
58
+ # on the controller you can now create the wrapper
59
+ obj: MyWrapperObject = MyWrapperObject.construct("path.to.worker.constructor", torch.rand(3, 4))
60
+
61
+ # and call its methods
62
+ t: monarch.Tensor = obj.a_remote_add(torch.rand(3, 4))
63
+
64
+ This interface can be used to build (unsafe) wrappers around stateful things such torch.nn.Modules
65
+ in order to make porting them to monarch-first structures easier.
66
+ """
67
+
68
+ def __init__(self, constructor: ConvertsToResolvable | OpaqueRef, *args, **kwargs):
69
+ if isinstance(constructor, OpaqueRef):
70
+ self._key = constructor._key
71
+ else:
72
+ self._key = _construct_object(
73
+ resolvable_function(constructor), *args, **kwargs
74
+ )._key
75
+
76
+ def call_method(self, method_name, propagation, *args, **kwargs):
77
+ endpoint = remote(
78
+ _invoke_method,
79
+ propagate=lambda self, method_name, *args, **kwargs: propagation(
80
+ self, *args, **kwargs
81
+ ),
82
+ )
83
+ return endpoint(self, method_name, *args, **kwargs)
84
+
85
+ def call_method_on_shard_and_fetch(self, method_name, *args, **kwargs):
86
+ return remote(_invoke_method).call_on_shard_and_fetch(
87
+ self, method_name, *args, **kwargs
88
+ )
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from monarch.parallel.pipelining.runtime import get_parameter_udf, PipelineParallelism
8
+
9
+ __all__ = ["PipelineParallelism", "get_parameter_udf"]
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict