torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from functools import wraps
|
9
|
+
|
10
|
+
|
11
|
+
class _ContextManager:
|
12
|
+
def __init__(self, generator):
|
13
|
+
self.generator = generator
|
14
|
+
self.generator.send(None)
|
15
|
+
|
16
|
+
def __enter__(self):
|
17
|
+
return
|
18
|
+
|
19
|
+
def __exit__(self, *args):
|
20
|
+
try:
|
21
|
+
self.generator.send(None)
|
22
|
+
except StopIteration:
|
23
|
+
pass
|
24
|
+
else:
|
25
|
+
raise RuntimeError("context manager generator did not exit")
|
26
|
+
|
27
|
+
|
28
|
+
def activate_first_context_manager(func):
|
29
|
+
"""
|
30
|
+
Similar to contextlib.contextmanager but it
|
31
|
+
starts the context when the function is called rather than
|
32
|
+
than at the start of the with statement. Useful for things where
|
33
|
+
you want to optionally activate the context without a guard.
|
34
|
+
"""
|
35
|
+
|
36
|
+
@wraps(func)
|
37
|
+
def helper(*args, **kwargs):
|
38
|
+
return _ContextManager(func(*args, **kwargs))
|
39
|
+
|
40
|
+
return helper
|
@@ -0,0 +1,104 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from typing import Any, List, NamedTuple, Optional, Protocol, Sequence, Union
|
9
|
+
|
10
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
11
|
+
DebuggerMessage,
|
12
|
+
LogLevel,
|
13
|
+
WorldState,
|
14
|
+
)
|
15
|
+
|
16
|
+
from monarch.common.invocation import DeviceException, RemoteException, Seq
|
17
|
+
from monarch.common.reference import Ref
|
18
|
+
from monarch.common.shape import NDSlice
|
19
|
+
from monarch.common.tensor import Tensor
|
20
|
+
|
21
|
+
|
22
|
+
class LogMessage(NamedTuple):
|
23
|
+
level: LogLevel
|
24
|
+
message: str
|
25
|
+
|
26
|
+
|
27
|
+
class MessageResult(NamedTuple):
|
28
|
+
"""
|
29
|
+
Message result given a seq id of an invocation.
|
30
|
+
"""
|
31
|
+
|
32
|
+
seq: Seq
|
33
|
+
result: Any
|
34
|
+
error: Optional[RemoteException | DeviceException] = None
|
35
|
+
|
36
|
+
|
37
|
+
class TController(Protocol):
|
38
|
+
"""
|
39
|
+
Controller APIs
|
40
|
+
"""
|
41
|
+
|
42
|
+
# =======================================================
|
43
|
+
# === APIs for the client to call into the controller ===
|
44
|
+
# =======================================================
|
45
|
+
|
46
|
+
def send(
|
47
|
+
self,
|
48
|
+
ranks: Union[NDSlice, List[NDSlice]],
|
49
|
+
msg: NamedTuple,
|
50
|
+
) -> None:
|
51
|
+
"""
|
52
|
+
Send a message to a set of ranks.
|
53
|
+
"""
|
54
|
+
...
|
55
|
+
|
56
|
+
def drop_refs(self, refs: Sequence[Ref]) -> None:
|
57
|
+
"""
|
58
|
+
Mark references as never being used again
|
59
|
+
"""
|
60
|
+
...
|
61
|
+
|
62
|
+
# TODO: there are a few things to do to clean up the API:
|
63
|
+
# 2. no need to depend on Tensors, a Referenceable; a Ref is enough.
|
64
|
+
# 3. support mutates as another input parameter.
|
65
|
+
def node(
|
66
|
+
self, seq: Seq, defs: Sequence["Tensor"], uses: Sequence["Tensor"]
|
67
|
+
) -> None:
|
68
|
+
"""
|
69
|
+
Create an invocation node given a sequence id. The node provides what tensors it defines,
|
70
|
+
what tensors it uses, and what tensors it mutates.
|
71
|
+
"""
|
72
|
+
...
|
73
|
+
|
74
|
+
# ==============================================================
|
75
|
+
# == APIs for the client to read response from the controller ==
|
76
|
+
# ==============================================================
|
77
|
+
|
78
|
+
# TODO: remove timeout parameter; instead, return a future that can wait on a timeout
|
79
|
+
def next_message(
|
80
|
+
self, timeout: Optional[float]
|
81
|
+
) -> Optional[MessageResult | LogMessage]:
|
82
|
+
"""
|
83
|
+
Read a message given a timeout in seconds. Returns a message output given the seq of an invocation.
|
84
|
+
The output could be the returned value or an exception.
|
85
|
+
If the returned message is None, it means there is no message to read within the given timeout.
|
86
|
+
If timeout is None, it means no timeout (infinite).
|
87
|
+
"""
|
88
|
+
...
|
89
|
+
|
90
|
+
def stop_mesh(self) -> None:
|
91
|
+
"""Stop the system."""
|
92
|
+
...
|
93
|
+
|
94
|
+
def drain_and_stop(self) -> List[MessageResult | LogMessage | DebuggerMessage]:
|
95
|
+
"""Drain all the messages in the controller upon shutdown."""
|
96
|
+
...
|
97
|
+
|
98
|
+
def worker_world_state(self) -> WorldState:
|
99
|
+
"""
|
100
|
+
Retrieve the worker world state.
|
101
|
+
|
102
|
+
:return: The worker WorldState.
|
103
|
+
"""
|
104
|
+
...
|
@@ -0,0 +1,417 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import logging
|
10
|
+
|
11
|
+
import warnings
|
12
|
+
from contextlib import AbstractContextManager, contextmanager
|
13
|
+
from dataclasses import dataclass
|
14
|
+
from enum import Enum
|
15
|
+
from logging import Logger
|
16
|
+
from typing import (
|
17
|
+
Any,
|
18
|
+
Callable,
|
19
|
+
Dict,
|
20
|
+
List,
|
21
|
+
NamedTuple,
|
22
|
+
Optional,
|
23
|
+
Sequence,
|
24
|
+
Tuple,
|
25
|
+
TYPE_CHECKING,
|
26
|
+
Union,
|
27
|
+
)
|
28
|
+
|
29
|
+
import monarch.common.messages as messages
|
30
|
+
import torch
|
31
|
+
from monarch.common.shape import MeshTrait
|
32
|
+
|
33
|
+
from torch.utils._python_dispatch import TorchDispatchMode
|
34
|
+
from torch.utils._pytree import tree_map
|
35
|
+
|
36
|
+
from ._tensor_to_table import tensor_to_table
|
37
|
+
from .context_manager import activate_first_context_manager
|
38
|
+
from .messages import Dims
|
39
|
+
from .reference import Referenceable
|
40
|
+
from .shape import NDSlice, Shape
|
41
|
+
from .stream import Stream
|
42
|
+
from .tensor import MeshSliceTensor, Tensor
|
43
|
+
|
44
|
+
if TYPE_CHECKING:
|
45
|
+
from monarch.common.client import Client
|
46
|
+
|
47
|
+
logger: Logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
class RemoteProcessGroup(Referenceable):
|
51
|
+
"""
|
52
|
+
Client's view of a process group.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self, dims, device_mesh):
|
56
|
+
logger.info(f"creating process group for {dims}")
|
57
|
+
self.dims = dims
|
58
|
+
self.device_mesh = device_mesh
|
59
|
+
self.ref = self.device_mesh.client.new_ref()
|
60
|
+
self._create_remotely()
|
61
|
+
# A set of streams for which we've sent the split-comm message.
|
62
|
+
self._split_comm_done = set()
|
63
|
+
|
64
|
+
def _create_remotely(self):
|
65
|
+
msg = messages.CreateRemoteProcessGroup(self, self.device_mesh, self.dims)
|
66
|
+
self.device_mesh._send(msg)
|
67
|
+
|
68
|
+
def ensure_split_comm_remotely(self, stream):
|
69
|
+
"""
|
70
|
+
If we haven't already, send a message to the worker to split off a
|
71
|
+
communicator for this PG on the given stream.
|
72
|
+
"""
|
73
|
+
|
74
|
+
# Currently, the worker will error if we try to do the split-comm more
|
75
|
+
# than once, so check for that here to allow this function to be called
|
76
|
+
# lazily.
|
77
|
+
if stream in self._split_comm_done:
|
78
|
+
return
|
79
|
+
self._split_comm_done.add(stream)
|
80
|
+
|
81
|
+
msg = messages.SplitCommForProcessGroup(
|
82
|
+
remote_process_group=self,
|
83
|
+
stream=stream,
|
84
|
+
)
|
85
|
+
self.device_mesh.client.send_nocoalesce(
|
86
|
+
self.device_mesh.client.all_ranks,
|
87
|
+
msg,
|
88
|
+
)
|
89
|
+
|
90
|
+
def delete_ref(self, ref: int):
|
91
|
+
if not self.device_mesh.client.has_shutdown:
|
92
|
+
self.device_mesh.client.handle_deletes(self.device_mesh.processes, [ref])
|
93
|
+
|
94
|
+
def drop(self):
|
95
|
+
if self.ref is None:
|
96
|
+
return
|
97
|
+
self._drop_ref()
|
98
|
+
|
99
|
+
def size(self):
|
100
|
+
return self.device_mesh.size(self.dims)
|
101
|
+
|
102
|
+
def _drop_ref(self):
|
103
|
+
if self.ref is None:
|
104
|
+
return
|
105
|
+
self.delete_ref(self.ref)
|
106
|
+
self.ref = None
|
107
|
+
|
108
|
+
@property
|
109
|
+
def dropped(self):
|
110
|
+
return self.ref is None
|
111
|
+
|
112
|
+
|
113
|
+
class ActivateGuard:
|
114
|
+
def __init__(self, iter):
|
115
|
+
self.iter = iter
|
116
|
+
next(iter)
|
117
|
+
|
118
|
+
def __enter__(self):
|
119
|
+
return
|
120
|
+
|
121
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
122
|
+
try:
|
123
|
+
next(self.iter)
|
124
|
+
except StopIteration:
|
125
|
+
pass
|
126
|
+
|
127
|
+
|
128
|
+
class DeviceMeshStatus(Enum):
|
129
|
+
"""
|
130
|
+
Enum representing the status of a device mesh.
|
131
|
+
Attributes:
|
132
|
+
LIVE (str): The mesh has enough processes than the world size specified and all of them are healthy.
|
133
|
+
UNHEALTHY (str): Either the mesh does not have enough processes or some of the processes are unhealthy.
|
134
|
+
AWAITING_CREATION (str): The mesh is still being created by the scheduler.
|
135
|
+
"""
|
136
|
+
|
137
|
+
LIVE = "Live"
|
138
|
+
UNHEALTHY = "Unhealthy"
|
139
|
+
AWAITING_CREATION = "Awaiting Creation"
|
140
|
+
|
141
|
+
|
142
|
+
@dataclass
|
143
|
+
class DeviceMeshInfo:
|
144
|
+
"""
|
145
|
+
Data class representing information about a device mesh.
|
146
|
+
|
147
|
+
Attributes:
|
148
|
+
mesh_labels (Dict[str, str]): Maps mesh labels to values.
|
149
|
+
devices_labels (List[Dict[str, str]]): MAps device labels to values.
|
150
|
+
"""
|
151
|
+
|
152
|
+
mesh_labels: Dict[str, str]
|
153
|
+
devices_labels: List[Dict[str, str]]
|
154
|
+
|
155
|
+
|
156
|
+
class DeviceMesh(Referenceable, MeshTrait):
|
157
|
+
def __init__(
|
158
|
+
self,
|
159
|
+
client: "Client",
|
160
|
+
processes: "NDSlice",
|
161
|
+
names: Dims,
|
162
|
+
mesh_name: str = "default",
|
163
|
+
):
|
164
|
+
assert isinstance(processes, NDSlice)
|
165
|
+
self.client = client
|
166
|
+
assert processes.ndim == len(names)
|
167
|
+
self.names = names
|
168
|
+
self.mesh_name = mesh_name
|
169
|
+
# processes are a list of processes that participate in this device mesh, encoded as an NDSlice
|
170
|
+
self.processes = processes
|
171
|
+
self.exit = lambda: None
|
172
|
+
self.ref = None
|
173
|
+
self._active_mesh_context = None
|
174
|
+
|
175
|
+
def define_remotely(self):
|
176
|
+
if self.ref is None:
|
177
|
+
self.ref = self.client.new_ref()
|
178
|
+
msg = messages.CreateDeviceMesh(self, self.names, self.processes)
|
179
|
+
self.client.send(self.processes, msg)
|
180
|
+
|
181
|
+
def process_group(self, dims: str | Dims) -> RemoteProcessGroup:
|
182
|
+
self.define_remotely()
|
183
|
+
if isinstance(dims, str):
|
184
|
+
dims = (dims,)
|
185
|
+
return RemoteProcessGroup(dims, self)
|
186
|
+
|
187
|
+
def to_tensor(self):
|
188
|
+
with no_mesh.activate():
|
189
|
+
vals = torch.tensor(list(self.processes), device="cpu", dtype=torch.int)
|
190
|
+
return vals.view(self.processes.sizes)
|
191
|
+
|
192
|
+
def to_table(self):
|
193
|
+
with no_mesh.activate():
|
194
|
+
tensor = self.to_tensor()
|
195
|
+
names = list(self.names)
|
196
|
+
labels = [list(str(i) for i in range(i)) for i in tensor.shape]
|
197
|
+
gpus_per_host = self.client.gpu_per_host
|
198
|
+
|
199
|
+
def format_data(x):
|
200
|
+
return f"{x//gpus_per_host}.gpu[{x%gpus_per_host}]"
|
201
|
+
|
202
|
+
return tensor_to_table(
|
203
|
+
tensor, format_data=format_data, axis_names=names, axis_labels=labels
|
204
|
+
)
|
205
|
+
|
206
|
+
def __repr__(self):
|
207
|
+
return f"<DeviceMesh(names({self.names}), processes({list(self.processes)})) at {hex(id(self))}>"
|
208
|
+
|
209
|
+
def delete_ref(self, ref: int):
|
210
|
+
if not self.client.has_shutdown:
|
211
|
+
self.client.handle_deletes(self.processes, [ref])
|
212
|
+
|
213
|
+
def _send(self, cmd: NamedTuple):
|
214
|
+
self.client.flush_deletes()
|
215
|
+
self.client.send(self.processes, cmd)
|
216
|
+
|
217
|
+
def stack(self, **kwargs):
|
218
|
+
raise NotImplementedError()
|
219
|
+
|
220
|
+
@property
|
221
|
+
def _ndslice(self) -> NDSlice:
|
222
|
+
return self.processes
|
223
|
+
|
224
|
+
@property
|
225
|
+
def _labels(self) -> Tuple[str, ...]:
|
226
|
+
return self.names
|
227
|
+
|
228
|
+
def _new_with_shape(self, shape: Shape) -> "DeviceMesh":
|
229
|
+
mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels))
|
230
|
+
mesh.exit = self.exit
|
231
|
+
return mesh
|
232
|
+
|
233
|
+
def __call__(self, **kwargs) -> "DeviceMesh":
|
234
|
+
"""
|
235
|
+
device_mesh(batch=3) or device_mesh(batch=slice(3, None))
|
236
|
+
"""
|
237
|
+
warnings.warn(
|
238
|
+
"The use of this method is deprecated. Please use mesh.slice instead.",
|
239
|
+
DeprecationWarning,
|
240
|
+
stacklevel=2,
|
241
|
+
)
|
242
|
+
return self.slice(**kwargs)
|
243
|
+
|
244
|
+
def rotate(self, **kwargs: Dict[str, int]):
|
245
|
+
raise NotImplementedError()
|
246
|
+
|
247
|
+
def rank(self, dims: Union[str, Sequence[str]]) -> torch.Tensor:
|
248
|
+
self.define_remotely()
|
249
|
+
if isinstance(dims, str):
|
250
|
+
if dims not in self.names:
|
251
|
+
raise KeyError(f"{self} does not have dimension {repr(dims)}")
|
252
|
+
return _remote(
|
253
|
+
_rank,
|
254
|
+
propagate=lambda _self, _dims: torch.full((), 0, dtype=torch.long),
|
255
|
+
)(self, dims)
|
256
|
+
|
257
|
+
combined_rank: Any = 0
|
258
|
+
for dim in dims:
|
259
|
+
combined_rank *= self.size(dim)
|
260
|
+
combined_rank += self.rank(dim)
|
261
|
+
return combined_rank
|
262
|
+
|
263
|
+
@property
|
264
|
+
def ranks(self) -> dict[str, torch.Tensor]:
|
265
|
+
return {dim: self.rank(dim) for dim in self.names}
|
266
|
+
|
267
|
+
def process_idx(self):
|
268
|
+
self.define_remotely()
|
269
|
+
return _remote(
|
270
|
+
"monarch.worker.worker._process_idx",
|
271
|
+
propagate=lambda _self: torch.full((), 0, dtype=torch.long),
|
272
|
+
)(self)
|
273
|
+
|
274
|
+
def _process(self, coordinates: Optional[Dict[str, int]]) -> NDSlice:
|
275
|
+
if coordinates is None:
|
276
|
+
return NDSlice(offset=self.processes.offset, sizes=[1], strides=[1])
|
277
|
+
if len(coordinates) > len(self.names):
|
278
|
+
extra = set(coordinates.keys()) - set(self.names)
|
279
|
+
raise KeyError(f"{list(extra)}")
|
280
|
+
for name in self.names:
|
281
|
+
if name not in coordinates:
|
282
|
+
raise ValueError(
|
283
|
+
f"Missing key '{name}' in shard map. Need all of {self.names}"
|
284
|
+
)
|
285
|
+
flat = [coordinates[name] for name in self.names]
|
286
|
+
return NDSlice(offset=self.processes.nditem(flat), sizes=[1], strides=[1])
|
287
|
+
|
288
|
+
def activate(self) -> AbstractContextManager:
|
289
|
+
self._active_mesh_context = _active_mesh(self)
|
290
|
+
return self._active_mesh_context
|
291
|
+
|
292
|
+
def deactivate(self):
|
293
|
+
if self._active_mesh_context is not None:
|
294
|
+
self._active_mesh_context.__exit__(None, None, None)
|
295
|
+
self._active_mesh_context = None
|
296
|
+
|
297
|
+
def get_info(self) -> DeviceMeshInfo:
|
298
|
+
"""
|
299
|
+
Retrieves metadata about the device mesh and its constituent devices.
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
DeviceMeshInfo: Contains mesh-level labels and per-device labels.
|
303
|
+
"""
|
304
|
+
mesh_state = self.client.mesh_state()
|
305
|
+
|
306
|
+
return DeviceMeshInfo(
|
307
|
+
mesh_labels=mesh_state.labels,
|
308
|
+
devices_labels=[proc.labels for proc in mesh_state.procs.values()],
|
309
|
+
)
|
310
|
+
|
311
|
+
|
312
|
+
_active: Optional[DeviceMesh] = None
|
313
|
+
_dispatch_enabled = False
|
314
|
+
|
315
|
+
|
316
|
+
def get_active_mesh():
|
317
|
+
if _active is None:
|
318
|
+
raise ValueError("no device mesh is active")
|
319
|
+
return _active
|
320
|
+
|
321
|
+
|
322
|
+
class _ActiveMesh(TorchDispatchMode):
|
323
|
+
ignore = ["profiler._record_function_exit._RecordFunction"]
|
324
|
+
allowed_local_accessors = ["aten._local_scalar_dense.default"]
|
325
|
+
|
326
|
+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
327
|
+
if _active is None:
|
328
|
+
return func(*args, **kwargs)
|
329
|
+
fnstr = str(func)
|
330
|
+
if fnstr in self.ignore:
|
331
|
+
return func(*args, **kwargs)
|
332
|
+
if fnstr in self.allowed_local_accessors and not isinstance(args[0], Tensor):
|
333
|
+
return func(*args, **kwargs)
|
334
|
+
return _remote(func, propagate=func)(*args, **kwargs)
|
335
|
+
|
336
|
+
|
337
|
+
def _rank(mesh, dim):
|
338
|
+
return torch.full((), mesh.dims[dim].rank, dtype=torch.long)
|
339
|
+
|
340
|
+
|
341
|
+
@contextmanager
|
342
|
+
def _dispatch():
|
343
|
+
global _dispatch_enabled
|
344
|
+
if _dispatch_enabled:
|
345
|
+
yield
|
346
|
+
else:
|
347
|
+
_dispatch_enabled = True
|
348
|
+
try:
|
349
|
+
with _ActiveMesh():
|
350
|
+
yield
|
351
|
+
finally:
|
352
|
+
_dispatch_enabled = False
|
353
|
+
|
354
|
+
|
355
|
+
_on_change: List[Callable] = []
|
356
|
+
|
357
|
+
|
358
|
+
@activate_first_context_manager
|
359
|
+
def _active_mesh(mesh: Optional[DeviceMesh]):
|
360
|
+
global _active
|
361
|
+
for on_change in _on_change:
|
362
|
+
on_change(_active, mesh)
|
363
|
+
_active, old = mesh, _active
|
364
|
+
try:
|
365
|
+
with _dispatch():
|
366
|
+
yield
|
367
|
+
finally:
|
368
|
+
for on_change in _on_change:
|
369
|
+
on_change(_active, old)
|
370
|
+
_active = old
|
371
|
+
|
372
|
+
|
373
|
+
class _NoMesh:
|
374
|
+
def activate(self):
|
375
|
+
return _active_mesh(None)
|
376
|
+
|
377
|
+
|
378
|
+
no_mesh = _NoMesh()
|
379
|
+
|
380
|
+
|
381
|
+
def _remote(*args, **kwargs):
|
382
|
+
# device_mesh <-> tensor <-> remote are mututally recursive
|
383
|
+
# we break the dependency to allow for separate files by
|
384
|
+
# having device_mesh and tensor locally import the `remote`
|
385
|
+
# entrypoint
|
386
|
+
from monarch.common.remote import remote
|
387
|
+
|
388
|
+
return remote(*args, **kwargs)
|
389
|
+
|
390
|
+
|
391
|
+
def to_mesh(
|
392
|
+
tensors: Any,
|
393
|
+
mesh: "DeviceMesh",
|
394
|
+
stream: Optional[Stream] = None,
|
395
|
+
) -> Any:
|
396
|
+
"""
|
397
|
+
Move all tensors in tensors to the given mesh.
|
398
|
+
"""
|
399
|
+
|
400
|
+
def _to_mesh(tensor: Union["Tensor", "MeshSliceTensor"]) -> "Tensor":
|
401
|
+
return tensor.to_mesh(mesh, stream)
|
402
|
+
|
403
|
+
return tree_map(_to_mesh, tensors)
|
404
|
+
|
405
|
+
|
406
|
+
def slice_mesh(
|
407
|
+
tensors: Any,
|
408
|
+
**kwargs: Union[int, slice],
|
409
|
+
) -> Any:
|
410
|
+
"""
|
411
|
+
Performs the slice_mesh operation for each tensor in tensors.
|
412
|
+
"""
|
413
|
+
|
414
|
+
def _slice_mesh(tensor: "Tensor") -> "MeshSliceTensor":
|
415
|
+
return tensor.slice_mesh(**kwargs)
|
416
|
+
|
417
|
+
return tree_map(_slice_mesh, tensors)
|
monarch/common/fake.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from functools import cache
|
10
|
+
|
11
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
12
|
+
|
13
|
+
|
14
|
+
@cache
|
15
|
+
def _fake_mode_worker():
|
16
|
+
return ThreadPoolExecutor(max_workers=1)
|
17
|
+
|
18
|
+
|
19
|
+
@cache
|
20
|
+
def _fake_mode():
|
21
|
+
return FakeTensorMode()
|
22
|
+
|
23
|
+
|
24
|
+
def fake_call(fn, *args, **kwargs):
|
25
|
+
"""Execute on work on a ThreadPool worker
|
26
|
+
|
27
|
+
First call (ThreadPoolExecutor init) will take the GIL and may block for long time!
|
28
|
+
TODO: this will be replaced with something more performant
|
29
|
+
"""
|
30
|
+
global _fake_mode_worker, fake_mode
|
31
|
+
|
32
|
+
# # Calls FakeTensorMode while re-enabling version counter tracking
|
33
|
+
# # todo(chilli): I'm not totally sure why I need to disable python dispatch
|
34
|
+
# # key. Perhaps there's some unwrapping that should have happened further up.
|
35
|
+
# include_to_set = torch._C._dispatch_tls_local_include_set()
|
36
|
+
# exclude_to_set = (
|
37
|
+
# torch._C._dispatch_tls_local_exclude_set()
|
38
|
+
# | torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
|
39
|
+
# ) - torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView)
|
40
|
+
|
41
|
+
# def work():
|
42
|
+
# with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
|
43
|
+
# with fake_mode:
|
44
|
+
# return fn(*args, **kwargs)
|
45
|
+
|
46
|
+
# return work()
|
47
|
+
|
48
|
+
def work():
|
49
|
+
# fake mode must be initialized in the worker thread
|
50
|
+
# otherwise a monarch dispatch mode may be active, causing
|
51
|
+
# FakeTensorMode to initialize wrong.
|
52
|
+
with _fake_mode():
|
53
|
+
return fn(*args, **kwargs)
|
54
|
+
|
55
|
+
return _fake_mode_worker().submit(work).result()
|