torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,793 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import traceback
10
+ import warnings
11
+ from collections import defaultdict
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Dict,
16
+ Iterable,
17
+ List,
18
+ Literal,
19
+ NamedTuple,
20
+ Optional,
21
+ Sequence,
22
+ TYPE_CHECKING,
23
+ TypeVar,
24
+ Union,
25
+ )
26
+
27
+ import torch
28
+ import torch._ops
29
+ from monarch.common.function import ResolvableFunctionFromPath
30
+ from torch._subclasses.fake_tensor import FakeTensor
31
+ from torch.utils._pytree import tree_map
32
+
33
+ from . import messages, stream
34
+ from .base_tensor import BaseTensor
35
+ from .borrows import StorageAliases
36
+
37
+ if TYPE_CHECKING:
38
+ from .device_mesh import DeviceMesh
39
+ from .fake import fake_call
40
+ from .function import Propagator, ResolvableFunction
41
+ from .invocation import Invocation
42
+ from .messages import Dims
43
+ from .reference import Referenceable
44
+ from .shape import NDSlice
45
+ from .stream import Stream
46
+ from .tree import flatten
47
+
48
+ _valid_reduce = Literal[
49
+ "stack", "sum", "avg", "product", "min", "max", "band", "bor", "bxor"
50
+ ]
51
+
52
+ T = TypeVar("T")
53
+
54
+
55
+ class DropLocation(NamedTuple):
56
+ tensor_id: int
57
+ traceback: List[traceback.FrameSummary]
58
+
59
+ def __repr__(self) -> str:
60
+ return f"tensor {self.tensor_id} is dropped at: \n" + "".join(
61
+ traceback.format_list(self.traceback)
62
+ )
63
+
64
+
65
+ class Tensor(Referenceable, BaseTensor):
66
+ # pyre-fixme[13]: Attribute `stream` is never initialized.
67
+ stream: Stream
68
+ # pyre-fixme[13]: Attribute `mesh` is never initialized.
69
+ mesh: "DeviceMesh"
70
+ ref: Optional[int]
71
+ # pyre-fixme[13]: Attribute `_invocation` is never initialized.
72
+ _invocation: Optional[Invocation]
73
+ # pyre-fixme[13]: Attribute `_fake` is never initialized.
74
+ _fake: torch.Tensor
75
+ # pyre-fixme[13]: Attribute `_aliases` is never initialized.
76
+ _aliases: StorageAliases
77
+ # pyre-fixme[13]: Attribute `_on_first_use` is never initialized.
78
+ _on_first_use: Optional[Callable]
79
+ # pyre-fixme[13]: Attribute `_drop_location` is never initialized.
80
+ _drop_location: Optional[DropLocation]
81
+ # _seq represents the sequence number of the concrete invocation that
82
+ # created this tensor, or the most recent invocation that mutated it.
83
+ # Unlike the _invocation field, this will be set for both the rust and
84
+ # python backends.
85
+ # pyre-fixme[13]: Attribute `_seq` is never initialized.
86
+ _seq: Optional[int]
87
+
88
+ def __new__(cls, fake: torch.Tensor, mesh: "DeviceMesh", stream: "Stream"):
89
+ # pyre-ignore[16]
90
+ r = torch.Tensor._make_wrapper_subclass(
91
+ cls,
92
+ fake.size(),
93
+ strides=fake.stride(),
94
+ storage_offset=fake.storage_offset(),
95
+ device=fake.device, # This is the device of of either input tensor or first tensor of a list
96
+ dtype=fake.dtype,
97
+ layout=fake.layout,
98
+ requires_grad=fake.requires_grad,
99
+ )
100
+ assert isinstance(fake, FakeTensor)
101
+ r._fake = fake
102
+ client = mesh.client
103
+ r.ref = client.new_ref()
104
+ r.mesh = mesh
105
+ r.stream = stream
106
+
107
+ storage = fake.untyped_storage()
108
+ client = mesh.client
109
+ if storage not in client.aliases:
110
+ client.aliases[storage] = StorageAliases()
111
+ r._aliases = client.aliases[storage]
112
+ r._aliases.register(r)
113
+ r._invocation = None
114
+ r._on_first_use = None
115
+ r._drop_location = None
116
+ r._seq = None
117
+ return r
118
+
119
+ @classmethod
120
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
121
+ from monarch.common.remote import remote
122
+
123
+ # device_mesh <-> tensor <-> remote are mututally recursive
124
+ # we break the dependency to allow for separate files by
125
+ # having device_mesh and tensor locally import the `remote`
126
+ # entrypoint
127
+ return remote(func, propagate=func)(*args, **kwargs)
128
+
129
+ def __init__(
130
+ self,
131
+ fake: Optional[torch.Tensor] = None,
132
+ mesh: Optional["DeviceMesh"] = None,
133
+ stream: Optional[Stream] = None,
134
+ ):
135
+ pass
136
+
137
+ def __repr__(self, *, tensor_contents=None):
138
+ return f"monarch.Tensor(mesh={self.mesh}, stream={self.stream}, fake={repr(self._fake)})"
139
+
140
+ def drop(self):
141
+ if self.ref is None:
142
+ return
143
+
144
+ for alias in self._aliases.aliases:
145
+ alias._drop_ref()
146
+
147
+ # we should be in the tensors list as well
148
+ assert self.ref is None
149
+
150
+ @property
151
+ def dropped(self):
152
+ return self.ref is None
153
+
154
+ def _drop_ref(self):
155
+ if self.ref is None:
156
+ return
157
+ self.delete_ref(self.ref)
158
+ self._drop_location = DropLocation(self.ref, traceback.extract_stack())
159
+ self.ref = None
160
+
161
+ @property
162
+ def _access_permissions(self):
163
+ return self._aliases.access
164
+
165
+ def _use(self):
166
+ if self._on_first_use:
167
+ self._on_first_use(self)
168
+ self._on_first_use = None
169
+
170
+ def to_mesh(self, mesh: "DeviceMesh", stream: Optional["Stream"] = None):
171
+ """
172
+ Move data between one device mesh and another. Sizes of named dimensions must match.
173
+ If mesh has dimensions that self.mesh does not, it will broadcast to those dimensions.
174
+
175
+
176
+ broadcast:
177
+ t.slice_mesh(batch=0).to_mesh(t.mesh)
178
+
179
+ """
180
+ return MeshSliceTensor(self, self.mesh).to_mesh(mesh, stream)
181
+
182
+ def reduce_(
183
+ self,
184
+ dims: Dims | str,
185
+ reduction: _valid_reduce = "sum",
186
+ scatter=False,
187
+ mesh=None,
188
+ ):
189
+ return self.reduce(dims, reduction, scatter, mesh, _inplace=True)
190
+
191
+ def reduce(
192
+ self,
193
+ dims: Dims | str,
194
+ reduction: _valid_reduce = "sum",
195
+ scatter: bool = False,
196
+ mesh: Optional["DeviceMesh"] = None,
197
+ _inplace: bool = False,
198
+ out: Optional["Tensor"] = None,
199
+ ):
200
+ """
201
+ Perform a reduction operation along dim, and move the data to mesh. If mesh=None, then mesh=self.mesh
202
+ 'stack' (gather) will concat the values along dim, and produce a local result tensor with an addition outer dimension of len(dim).
203
+ If scatter=True, the local result tensor will be evenly split across dim.
204
+
205
+ allreduce:
206
+ t.reduce(dims='gpu', reduction='sum')
207
+
208
+ First reduces dim 'gpu' creating a local tensor with the 'gpu' dimension, then because output_mesh=input_mesh, and it still has dim 'gpu',
209
+ we broadcast the result reduced tensor to all members of gpu.
210
+
211
+ reducescatter:
212
+ t.reduce(dims='gpu', reduction='sum', scatter=True)
213
+
214
+ Same as above except that scatter=True introduces a new 'gpu' dimension that is the result of splitting the local tensor across 'gpu'
215
+
216
+ allgather:
217
+ t.reduce(dims='gpu', reduction='stack')
218
+
219
+ First reduces dim 'gpu' creating a bigger local tensor, then because output_mesh=input_mesh, and it still has dim 'gpu',
220
+ broadcasts the result concatenated tensor to all members of gpu.
221
+
222
+ alltoall:
223
+ t.reduce(dims='gpu', reduction='stack', scatter=True)
224
+
225
+
226
+ First reduces dim 'gpu' creating a bigger local tensor, then introduces a new 'gpu' dimension that is the result of splitting this
227
+ (bigger) tensor across 'gpu'. The result is the same dimension as the original tensor, but with each rank sending to all other ranks.
228
+
229
+
230
+ gather (to dim 0):
231
+ t.reduce(dims='gpu', reduction='stack', mesh=device_mesh(gpu=0))
232
+
233
+ First gathers dim 'gpu' and then places it on the first rank. t.mesh.gpu[0] doesn't have a 'gpu' dimension, but this is
234
+ ok because we eliminated the 'gpu' dim via reduction.
235
+
236
+ reduce:
237
+ t.reduce(dims='gpu', reduction='sum', mesh=device_mesh(gpu=0))
238
+
239
+ First reduces dim 'gpu' and then places it on the first rank. t.mesh.gpu[0] doesn't have a 'gpu' dimension, but this is
240
+ ok because we eliminated the 'gpu' dim via reduction.
241
+
242
+
243
+ Args:
244
+ dims (Dims | str): The dimensions along which to perform the reduction.
245
+ reduction (_valid_reduce): The type of reduction to perform. Defaults to "sum".
246
+ scatter (bool): If True, the local result tensor will be evenly split across dimensions.
247
+ Defaults to False.
248
+ mesh (Optional["DeviceMesh"], optional): The target mesh to move the data to.
249
+ If None, uses self.mesh. Defaults to None.
250
+ _inplace (bool): If True, performs the operation in-place. Defaults to False.
251
+ Note that not all the reduction operations support in-place.
252
+ out (Optional["Tensor"]): The output tensor to store the result. If None, a new tensor
253
+ will be created on the stream where the reduce operation executes. Defaults to None.
254
+
255
+ Returns:
256
+ Tensor: The result of the reduction operation.
257
+ """
258
+ if mesh is not None:
259
+ raise NotImplementedError()
260
+ if isinstance(dims, str):
261
+ dims = (dims,)
262
+ for d in dims:
263
+ if d not in self.mesh.names:
264
+ raise KeyError(f"dim {d} not found in {self.mesh}")
265
+ if len(dims) == 0:
266
+ dims = self.mesh.names
267
+ if len(set(dims)) != len(dims):
268
+ raise ValueError(f"reducing the same dimension twice: {dims}")
269
+ if len(dims) > 1:
270
+ if reduction == "stack" or scatter:
271
+ raise ValueError(
272
+ f"reduction {reduction} or scatter = {scatter} is not valid for multiple dimensions"
273
+ )
274
+ if reduction not in _valid_reduce.__args__:
275
+ raise ValueError(
276
+ f"reduction {reduction} not supported, reductions are {_valid_reduce.__args__}"
277
+ )
278
+
279
+ if mesh is None:
280
+ mesh = self.mesh
281
+
282
+ ts: List[torch.Tensor] = [self]
283
+ if out is not None:
284
+ ts.append(out)
285
+ with InputChecker(
286
+ ts,
287
+ lambda ts: (
288
+ f"reduce({next(ts)}, {dims}, reduction={reduction}, out={next(ts, None)})"
289
+ ),
290
+ ) as checker:
291
+ checker.check_no_requires_grad()
292
+ checker.check_cuda()
293
+ checker.check_mesh_stream_local(self.mesh, stream._active)
294
+ checker.check_permission((out,) if out is not None else ())
295
+
296
+ if _inplace:
297
+ if out is not None:
298
+ raise ValueError("`out` cannot be used with inplace reduce.")
299
+ inplace_valid = (reduction == "gather" and scatter) or not scatter
300
+ if not inplace_valid:
301
+ raise ValueError(
302
+ f"reduction {reduction} is not valid for in-place operation because "
303
+ "the output size will not match the input size."
304
+ )
305
+ fake_output = self._fake
306
+ else:
307
+ N = (
308
+ self.mesh.processes.sizes[self.mesh.names.index(dims[0])]
309
+ if reduction == "stack" or scatter
310
+ else -1
311
+ )
312
+
313
+ fake_output = fake_call(
314
+ _fake_reduce, self._fake, self.mesh, N, reduction, scatter
315
+ )
316
+ if out is not None:
317
+ if out.shape != fake_output.shape:
318
+ raise ValueError(
319
+ f"The given output shape, {out.shape}, is incorrect. "
320
+ f"Reduce expects the shape to be {fake_output.shape}."
321
+ )
322
+ fake_output = out._fake
323
+
324
+ r = Tensor(fake_output, self.mesh, self.stream)
325
+ assert r.ref is not None
326
+ self.mesh.define_remotely()
327
+ defines = (r,) if out is None else (r, out)
328
+ self.mesh.client.new_node(defines, (self,))
329
+ self.mesh.client.backend_network_init()
330
+ self.mesh.client.split_comm(dims, self.mesh, self.stream._to_ref(mesh.client))
331
+ self.mesh._send(
332
+ messages.Reduce(
333
+ r,
334
+ self,
335
+ self._factory(),
336
+ self.mesh,
337
+ self.stream._to_ref(mesh.client),
338
+ dims,
339
+ reduction,
340
+ scatter,
341
+ _inplace,
342
+ out,
343
+ )
344
+ )
345
+ return r
346
+
347
+ def slice_mesh(self, **kwargs: Dict[str, Union[int, slice]]) -> "MeshSliceTensor":
348
+ # technically a slice of a device mesh and a device mesh are not same thing
349
+ # because a device mesh also has caches for doing collectives.
350
+ # but this is an easy way to create a MeshSliceTensor until we optimize
351
+ # how we represent mesh slices.
352
+ slicing = self.mesh.slice(**kwargs)
353
+ return MeshSliceTensor(self, slicing)
354
+
355
+ def delete_ref(self, ref: int):
356
+ mesh = self.mesh
357
+ if not mesh.client.has_shutdown:
358
+ self._aliases.unregister(self)
359
+ mesh.client.delete_ref(mesh, ref)
360
+
361
+ def _factory(self):
362
+ return messages.TensorFactory.from_tensor(self._fake)
363
+
364
+
365
+ class MeshSliceTensor:
366
+ def __init__(self, tensor: "Tensor", slicing: "DeviceMesh"):
367
+ self.tensor = tensor
368
+ self.slicing = slicing
369
+
370
+ def to_mesh(
371
+ self, mesh: "DeviceMesh", stream: Optional["Stream"] = None
372
+ ) -> "Tensor":
373
+ if stream is None:
374
+ stream = self.tensor.stream
375
+
376
+ with InputChecker(
377
+ [self.tensor], lambda ts: f"{next(ts)}.to_mesh({mesh})"
378
+ ) as checker:
379
+ checker.check_no_requires_grad()
380
+ checker.check_cuda()
381
+ checker.check_permission(mutated_tensors=())
382
+
383
+ sizes = []
384
+ strides = []
385
+ broadcast_dims = []
386
+ for name, size in zip(mesh.names, mesh.processes.sizes):
387
+ if name not in self.slicing.names:
388
+ broadcast_dims.append(name)
389
+ warnings.warn(
390
+ f"to_mesh is broadcasting along {name} dimension."
391
+ "This is implemented inefficiently and should only be used for initialization before it is fixed.",
392
+ stacklevel=2,
393
+ )
394
+ continue
395
+ index = self.slicing.names.index(name)
396
+ if self.slicing.processes.sizes[index] != size:
397
+ raise ValueError(
398
+ f"dimension {name} of destination device_mesh has a different length than the source tensor"
399
+ )
400
+ sizes.append(size)
401
+ strides.append(self.slicing.processes.strides[index])
402
+
403
+ if len(sizes) != len(self.slicing.names):
404
+ missing = set(self.slicing.names) - set(mesh.names)
405
+ raise ValueError(f"destination mesh does not have dimensions {missing}")
406
+
407
+ # Optimized algorithm where:
408
+ # 1. We can represent submeshes as NDSlice(offet, sizes, strides) on rank.
409
+ # 2. A message can be efficiently broadcast to List[NDSlice] ranks by a smart tree based algorithm that can
410
+ # figure out which subtrees need the message.
411
+ # 3. The message itself will uses List[NDSlice] objects to express the send/recv set and so it is very small
412
+
413
+ # so basically both the way the message is broadcast and its size will be compressed but the
414
+ # send pattern and the meaning of the message will be the same as this ineffiecient form
415
+
416
+ from_ranks = NDSlice(
417
+ offset=self.slicing.processes.offset, sizes=sizes, strides=strides
418
+ )
419
+ r = Tensor(fake_call(self.tensor._fake.clone), mesh, stream)
420
+ assert r.ref is not None
421
+ client = self.tensor.mesh.client
422
+ from_stream_ref = self.tensor.stream._to_ref(client)
423
+ to_stream_ref = stream._to_ref(client)
424
+ client.backend_network_init()
425
+ client.backend_network_point_to_point_init(from_stream_ref, to_stream_ref)
426
+ client.new_node((r,), (self.tensor,))
427
+
428
+ if broadcast_dims:
429
+ mesh_sizes = mesh.sizes
430
+ dim_sequences = [
431
+ zip(itertools.repeat(dim), range(mesh_sizes[dim]))
432
+ for dim in broadcast_dims
433
+ ]
434
+ destinations = [
435
+ mesh.slice(**dict(dim_settings)).processes
436
+ for dim_settings in itertools.product(*dim_sequences)
437
+ ]
438
+ else:
439
+ destinations = [mesh.processes]
440
+
441
+ for to_ranks in destinations:
442
+ client.send(
443
+ [from_ranks, to_ranks],
444
+ messages.SendTensor(
445
+ r,
446
+ from_ranks,
447
+ to_ranks,
448
+ self.tensor,
449
+ self.tensor._factory(),
450
+ from_stream_ref,
451
+ to_stream_ref,
452
+ ),
453
+ )
454
+ return r
455
+
456
+
457
+ def _fake_reduce(
458
+ tensor, source_mesh: "DeviceMesh", group_size: int, reduction, scatter: bool
459
+ ):
460
+ if scatter:
461
+ if tensor.ndim == 0 or tensor.size(0) != group_size:
462
+ raise TypeError(
463
+ f"When scattering results the outer most dimension of tensor with sizes ({list(tensor.size())} must match the size ({group_size})"
464
+ )
465
+ if reduction == "stack":
466
+ # scatter removes a dimension of mesh size
467
+ # but gather adds the dimension back
468
+ return tensor
469
+ return tensor.sum(dim=0)
470
+ else:
471
+ if reduction == "stack":
472
+ return torch.empty(
473
+ [group_size, *tensor.shape],
474
+ dtype=tensor.dtype,
475
+ device=tensor.device,
476
+ layout=tensor.layout,
477
+ )
478
+ return tensor.add(tensor)
479
+
480
+
481
+ _explain = """\
482
+ LOCAL_TENSOR
483
+ This tensor is a local (non-distributed) tensor being used while a device_mesh is active.
484
+ If you want to do local tensor compute use `with no_mesh.activate():`
485
+
486
+ WRONG_MESH
487
+ This tensor is on a device mesh that is not the current device_mesh.
488
+ Use `with m.activate():` to switch the active mesh, or move the tensor to the correct device mesh with `to_mesh`/`on_mesh`.
489
+
490
+ WRONG_STREAM
491
+ This tensor is on a stream that is not the current active stream. Use with `stream.activate()` to switch streams, or
492
+ move the tensor to the correct stream with `.borrow`.
493
+
494
+ DROPPED
495
+ This tensor, or a view of it, was explicitly deleted with the t.drop() function and is no longer usable.
496
+
497
+ BORROWED
498
+ This tensor cannot be read because it is being used mutably in another stream.
499
+
500
+ MUTATING_BORROW
501
+ This tensor would be mutated by this operator but it is read only because it is being borrowed.
502
+
503
+ REQUIRES_GRAD
504
+ This tensor requires gradients but this operation does not work with autograd.
505
+
506
+ CROSS_DEVICE_REQUIRES_CUDA
507
+ Operations that send tensors across devices currently require CUDA tensors.
508
+ """
509
+
510
+ explain = {}
511
+ for entry in _explain.split("\n\n"):
512
+ lines = entry.split("\n")
513
+ explain[lines[0]] = "".join(f" {l}\n" for l in lines)
514
+
515
+
516
+ def handle_lift_fresh_dispatch(
517
+ propagate, rfunction, args, kwargs, ambient_mesh, stream
518
+ ):
519
+ assert ambient_mesh is not None
520
+ fake_result = fake_call(
521
+ torch.zeros, args[0].shape, device=args[0].device, dtype=args[0].dtype
522
+ )
523
+ return fake_result, (), (), ambient_mesh
524
+
525
+
526
+ special_ops_handler = {"torch.ops.aten.lift_fresh.default": handle_lift_fresh_dispatch}
527
+
528
+
529
+ class _Symbol(NamedTuple):
530
+ name: str
531
+
532
+ def __repr__(self):
533
+ return self.name
534
+
535
+
536
+ class InputChecker:
537
+ @staticmethod
538
+ def from_flat_args(func: Any, tensors: Sequence[torch.Tensor], unflatten: Callable):
539
+ def format(tensor_values: Iterable[str]):
540
+ args, kwargs = unflatten(tensor_values)
541
+ actuals = ", ".join(
542
+ itertools.chain(
543
+ map(repr, args),
544
+ (f"{key}={repr(value)}" for key, value in kwargs.items()),
545
+ )
546
+ )
547
+ return f"{func}({actuals})"
548
+
549
+ return InputChecker(tensors, format)
550
+
551
+ def __init__(
552
+ self, tensors: Sequence[torch.Tensor], format: Callable[[Iterable[Any]], str]
553
+ ):
554
+ self.tensors = tensors
555
+ self.format = format
556
+ self.errors: Dict[torch.Tensor, List[str]] = defaultdict(list)
557
+ self.overall_errors = []
558
+ # we set this here just so we have stream to report as the current
559
+ # stream in errors where the stream does not matter.
560
+ # If the stream matters for this call, we
561
+ # get the right stream in `check_stream`.
562
+ self.stream = stream._active
563
+ self._mesh = None
564
+
565
+ def check_mesh_stream_local(
566
+ self, ambient_mesh: Optional["DeviceMesh"], stream: "Stream"
567
+ ):
568
+ self.stream = stream
569
+ for t in self.tensors:
570
+ if isinstance(t, Tensor):
571
+ self._mesh = t.mesh
572
+ break
573
+ if self._mesh is None:
574
+ self._mesh = ambient_mesh
575
+ if self._mesh is None:
576
+ self.overall_errors.append(
577
+ "Remote functions require an active device mesh, use `with mesh.activate():`"
578
+ )
579
+
580
+ for t in self.tensors:
581
+ if isinstance(t, Tensor):
582
+ if t.mesh is not self._mesh:
583
+ self.errors[t].append(explain["WRONG_MESH"])
584
+ if t.stream is not self.stream:
585
+ self.errors[t].append(explain["WRONG_STREAM"])
586
+ else:
587
+ self.errors[t].append(explain["LOCAL_TENSOR"])
588
+
589
+ @property
590
+ def mesh(self) -> "DeviceMesh":
591
+ assert self._mesh is not None
592
+ return self._mesh
593
+
594
+ def raise_current_errors(self):
595
+ if not self.errors and not self.overall_errors:
596
+ return
597
+ error_info: List[str] = [
598
+ f"active_mesh = {self._mesh}\n",
599
+ f"active_stream = {self.stream}\n",
600
+ *self.overall_errors,
601
+ ]
602
+ error_names: Dict["Tensor", "str"] = {}
603
+ for i, (t, errors) in enumerate(self.errors.items()):
604
+ name = f"ERROR_{i}"
605
+ error_names[t] = name
606
+ error_info.append(f"{name}:\n")
607
+ error_info.extend(errors)
608
+
609
+ call = self.format(_Symbol(error_names.get(t, ".")) for t in self.tensors)
610
+ msg = f"Incorrect arguments to monarch operation:\n\n {call}\n\n{''.join(error_info)}"
611
+ raise TypeError(msg)
612
+
613
+ def _borrow_tracebacks(self, t: Tensor):
614
+ lines = []
615
+ for b in t._aliases.live_borrows:
616
+ lines.append(" Traceback of borrow (most recent frame last):\n")
617
+ lines.extend(f" {line}\n" for line in b.traceback_string.split("\n"))
618
+ return lines
619
+
620
+ def check_permission(self, mutated_tensors: Sequence["Tensor"]):
621
+ for t in self.tensors:
622
+ if not isinstance(t, Tensor):
623
+ continue
624
+ if "r" not in t._access_permissions:
625
+ errors = self.errors[t]
626
+ errors.append(explain["BORROWED"])
627
+ errors.extend(self._borrow_tracebacks(t))
628
+ if t.dropped:
629
+ self.errors[t].append(explain["DROPPED"])
630
+ if t._drop_location:
631
+ self.errors[t].append(str(t._drop_location))
632
+
633
+ for t in mutated_tensors:
634
+ if "w" not in t._access_permissions:
635
+ errors = self.errors[t]
636
+ errors.append(explain["MUTATING_BORROW"])
637
+ errors.extend(self._borrow_tracebacks(t))
638
+
639
+ def check_no_requires_grad(self):
640
+ for t in self.tensors:
641
+ if torch.is_grad_enabled() and t.requires_grad:
642
+ self.errors[t].append(explain["REQUIRES_GRAD"])
643
+
644
+ def check_cuda(self):
645
+ for t in self.tensors:
646
+ if not t.is_cuda:
647
+ self.errors[t].append(explain["CROSS_DEVICE_REQUIRES_CUDA"])
648
+
649
+ def __enter__(self) -> "InputChecker":
650
+ return self
651
+
652
+ def __exit__(self, exc_type, exc_value, traceback):
653
+ if exc_type is not None:
654
+ return
655
+ self.raise_current_errors()
656
+
657
+
658
+ def dtensor_check(
659
+ propagate: "Propagator",
660
+ rfunc: "ResolvableFunction",
661
+ args,
662
+ kwargs,
663
+ ambient_mesh: Optional["DeviceMesh"],
664
+ stream: Stream,
665
+ ):
666
+ dtensors, unflatten = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
667
+ with InputChecker.from_flat_args(rfunc, dtensors, unflatten) as checker:
668
+ checker.check_mesh_stream_local(ambient_mesh, stream)
669
+
670
+ # ensure tensors are correct enough to do propagation with them.
671
+ checker.raise_current_errors()
672
+
673
+ # the distinction is we only check permissions on the first level mutates
674
+ # but have to record error-tracking dependency edges for all parent borrows.
675
+
676
+ # future diff will change how we track this and then simplify this code.
677
+
678
+ mutates = []
679
+ fake_input_tensors = [d._fake for d in dtensors]
680
+ before_versions = [f._version for f in fake_input_tensors]
681
+ fake_args, fake_kwargs = unflatten(fake_input_tensors)
682
+ result = propagate(args, kwargs, fake_args, fake_kwargs)
683
+ for i in range(len(dtensors)):
684
+ if before_versions[i] < fake_input_tensors[i]._version:
685
+ mutates.extend(dtensors[i]._aliases.aliases)
686
+ checker.check_permission(mutates)
687
+
688
+ return result, dtensors, tuple(mutates), checker.mesh
689
+
690
+
691
+ def dtensor_dispatch(
692
+ rfunction: ResolvableFunction,
693
+ propagate: Propagator,
694
+ args,
695
+ kwargs,
696
+ ambient_mesh: Optional["DeviceMesh"],
697
+ stream: Stream,
698
+ ):
699
+ from .device_mesh import RemoteProcessGroup
700
+
701
+ op_handler = dtensor_check
702
+ if isinstance(rfunction, ResolvableFunctionFromPath):
703
+ op_handler = special_ops_handler.get(rfunction.path, dtensor_check)
704
+
705
+ fake_result, dtensors, mutates, device_mesh = op_handler(
706
+ propagate, rfunction, args, kwargs, ambient_mesh, stream
707
+ )
708
+ assert device_mesh is not None
709
+
710
+ fake_result_dtensors, unflatten_result = flatten(
711
+ fake_result, lambda x: isinstance(x, torch.Tensor)
712
+ )
713
+ result_dtensors = tuple(
714
+ Tensor(fake, device_mesh, stream) for fake in fake_result_dtensors
715
+ )
716
+ seq = device_mesh.client.new_node(result_dtensors + mutates, dtensors)
717
+ assert all(t.ref is not None for t in result_dtensors)
718
+ assert all(t.ref is not None for t in mutates)
719
+ result = result_msg = unflatten_result(result_dtensors)
720
+ if len(result_dtensors) == 0:
721
+ result_msg = None
722
+
723
+ # note the device mesh has to be defined regardles so the remote functions
724
+ # can invoke device_mesh.rank("...")
725
+ device_mesh.define_remotely()
726
+
727
+ # if there's a process group anywhere in the args, kwargs we need to initialize the backend network
728
+ # if it hasn't already been done.
729
+ process_groups, _ = flatten(
730
+ (args, kwargs), lambda x: isinstance(x, RemoteProcessGroup)
731
+ )
732
+ if len(process_groups) > 0:
733
+ device_mesh.client.backend_network_init()
734
+ for pg in process_groups:
735
+ assert not pg.dropped
736
+ pg.ensure_split_comm_remotely(stream._to_ref(device_mesh.client))
737
+
738
+ device_mesh._send(
739
+ messages.CallFunction(
740
+ seq,
741
+ result_msg,
742
+ tuple(mutates),
743
+ rfunction,
744
+ args,
745
+ kwargs,
746
+ stream._to_ref(device_mesh.client),
747
+ device_mesh,
748
+ process_groups,
749
+ )
750
+ )
751
+ # XXX - realistically this would be done on a non-python thread, keeping our messages up to date
752
+ # but we can approximate it by checking for all ready meassages whenever we schedule new work
753
+ while device_mesh.client.handle_next_message(0):
754
+ pass
755
+ return result
756
+
757
+
758
+ def reduce(
759
+ tensors: T,
760
+ dims: Dims | str,
761
+ reduction: _valid_reduce = "sum",
762
+ scatter: bool = False,
763
+ mesh: Optional["DeviceMesh"] = None,
764
+ _inplace: bool = False,
765
+ ) -> T:
766
+ """
767
+ Performs the tensor reduction operation for each tensor in tensors.
768
+ Args:
769
+ tensors (pytree["Tensor"]): The pytree of input tensors to reduce.
770
+ dims (Dims | str): The dimensions along which to perform the reduction.
771
+ reduction (_valid_reduce): The type of reduction to perform. Defaults to "sum".
772
+ scatter (bool): If True, the local result tensor will be evenly split across dimensions.
773
+ Defaults to False.
774
+ mesh (Optional["DeviceMesh"], optional): The target mesh to move the data to.
775
+ If None, uses self.mesh. Defaults to None.
776
+ _inplace (bool): If True, performs the operation in-place. Defaults to False.
777
+ Note that not all the reduction operations support in-place.
778
+ """
779
+
780
+ def _reduce(tensor: "Tensor") -> "Tensor":
781
+ return tensor.reduce(dims, reduction, scatter, mesh, _inplace)
782
+
783
+ return tree_map(_reduce, tensors)
784
+
785
+
786
+ def reduce_(
787
+ tensors: T,
788
+ dims: Dims | str,
789
+ reduction: _valid_reduce = "sum",
790
+ scatter: bool = False,
791
+ mesh: Optional["DeviceMesh"] = None,
792
+ ) -> T:
793
+ return reduce(tensors, dims, reduction, scatter, mesh, _inplace=True)