torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,814 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import itertools
9
+ import traceback
10
+ import typing
11
+ import warnings
12
+ from collections import defaultdict
13
+ from typing import (
14
+ Any,
15
+ Callable,
16
+ cast,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ Literal,
21
+ NamedTuple,
22
+ Optional,
23
+ runtime_checkable,
24
+ Sequence,
25
+ TYPE_CHECKING,
26
+ TypeVar,
27
+ Union,
28
+ )
29
+
30
+ import torch
31
+ import torch._ops
32
+ from monarch.common.function import ResolvableFunctionFromPath
33
+ from torch._subclasses.fake_tensor import FakeTensor
34
+ from torch.utils._pytree import tree_map
35
+
36
+ from . import messages, stream
37
+ from .base_tensor import BaseTensor
38
+ from .borrows import StorageAliases
39
+
40
+ if TYPE_CHECKING:
41
+ from monarch.common.device_mesh import DeviceMesh
42
+
43
+ from .fake import fake_call
44
+ from .function import Propagator, ResolvableFunction
45
+ from .invocation import Invocation
46
+ from .messages import Dims
47
+ from .reference import Referenceable
48
+ from .shape import NDSlice
49
+ from .stream import Stream
50
+ from .tree import flatten
51
+
52
+ _valid_reduce = Literal[
53
+ "stack", "sum", "avg", "product", "min", "max", "band", "bor", "bxor"
54
+ ]
55
+
56
+ T = TypeVar("T")
57
+
58
+
59
+ @runtime_checkable
60
+ class HasDeviceMesh(typing.Protocol):
61
+ @property
62
+ def _device_mesh(self) -> "DeviceMesh": ...
63
+
64
+
65
+ class DropLocation(NamedTuple):
66
+ tensor_id: int
67
+ traceback: List[traceback.FrameSummary]
68
+
69
+ def __repr__(self) -> str:
70
+ return f"tensor {self.tensor_id} is dropped at: \n" + "".join(
71
+ traceback.format_list(self.traceback)
72
+ )
73
+
74
+
75
+ class Tensor(Referenceable, BaseTensor):
76
+ # pyre-fixme[13]: Attribute `stream` is never initialized.
77
+ stream: Stream
78
+ # pyre-fixme[13]: Attribute `mesh` is never initialized.
79
+ mesh: "DeviceMesh"
80
+ ref: Optional[int]
81
+ # pyre-fixme[13]: Attribute `_invocation` is never initialized.
82
+ _invocation: Optional[Invocation]
83
+ # pyre-fixme[13]: Attribute `_fake` is never initialized.
84
+ _fake: torch.Tensor
85
+ # pyre-fixme[13]: Attribute `_aliases` is never initialized.
86
+ _aliases: StorageAliases
87
+ # pyre-fixme[13]: Attribute `_on_first_use` is never initialized.
88
+ _on_first_use: Optional[Callable]
89
+ # pyre-fixme[13]: Attribute `_drop_location` is never initialized.
90
+ _drop_location: Optional[DropLocation]
91
+ # _seq represents the sequence number of the concrete invocation that
92
+ # created this tensor, or the most recent invocation that mutated it.
93
+ # Unlike the _invocation field, this will be set for both the rust and
94
+ # python backends.
95
+ # pyre-fixme[13]: Attribute `_seq` is never initialized.
96
+ _seq: Optional[int]
97
+
98
+ def __new__(cls, fake: torch.Tensor, mesh: "DeviceMesh", stream: "Stream"):
99
+ # pyre-ignore[16]
100
+ r = torch.Tensor._make_wrapper_subclass(
101
+ cls,
102
+ fake.size(),
103
+ strides=fake.stride(),
104
+ storage_offset=fake.storage_offset(),
105
+ device=fake.device, # This is the device of of either input tensor or first tensor of a list
106
+ dtype=fake.dtype,
107
+ layout=fake.layout,
108
+ requires_grad=fake.requires_grad,
109
+ )
110
+ assert isinstance(fake, FakeTensor)
111
+ r._fake = fake
112
+ client = mesh.client
113
+ r.ref = client.new_ref()
114
+ r.mesh = mesh
115
+ r.stream = stream
116
+
117
+ storage = fake.untyped_storage()
118
+ client = mesh.client
119
+ if storage not in client.aliases:
120
+ client.aliases[storage] = StorageAliases()
121
+ r._aliases = client.aliases[storage]
122
+ r._aliases.register(r)
123
+ r._invocation = None
124
+ r._on_first_use = None
125
+ r._drop_location = None
126
+ r._seq = None
127
+ return r
128
+
129
+ @classmethod
130
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
131
+ from monarch.common.remote import remote
132
+
133
+ # device_mesh <-> tensor <-> remote are mututally recursive
134
+ # we break the dependency to allow for separate files by
135
+ # having device_mesh and tensor locally import the `remote`
136
+ # entrypoint
137
+ return remote(func, propagate=func)(*args, **kwargs)
138
+
139
+ def __init__(
140
+ self,
141
+ fake: Optional[torch.Tensor] = None,
142
+ mesh: Optional["DeviceMesh"] = None,
143
+ stream: Optional[Stream] = None,
144
+ ):
145
+ pass
146
+
147
+ def __repr__(self, *, tensor_contents=None):
148
+ return f"monarch.Tensor(mesh={self.mesh}, stream={self.stream}, fake={repr(self._fake)})"
149
+
150
+ def drop(self):
151
+ if self.ref is None:
152
+ return
153
+
154
+ for alias in self._aliases.aliases:
155
+ alias._drop_ref()
156
+
157
+ # we should be in the tensors list as well
158
+ assert self.ref is None
159
+
160
+ @property
161
+ def dropped(self):
162
+ return self.ref is None
163
+
164
+ def _drop_ref(self):
165
+ if self.ref is None:
166
+ return
167
+ self.delete_ref(self.ref)
168
+ self._drop_location = DropLocation(self.ref, traceback.extract_stack())
169
+ self.ref = None
170
+
171
+ @property
172
+ def _access_permissions(self):
173
+ return self._aliases.access
174
+
175
+ def _use(self):
176
+ if self._on_first_use:
177
+ self._on_first_use(self)
178
+ self._on_first_use = None
179
+
180
+ def to_mesh(
181
+ self,
182
+ mesh: Union["DeviceMesh", "HasDeviceMesh"],
183
+ stream: Optional["Stream"] = None,
184
+ ):
185
+ """
186
+ Move data between one device mesh and another. Sizes of named dimensions must match.
187
+ If mesh has dimensions that self.mesh does not, it will broadcast to those dimensions.
188
+
189
+
190
+ broadcast:
191
+ t.slice_mesh(batch=0).to_mesh(t.mesh)
192
+
193
+ """
194
+ if isinstance(mesh, HasDeviceMesh):
195
+ mesh = mesh._device_mesh
196
+ return MeshSliceTensor(self, self.mesh).to_mesh(mesh, stream)
197
+
198
+ def reduce_(
199
+ self,
200
+ dims: Dims | str,
201
+ reduction: _valid_reduce = "sum",
202
+ scatter=False,
203
+ mesh=None,
204
+ ):
205
+ return self.reduce(dims, reduction, scatter, mesh, _inplace=True)
206
+
207
+ def reduce(
208
+ self,
209
+ dims: Dims | str,
210
+ reduction: _valid_reduce = "sum",
211
+ scatter: bool = False,
212
+ mesh: Optional["DeviceMesh"] = None,
213
+ _inplace: bool = False,
214
+ out: Optional["Tensor"] = None,
215
+ ):
216
+ """
217
+ Perform a reduction operation along dim, and move the data to mesh. If mesh=None, then mesh=self.mesh
218
+ 'stack' (gather) will concat the values along dim, and produce a local result tensor with an addition outer dimension of len(dim).
219
+ If scatter=True, the local result tensor will be evenly split across dim.
220
+
221
+ allreduce:
222
+ t.reduce(dims='gpu', reduction='sum')
223
+
224
+ First reduces dim 'gpu' creating a local tensor with the 'gpu' dimension, then because output_mesh=input_mesh, and it still has dim 'gpu',
225
+ we broadcast the result reduced tensor to all members of gpu.
226
+
227
+ reducescatter:
228
+ t.reduce(dims='gpu', reduction='sum', scatter=True)
229
+
230
+ Same as above except that scatter=True introduces a new 'gpu' dimension that is the result of splitting the local tensor across 'gpu'
231
+
232
+ allgather:
233
+ t.reduce(dims='gpu', reduction='stack')
234
+
235
+ First reduces dim 'gpu' creating a bigger local tensor, then because output_mesh=input_mesh, and it still has dim 'gpu',
236
+ broadcasts the result concatenated tensor to all members of gpu.
237
+
238
+ alltoall:
239
+ t.reduce(dims='gpu', reduction='stack', scatter=True)
240
+
241
+
242
+ First reduces dim 'gpu' creating a bigger local tensor, then introduces a new 'gpu' dimension that is the result of splitting this
243
+ (bigger) tensor across 'gpu'. The result is the same dimension as the original tensor, but with each rank sending to all other ranks.
244
+
245
+
246
+ gather (to dim 0):
247
+ t.reduce(dims='gpu', reduction='stack', mesh=device_mesh(gpu=0))
248
+
249
+ First gathers dim 'gpu' and then places it on the first rank. t.mesh.gpu[0] doesn't have a 'gpu' dimension, but this is
250
+ ok because we eliminated the 'gpu' dim via reduction.
251
+
252
+ reduce:
253
+ t.reduce(dims='gpu', reduction='sum', mesh=device_mesh(gpu=0))
254
+
255
+ First reduces dim 'gpu' and then places it on the first rank. t.mesh.gpu[0] doesn't have a 'gpu' dimension, but this is
256
+ ok because we eliminated the 'gpu' dim via reduction.
257
+
258
+
259
+ Args:
260
+ dims (Dims | str): The dimensions along which to perform the reduction.
261
+ reduction (_valid_reduce): The type of reduction to perform. Defaults to "sum".
262
+ scatter (bool): If True, the local result tensor will be evenly split across dimensions.
263
+ Defaults to False.
264
+ mesh (Optional["DeviceMesh"], optional): The target mesh to move the data to.
265
+ If None, uses self.mesh. Defaults to None.
266
+ _inplace (bool): If True, performs the operation in-place. Defaults to False.
267
+ Note that not all the reduction operations support in-place.
268
+ out (Optional["Tensor"]): The output tensor to store the result. If None, a new tensor
269
+ will be created on the stream where the reduce operation executes. Defaults to None.
270
+
271
+ Returns:
272
+ Tensor: The result of the reduction operation.
273
+ """
274
+ if mesh is not None:
275
+ raise NotImplementedError()
276
+ if isinstance(dims, str):
277
+ dims = (dims,)
278
+ for d in dims:
279
+ if d not in self.mesh.names:
280
+ raise KeyError(f"dim {d} not found in {self.mesh}")
281
+ if len(dims) == 0:
282
+ dims = self.mesh.names
283
+ if len(set(dims)) != len(dims):
284
+ raise ValueError(f"reducing the same dimension twice: {dims}")
285
+ if len(dims) > 1:
286
+ if reduction == "stack" or scatter:
287
+ raise ValueError(
288
+ f"reduction {reduction} or scatter = {scatter} is not valid for multiple dimensions"
289
+ )
290
+ if reduction not in _valid_reduce.__args__:
291
+ raise ValueError(
292
+ f"reduction {reduction} not supported, reductions are {_valid_reduce.__args__}"
293
+ )
294
+
295
+ if mesh is None:
296
+ mesh = self.mesh
297
+
298
+ ts: List[torch.Tensor] = [self]
299
+ if out is not None:
300
+ ts.append(out)
301
+ with InputChecker(
302
+ ts,
303
+ lambda ts: (
304
+ f"reduce({next(ts)}, {dims}, reduction={reduction}, out={next(ts, None)})"
305
+ ),
306
+ ) as checker:
307
+ checker.check_no_requires_grad()
308
+ checker.check_cuda()
309
+ checker.check_mesh_stream_local(self.mesh, stream._active)
310
+ checker.check_permission((out,) if out is not None else ())
311
+
312
+ if _inplace:
313
+ if out is not None:
314
+ raise ValueError("`out` cannot be used with inplace reduce.")
315
+ inplace_valid = (reduction == "gather" and scatter) or not scatter
316
+ if not inplace_valid:
317
+ raise ValueError(
318
+ f"reduction {reduction} is not valid for in-place operation because "
319
+ "the output size will not match the input size."
320
+ )
321
+ fake_output = self._fake
322
+ else:
323
+ N = (
324
+ self.mesh.processes.sizes[self.mesh.names.index(dims[0])]
325
+ if reduction == "stack" or scatter
326
+ else -1
327
+ )
328
+
329
+ fake_output = fake_call(
330
+ _fake_reduce, self._fake, self.mesh, N, reduction, scatter
331
+ )
332
+ if out is not None:
333
+ if out.shape != fake_output.shape:
334
+ raise ValueError(
335
+ f"The given output shape, {out.shape}, is incorrect. "
336
+ f"Reduce expects the shape to be {fake_output.shape}."
337
+ )
338
+ fake_output = out._fake
339
+
340
+ r = Tensor(fake_output, self.mesh, self.stream)
341
+ assert r.ref is not None
342
+ self.mesh.define_remotely()
343
+ defines = (r,) if out is None else (r, out)
344
+ self.mesh.client.new_node(defines, (self,))
345
+ self.mesh.client.backend_network_init()
346
+ self.mesh.client.split_comm(dims, self.mesh, self.stream._to_ref(mesh.client))
347
+ self.mesh._send(
348
+ messages.Reduce(
349
+ r,
350
+ self,
351
+ self._factory(),
352
+ self.mesh,
353
+ self.stream._to_ref(mesh.client),
354
+ dims,
355
+ reduction,
356
+ scatter,
357
+ _inplace,
358
+ out,
359
+ )
360
+ )
361
+ return r
362
+
363
+ def slice_mesh(self, **kwargs: Union[int, slice]) -> "MeshSliceTensor":
364
+ # technically a slice of a device mesh and a device mesh are not same thing
365
+ # because a device mesh also has caches for doing collectives.
366
+ # but this is an easy way to create a MeshSliceTensor until we optimize
367
+ # how we represent mesh slices.
368
+ slicing = self.mesh.slice(**kwargs)
369
+ return MeshSliceTensor(self, slicing)
370
+
371
+ def delete_ref(self, ref: int):
372
+ mesh = self.mesh
373
+ if not mesh.client.has_shutdown:
374
+ self._aliases.unregister(self)
375
+ mesh.client.delete_ref(mesh, ref)
376
+
377
+ def _factory(self):
378
+ return messages.TensorFactory.from_tensor(self._fake)
379
+
380
+
381
+ class MeshSliceTensor:
382
+ def __init__(self, tensor: "Tensor", slicing: "DeviceMesh"):
383
+ self.tensor = tensor
384
+ self.slicing = slicing
385
+
386
+ def to_mesh(
387
+ self,
388
+ mesh: Union["DeviceMesh", "HasDeviceMesh"],
389
+ stream: Optional["Stream"] = None,
390
+ ) -> "Tensor":
391
+ if isinstance(mesh, HasDeviceMesh):
392
+ mesh = mesh._device_mesh
393
+
394
+ if stream is None:
395
+ stream = self.tensor.stream
396
+
397
+ with InputChecker(
398
+ [self.tensor], lambda ts: f"{next(ts)}.to_mesh({mesh})"
399
+ ) as checker:
400
+ checker.check_no_requires_grad()
401
+ checker.check_cuda()
402
+ checker.check_permission(mutated_tensors=())
403
+
404
+ sizes = []
405
+ strides = []
406
+ broadcast_dims = []
407
+ for name, size in zip(mesh.names, mesh.processes.sizes):
408
+ if name not in self.slicing.names:
409
+ broadcast_dims.append(name)
410
+ warnings.warn(
411
+ f"to_mesh is broadcasting along {name} dimension."
412
+ "This is implemented inefficiently and should only be used for initialization before it is fixed.",
413
+ stacklevel=2,
414
+ )
415
+ continue
416
+ index = self.slicing.names.index(name)
417
+ if self.slicing.processes.sizes[index] != size:
418
+ raise ValueError(
419
+ f"dimension {name} of destination device_mesh has a different length than the source tensor"
420
+ )
421
+ sizes.append(size)
422
+ strides.append(self.slicing.processes.strides[index])
423
+
424
+ if len(sizes) != len(self.slicing.names):
425
+ missing = set(self.slicing.names) - set(mesh.names)
426
+ raise ValueError(f"destination mesh does not have dimensions {missing}")
427
+
428
+ # Optimized algorithm where:
429
+ # 1. We can represent submeshes as NDSlice(offet, sizes, strides) on rank.
430
+ # 2. A message can be efficiently broadcast to List[NDSlice] ranks by a smart tree based algorithm that can
431
+ # figure out which subtrees need the message.
432
+ # 3. The message itself will uses List[NDSlice] objects to express the send/recv set and so it is very small
433
+
434
+ # so basically both the way the message is broadcast and its size will be compressed but the
435
+ # send pattern and the meaning of the message will be the same as this ineffiecient form
436
+
437
+ from_ranks = NDSlice(
438
+ offset=self.slicing.processes.offset, sizes=sizes, strides=strides
439
+ )
440
+ r = Tensor(fake_call(self.tensor._fake.clone), mesh, stream)
441
+ assert r.ref is not None
442
+ client = self.tensor.mesh.client
443
+ from_stream_ref = self.tensor.stream._to_ref(client)
444
+ to_stream_ref = stream._to_ref(client)
445
+ client.backend_network_init()
446
+ client.backend_network_point_to_point_init(from_stream_ref, to_stream_ref)
447
+ client.new_node((r,), (self.tensor,))
448
+
449
+ if broadcast_dims:
450
+ mesh_sizes = mesh.sizes
451
+ dim_sequences = [
452
+ zip(itertools.repeat(dim), range(mesh_sizes[dim]))
453
+ for dim in broadcast_dims
454
+ ]
455
+ destinations = [
456
+ mesh.slice(**dict(dim_settings)).processes
457
+ for dim_settings in itertools.product(*dim_sequences)
458
+ ]
459
+ else:
460
+ destinations = [mesh.processes]
461
+
462
+ for to_ranks in destinations:
463
+ client.send(
464
+ [from_ranks, to_ranks],
465
+ messages.SendTensor(
466
+ r,
467
+ from_ranks,
468
+ to_ranks,
469
+ self.tensor,
470
+ self.tensor._factory(),
471
+ from_stream_ref,
472
+ to_stream_ref,
473
+ ),
474
+ )
475
+ return r
476
+
477
+
478
+ def _fake_reduce(
479
+ tensor, source_mesh: "DeviceMesh", group_size: int, reduction, scatter: bool
480
+ ):
481
+ if scatter:
482
+ if tensor.ndim == 0 or tensor.size(0) != group_size:
483
+ raise TypeError(
484
+ f"When scattering results the outer most dimension of tensor with sizes ({list(tensor.size())} must match the size ({group_size})"
485
+ )
486
+ if reduction == "stack":
487
+ # scatter removes a dimension of mesh size
488
+ # but gather adds the dimension back
489
+ return tensor
490
+ return tensor.sum(dim=0)
491
+ else:
492
+ if reduction == "stack":
493
+ return torch.empty(
494
+ [group_size, *tensor.shape],
495
+ dtype=tensor.dtype,
496
+ device=tensor.device,
497
+ layout=tensor.layout,
498
+ )
499
+ return tensor.add(tensor)
500
+
501
+
502
+ _explain = """\
503
+ LOCAL_TENSOR
504
+ This tensor is a local (non-distributed) tensor being used while a device_mesh is active.
505
+ If you want to do local tensor compute use `with no_mesh.activate():`
506
+
507
+ WRONG_MESH
508
+ This tensor is on a device mesh that is not the current device_mesh.
509
+ Use `with m.activate():` to switch the active mesh, or move the tensor to the correct device mesh with `to_mesh`/`on_mesh`.
510
+
511
+ WRONG_STREAM
512
+ This tensor is on a stream that is not the current active stream. Use with `stream.activate()` to switch streams, or
513
+ move the tensor to the correct stream with `.borrow`.
514
+
515
+ DROPPED
516
+ This tensor, or a view of it, was explicitly deleted with the t.drop() function and is no longer usable.
517
+
518
+ BORROWED
519
+ This tensor cannot be read because it is being used mutably in another stream.
520
+
521
+ MUTATING_BORROW
522
+ This tensor would be mutated by this operator but it is read only because it is being borrowed.
523
+
524
+ REQUIRES_GRAD
525
+ This tensor requires gradients but this operation does not work with autograd.
526
+
527
+ CROSS_DEVICE_REQUIRES_CUDA
528
+ Operations that send tensors across devices currently require CUDA tensors.
529
+ """
530
+
531
+ explain = {}
532
+ for entry in _explain.split("\n\n"):
533
+ lines = entry.split("\n")
534
+ explain[lines[0]] = "".join(f" {l}\n" for l in lines)
535
+
536
+
537
+ def handle_lift_fresh_dispatch(
538
+ propagate, rfunction, args, kwargs, ambient_mesh, stream
539
+ ):
540
+ assert ambient_mesh is not None
541
+ fake_result = fake_call(
542
+ torch.zeros, args[0].shape, device=args[0].device, dtype=args[0].dtype
543
+ )
544
+ return fake_result, (), (), ambient_mesh
545
+
546
+
547
+ special_ops_handler = {"torch.ops.aten.lift_fresh.default": handle_lift_fresh_dispatch}
548
+
549
+
550
+ class _Symbol(NamedTuple):
551
+ name: str
552
+
553
+ def __repr__(self):
554
+ return self.name
555
+
556
+
557
+ class InputChecker:
558
+ @staticmethod
559
+ def from_flat_args(func: Any, tensors: Sequence[torch.Tensor], unflatten: Callable):
560
+ def format(tensor_values: Iterable[str]):
561
+ args, kwargs = unflatten(tensor_values)
562
+ actuals = ", ".join(
563
+ itertools.chain(
564
+ map(repr, args),
565
+ (f"{key}={repr(value)}" for key, value in kwargs.items()),
566
+ )
567
+ )
568
+ return f"{func}({actuals})"
569
+
570
+ return InputChecker(tensors, format)
571
+
572
+ def __init__(
573
+ self, tensors: Sequence[torch.Tensor], format: Callable[[Iterable[Any]], str]
574
+ ):
575
+ self.tensors = tensors
576
+ self.format = format
577
+ self.errors: Dict[torch.Tensor, List[str]] = defaultdict(list)
578
+ self.overall_errors = []
579
+ # we set this here just so we have stream to report as the current
580
+ # stream in errors where the stream does not matter.
581
+ # If the stream matters for this call, we
582
+ # get the right stream in `check_stream`.
583
+ self.stream = stream._active
584
+ self._mesh = None
585
+
586
+ def check_mesh_stream_local(
587
+ self, ambient_mesh: Optional["DeviceMesh"], stream: "Stream"
588
+ ):
589
+ self.stream = stream
590
+ for t in self.tensors:
591
+ if isinstance(t, Tensor):
592
+ self._mesh = t.mesh
593
+ break
594
+ if self._mesh is None:
595
+ self._mesh = ambient_mesh
596
+ if self._mesh is None:
597
+ self.overall_errors.append(
598
+ "Remote functions require an active device mesh, use `with mesh.activate():`"
599
+ )
600
+
601
+ for t in self.tensors:
602
+ if isinstance(t, Tensor):
603
+ if t.mesh is not self._mesh:
604
+ self.errors[t].append(explain["WRONG_MESH"])
605
+ if t.stream is not self.stream:
606
+ self.errors[t].append(explain["WRONG_STREAM"])
607
+ else:
608
+ self.errors[t].append(explain["LOCAL_TENSOR"])
609
+
610
+ @property
611
+ def mesh(self) -> "DeviceMesh":
612
+ assert self._mesh is not None
613
+ return self._mesh
614
+
615
+ def raise_current_errors(self):
616
+ if not self.errors and not self.overall_errors:
617
+ return
618
+ error_info: List[str] = [
619
+ f"active_mesh = {self._mesh}\n",
620
+ f"active_stream = {self.stream}\n",
621
+ *self.overall_errors,
622
+ ]
623
+ error_names: Dict["Tensor", "str"] = {}
624
+ for i, (t, errors) in enumerate(self.errors.items()):
625
+ name = f"ERROR_{i}"
626
+ error_names[t] = name
627
+ error_info.append(f"{name}:\n")
628
+ error_info.extend(errors)
629
+
630
+ call = self.format(_Symbol(error_names.get(t, ".")) for t in self.tensors)
631
+ msg = f"Incorrect arguments to monarch operation:\n\n {call}\n\n{''.join(error_info)}"
632
+ raise TypeError(msg)
633
+
634
+ def _borrow_tracebacks(self, t: Tensor):
635
+ lines = []
636
+ for b in t._aliases.live_borrows:
637
+ lines.append(" Traceback of borrow (most recent frame last):\n")
638
+ lines.extend(f" {line}\n" for line in b.traceback_string.split("\n"))
639
+ return lines
640
+
641
+ def check_permission(self, mutated_tensors: Sequence["Tensor"]):
642
+ for t in self.tensors:
643
+ if not isinstance(t, Tensor):
644
+ continue
645
+ if "r" not in t._access_permissions:
646
+ errors = self.errors[t]
647
+ errors.append(explain["BORROWED"])
648
+ errors.extend(self._borrow_tracebacks(t))
649
+ if t.dropped:
650
+ self.errors[t].append(explain["DROPPED"])
651
+ if t._drop_location:
652
+ self.errors[t].append(str(t._drop_location))
653
+
654
+ for t in mutated_tensors:
655
+ if "w" not in t._access_permissions:
656
+ errors = self.errors[t]
657
+ errors.append(explain["MUTATING_BORROW"])
658
+ errors.extend(self._borrow_tracebacks(t))
659
+
660
+ def check_no_requires_grad(self):
661
+ for t in self.tensors:
662
+ if torch.is_grad_enabled() and t.requires_grad:
663
+ self.errors[t].append(explain["REQUIRES_GRAD"])
664
+
665
+ def check_cuda(self):
666
+ for t in self.tensors:
667
+ if not t.is_cuda:
668
+ self.errors[t].append(explain["CROSS_DEVICE_REQUIRES_CUDA"])
669
+
670
+ def __enter__(self) -> "InputChecker":
671
+ return self
672
+
673
+ def __exit__(self, exc_type, exc_value, traceback):
674
+ if exc_type is not None:
675
+ return
676
+ self.raise_current_errors()
677
+
678
+
679
+ def dtensor_check(
680
+ propagate: "Propagator",
681
+ rfunc: "ResolvableFunction",
682
+ args,
683
+ kwargs,
684
+ ambient_mesh: Optional["DeviceMesh"],
685
+ stream: Stream,
686
+ ):
687
+ dtensors, unflatten = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
688
+ with InputChecker.from_flat_args(rfunc, dtensors, unflatten) as checker:
689
+ checker.check_mesh_stream_local(ambient_mesh, stream)
690
+
691
+ # ensure tensors are correct enough to do propagation with them.
692
+ checker.raise_current_errors()
693
+
694
+ # the distinction is we only check permissions on the first level mutates
695
+ # but have to record error-tracking dependency edges for all parent borrows.
696
+
697
+ # future diff will change how we track this and then simplify this code.
698
+
699
+ mutates = []
700
+ fake_input_tensors = [d._fake for d in dtensors]
701
+ before_versions = [f._version for f in fake_input_tensors]
702
+ fake_args, fake_kwargs = unflatten(fake_input_tensors)
703
+ result = propagate(args, kwargs, fake_args, fake_kwargs)
704
+ for i in range(len(dtensors)):
705
+ if before_versions[i] < fake_input_tensors[i]._version:
706
+ mutates.extend(dtensors[i]._aliases.aliases)
707
+ checker.check_permission(mutates)
708
+
709
+ return result, dtensors, tuple(mutates), checker.mesh
710
+
711
+
712
+ def dtensor_dispatch(
713
+ rfunction: ResolvableFunction,
714
+ propagate: Propagator,
715
+ args,
716
+ kwargs,
717
+ ambient_mesh: Optional["DeviceMesh"],
718
+ stream: Stream,
719
+ ):
720
+ from .device_mesh import RemoteProcessGroup
721
+
722
+ op_handler = dtensor_check
723
+ if isinstance(rfunction, ResolvableFunctionFromPath):
724
+ op_handler = special_ops_handler.get(rfunction.path, dtensor_check)
725
+
726
+ fake_result, dtensors, mutates, device_mesh = op_handler(
727
+ propagate, rfunction, args, kwargs, ambient_mesh, stream
728
+ )
729
+ assert device_mesh is not None
730
+
731
+ fake_result_dtensors, unflatten_result = flatten(
732
+ fake_result, lambda x: isinstance(x, torch.Tensor)
733
+ )
734
+ result_dtensors = tuple(
735
+ Tensor(fake, device_mesh, stream) for fake in fake_result_dtensors
736
+ )
737
+ seq = device_mesh.client.new_node(result_dtensors + mutates, dtensors)
738
+ assert all(t.ref is not None for t in result_dtensors)
739
+ assert all(t.ref is not None for t in mutates)
740
+ result = result_msg = unflatten_result(result_dtensors)
741
+ if len(result_dtensors) == 0:
742
+ result_msg = None
743
+
744
+ # note the device mesh has to be defined regardles so the remote functions
745
+ # can invoke device_mesh.rank("...")
746
+ device_mesh.define_remotely()
747
+
748
+ # if there's a process group anywhere in the args, kwargs we need to initialize the backend network
749
+ # if it hasn't already been done.
750
+ process_groups, _ = flatten(
751
+ (args, kwargs), lambda x: isinstance(x, RemoteProcessGroup)
752
+ )
753
+ if len(process_groups) > 0:
754
+ device_mesh.client.backend_network_init()
755
+ for pg in process_groups:
756
+ assert not pg.dropped
757
+ pg.ensure_split_comm_remotely(stream._to_ref(device_mesh.client))
758
+
759
+ device_mesh._send(
760
+ messages.CallFunction(
761
+ seq,
762
+ result_msg,
763
+ tuple(mutates),
764
+ rfunction,
765
+ args,
766
+ kwargs,
767
+ stream._to_ref(device_mesh.client),
768
+ device_mesh,
769
+ process_groups,
770
+ )
771
+ )
772
+ # XXX - realistically this would be done on a non-python thread, keeping our messages up to date
773
+ # but we can approximate it by checking for all ready meassages whenever we schedule new work
774
+ while device_mesh.client.handle_next_message(0):
775
+ pass
776
+ return result
777
+
778
+
779
+ def reduce(
780
+ tensors: T,
781
+ dims: Dims | str,
782
+ reduction: _valid_reduce = "sum",
783
+ scatter: bool = False,
784
+ mesh: Optional["DeviceMesh"] = None,
785
+ _inplace: bool = False,
786
+ ) -> T:
787
+ """
788
+ Performs the tensor reduction operation for each tensor in tensors.
789
+ Args:
790
+ tensors (pytree["Tensor"]): The pytree of input tensors to reduce.
791
+ dims (Dims | str): The dimensions along which to perform the reduction.
792
+ reduction (_valid_reduce): The type of reduction to perform. Defaults to "sum".
793
+ scatter (bool): If True, the local result tensor will be evenly split across dimensions.
794
+ Defaults to False.
795
+ mesh (Optional["DeviceMesh"], optional): The target mesh to move the data to.
796
+ If None, uses self.mesh. Defaults to None.
797
+ _inplace (bool): If True, performs the operation in-place. Defaults to False.
798
+ Note that not all the reduction operations support in-place.
799
+ """
800
+
801
+ def _reduce(tensor: "Tensor") -> "Tensor":
802
+ return tensor.reduce(dims, reduction, scatter, mesh, _inplace)
803
+
804
+ return tree_map(_reduce, tensors)
805
+
806
+
807
+ def reduce_(
808
+ tensors: T,
809
+ dims: Dims | str,
810
+ reduction: _valid_reduce = "sum",
811
+ scatter: bool = False,
812
+ mesh: Optional["DeviceMesh"] = None,
813
+ ) -> T:
814
+ return reduce(tensors, dims, reduction, scatter, mesh, _inplace=True)