torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,40 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from functools import wraps
9
+
10
+
11
+ class _ContextManager:
12
+ def __init__(self, generator):
13
+ self.generator = generator
14
+ self.generator.send(None)
15
+
16
+ def __enter__(self):
17
+ return
18
+
19
+ def __exit__(self, *args):
20
+ try:
21
+ self.generator.send(None)
22
+ except StopIteration:
23
+ pass
24
+ else:
25
+ raise RuntimeError("context manager generator did not exit")
26
+
27
+
28
+ def activate_first_context_manager(func):
29
+ """
30
+ Similar to contextlib.contextmanager but it
31
+ starts the context when the function is called rather than
32
+ than at the start of the with statement. Useful for things where
33
+ you want to optionally activate the context without a guard.
34
+ """
35
+
36
+ @wraps(func)
37
+ def helper(*args, **kwargs):
38
+ return _ContextManager(func(*args, **kwargs))
39
+
40
+ return helper
@@ -0,0 +1,104 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import Any, List, NamedTuple, Optional, Protocol, Sequence, Union
9
+
10
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
11
+ DebuggerMessage,
12
+ LogLevel,
13
+ WorldState,
14
+ )
15
+
16
+ from monarch.common.invocation import DeviceException, RemoteException, Seq
17
+ from monarch.common.reference import Ref
18
+ from monarch.common.shape import NDSlice
19
+ from monarch.common.tensor import Tensor
20
+
21
+
22
+ class LogMessage(NamedTuple):
23
+ level: LogLevel
24
+ message: str
25
+
26
+
27
+ class MessageResult(NamedTuple):
28
+ """
29
+ Message result given a seq id of an invocation.
30
+ """
31
+
32
+ seq: Seq
33
+ result: Any
34
+ error: Optional[RemoteException | DeviceException] = None
35
+
36
+
37
+ class TController(Protocol):
38
+ """
39
+ Controller APIs
40
+ """
41
+
42
+ # =======================================================
43
+ # === APIs for the client to call into the controller ===
44
+ # =======================================================
45
+
46
+ def send(
47
+ self,
48
+ ranks: Union[NDSlice, List[NDSlice]],
49
+ msg: NamedTuple,
50
+ ) -> None:
51
+ """
52
+ Send a message to a set of ranks.
53
+ """
54
+ ...
55
+
56
+ def drop_refs(self, refs: Sequence[Ref]) -> None:
57
+ """
58
+ Mark references as never being used again
59
+ """
60
+ ...
61
+
62
+ # TODO: there are a few things to do to clean up the API:
63
+ # 2. no need to depend on Tensors, a Referenceable; a Ref is enough.
64
+ # 3. support mutates as another input parameter.
65
+ def node(
66
+ self, seq: Seq, defs: Sequence["Tensor"], uses: Sequence["Tensor"]
67
+ ) -> None:
68
+ """
69
+ Create an invocation node given a sequence id. The node provides what tensors it defines,
70
+ what tensors it uses, and what tensors it mutates.
71
+ """
72
+ ...
73
+
74
+ # ==============================================================
75
+ # == APIs for the client to read response from the controller ==
76
+ # ==============================================================
77
+
78
+ # TODO: remove timeout parameter; instead, return a future that can wait on a timeout
79
+ def next_message(
80
+ self, timeout: Optional[float]
81
+ ) -> Optional[MessageResult | LogMessage]:
82
+ """
83
+ Read a message given a timeout in seconds. Returns a message output given the seq of an invocation.
84
+ The output could be the returned value or an exception.
85
+ If the returned message is None, it means there is no message to read within the given timeout.
86
+ If timeout is None, it means no timeout (infinite).
87
+ """
88
+ ...
89
+
90
+ def stop_mesh(self) -> None:
91
+ """Stop the system."""
92
+ ...
93
+
94
+ def drain_and_stop(self) -> List[MessageResult | LogMessage | DebuggerMessage]:
95
+ """Drain all the messages in the controller upon shutdown."""
96
+ ...
97
+
98
+ def worker_world_state(self) -> WorldState:
99
+ """
100
+ Retrieve the worker world state.
101
+
102
+ :return: The worker WorldState.
103
+ """
104
+ ...
@@ -0,0 +1,417 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import logging
10
+
11
+ import warnings
12
+ from contextlib import AbstractContextManager, contextmanager
13
+ from dataclasses import dataclass
14
+ from enum import Enum
15
+ from logging import Logger
16
+ from typing import (
17
+ Any,
18
+ Callable,
19
+ Dict,
20
+ List,
21
+ NamedTuple,
22
+ Optional,
23
+ Sequence,
24
+ Tuple,
25
+ TYPE_CHECKING,
26
+ Union,
27
+ )
28
+
29
+ import monarch.common.messages as messages
30
+ import torch
31
+ from monarch.common.shape import MeshTrait
32
+
33
+ from torch.utils._python_dispatch import TorchDispatchMode
34
+ from torch.utils._pytree import tree_map
35
+
36
+ from ._tensor_to_table import tensor_to_table
37
+ from .context_manager import activate_first_context_manager
38
+ from .messages import Dims
39
+ from .reference import Referenceable
40
+ from .shape import NDSlice, Shape
41
+ from .stream import Stream
42
+ from .tensor import MeshSliceTensor, Tensor
43
+
44
+ if TYPE_CHECKING:
45
+ from monarch.common.client import Client
46
+
47
+ logger: Logger = logging.getLogger(__name__)
48
+
49
+
50
+ class RemoteProcessGroup(Referenceable):
51
+ """
52
+ Client's view of a process group.
53
+ """
54
+
55
+ def __init__(self, dims, device_mesh):
56
+ logger.info(f"creating process group for {dims}")
57
+ self.dims = dims
58
+ self.device_mesh = device_mesh
59
+ self.ref = self.device_mesh.client.new_ref()
60
+ self._create_remotely()
61
+ # A set of streams for which we've sent the split-comm message.
62
+ self._split_comm_done = set()
63
+
64
+ def _create_remotely(self):
65
+ msg = messages.CreateRemoteProcessGroup(self, self.device_mesh, self.dims)
66
+ self.device_mesh._send(msg)
67
+
68
+ def ensure_split_comm_remotely(self, stream):
69
+ """
70
+ If we haven't already, send a message to the worker to split off a
71
+ communicator for this PG on the given stream.
72
+ """
73
+
74
+ # Currently, the worker will error if we try to do the split-comm more
75
+ # than once, so check for that here to allow this function to be called
76
+ # lazily.
77
+ if stream in self._split_comm_done:
78
+ return
79
+ self._split_comm_done.add(stream)
80
+
81
+ msg = messages.SplitCommForProcessGroup(
82
+ remote_process_group=self,
83
+ stream=stream,
84
+ )
85
+ self.device_mesh.client.send_nocoalesce(
86
+ self.device_mesh.client.all_ranks,
87
+ msg,
88
+ )
89
+
90
+ def delete_ref(self, ref: int):
91
+ if not self.device_mesh.client.has_shutdown:
92
+ self.device_mesh.client.handle_deletes(self.device_mesh.processes, [ref])
93
+
94
+ def drop(self):
95
+ if self.ref is None:
96
+ return
97
+ self._drop_ref()
98
+
99
+ def size(self):
100
+ return self.device_mesh.size(self.dims)
101
+
102
+ def _drop_ref(self):
103
+ if self.ref is None:
104
+ return
105
+ self.delete_ref(self.ref)
106
+ self.ref = None
107
+
108
+ @property
109
+ def dropped(self):
110
+ return self.ref is None
111
+
112
+
113
+ class ActivateGuard:
114
+ def __init__(self, iter):
115
+ self.iter = iter
116
+ next(iter)
117
+
118
+ def __enter__(self):
119
+ return
120
+
121
+ def __exit__(self, exc_type, exc_val, exc_tb):
122
+ try:
123
+ next(self.iter)
124
+ except StopIteration:
125
+ pass
126
+
127
+
128
+ class DeviceMeshStatus(Enum):
129
+ """
130
+ Enum representing the status of a device mesh.
131
+ Attributes:
132
+ LIVE (str): The mesh has enough processes than the world size specified and all of them are healthy.
133
+ UNHEALTHY (str): Either the mesh does not have enough processes or some of the processes are unhealthy.
134
+ AWAITING_CREATION (str): The mesh is still being created by the scheduler.
135
+ """
136
+
137
+ LIVE = "Live"
138
+ UNHEALTHY = "Unhealthy"
139
+ AWAITING_CREATION = "Awaiting Creation"
140
+
141
+
142
+ @dataclass
143
+ class DeviceMeshInfo:
144
+ """
145
+ Data class representing information about a device mesh.
146
+
147
+ Attributes:
148
+ mesh_labels (Dict[str, str]): Maps mesh labels to values.
149
+ devices_labels (List[Dict[str, str]]): MAps device labels to values.
150
+ """
151
+
152
+ mesh_labels: Dict[str, str]
153
+ devices_labels: List[Dict[str, str]]
154
+
155
+
156
+ class DeviceMesh(Referenceable, MeshTrait):
157
+ def __init__(
158
+ self,
159
+ client: "Client",
160
+ processes: "NDSlice",
161
+ names: Dims,
162
+ mesh_name: str = "default",
163
+ ):
164
+ assert isinstance(processes, NDSlice)
165
+ self.client = client
166
+ assert processes.ndim == len(names)
167
+ self.names = names
168
+ self.mesh_name = mesh_name
169
+ # processes are a list of processes that participate in this device mesh, encoded as an NDSlice
170
+ self.processes = processes
171
+ self.exit = lambda: None
172
+ self.ref = None
173
+ self._active_mesh_context = None
174
+
175
+ def define_remotely(self):
176
+ if self.ref is None:
177
+ self.ref = self.client.new_ref()
178
+ msg = messages.CreateDeviceMesh(self, self.names, self.processes)
179
+ self.client.send(self.processes, msg)
180
+
181
+ def process_group(self, dims: str | Dims) -> RemoteProcessGroup:
182
+ self.define_remotely()
183
+ if isinstance(dims, str):
184
+ dims = (dims,)
185
+ return RemoteProcessGroup(dims, self)
186
+
187
+ def to_tensor(self):
188
+ with no_mesh.activate():
189
+ vals = torch.tensor(list(self.processes), device="cpu", dtype=torch.int)
190
+ return vals.view(self.processes.sizes)
191
+
192
+ def to_table(self):
193
+ with no_mesh.activate():
194
+ tensor = self.to_tensor()
195
+ names = list(self.names)
196
+ labels = [list(str(i) for i in range(i)) for i in tensor.shape]
197
+ gpus_per_host = self.client.gpu_per_host
198
+
199
+ def format_data(x):
200
+ return f"{x//gpus_per_host}.gpu[{x%gpus_per_host}]"
201
+
202
+ return tensor_to_table(
203
+ tensor, format_data=format_data, axis_names=names, axis_labels=labels
204
+ )
205
+
206
+ def __repr__(self):
207
+ return f"<DeviceMesh(names({self.names}), processes({list(self.processes)})) at {hex(id(self))}>"
208
+
209
+ def delete_ref(self, ref: int):
210
+ if not self.client.has_shutdown:
211
+ self.client.handle_deletes(self.processes, [ref])
212
+
213
+ def _send(self, cmd: NamedTuple):
214
+ self.client.flush_deletes()
215
+ self.client.send(self.processes, cmd)
216
+
217
+ def stack(self, **kwargs):
218
+ raise NotImplementedError()
219
+
220
+ @property
221
+ def _ndslice(self) -> NDSlice:
222
+ return self.processes
223
+
224
+ @property
225
+ def _labels(self) -> Tuple[str, ...]:
226
+ return self.names
227
+
228
+ def _new_with_shape(self, shape: Shape) -> "DeviceMesh":
229
+ mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels))
230
+ mesh.exit = self.exit
231
+ return mesh
232
+
233
+ def __call__(self, **kwargs) -> "DeviceMesh":
234
+ """
235
+ device_mesh(batch=3) or device_mesh(batch=slice(3, None))
236
+ """
237
+ warnings.warn(
238
+ "The use of this method is deprecated. Please use mesh.slice instead.",
239
+ DeprecationWarning,
240
+ stacklevel=2,
241
+ )
242
+ return self.slice(**kwargs)
243
+
244
+ def rotate(self, **kwargs: Dict[str, int]):
245
+ raise NotImplementedError()
246
+
247
+ def rank(self, dims: Union[str, Sequence[str]]) -> torch.Tensor:
248
+ self.define_remotely()
249
+ if isinstance(dims, str):
250
+ if dims not in self.names:
251
+ raise KeyError(f"{self} does not have dimension {repr(dims)}")
252
+ return _remote(
253
+ _rank,
254
+ propagate=lambda _self, _dims: torch.full((), 0, dtype=torch.long),
255
+ )(self, dims)
256
+
257
+ combined_rank: Any = 0
258
+ for dim in dims:
259
+ combined_rank *= self.size(dim)
260
+ combined_rank += self.rank(dim)
261
+ return combined_rank
262
+
263
+ @property
264
+ def ranks(self) -> dict[str, torch.Tensor]:
265
+ return {dim: self.rank(dim) for dim in self.names}
266
+
267
+ def process_idx(self):
268
+ self.define_remotely()
269
+ return _remote(
270
+ "monarch.worker.worker._process_idx",
271
+ propagate=lambda _self: torch.full((), 0, dtype=torch.long),
272
+ )(self)
273
+
274
+ def _process(self, coordinates: Optional[Dict[str, int]]) -> NDSlice:
275
+ if coordinates is None:
276
+ return NDSlice(offset=self.processes.offset, sizes=[1], strides=[1])
277
+ if len(coordinates) > len(self.names):
278
+ extra = set(coordinates.keys()) - set(self.names)
279
+ raise KeyError(f"{list(extra)}")
280
+ for name in self.names:
281
+ if name not in coordinates:
282
+ raise ValueError(
283
+ f"Missing key '{name}' in shard map. Need all of {self.names}"
284
+ )
285
+ flat = [coordinates[name] for name in self.names]
286
+ return NDSlice(offset=self.processes.nditem(flat), sizes=[1], strides=[1])
287
+
288
+ def activate(self) -> AbstractContextManager:
289
+ self._active_mesh_context = _active_mesh(self)
290
+ return self._active_mesh_context
291
+
292
+ def deactivate(self):
293
+ if self._active_mesh_context is not None:
294
+ self._active_mesh_context.__exit__(None, None, None)
295
+ self._active_mesh_context = None
296
+
297
+ def get_info(self) -> DeviceMeshInfo:
298
+ """
299
+ Retrieves metadata about the device mesh and its constituent devices.
300
+
301
+ Returns:
302
+ DeviceMeshInfo: Contains mesh-level labels and per-device labels.
303
+ """
304
+ mesh_state = self.client.mesh_state()
305
+
306
+ return DeviceMeshInfo(
307
+ mesh_labels=mesh_state.labels,
308
+ devices_labels=[proc.labels for proc in mesh_state.procs.values()],
309
+ )
310
+
311
+
312
+ _active: Optional[DeviceMesh] = None
313
+ _dispatch_enabled = False
314
+
315
+
316
+ def get_active_mesh():
317
+ if _active is None:
318
+ raise ValueError("no device mesh is active")
319
+ return _active
320
+
321
+
322
+ class _ActiveMesh(TorchDispatchMode):
323
+ ignore = ["profiler._record_function_exit._RecordFunction"]
324
+ allowed_local_accessors = ["aten._local_scalar_dense.default"]
325
+
326
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
327
+ if _active is None:
328
+ return func(*args, **kwargs)
329
+ fnstr = str(func)
330
+ if fnstr in self.ignore:
331
+ return func(*args, **kwargs)
332
+ if fnstr in self.allowed_local_accessors and not isinstance(args[0], Tensor):
333
+ return func(*args, **kwargs)
334
+ return _remote(func, propagate=func)(*args, **kwargs)
335
+
336
+
337
+ def _rank(mesh, dim):
338
+ return torch.full((), mesh.dims[dim].rank, dtype=torch.long)
339
+
340
+
341
+ @contextmanager
342
+ def _dispatch():
343
+ global _dispatch_enabled
344
+ if _dispatch_enabled:
345
+ yield
346
+ else:
347
+ _dispatch_enabled = True
348
+ try:
349
+ with _ActiveMesh():
350
+ yield
351
+ finally:
352
+ _dispatch_enabled = False
353
+
354
+
355
+ _on_change: List[Callable] = []
356
+
357
+
358
+ @activate_first_context_manager
359
+ def _active_mesh(mesh: Optional[DeviceMesh]):
360
+ global _active
361
+ for on_change in _on_change:
362
+ on_change(_active, mesh)
363
+ _active, old = mesh, _active
364
+ try:
365
+ with _dispatch():
366
+ yield
367
+ finally:
368
+ for on_change in _on_change:
369
+ on_change(_active, old)
370
+ _active = old
371
+
372
+
373
+ class _NoMesh:
374
+ def activate(self):
375
+ return _active_mesh(None)
376
+
377
+
378
+ no_mesh = _NoMesh()
379
+
380
+
381
+ def _remote(*args, **kwargs):
382
+ # device_mesh <-> tensor <-> remote are mututally recursive
383
+ # we break the dependency to allow for separate files by
384
+ # having device_mesh and tensor locally import the `remote`
385
+ # entrypoint
386
+ from monarch.common.remote import remote
387
+
388
+ return remote(*args, **kwargs)
389
+
390
+
391
+ def to_mesh(
392
+ tensors: Any,
393
+ mesh: "DeviceMesh",
394
+ stream: Optional[Stream] = None,
395
+ ) -> Any:
396
+ """
397
+ Move all tensors in tensors to the given mesh.
398
+ """
399
+
400
+ def _to_mesh(tensor: Union["Tensor", "MeshSliceTensor"]) -> "Tensor":
401
+ return tensor.to_mesh(mesh, stream)
402
+
403
+ return tree_map(_to_mesh, tensors)
404
+
405
+
406
+ def slice_mesh(
407
+ tensors: Any,
408
+ **kwargs: Union[int, slice],
409
+ ) -> Any:
410
+ """
411
+ Performs the slice_mesh operation for each tensor in tensors.
412
+ """
413
+
414
+ def _slice_mesh(tensor: "Tensor") -> "MeshSliceTensor":
415
+ return tensor.slice_mesh(**kwargs)
416
+
417
+ return tree_map(_slice_mesh, tensors)
monarch/common/fake.py ADDED
@@ -0,0 +1,55 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from functools import cache
10
+
11
+ from torch._subclasses.fake_tensor import FakeTensorMode
12
+
13
+
14
+ @cache
15
+ def _fake_mode_worker():
16
+ return ThreadPoolExecutor(max_workers=1)
17
+
18
+
19
+ @cache
20
+ def _fake_mode():
21
+ return FakeTensorMode()
22
+
23
+
24
+ def fake_call(fn, *args, **kwargs):
25
+ """Execute on work on a ThreadPool worker
26
+
27
+ First call (ThreadPoolExecutor init) will take the GIL and may block for long time!
28
+ TODO: this will be replaced with something more performant
29
+ """
30
+ global _fake_mode_worker, fake_mode
31
+
32
+ # # Calls FakeTensorMode while re-enabling version counter tracking
33
+ # # todo(chilli): I'm not totally sure why I need to disable python dispatch
34
+ # # key. Perhaps there's some unwrapping that should have happened further up.
35
+ # include_to_set = torch._C._dispatch_tls_local_include_set()
36
+ # exclude_to_set = (
37
+ # torch._C._dispatch_tls_local_exclude_set()
38
+ # | torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
39
+ # ) - torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
40
+
41
+ # def work():
42
+ # with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
43
+ # with fake_mode:
44
+ # return fn(*args, **kwargs)
45
+
46
+ # return work()
47
+
48
+ def work():
49
+ # fake mode must be initialized in the worker thread
50
+ # otherwise a monarch dispatch mode may be active, causing
51
+ # FakeTensorMode to initialize wrong.
52
+ with _fake_mode():
53
+ return fn(*args, **kwargs)
54
+
55
+ return _fake_mode_worker().submit(work).result()