torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/common/tensor.py
ADDED
@@ -0,0 +1,814 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import itertools
|
9
|
+
import traceback
|
10
|
+
import typing
|
11
|
+
import warnings
|
12
|
+
from collections import defaultdict
|
13
|
+
from typing import (
|
14
|
+
Any,
|
15
|
+
Callable,
|
16
|
+
cast,
|
17
|
+
Dict,
|
18
|
+
Iterable,
|
19
|
+
List,
|
20
|
+
Literal,
|
21
|
+
NamedTuple,
|
22
|
+
Optional,
|
23
|
+
runtime_checkable,
|
24
|
+
Sequence,
|
25
|
+
TYPE_CHECKING,
|
26
|
+
TypeVar,
|
27
|
+
Union,
|
28
|
+
)
|
29
|
+
|
30
|
+
import torch
|
31
|
+
import torch._ops
|
32
|
+
from monarch.common.function import ResolvableFunctionFromPath
|
33
|
+
from torch._subclasses.fake_tensor import FakeTensor
|
34
|
+
from torch.utils._pytree import tree_map
|
35
|
+
|
36
|
+
from . import messages, stream
|
37
|
+
from .base_tensor import BaseTensor
|
38
|
+
from .borrows import StorageAliases
|
39
|
+
|
40
|
+
if TYPE_CHECKING:
|
41
|
+
from monarch.common.device_mesh import DeviceMesh
|
42
|
+
|
43
|
+
from .fake import fake_call
|
44
|
+
from .function import Propagator, ResolvableFunction
|
45
|
+
from .invocation import Invocation
|
46
|
+
from .messages import Dims
|
47
|
+
from .reference import Referenceable
|
48
|
+
from .shape import NDSlice
|
49
|
+
from .stream import Stream
|
50
|
+
from .tree import flatten
|
51
|
+
|
52
|
+
_valid_reduce = Literal[
|
53
|
+
"stack", "sum", "avg", "product", "min", "max", "band", "bor", "bxor"
|
54
|
+
]
|
55
|
+
|
56
|
+
T = TypeVar("T")
|
57
|
+
|
58
|
+
|
59
|
+
@runtime_checkable
|
60
|
+
class HasDeviceMesh(typing.Protocol):
|
61
|
+
@property
|
62
|
+
def _device_mesh(self) -> "DeviceMesh": ...
|
63
|
+
|
64
|
+
|
65
|
+
class DropLocation(NamedTuple):
|
66
|
+
tensor_id: int
|
67
|
+
traceback: List[traceback.FrameSummary]
|
68
|
+
|
69
|
+
def __repr__(self) -> str:
|
70
|
+
return f"tensor {self.tensor_id} is dropped at: \n" + "".join(
|
71
|
+
traceback.format_list(self.traceback)
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
class Tensor(Referenceable, BaseTensor):
|
76
|
+
# pyre-fixme[13]: Attribute `stream` is never initialized.
|
77
|
+
stream: Stream
|
78
|
+
# pyre-fixme[13]: Attribute `mesh` is never initialized.
|
79
|
+
mesh: "DeviceMesh"
|
80
|
+
ref: Optional[int]
|
81
|
+
# pyre-fixme[13]: Attribute `_invocation` is never initialized.
|
82
|
+
_invocation: Optional[Invocation]
|
83
|
+
# pyre-fixme[13]: Attribute `_fake` is never initialized.
|
84
|
+
_fake: torch.Tensor
|
85
|
+
# pyre-fixme[13]: Attribute `_aliases` is never initialized.
|
86
|
+
_aliases: StorageAliases
|
87
|
+
# pyre-fixme[13]: Attribute `_on_first_use` is never initialized.
|
88
|
+
_on_first_use: Optional[Callable]
|
89
|
+
# pyre-fixme[13]: Attribute `_drop_location` is never initialized.
|
90
|
+
_drop_location: Optional[DropLocation]
|
91
|
+
# _seq represents the sequence number of the concrete invocation that
|
92
|
+
# created this tensor, or the most recent invocation that mutated it.
|
93
|
+
# Unlike the _invocation field, this will be set for both the rust and
|
94
|
+
# python backends.
|
95
|
+
# pyre-fixme[13]: Attribute `_seq` is never initialized.
|
96
|
+
_seq: Optional[int]
|
97
|
+
|
98
|
+
def __new__(cls, fake: torch.Tensor, mesh: "DeviceMesh", stream: "Stream"):
|
99
|
+
# pyre-ignore[16]
|
100
|
+
r = torch.Tensor._make_wrapper_subclass(
|
101
|
+
cls,
|
102
|
+
fake.size(),
|
103
|
+
strides=fake.stride(),
|
104
|
+
storage_offset=fake.storage_offset(),
|
105
|
+
device=fake.device, # This is the device of of either input tensor or first tensor of a list
|
106
|
+
dtype=fake.dtype,
|
107
|
+
layout=fake.layout,
|
108
|
+
requires_grad=fake.requires_grad,
|
109
|
+
)
|
110
|
+
assert isinstance(fake, FakeTensor)
|
111
|
+
r._fake = fake
|
112
|
+
client = mesh.client
|
113
|
+
r.ref = client.new_ref()
|
114
|
+
r.mesh = mesh
|
115
|
+
r.stream = stream
|
116
|
+
|
117
|
+
storage = fake.untyped_storage()
|
118
|
+
client = mesh.client
|
119
|
+
if storage not in client.aliases:
|
120
|
+
client.aliases[storage] = StorageAliases()
|
121
|
+
r._aliases = client.aliases[storage]
|
122
|
+
r._aliases.register(r)
|
123
|
+
r._invocation = None
|
124
|
+
r._on_first_use = None
|
125
|
+
r._drop_location = None
|
126
|
+
r._seq = None
|
127
|
+
return r
|
128
|
+
|
129
|
+
@classmethod
|
130
|
+
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
131
|
+
from monarch.common.remote import remote
|
132
|
+
|
133
|
+
# device_mesh <-> tensor <-> remote are mututally recursive
|
134
|
+
# we break the dependency to allow for separate files by
|
135
|
+
# having device_mesh and tensor locally import the `remote`
|
136
|
+
# entrypoint
|
137
|
+
return remote(func, propagate=func)(*args, **kwargs)
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
fake: Optional[torch.Tensor] = None,
|
142
|
+
mesh: Optional["DeviceMesh"] = None,
|
143
|
+
stream: Optional[Stream] = None,
|
144
|
+
):
|
145
|
+
pass
|
146
|
+
|
147
|
+
def __repr__(self, *, tensor_contents=None):
|
148
|
+
return f"monarch.Tensor(mesh={self.mesh}, stream={self.stream}, fake={repr(self._fake)})"
|
149
|
+
|
150
|
+
def drop(self):
|
151
|
+
if self.ref is None:
|
152
|
+
return
|
153
|
+
|
154
|
+
for alias in self._aliases.aliases:
|
155
|
+
alias._drop_ref()
|
156
|
+
|
157
|
+
# we should be in the tensors list as well
|
158
|
+
assert self.ref is None
|
159
|
+
|
160
|
+
@property
|
161
|
+
def dropped(self):
|
162
|
+
return self.ref is None
|
163
|
+
|
164
|
+
def _drop_ref(self):
|
165
|
+
if self.ref is None:
|
166
|
+
return
|
167
|
+
self.delete_ref(self.ref)
|
168
|
+
self._drop_location = DropLocation(self.ref, traceback.extract_stack())
|
169
|
+
self.ref = None
|
170
|
+
|
171
|
+
@property
|
172
|
+
def _access_permissions(self):
|
173
|
+
return self._aliases.access
|
174
|
+
|
175
|
+
def _use(self):
|
176
|
+
if self._on_first_use:
|
177
|
+
self._on_first_use(self)
|
178
|
+
self._on_first_use = None
|
179
|
+
|
180
|
+
def to_mesh(
|
181
|
+
self,
|
182
|
+
mesh: Union["DeviceMesh", "HasDeviceMesh"],
|
183
|
+
stream: Optional["Stream"] = None,
|
184
|
+
):
|
185
|
+
"""
|
186
|
+
Move data between one device mesh and another. Sizes of named dimensions must match.
|
187
|
+
If mesh has dimensions that self.mesh does not, it will broadcast to those dimensions.
|
188
|
+
|
189
|
+
|
190
|
+
broadcast:
|
191
|
+
t.slice_mesh(batch=0).to_mesh(t.mesh)
|
192
|
+
|
193
|
+
"""
|
194
|
+
if isinstance(mesh, HasDeviceMesh):
|
195
|
+
mesh = mesh._device_mesh
|
196
|
+
return MeshSliceTensor(self, self.mesh).to_mesh(mesh, stream)
|
197
|
+
|
198
|
+
def reduce_(
|
199
|
+
self,
|
200
|
+
dims: Dims | str,
|
201
|
+
reduction: _valid_reduce = "sum",
|
202
|
+
scatter=False,
|
203
|
+
mesh=None,
|
204
|
+
):
|
205
|
+
return self.reduce(dims, reduction, scatter, mesh, _inplace=True)
|
206
|
+
|
207
|
+
def reduce(
|
208
|
+
self,
|
209
|
+
dims: Dims | str,
|
210
|
+
reduction: _valid_reduce = "sum",
|
211
|
+
scatter: bool = False,
|
212
|
+
mesh: Optional["DeviceMesh"] = None,
|
213
|
+
_inplace: bool = False,
|
214
|
+
out: Optional["Tensor"] = None,
|
215
|
+
):
|
216
|
+
"""
|
217
|
+
Perform a reduction operation along dim, and move the data to mesh. If mesh=None, then mesh=self.mesh
|
218
|
+
'stack' (gather) will concat the values along dim, and produce a local result tensor with an addition outer dimension of len(dim).
|
219
|
+
If scatter=True, the local result tensor will be evenly split across dim.
|
220
|
+
|
221
|
+
allreduce:
|
222
|
+
t.reduce(dims='gpu', reduction='sum')
|
223
|
+
|
224
|
+
First reduces dim 'gpu' creating a local tensor with the 'gpu' dimension, then because output_mesh=input_mesh, and it still has dim 'gpu',
|
225
|
+
we broadcast the result reduced tensor to all members of gpu.
|
226
|
+
|
227
|
+
reducescatter:
|
228
|
+
t.reduce(dims='gpu', reduction='sum', scatter=True)
|
229
|
+
|
230
|
+
Same as above except that scatter=True introduces a new 'gpu' dimension that is the result of splitting the local tensor across 'gpu'
|
231
|
+
|
232
|
+
allgather:
|
233
|
+
t.reduce(dims='gpu', reduction='stack')
|
234
|
+
|
235
|
+
First reduces dim 'gpu' creating a bigger local tensor, then because output_mesh=input_mesh, and it still has dim 'gpu',
|
236
|
+
broadcasts the result concatenated tensor to all members of gpu.
|
237
|
+
|
238
|
+
alltoall:
|
239
|
+
t.reduce(dims='gpu', reduction='stack', scatter=True)
|
240
|
+
|
241
|
+
|
242
|
+
First reduces dim 'gpu' creating a bigger local tensor, then introduces a new 'gpu' dimension that is the result of splitting this
|
243
|
+
(bigger) tensor across 'gpu'. The result is the same dimension as the original tensor, but with each rank sending to all other ranks.
|
244
|
+
|
245
|
+
|
246
|
+
gather (to dim 0):
|
247
|
+
t.reduce(dims='gpu', reduction='stack', mesh=device_mesh(gpu=0))
|
248
|
+
|
249
|
+
First gathers dim 'gpu' and then places it on the first rank. t.mesh.gpu[0] doesn't have a 'gpu' dimension, but this is
|
250
|
+
ok because we eliminated the 'gpu' dim via reduction.
|
251
|
+
|
252
|
+
reduce:
|
253
|
+
t.reduce(dims='gpu', reduction='sum', mesh=device_mesh(gpu=0))
|
254
|
+
|
255
|
+
First reduces dim 'gpu' and then places it on the first rank. t.mesh.gpu[0] doesn't have a 'gpu' dimension, but this is
|
256
|
+
ok because we eliminated the 'gpu' dim via reduction.
|
257
|
+
|
258
|
+
|
259
|
+
Args:
|
260
|
+
dims (Dims | str): The dimensions along which to perform the reduction.
|
261
|
+
reduction (_valid_reduce): The type of reduction to perform. Defaults to "sum".
|
262
|
+
scatter (bool): If True, the local result tensor will be evenly split across dimensions.
|
263
|
+
Defaults to False.
|
264
|
+
mesh (Optional["DeviceMesh"], optional): The target mesh to move the data to.
|
265
|
+
If None, uses self.mesh. Defaults to None.
|
266
|
+
_inplace (bool): If True, performs the operation in-place. Defaults to False.
|
267
|
+
Note that not all the reduction operations support in-place.
|
268
|
+
out (Optional["Tensor"]): The output tensor to store the result. If None, a new tensor
|
269
|
+
will be created on the stream where the reduce operation executes. Defaults to None.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
Tensor: The result of the reduction operation.
|
273
|
+
"""
|
274
|
+
if mesh is not None:
|
275
|
+
raise NotImplementedError()
|
276
|
+
if isinstance(dims, str):
|
277
|
+
dims = (dims,)
|
278
|
+
for d in dims:
|
279
|
+
if d not in self.mesh.names:
|
280
|
+
raise KeyError(f"dim {d} not found in {self.mesh}")
|
281
|
+
if len(dims) == 0:
|
282
|
+
dims = self.mesh.names
|
283
|
+
if len(set(dims)) != len(dims):
|
284
|
+
raise ValueError(f"reducing the same dimension twice: {dims}")
|
285
|
+
if len(dims) > 1:
|
286
|
+
if reduction == "stack" or scatter:
|
287
|
+
raise ValueError(
|
288
|
+
f"reduction {reduction} or scatter = {scatter} is not valid for multiple dimensions"
|
289
|
+
)
|
290
|
+
if reduction not in _valid_reduce.__args__:
|
291
|
+
raise ValueError(
|
292
|
+
f"reduction {reduction} not supported, reductions are {_valid_reduce.__args__}"
|
293
|
+
)
|
294
|
+
|
295
|
+
if mesh is None:
|
296
|
+
mesh = self.mesh
|
297
|
+
|
298
|
+
ts: List[torch.Tensor] = [self]
|
299
|
+
if out is not None:
|
300
|
+
ts.append(out)
|
301
|
+
with InputChecker(
|
302
|
+
ts,
|
303
|
+
lambda ts: (
|
304
|
+
f"reduce({next(ts)}, {dims}, reduction={reduction}, out={next(ts, None)})"
|
305
|
+
),
|
306
|
+
) as checker:
|
307
|
+
checker.check_no_requires_grad()
|
308
|
+
checker.check_cuda()
|
309
|
+
checker.check_mesh_stream_local(self.mesh, stream._active)
|
310
|
+
checker.check_permission((out,) if out is not None else ())
|
311
|
+
|
312
|
+
if _inplace:
|
313
|
+
if out is not None:
|
314
|
+
raise ValueError("`out` cannot be used with inplace reduce.")
|
315
|
+
inplace_valid = (reduction == "gather" and scatter) or not scatter
|
316
|
+
if not inplace_valid:
|
317
|
+
raise ValueError(
|
318
|
+
f"reduction {reduction} is not valid for in-place operation because "
|
319
|
+
"the output size will not match the input size."
|
320
|
+
)
|
321
|
+
fake_output = self._fake
|
322
|
+
else:
|
323
|
+
N = (
|
324
|
+
self.mesh.processes.sizes[self.mesh.names.index(dims[0])]
|
325
|
+
if reduction == "stack" or scatter
|
326
|
+
else -1
|
327
|
+
)
|
328
|
+
|
329
|
+
fake_output = fake_call(
|
330
|
+
_fake_reduce, self._fake, self.mesh, N, reduction, scatter
|
331
|
+
)
|
332
|
+
if out is not None:
|
333
|
+
if out.shape != fake_output.shape:
|
334
|
+
raise ValueError(
|
335
|
+
f"The given output shape, {out.shape}, is incorrect. "
|
336
|
+
f"Reduce expects the shape to be {fake_output.shape}."
|
337
|
+
)
|
338
|
+
fake_output = out._fake
|
339
|
+
|
340
|
+
r = Tensor(fake_output, self.mesh, self.stream)
|
341
|
+
assert r.ref is not None
|
342
|
+
self.mesh.define_remotely()
|
343
|
+
defines = (r,) if out is None else (r, out)
|
344
|
+
self.mesh.client.new_node(defines, (self,))
|
345
|
+
self.mesh.client.backend_network_init()
|
346
|
+
self.mesh.client.split_comm(dims, self.mesh, self.stream._to_ref(mesh.client))
|
347
|
+
self.mesh._send(
|
348
|
+
messages.Reduce(
|
349
|
+
r,
|
350
|
+
self,
|
351
|
+
self._factory(),
|
352
|
+
self.mesh,
|
353
|
+
self.stream._to_ref(mesh.client),
|
354
|
+
dims,
|
355
|
+
reduction,
|
356
|
+
scatter,
|
357
|
+
_inplace,
|
358
|
+
out,
|
359
|
+
)
|
360
|
+
)
|
361
|
+
return r
|
362
|
+
|
363
|
+
def slice_mesh(self, **kwargs: Union[int, slice]) -> "MeshSliceTensor":
|
364
|
+
# technically a slice of a device mesh and a device mesh are not same thing
|
365
|
+
# because a device mesh also has caches for doing collectives.
|
366
|
+
# but this is an easy way to create a MeshSliceTensor until we optimize
|
367
|
+
# how we represent mesh slices.
|
368
|
+
slicing = self.mesh.slice(**kwargs)
|
369
|
+
return MeshSliceTensor(self, slicing)
|
370
|
+
|
371
|
+
def delete_ref(self, ref: int):
|
372
|
+
mesh = self.mesh
|
373
|
+
if not mesh.client.has_shutdown:
|
374
|
+
self._aliases.unregister(self)
|
375
|
+
mesh.client.delete_ref(mesh, ref)
|
376
|
+
|
377
|
+
def _factory(self):
|
378
|
+
return messages.TensorFactory.from_tensor(self._fake)
|
379
|
+
|
380
|
+
|
381
|
+
class MeshSliceTensor:
|
382
|
+
def __init__(self, tensor: "Tensor", slicing: "DeviceMesh"):
|
383
|
+
self.tensor = tensor
|
384
|
+
self.slicing = slicing
|
385
|
+
|
386
|
+
def to_mesh(
|
387
|
+
self,
|
388
|
+
mesh: Union["DeviceMesh", "HasDeviceMesh"],
|
389
|
+
stream: Optional["Stream"] = None,
|
390
|
+
) -> "Tensor":
|
391
|
+
if isinstance(mesh, HasDeviceMesh):
|
392
|
+
mesh = mesh._device_mesh
|
393
|
+
|
394
|
+
if stream is None:
|
395
|
+
stream = self.tensor.stream
|
396
|
+
|
397
|
+
with InputChecker(
|
398
|
+
[self.tensor], lambda ts: f"{next(ts)}.to_mesh({mesh})"
|
399
|
+
) as checker:
|
400
|
+
checker.check_no_requires_grad()
|
401
|
+
checker.check_cuda()
|
402
|
+
checker.check_permission(mutated_tensors=())
|
403
|
+
|
404
|
+
sizes = []
|
405
|
+
strides = []
|
406
|
+
broadcast_dims = []
|
407
|
+
for name, size in zip(mesh.names, mesh.processes.sizes):
|
408
|
+
if name not in self.slicing.names:
|
409
|
+
broadcast_dims.append(name)
|
410
|
+
warnings.warn(
|
411
|
+
f"to_mesh is broadcasting along {name} dimension."
|
412
|
+
"This is implemented inefficiently and should only be used for initialization before it is fixed.",
|
413
|
+
stacklevel=2,
|
414
|
+
)
|
415
|
+
continue
|
416
|
+
index = self.slicing.names.index(name)
|
417
|
+
if self.slicing.processes.sizes[index] != size:
|
418
|
+
raise ValueError(
|
419
|
+
f"dimension {name} of destination device_mesh has a different length than the source tensor"
|
420
|
+
)
|
421
|
+
sizes.append(size)
|
422
|
+
strides.append(self.slicing.processes.strides[index])
|
423
|
+
|
424
|
+
if len(sizes) != len(self.slicing.names):
|
425
|
+
missing = set(self.slicing.names) - set(mesh.names)
|
426
|
+
raise ValueError(f"destination mesh does not have dimensions {missing}")
|
427
|
+
|
428
|
+
# Optimized algorithm where:
|
429
|
+
# 1. We can represent submeshes as NDSlice(offet, sizes, strides) on rank.
|
430
|
+
# 2. A message can be efficiently broadcast to List[NDSlice] ranks by a smart tree based algorithm that can
|
431
|
+
# figure out which subtrees need the message.
|
432
|
+
# 3. The message itself will uses List[NDSlice] objects to express the send/recv set and so it is very small
|
433
|
+
|
434
|
+
# so basically both the way the message is broadcast and its size will be compressed but the
|
435
|
+
# send pattern and the meaning of the message will be the same as this ineffiecient form
|
436
|
+
|
437
|
+
from_ranks = NDSlice(
|
438
|
+
offset=self.slicing.processes.offset, sizes=sizes, strides=strides
|
439
|
+
)
|
440
|
+
r = Tensor(fake_call(self.tensor._fake.clone), mesh, stream)
|
441
|
+
assert r.ref is not None
|
442
|
+
client = self.tensor.mesh.client
|
443
|
+
from_stream_ref = self.tensor.stream._to_ref(client)
|
444
|
+
to_stream_ref = stream._to_ref(client)
|
445
|
+
client.backend_network_init()
|
446
|
+
client.backend_network_point_to_point_init(from_stream_ref, to_stream_ref)
|
447
|
+
client.new_node((r,), (self.tensor,))
|
448
|
+
|
449
|
+
if broadcast_dims:
|
450
|
+
mesh_sizes = mesh.sizes
|
451
|
+
dim_sequences = [
|
452
|
+
zip(itertools.repeat(dim), range(mesh_sizes[dim]))
|
453
|
+
for dim in broadcast_dims
|
454
|
+
]
|
455
|
+
destinations = [
|
456
|
+
mesh.slice(**dict(dim_settings)).processes
|
457
|
+
for dim_settings in itertools.product(*dim_sequences)
|
458
|
+
]
|
459
|
+
else:
|
460
|
+
destinations = [mesh.processes]
|
461
|
+
|
462
|
+
for to_ranks in destinations:
|
463
|
+
client.send(
|
464
|
+
[from_ranks, to_ranks],
|
465
|
+
messages.SendTensor(
|
466
|
+
r,
|
467
|
+
from_ranks,
|
468
|
+
to_ranks,
|
469
|
+
self.tensor,
|
470
|
+
self.tensor._factory(),
|
471
|
+
from_stream_ref,
|
472
|
+
to_stream_ref,
|
473
|
+
),
|
474
|
+
)
|
475
|
+
return r
|
476
|
+
|
477
|
+
|
478
|
+
def _fake_reduce(
|
479
|
+
tensor, source_mesh: "DeviceMesh", group_size: int, reduction, scatter: bool
|
480
|
+
):
|
481
|
+
if scatter:
|
482
|
+
if tensor.ndim == 0 or tensor.size(0) != group_size:
|
483
|
+
raise TypeError(
|
484
|
+
f"When scattering results the outer most dimension of tensor with sizes ({list(tensor.size())} must match the size ({group_size})"
|
485
|
+
)
|
486
|
+
if reduction == "stack":
|
487
|
+
# scatter removes a dimension of mesh size
|
488
|
+
# but gather adds the dimension back
|
489
|
+
return tensor
|
490
|
+
return tensor.sum(dim=0)
|
491
|
+
else:
|
492
|
+
if reduction == "stack":
|
493
|
+
return torch.empty(
|
494
|
+
[group_size, *tensor.shape],
|
495
|
+
dtype=tensor.dtype,
|
496
|
+
device=tensor.device,
|
497
|
+
layout=tensor.layout,
|
498
|
+
)
|
499
|
+
return tensor.add(tensor)
|
500
|
+
|
501
|
+
|
502
|
+
_explain = """\
|
503
|
+
LOCAL_TENSOR
|
504
|
+
This tensor is a local (non-distributed) tensor being used while a device_mesh is active.
|
505
|
+
If you want to do local tensor compute use `with no_mesh.activate():`
|
506
|
+
|
507
|
+
WRONG_MESH
|
508
|
+
This tensor is on a device mesh that is not the current device_mesh.
|
509
|
+
Use `with m.activate():` to switch the active mesh, or move the tensor to the correct device mesh with `to_mesh`/`on_mesh`.
|
510
|
+
|
511
|
+
WRONG_STREAM
|
512
|
+
This tensor is on a stream that is not the current active stream. Use with `stream.activate()` to switch streams, or
|
513
|
+
move the tensor to the correct stream with `.borrow`.
|
514
|
+
|
515
|
+
DROPPED
|
516
|
+
This tensor, or a view of it, was explicitly deleted with the t.drop() function and is no longer usable.
|
517
|
+
|
518
|
+
BORROWED
|
519
|
+
This tensor cannot be read because it is being used mutably in another stream.
|
520
|
+
|
521
|
+
MUTATING_BORROW
|
522
|
+
This tensor would be mutated by this operator but it is read only because it is being borrowed.
|
523
|
+
|
524
|
+
REQUIRES_GRAD
|
525
|
+
This tensor requires gradients but this operation does not work with autograd.
|
526
|
+
|
527
|
+
CROSS_DEVICE_REQUIRES_CUDA
|
528
|
+
Operations that send tensors across devices currently require CUDA tensors.
|
529
|
+
"""
|
530
|
+
|
531
|
+
explain = {}
|
532
|
+
for entry in _explain.split("\n\n"):
|
533
|
+
lines = entry.split("\n")
|
534
|
+
explain[lines[0]] = "".join(f" {l}\n" for l in lines)
|
535
|
+
|
536
|
+
|
537
|
+
def handle_lift_fresh_dispatch(
|
538
|
+
propagate, rfunction, args, kwargs, ambient_mesh, stream
|
539
|
+
):
|
540
|
+
assert ambient_mesh is not None
|
541
|
+
fake_result = fake_call(
|
542
|
+
torch.zeros, args[0].shape, device=args[0].device, dtype=args[0].dtype
|
543
|
+
)
|
544
|
+
return fake_result, (), (), ambient_mesh
|
545
|
+
|
546
|
+
|
547
|
+
special_ops_handler = {"torch.ops.aten.lift_fresh.default": handle_lift_fresh_dispatch}
|
548
|
+
|
549
|
+
|
550
|
+
class _Symbol(NamedTuple):
|
551
|
+
name: str
|
552
|
+
|
553
|
+
def __repr__(self):
|
554
|
+
return self.name
|
555
|
+
|
556
|
+
|
557
|
+
class InputChecker:
|
558
|
+
@staticmethod
|
559
|
+
def from_flat_args(func: Any, tensors: Sequence[torch.Tensor], unflatten: Callable):
|
560
|
+
def format(tensor_values: Iterable[str]):
|
561
|
+
args, kwargs = unflatten(tensor_values)
|
562
|
+
actuals = ", ".join(
|
563
|
+
itertools.chain(
|
564
|
+
map(repr, args),
|
565
|
+
(f"{key}={repr(value)}" for key, value in kwargs.items()),
|
566
|
+
)
|
567
|
+
)
|
568
|
+
return f"{func}({actuals})"
|
569
|
+
|
570
|
+
return InputChecker(tensors, format)
|
571
|
+
|
572
|
+
def __init__(
|
573
|
+
self, tensors: Sequence[torch.Tensor], format: Callable[[Iterable[Any]], str]
|
574
|
+
):
|
575
|
+
self.tensors = tensors
|
576
|
+
self.format = format
|
577
|
+
self.errors: Dict[torch.Tensor, List[str]] = defaultdict(list)
|
578
|
+
self.overall_errors = []
|
579
|
+
# we set this here just so we have stream to report as the current
|
580
|
+
# stream in errors where the stream does not matter.
|
581
|
+
# If the stream matters for this call, we
|
582
|
+
# get the right stream in `check_stream`.
|
583
|
+
self.stream = stream._active
|
584
|
+
self._mesh = None
|
585
|
+
|
586
|
+
def check_mesh_stream_local(
|
587
|
+
self, ambient_mesh: Optional["DeviceMesh"], stream: "Stream"
|
588
|
+
):
|
589
|
+
self.stream = stream
|
590
|
+
for t in self.tensors:
|
591
|
+
if isinstance(t, Tensor):
|
592
|
+
self._mesh = t.mesh
|
593
|
+
break
|
594
|
+
if self._mesh is None:
|
595
|
+
self._mesh = ambient_mesh
|
596
|
+
if self._mesh is None:
|
597
|
+
self.overall_errors.append(
|
598
|
+
"Remote functions require an active device mesh, use `with mesh.activate():`"
|
599
|
+
)
|
600
|
+
|
601
|
+
for t in self.tensors:
|
602
|
+
if isinstance(t, Tensor):
|
603
|
+
if t.mesh is not self._mesh:
|
604
|
+
self.errors[t].append(explain["WRONG_MESH"])
|
605
|
+
if t.stream is not self.stream:
|
606
|
+
self.errors[t].append(explain["WRONG_STREAM"])
|
607
|
+
else:
|
608
|
+
self.errors[t].append(explain["LOCAL_TENSOR"])
|
609
|
+
|
610
|
+
@property
|
611
|
+
def mesh(self) -> "DeviceMesh":
|
612
|
+
assert self._mesh is not None
|
613
|
+
return self._mesh
|
614
|
+
|
615
|
+
def raise_current_errors(self):
|
616
|
+
if not self.errors and not self.overall_errors:
|
617
|
+
return
|
618
|
+
error_info: List[str] = [
|
619
|
+
f"active_mesh = {self._mesh}\n",
|
620
|
+
f"active_stream = {self.stream}\n",
|
621
|
+
*self.overall_errors,
|
622
|
+
]
|
623
|
+
error_names: Dict["Tensor", "str"] = {}
|
624
|
+
for i, (t, errors) in enumerate(self.errors.items()):
|
625
|
+
name = f"ERROR_{i}"
|
626
|
+
error_names[t] = name
|
627
|
+
error_info.append(f"{name}:\n")
|
628
|
+
error_info.extend(errors)
|
629
|
+
|
630
|
+
call = self.format(_Symbol(error_names.get(t, ".")) for t in self.tensors)
|
631
|
+
msg = f"Incorrect arguments to monarch operation:\n\n {call}\n\n{''.join(error_info)}"
|
632
|
+
raise TypeError(msg)
|
633
|
+
|
634
|
+
def _borrow_tracebacks(self, t: Tensor):
|
635
|
+
lines = []
|
636
|
+
for b in t._aliases.live_borrows:
|
637
|
+
lines.append(" Traceback of borrow (most recent frame last):\n")
|
638
|
+
lines.extend(f" {line}\n" for line in b.traceback_string.split("\n"))
|
639
|
+
return lines
|
640
|
+
|
641
|
+
def check_permission(self, mutated_tensors: Sequence["Tensor"]):
|
642
|
+
for t in self.tensors:
|
643
|
+
if not isinstance(t, Tensor):
|
644
|
+
continue
|
645
|
+
if "r" not in t._access_permissions:
|
646
|
+
errors = self.errors[t]
|
647
|
+
errors.append(explain["BORROWED"])
|
648
|
+
errors.extend(self._borrow_tracebacks(t))
|
649
|
+
if t.dropped:
|
650
|
+
self.errors[t].append(explain["DROPPED"])
|
651
|
+
if t._drop_location:
|
652
|
+
self.errors[t].append(str(t._drop_location))
|
653
|
+
|
654
|
+
for t in mutated_tensors:
|
655
|
+
if "w" not in t._access_permissions:
|
656
|
+
errors = self.errors[t]
|
657
|
+
errors.append(explain["MUTATING_BORROW"])
|
658
|
+
errors.extend(self._borrow_tracebacks(t))
|
659
|
+
|
660
|
+
def check_no_requires_grad(self):
|
661
|
+
for t in self.tensors:
|
662
|
+
if torch.is_grad_enabled() and t.requires_grad:
|
663
|
+
self.errors[t].append(explain["REQUIRES_GRAD"])
|
664
|
+
|
665
|
+
def check_cuda(self):
|
666
|
+
for t in self.tensors:
|
667
|
+
if not t.is_cuda:
|
668
|
+
self.errors[t].append(explain["CROSS_DEVICE_REQUIRES_CUDA"])
|
669
|
+
|
670
|
+
def __enter__(self) -> "InputChecker":
|
671
|
+
return self
|
672
|
+
|
673
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
674
|
+
if exc_type is not None:
|
675
|
+
return
|
676
|
+
self.raise_current_errors()
|
677
|
+
|
678
|
+
|
679
|
+
def dtensor_check(
|
680
|
+
propagate: "Propagator",
|
681
|
+
rfunc: "ResolvableFunction",
|
682
|
+
args,
|
683
|
+
kwargs,
|
684
|
+
ambient_mesh: Optional["DeviceMesh"],
|
685
|
+
stream: Stream,
|
686
|
+
):
|
687
|
+
dtensors, unflatten = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
|
688
|
+
with InputChecker.from_flat_args(rfunc, dtensors, unflatten) as checker:
|
689
|
+
checker.check_mesh_stream_local(ambient_mesh, stream)
|
690
|
+
|
691
|
+
# ensure tensors are correct enough to do propagation with them.
|
692
|
+
checker.raise_current_errors()
|
693
|
+
|
694
|
+
# the distinction is we only check permissions on the first level mutates
|
695
|
+
# but have to record error-tracking dependency edges for all parent borrows.
|
696
|
+
|
697
|
+
# future diff will change how we track this and then simplify this code.
|
698
|
+
|
699
|
+
mutates = []
|
700
|
+
fake_input_tensors = [d._fake for d in dtensors]
|
701
|
+
before_versions = [f._version for f in fake_input_tensors]
|
702
|
+
fake_args, fake_kwargs = unflatten(fake_input_tensors)
|
703
|
+
result = propagate(args, kwargs, fake_args, fake_kwargs)
|
704
|
+
for i in range(len(dtensors)):
|
705
|
+
if before_versions[i] < fake_input_tensors[i]._version:
|
706
|
+
mutates.extend(dtensors[i]._aliases.aliases)
|
707
|
+
checker.check_permission(mutates)
|
708
|
+
|
709
|
+
return result, dtensors, tuple(mutates), checker.mesh
|
710
|
+
|
711
|
+
|
712
|
+
def dtensor_dispatch(
|
713
|
+
rfunction: ResolvableFunction,
|
714
|
+
propagate: Propagator,
|
715
|
+
args,
|
716
|
+
kwargs,
|
717
|
+
ambient_mesh: Optional["DeviceMesh"],
|
718
|
+
stream: Stream,
|
719
|
+
):
|
720
|
+
from .device_mesh import RemoteProcessGroup
|
721
|
+
|
722
|
+
op_handler = dtensor_check
|
723
|
+
if isinstance(rfunction, ResolvableFunctionFromPath):
|
724
|
+
op_handler = special_ops_handler.get(rfunction.path, dtensor_check)
|
725
|
+
|
726
|
+
fake_result, dtensors, mutates, device_mesh = op_handler(
|
727
|
+
propagate, rfunction, args, kwargs, ambient_mesh, stream
|
728
|
+
)
|
729
|
+
assert device_mesh is not None
|
730
|
+
|
731
|
+
fake_result_dtensors, unflatten_result = flatten(
|
732
|
+
fake_result, lambda x: isinstance(x, torch.Tensor)
|
733
|
+
)
|
734
|
+
result_dtensors = tuple(
|
735
|
+
Tensor(fake, device_mesh, stream) for fake in fake_result_dtensors
|
736
|
+
)
|
737
|
+
seq = device_mesh.client.new_node(result_dtensors + mutates, dtensors)
|
738
|
+
assert all(t.ref is not None for t in result_dtensors)
|
739
|
+
assert all(t.ref is not None for t in mutates)
|
740
|
+
result = result_msg = unflatten_result(result_dtensors)
|
741
|
+
if len(result_dtensors) == 0:
|
742
|
+
result_msg = None
|
743
|
+
|
744
|
+
# note the device mesh has to be defined regardles so the remote functions
|
745
|
+
# can invoke device_mesh.rank("...")
|
746
|
+
device_mesh.define_remotely()
|
747
|
+
|
748
|
+
# if there's a process group anywhere in the args, kwargs we need to initialize the backend network
|
749
|
+
# if it hasn't already been done.
|
750
|
+
process_groups, _ = flatten(
|
751
|
+
(args, kwargs), lambda x: isinstance(x, RemoteProcessGroup)
|
752
|
+
)
|
753
|
+
if len(process_groups) > 0:
|
754
|
+
device_mesh.client.backend_network_init()
|
755
|
+
for pg in process_groups:
|
756
|
+
assert not pg.dropped
|
757
|
+
pg.ensure_split_comm_remotely(stream._to_ref(device_mesh.client))
|
758
|
+
|
759
|
+
device_mesh._send(
|
760
|
+
messages.CallFunction(
|
761
|
+
seq,
|
762
|
+
result_msg,
|
763
|
+
tuple(mutates),
|
764
|
+
rfunction,
|
765
|
+
args,
|
766
|
+
kwargs,
|
767
|
+
stream._to_ref(device_mesh.client),
|
768
|
+
device_mesh,
|
769
|
+
process_groups,
|
770
|
+
)
|
771
|
+
)
|
772
|
+
# XXX - realistically this would be done on a non-python thread, keeping our messages up to date
|
773
|
+
# but we can approximate it by checking for all ready meassages whenever we schedule new work
|
774
|
+
while device_mesh.client.handle_next_message(0):
|
775
|
+
pass
|
776
|
+
return result
|
777
|
+
|
778
|
+
|
779
|
+
def reduce(
|
780
|
+
tensors: T,
|
781
|
+
dims: Dims | str,
|
782
|
+
reduction: _valid_reduce = "sum",
|
783
|
+
scatter: bool = False,
|
784
|
+
mesh: Optional["DeviceMesh"] = None,
|
785
|
+
_inplace: bool = False,
|
786
|
+
) -> T:
|
787
|
+
"""
|
788
|
+
Performs the tensor reduction operation for each tensor in tensors.
|
789
|
+
Args:
|
790
|
+
tensors (pytree["Tensor"]): The pytree of input tensors to reduce.
|
791
|
+
dims (Dims | str): The dimensions along which to perform the reduction.
|
792
|
+
reduction (_valid_reduce): The type of reduction to perform. Defaults to "sum".
|
793
|
+
scatter (bool): If True, the local result tensor will be evenly split across dimensions.
|
794
|
+
Defaults to False.
|
795
|
+
mesh (Optional["DeviceMesh"], optional): The target mesh to move the data to.
|
796
|
+
If None, uses self.mesh. Defaults to None.
|
797
|
+
_inplace (bool): If True, performs the operation in-place. Defaults to False.
|
798
|
+
Note that not all the reduction operations support in-place.
|
799
|
+
"""
|
800
|
+
|
801
|
+
def _reduce(tensor: "Tensor") -> "Tensor":
|
802
|
+
return tensor.reduce(dims, reduction, scatter, mesh, _inplace)
|
803
|
+
|
804
|
+
return tree_map(_reduce, tensors)
|
805
|
+
|
806
|
+
|
807
|
+
def reduce_(
|
808
|
+
tensors: T,
|
809
|
+
dims: Dims | str,
|
810
|
+
reduction: _valid_reduce = "sum",
|
811
|
+
scatter: bool = False,
|
812
|
+
mesh: Optional["DeviceMesh"] = None,
|
813
|
+
) -> T:
|
814
|
+
return reduce(tensors, dims, reduction, scatter, mesh, _inplace=True)
|