torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/common/remote.py
ADDED
@@ -0,0 +1,297 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import functools
|
10
|
+
import logging
|
11
|
+
import warnings
|
12
|
+
|
13
|
+
from logging import Logger
|
14
|
+
from typing import (
|
15
|
+
Any,
|
16
|
+
Callable,
|
17
|
+
Dict,
|
18
|
+
Generic,
|
19
|
+
Literal,
|
20
|
+
Optional,
|
21
|
+
overload,
|
22
|
+
Protocol,
|
23
|
+
Tuple,
|
24
|
+
TYPE_CHECKING,
|
25
|
+
TypeVar,
|
26
|
+
)
|
27
|
+
|
28
|
+
import monarch.common.messages as messages
|
29
|
+
|
30
|
+
import torch
|
31
|
+
|
32
|
+
from monarch.common import _coalescing, device_mesh, messages, stream
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from monarch.common.client import Client
|
36
|
+
|
37
|
+
from monarch.common.device_mesh import RemoteProcessGroup
|
38
|
+
from monarch.common.fake import fake_call
|
39
|
+
|
40
|
+
from monarch.common.function import (
|
41
|
+
Propagator,
|
42
|
+
resolvable_function,
|
43
|
+
ResolvableFunction,
|
44
|
+
ResolvableFunctionFromPath,
|
45
|
+
)
|
46
|
+
from monarch.common.function_caching import (
|
47
|
+
hashable_tensor_flatten,
|
48
|
+
tensor_placeholder,
|
49
|
+
TensorGroup,
|
50
|
+
TensorPlaceholder,
|
51
|
+
)
|
52
|
+
from monarch.common.future import Future
|
53
|
+
from monarch.common.messages import Dims
|
54
|
+
from monarch.common.tensor import dtensor_check, dtensor_dispatch
|
55
|
+
from monarch.common.tree import flatten, tree_map
|
56
|
+
from torch import autograd, distributed as dist
|
57
|
+
from typing_extensions import ParamSpec
|
58
|
+
|
59
|
+
logger: Logger = logging.getLogger(__name__)
|
60
|
+
|
61
|
+
P = ParamSpec("P")
|
62
|
+
R = TypeVar("R")
|
63
|
+
T = TypeVar("T")
|
64
|
+
|
65
|
+
Propagator = Callable | Literal["mocked", "cached", "inspect"] | None
|
66
|
+
|
67
|
+
|
68
|
+
class Remote(Generic[P, R]):
|
69
|
+
def __init__(self, impl: Any, propagator_arg: Propagator):
|
70
|
+
self._remote_impl = impl
|
71
|
+
self._propagator_arg = propagator_arg
|
72
|
+
self._cache: Optional[dict] = None
|
73
|
+
|
74
|
+
@property
|
75
|
+
def _resolvable(self):
|
76
|
+
return resolvable_function(self._remote_impl)
|
77
|
+
|
78
|
+
def _propagate(self, args, kwargs, fake_args, fake_kwargs):
|
79
|
+
if self._propagator_arg is None or self._propagator_arg == "cached":
|
80
|
+
if self._cache is None:
|
81
|
+
self._cache = {}
|
82
|
+
return _cached_propagation(self._cache, self._resolvable, args, kwargs)
|
83
|
+
elif self._propagator_arg == "inspect":
|
84
|
+
return None
|
85
|
+
elif self._propagator_arg == "mocked":
|
86
|
+
raise NotImplementedError("mocked propagation")
|
87
|
+
else:
|
88
|
+
return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
|
89
|
+
|
90
|
+
def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
|
91
|
+
if self._propagator_arg is None:
|
92
|
+
return # no propgator provided, so we just assume no mutations
|
93
|
+
return self._propagate(args, kwargs, fake_args, fake_kwargs)
|
94
|
+
|
95
|
+
def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
|
96
|
+
if not callable(self._propagator_arg):
|
97
|
+
raise ValueError("Must specify explicit callable for pipe")
|
98
|
+
return self._propagate(args, kwargs, fake_args, fake_kwargs)
|
99
|
+
|
100
|
+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
101
|
+
return dtensor_dispatch(
|
102
|
+
self._resolvable,
|
103
|
+
self._propagate,
|
104
|
+
args,
|
105
|
+
kwargs,
|
106
|
+
device_mesh._active,
|
107
|
+
stream._active,
|
108
|
+
)
|
109
|
+
|
110
|
+
def call_on_shard_and_fetch(
|
111
|
+
self, *args, shard: Dict[str, int] | None = None, **kwargs
|
112
|
+
) -> Future[R]:
|
113
|
+
return _call_on_shard_and_fetch(
|
114
|
+
self._resolvable, self._fetch_propagate, *args, shard=shard, **kwargs
|
115
|
+
)
|
116
|
+
|
117
|
+
|
118
|
+
# This can't just be Callable because otherwise we are not
|
119
|
+
# allowed to use type arguments in the return value.
|
120
|
+
class RemoteIfy(Protocol):
|
121
|
+
def __call__(self, function: Callable[P, R]) -> Remote[P, R]: ...
|
122
|
+
|
123
|
+
|
124
|
+
@overload
|
125
|
+
def remote(
|
126
|
+
function: Callable[P, R], *, propagate: Propagator = None
|
127
|
+
) -> "Remote[P, R]": ...
|
128
|
+
|
129
|
+
|
130
|
+
@overload
|
131
|
+
def remote(
|
132
|
+
function: str, *, propagate: Literal["mocked", "cached", "inspect"] | None = None
|
133
|
+
) -> "Remote": ...
|
134
|
+
|
135
|
+
|
136
|
+
@overload
|
137
|
+
def remote(function: str, *, propagate: Callable[P, R]) -> Remote[P, R]: ...
|
138
|
+
|
139
|
+
|
140
|
+
@overload
|
141
|
+
def remote(*, propagate: Propagator = None) -> RemoteIfy: ... # type: ignore
|
142
|
+
|
143
|
+
|
144
|
+
# ignore because otherwise it claims that the actual implementation doesn't
|
145
|
+
# accept the above list of arguments
|
146
|
+
|
147
|
+
|
148
|
+
def remote(function: Any = None, *, propagate: Propagator = None) -> Any:
|
149
|
+
if function is None:
|
150
|
+
return functools.partial(remote, propagate=propagate)
|
151
|
+
return Remote(function, propagate)
|
152
|
+
|
153
|
+
|
154
|
+
def _call_on_shard_and_fetch(
|
155
|
+
rfunction: ResolvableFunction | None,
|
156
|
+
propagator: Any,
|
157
|
+
/,
|
158
|
+
*args: object,
|
159
|
+
shard: dict[str, int] | None = None,
|
160
|
+
**kwargs: object,
|
161
|
+
) -> Future:
|
162
|
+
"""
|
163
|
+
Call `function` at the coordinates `shard` of the current device mesh, and retrieve the result as a Future.
|
164
|
+
function - the remote function to call
|
165
|
+
*args/**kwargs - arguments to the function
|
166
|
+
shard - a dictionary from mesh dimension name to coordinate of the shard
|
167
|
+
If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
|
168
|
+
"""
|
169
|
+
ambient_mesh = device_mesh._active
|
170
|
+
|
171
|
+
if rfunction is None:
|
172
|
+
preprocess_message = None
|
173
|
+
rfunction = ResolvableFunctionFromPath("ident")
|
174
|
+
else:
|
175
|
+
preprocess_message = rfunction
|
176
|
+
_, dtensors, mutates, mesh = dtensor_check(
|
177
|
+
propagator, rfunction, args, kwargs, ambient_mesh, stream._active
|
178
|
+
)
|
179
|
+
|
180
|
+
client: "Client" = mesh.client
|
181
|
+
if _coalescing.is_active(client):
|
182
|
+
raise NotImplementedError("NYI: fetching results during a coalescing block")
|
183
|
+
return client.fetch(
|
184
|
+
mesh,
|
185
|
+
stream._active._to_ref(client),
|
186
|
+
shard,
|
187
|
+
preprocess_message,
|
188
|
+
args,
|
189
|
+
kwargs,
|
190
|
+
mutates,
|
191
|
+
dtensors,
|
192
|
+
)
|
193
|
+
|
194
|
+
|
195
|
+
@remote
|
196
|
+
def _propagate(
|
197
|
+
function: ResolvableFunction, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
198
|
+
):
|
199
|
+
"""
|
200
|
+
RF preprocess function
|
201
|
+
"""
|
202
|
+
fn = function.resolve()
|
203
|
+
|
204
|
+
# XXX - in addition to the functional properties,
|
205
|
+
# and info about if any of the input tensors got mutated.
|
206
|
+
arg_tensors, _ = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
|
207
|
+
input_group = TensorGroup(arg_tensors)
|
208
|
+
result = fn(*args, **kwargs)
|
209
|
+
result_tensors, unflatten_result = flatten(
|
210
|
+
result, lambda x: isinstance(x, torch.Tensor)
|
211
|
+
)
|
212
|
+
|
213
|
+
output_group = TensorGroup(result_tensors, parent=input_group)
|
214
|
+
|
215
|
+
the_result = unflatten_result([tensor_placeholder for _ in result_tensors])
|
216
|
+
return (
|
217
|
+
the_result,
|
218
|
+
output_group.pattern,
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
class DummyProcessGroup(dist.ProcessGroup):
|
223
|
+
def __init__(self, dims: Dims, world_size: int):
|
224
|
+
# pyre-ignore
|
225
|
+
super().__init__(0, world_size)
|
226
|
+
self.dims = dims
|
227
|
+
self.world_size = world_size
|
228
|
+
|
229
|
+
def allreduce(self, tensor, op=dist.ReduceOp.SUM, async_op=False):
|
230
|
+
class DummyWork:
|
231
|
+
def wait(self):
|
232
|
+
return tensor
|
233
|
+
|
234
|
+
return DummyWork()
|
235
|
+
|
236
|
+
def _allgather_base(self, output_tensor, input_tensor, opts):
|
237
|
+
class DummyWork:
|
238
|
+
def wait(self):
|
239
|
+
return output_tensor
|
240
|
+
|
241
|
+
return DummyWork()
|
242
|
+
|
243
|
+
def _reduce_scatter_base(self, output_tensor, input_tensor, opts):
|
244
|
+
class DummyWork:
|
245
|
+
def wait(self):
|
246
|
+
return output_tensor
|
247
|
+
|
248
|
+
return DummyWork()
|
249
|
+
|
250
|
+
def __getstate__(self):
|
251
|
+
return {"dims": self.dims, "world_size": self.world_size}
|
252
|
+
|
253
|
+
def __setstate__(self, state):
|
254
|
+
self.__init__(state["dims"], state["world_size"])
|
255
|
+
|
256
|
+
|
257
|
+
def _mock_pgs(x):
|
258
|
+
if isinstance(x, autograd.function.FunctionCtx):
|
259
|
+
for attr in dir(x):
|
260
|
+
if not attr.startswith("__") and isinstance(attr, RemoteProcessGroup):
|
261
|
+
setattr(x, attr, DummyProcessGroup(attr.dims, attr.size()))
|
262
|
+
return x
|
263
|
+
if isinstance(x, RemoteProcessGroup):
|
264
|
+
return DummyProcessGroup(x.dims, x.size())
|
265
|
+
return x
|
266
|
+
|
267
|
+
|
268
|
+
# for testing
|
269
|
+
_miss = 0
|
270
|
+
_hit = 0
|
271
|
+
|
272
|
+
|
273
|
+
def _cached_propagation(_cache, rfunction, args, kwargs):
|
274
|
+
tensors, shape_key = hashable_tensor_flatten(args, kwargs)
|
275
|
+
inputs_group = TensorGroup([t._fake for t in tensors])
|
276
|
+
requires_grads = tuple(t.requires_grad for t in tensors)
|
277
|
+
key = (shape_key, inputs_group.pattern, requires_grads)
|
278
|
+
|
279
|
+
global _miss, _hit
|
280
|
+
if key not in _cache:
|
281
|
+
_miss += 1
|
282
|
+
args_no_pg, kwargs_no_pg = tree_map(_mock_pgs, (args, kwargs))
|
283
|
+
result_with_placeholders, output_pattern = _propagate.call_on_shard_and_fetch(
|
284
|
+
function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
|
285
|
+
).result()
|
286
|
+
|
287
|
+
_, unflatten_result = flatten(
|
288
|
+
result_with_placeholders, lambda x: isinstance(x, TensorPlaceholder)
|
289
|
+
)
|
290
|
+
_cache[key] = (unflatten_result, output_pattern)
|
291
|
+
else:
|
292
|
+
_hit += 1
|
293
|
+
# return fresh fake result every time to avoid spurious aliasing
|
294
|
+
unflatten_result, output_pattern = _cache[key]
|
295
|
+
|
296
|
+
output_tensors = fake_call(output_pattern.empty, [inputs_group.tensors])
|
297
|
+
return unflatten_result(output_tensors)
|
@@ -0,0 +1,9 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
|
8
|
+
|
9
|
+
__all__ = ["Selection"]
|
monarch/common/shape.py
ADDED
@@ -0,0 +1,229 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import itertools
|
8
|
+
import operator
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
|
11
|
+
from typing import Dict, Generator, Sequence, Tuple, Union
|
12
|
+
|
13
|
+
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
|
14
|
+
|
15
|
+
from typing_extensions import Self
|
16
|
+
|
17
|
+
NDSlice = Slice
|
18
|
+
|
19
|
+
Slices = Slice | list[Slice]
|
20
|
+
|
21
|
+
|
22
|
+
def iter_ranks(ranks: Slices) -> Generator[int, None, None]:
|
23
|
+
if isinstance(ranks, list):
|
24
|
+
seen = set()
|
25
|
+
for slice_ in ranks:
|
26
|
+
for rank in slice_:
|
27
|
+
if rank not in seen:
|
28
|
+
seen.add(rank)
|
29
|
+
yield rank
|
30
|
+
else:
|
31
|
+
yield from ranks
|
32
|
+
|
33
|
+
|
34
|
+
class MeshTrait(ABC):
|
35
|
+
"""
|
36
|
+
Mesh interface. Implemented via Shape.
|
37
|
+
"""
|
38
|
+
|
39
|
+
@property
|
40
|
+
@abstractmethod
|
41
|
+
def _ndslice(self) -> NDSlice: ...
|
42
|
+
|
43
|
+
@property
|
44
|
+
@abstractmethod
|
45
|
+
def _labels(self) -> Tuple[str, ...]: ...
|
46
|
+
|
47
|
+
# mesh trait guarentees that its own calls to _new_with_shape
|
48
|
+
# will only ever select a shape that is a subspace of the
|
49
|
+
# current _ndslice.
|
50
|
+
@abstractmethod
|
51
|
+
def _new_with_shape(self, shape: Shape) -> Self: ...
|
52
|
+
|
53
|
+
def slice(self, **kwargs) -> Self:
|
54
|
+
"""
|
55
|
+
mesh.slice(batch=3) or mesh.slice(batch=slice(3, None))
|
56
|
+
"""
|
57
|
+
ndslice = self._ndslice
|
58
|
+
labels = self._labels
|
59
|
+
offset = ndslice.offset
|
60
|
+
names = []
|
61
|
+
sizes = []
|
62
|
+
strides = []
|
63
|
+
for name, size, stride in zip(labels, ndslice.sizes, ndslice.strides):
|
64
|
+
if name in kwargs:
|
65
|
+
e = kwargs.pop(name)
|
66
|
+
if isinstance(e, slice):
|
67
|
+
start, stop, slice_stride = e.indices(size)
|
68
|
+
offset += start * stride
|
69
|
+
names.append(name)
|
70
|
+
sizes.append((stop - start) // slice_stride)
|
71
|
+
strides.append(slice_stride * stride)
|
72
|
+
else:
|
73
|
+
if e >= size or e < 0:
|
74
|
+
raise IndexError("index out of range")
|
75
|
+
offset += e * stride
|
76
|
+
else:
|
77
|
+
names.append(name)
|
78
|
+
sizes.append(size)
|
79
|
+
strides.append(stride)
|
80
|
+
|
81
|
+
if kwargs:
|
82
|
+
raise TypeError(
|
83
|
+
f"{self} does not have dimension(s) named {tuple(kwargs.keys())}"
|
84
|
+
)
|
85
|
+
|
86
|
+
new_ndslice = NDSlice(offset=offset, sizes=sizes, strides=strides)
|
87
|
+
return self._new_with_shape(Shape(names, new_ndslice))
|
88
|
+
|
89
|
+
def split(self, **kwargs) -> Self:
|
90
|
+
"""
|
91
|
+
Returns a new device mesh with some dimensions of this mesh split.
|
92
|
+
For instance, this call splits the host dimension into dp and pp dimensions,
|
93
|
+
The size of 'pp' is specified and the dimension size is derived from it:
|
94
|
+
|
95
|
+
new_mesh = mesh.split(host=('dp', 'pp'), gpu=('tp','cp'), pp=16, cp=2)
|
96
|
+
|
97
|
+
Dimensions not specified will remain unchanged.
|
98
|
+
"""
|
99
|
+
splits: Dict[str, Sequence[str]] = {}
|
100
|
+
size_constraints: Dict[str, int] = {}
|
101
|
+
for key, value in kwargs.items():
|
102
|
+
if key in self._labels:
|
103
|
+
if isinstance(value, str):
|
104
|
+
raise ValueError(
|
105
|
+
f"expected a sequence of dimensions, but got '{value}'"
|
106
|
+
)
|
107
|
+
splits[key] = value
|
108
|
+
else:
|
109
|
+
if not isinstance(value, int):
|
110
|
+
raise ValueError(
|
111
|
+
f"'{key}' is not an existing dim. Expected an integer size constraint on a new dim."
|
112
|
+
)
|
113
|
+
size_constraints[key] = value
|
114
|
+
|
115
|
+
names = []
|
116
|
+
sizes = []
|
117
|
+
strides = []
|
118
|
+
ndslice = self._ndslice
|
119
|
+
for name, size, stride in zip(self._labels, ndslice.sizes, ndslice.strides):
|
120
|
+
to_names = splits.get(name, (name,))
|
121
|
+
total_size = 1
|
122
|
+
unknown_size_name = None
|
123
|
+
for to_name in to_names:
|
124
|
+
if to_name in size_constraints:
|
125
|
+
total_size *= size_constraints[to_name]
|
126
|
+
elif unknown_size_name is None:
|
127
|
+
unknown_size_name = to_name
|
128
|
+
else:
|
129
|
+
raise ValueError(
|
130
|
+
f"Cannot infer size of {to_names} because both {to_name} and {unknown_size_name} have unknown size. Specify at least one as argument, e.g. {to_name}=4"
|
131
|
+
)
|
132
|
+
if unknown_size_name is not None:
|
133
|
+
inferred_size, m = divmod(size, total_size)
|
134
|
+
if m != 0:
|
135
|
+
to_sizes = tuple(
|
136
|
+
(
|
137
|
+
size_constraints[to_name]
|
138
|
+
if to_name in size_constraints
|
139
|
+
else "?"
|
140
|
+
)
|
141
|
+
for to_name in to_names
|
142
|
+
)
|
143
|
+
raise ValueError(
|
144
|
+
f"Dimension '{name}' of size {size} is not evenly divided by {to_names!r} with sizes {to_sizes!r}"
|
145
|
+
)
|
146
|
+
size_constraints[unknown_size_name] = inferred_size
|
147
|
+
elif total_size != size:
|
148
|
+
to_sizes = tuple(size_constraints[to_name] for to_name in to_names)
|
149
|
+
raise ValueError(
|
150
|
+
f"Dimension '{name}' of size {size} is not evenly divided by {to_names!r} with sizes {to_sizes!r}"
|
151
|
+
)
|
152
|
+
new_sizes = [size_constraints.pop(to_name) for to_name in to_names]
|
153
|
+
new_strides_reversed = tuple(
|
154
|
+
itertools.accumulate(reversed(new_sizes), operator.mul, initial=stride)
|
155
|
+
)
|
156
|
+
sizes.extend(new_sizes)
|
157
|
+
strides.extend(reversed(new_strides_reversed[:-1]))
|
158
|
+
for name in to_names:
|
159
|
+
if name in names:
|
160
|
+
raise ValueError(f"Duplicate dimension name '{name}'")
|
161
|
+
names.extend(to_names)
|
162
|
+
if size_constraints:
|
163
|
+
raise ValueError(
|
164
|
+
f"unused size constraints: {tuple(size_constraints.keys())}"
|
165
|
+
)
|
166
|
+
return self._new_with_shape(
|
167
|
+
Shape(names, NDSlice(offset=ndslice.offset, sizes=sizes, strides=strides))
|
168
|
+
)
|
169
|
+
|
170
|
+
def flatten(self, name: str) -> Self:
|
171
|
+
"""
|
172
|
+
Returns a new device mesh with all dimensions flattened into a single dimension
|
173
|
+
with the given name.
|
174
|
+
|
175
|
+
Currently this supports only dense meshes: that is, all ranks must be contiguous
|
176
|
+
in the mesh.
|
177
|
+
"""
|
178
|
+
ndslice = self._ndslice
|
179
|
+
dense_strides = tuple(
|
180
|
+
itertools.accumulate(reversed(ndslice.sizes), operator.mul, initial=1)
|
181
|
+
)
|
182
|
+
dense_strides, total_size = (
|
183
|
+
list(reversed(dense_strides[:-1])),
|
184
|
+
dense_strides[-1],
|
185
|
+
)
|
186
|
+
if dense_strides != ndslice.strides:
|
187
|
+
raise ValueError(
|
188
|
+
"cannot flatten sparse mesh: " f"{ndslice.strides=} != {dense_strides=}"
|
189
|
+
)
|
190
|
+
|
191
|
+
return self._new_with_shape(
|
192
|
+
Shape(
|
193
|
+
[name], NDSlice(offset=ndslice.offset, sizes=[total_size], strides=[1])
|
194
|
+
)
|
195
|
+
)
|
196
|
+
|
197
|
+
def rename(self, **kwargs) -> Self:
|
198
|
+
"""
|
199
|
+
Returns a new device mesh with some of dimensions renamed.
|
200
|
+
Dimensions not mentioned are retained:
|
201
|
+
|
202
|
+
new_mesh = mesh.rename(host='dp', gpu='tp')
|
203
|
+
"""
|
204
|
+
return self.split(**{k: (v,) for k, v in kwargs.items()})
|
205
|
+
|
206
|
+
def size(self, dim: Union[None, str, Sequence[str]] = None) -> int:
|
207
|
+
"""
|
208
|
+
Returns the number of elements (total) of the subset of mesh asked for.
|
209
|
+
If dims is None, returns the total number of devices in the mesh.
|
210
|
+
"""
|
211
|
+
|
212
|
+
if dim is None:
|
213
|
+
dim = self._labels
|
214
|
+
if isinstance(dim, str):
|
215
|
+
if dim not in self._labels:
|
216
|
+
raise KeyError(f"{self} does not have dimension {repr(dim)}")
|
217
|
+
return self._ndslice.sizes[self._labels.index(dim)]
|
218
|
+
else:
|
219
|
+
p = 1
|
220
|
+
for d in dim:
|
221
|
+
p *= self.size(d)
|
222
|
+
return p
|
223
|
+
|
224
|
+
@property
|
225
|
+
def sizes(self) -> dict[str, int]:
|
226
|
+
return dict(zip(self._labels, self._ndslice.sizes))
|
227
|
+
|
228
|
+
|
229
|
+
__all__ = ["NDSlice", "Shape", "MeshTrait"]
|
monarch/common/stream.py
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from typing import Callable, List, Tuple, TYPE_CHECKING
|
9
|
+
from weakref import ref, WeakKeyDictionary
|
10
|
+
|
11
|
+
from . import messages
|
12
|
+
from .borrows import Borrow
|
13
|
+
from .context_manager import activate_first_context_manager
|
14
|
+
from .fake import fake_call
|
15
|
+
from .reference import Referenceable
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from monarch.common.client import Client # @manual
|
19
|
+
|
20
|
+
from .tensor import Tensor
|
21
|
+
|
22
|
+
|
23
|
+
class Stream:
|
24
|
+
def __init__(self, name: str, _default=False):
|
25
|
+
self.name = name
|
26
|
+
self.default: bool = _default
|
27
|
+
self.clients: WeakKeyDictionary["Client", "StreamRef"] = WeakKeyDictionary()
|
28
|
+
|
29
|
+
def __repr__(self):
|
30
|
+
return f"<Stream({repr(self.name)}) at {hex(id(self))}>"
|
31
|
+
|
32
|
+
def __str__(self):
|
33
|
+
return f"stream {repr(self.name)}"
|
34
|
+
|
35
|
+
def activate(self):
|
36
|
+
return _active_stream(self)
|
37
|
+
|
38
|
+
def _to_ref(self, client: "Client"):
|
39
|
+
if client not in self.clients:
|
40
|
+
self.clients[client] = StreamRef(client, self.name, self.default)
|
41
|
+
return self.clients[client]
|
42
|
+
|
43
|
+
def borrow(self, t: "Tensor", mutable: bool = False) -> Tuple["Tensor", "Borrow"]:
|
44
|
+
"""
|
45
|
+
borrowed_tensor, borrow = self.borrow(t)
|
46
|
+
|
47
|
+
Borrows tensor 't' for use on this stream.
|
48
|
+
The memory of t will stay alive until borrow.drop() is called, which will free t and
|
49
|
+
and any of its alises on stream `self` and will cause t.stream to wait on self at that point so
|
50
|
+
that the memory of t can be reused.
|
51
|
+
|
52
|
+
If `mutable` then self can write to the storage of `t`, but t.stream cannot read or write `t` until,
|
53
|
+
the borrow is returned (becomes free and a wait_for has been issued).
|
54
|
+
|
55
|
+
If not `mutable` both `self` and `t.stream` can read from t's storage but neither can write to it.
|
56
|
+
"""
|
57
|
+
client = t.mesh.client
|
58
|
+
aliases = t._aliases
|
59
|
+
r = type(t)(fake_call(t._fake.clone), t.mesh, self)
|
60
|
+
client.new_node((r,), (t,))
|
61
|
+
borrow = r._aliases.borrow_from(client.new_ref(), t.mesh, aliases, mutable)
|
62
|
+
client.new_borrow(borrow)
|
63
|
+
assert r.ref is not None
|
64
|
+
t.mesh._send(
|
65
|
+
messages.BorrowCreate(
|
66
|
+
r, borrow._id, t, t.stream._to_ref(client), self._to_ref(client)
|
67
|
+
)
|
68
|
+
)
|
69
|
+
r._on_first_use = lambda t: borrow._use()
|
70
|
+
|
71
|
+
return r, borrow
|
72
|
+
|
73
|
+
|
74
|
+
class StreamRef(Referenceable):
|
75
|
+
def __init__(self, client: "Client", name: str, default: bool):
|
76
|
+
self.ref = client.new_ref()
|
77
|
+
self.client = ref(client)
|
78
|
+
self.name = name
|
79
|
+
self.default = default
|
80
|
+
client.send(
|
81
|
+
client.all_ranks,
|
82
|
+
messages.CreateStream(self, self.default),
|
83
|
+
)
|
84
|
+
|
85
|
+
def __repr__(self):
|
86
|
+
return f"<StreamRef {repr(self.name)} {self.ref}>"
|
87
|
+
|
88
|
+
def delete_ref(self, ref):
|
89
|
+
client = self.client()
|
90
|
+
if client is not None and not client._shutdown:
|
91
|
+
client.handle_deletes(client.all_ranks, [ref])
|
92
|
+
|
93
|
+
|
94
|
+
_active = Stream("main", _default=True)
|
95
|
+
_on_change: List[Callable] = []
|
96
|
+
|
97
|
+
|
98
|
+
def get_active_stream():
|
99
|
+
return _active
|
100
|
+
|
101
|
+
|
102
|
+
@activate_first_context_manager
|
103
|
+
def _active_stream(stream: Stream):
|
104
|
+
global _active
|
105
|
+
for on_change in _on_change:
|
106
|
+
on_change(_active, stream)
|
107
|
+
|
108
|
+
_active, old = stream, _active
|
109
|
+
try:
|
110
|
+
yield
|
111
|
+
finally:
|
112
|
+
for on_change in _on_change:
|
113
|
+
on_change(_active, old)
|
114
|
+
_active = old
|