torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,304 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import functools
10
+ import logging
11
+ import warnings
12
+
13
+ from logging import Logger
14
+ from typing import (
15
+ Any,
16
+ Callable,
17
+ Dict,
18
+ Generic,
19
+ Literal,
20
+ Optional,
21
+ overload,
22
+ Protocol,
23
+ Tuple,
24
+ TypeVar,
25
+ )
26
+
27
+ import monarch.common.messages as messages
28
+
29
+ import torch
30
+
31
+ from monarch.common import _coalescing, device_mesh, messages, stream
32
+
33
+ from monarch.common.device_mesh import RemoteProcessGroup
34
+ from monarch.common.fake import fake_call
35
+
36
+ from monarch.common.function import (
37
+ Propagator,
38
+ resolvable_function,
39
+ ResolvableFunction,
40
+ ResolvableFunctionFromPath,
41
+ )
42
+ from monarch.common.function_caching import (
43
+ hashable_tensor_flatten,
44
+ tensor_placeholder,
45
+ TensorGroup,
46
+ TensorPlaceholder,
47
+ )
48
+ from monarch.common.future import Future
49
+ from monarch.common.messages import Dims
50
+ from monarch.common.tensor import dtensor_check, dtensor_dispatch
51
+ from monarch.common.tree import flatten, tree_map
52
+ from torch import autograd, distributed as dist
53
+ from typing_extensions import ParamSpec
54
+
55
+ logger: Logger = logging.getLogger(__name__)
56
+
57
+ P = ParamSpec("P")
58
+ R = TypeVar("R")
59
+ T = TypeVar("T")
60
+
61
+ Propagator = Callable | Literal["mocked", "cached", "inspect"] | None
62
+
63
+
64
+ class Remote(Generic[P, R]):
65
+ def __init__(self, impl: Any, propagator_arg: Propagator):
66
+ self._remote_impl = impl
67
+ self._propagator_arg = propagator_arg
68
+ self._cache: Optional[dict] = None
69
+
70
+ @property
71
+ def _resolvable(self):
72
+ return resolvable_function(self._remote_impl)
73
+
74
+ def _propagate(self, args, kwargs, fake_args, fake_kwargs):
75
+ if self._propagator_arg is None or self._propagator_arg == "cached":
76
+ if self._cache is None:
77
+ self._cache = {}
78
+ return _cached_propagation(self._cache, self._resolvable, args, kwargs)
79
+ elif self._propagator_arg == "inspect":
80
+ return None
81
+ elif self._propagator_arg == "mocked":
82
+ raise NotImplementedError("mocked propagation")
83
+ else:
84
+ return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
85
+
86
+ def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
87
+ if self._propagator_arg is None:
88
+ return # no propgator provided, so we just assume no mutations
89
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
90
+
91
+ def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
92
+ if not callable(self._propagator_arg):
93
+ raise ValueError("Must specify explicit callable for pipe")
94
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
95
+
96
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
97
+ return dtensor_dispatch(
98
+ self._resolvable,
99
+ self._propagate,
100
+ args,
101
+ kwargs,
102
+ device_mesh._active,
103
+ stream._active,
104
+ )
105
+
106
+ def call_on_shard_and_fetch(
107
+ self, *args, shard: Dict[str, int] | None = None, **kwargs
108
+ ) -> Future[R]:
109
+ return _call_on_shard_and_fetch(
110
+ self._resolvable, self._fetch_propagate, *args, shard=shard, **kwargs
111
+ )
112
+
113
+
114
+ # This can't just be Callable because otherwise we are not
115
+ # allowed to use type arguments in the return value.
116
+ class RemoteIfy(Protocol):
117
+ def __call__(self, function: Callable[P, R]) -> Remote[P, R]: ...
118
+
119
+
120
+ @overload
121
+ def remote(
122
+ function: Callable[P, R], *, propagate: Propagator = None
123
+ ) -> "Remote[P, R]": ...
124
+
125
+
126
+ @overload
127
+ def remote(
128
+ function: str, *, propagate: Literal["mocked", "cached", "inspect"] | None = None
129
+ ) -> "Remote": ...
130
+
131
+
132
+ @overload
133
+ def remote(function: str, *, propagate: Callable[P, R]) -> Remote[P, R]: ...
134
+
135
+
136
+ @overload
137
+ def remote(*, propagate: Propagator = None) -> RemoteIfy: ... # type: ignore
138
+
139
+
140
+ # ignore because otherwise it claims that the actual implementation doesn't
141
+ # accept the above list of arguments
142
+
143
+
144
+ def remote(function: Any = None, *, propagate: Propagator = None) -> Any:
145
+ if function is None:
146
+ return functools.partial(remote, propagate=propagate)
147
+ return Remote(function, propagate)
148
+
149
+
150
+ def _call_on_shard_and_fetch(
151
+ rfunction: ResolvableFunction | None,
152
+ propagator: Any,
153
+ /,
154
+ *args: object,
155
+ shard: dict[str, int] | None = None,
156
+ **kwargs: object,
157
+ ) -> Future:
158
+ """
159
+ Call `function` at the coordinates `shard` of the current device mesh, and retrieve the result as a Future.
160
+ function - the remote function to call
161
+ *args/**kwargs - arguments to the function
162
+ shard - a dictionary from mesh dimension name to coordinate of the shard
163
+ If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
164
+ """
165
+ ambient_mesh = device_mesh._active
166
+
167
+ if rfunction is None:
168
+ preprocess_message = None
169
+ rfunction = ResolvableFunctionFromPath("ident")
170
+ else:
171
+ preprocess_message = rfunction
172
+ _, dtensors, mutates, mesh = dtensor_check(
173
+ propagator, rfunction, args, kwargs, ambient_mesh, stream._active
174
+ )
175
+
176
+ client = mesh.client
177
+ if _coalescing.is_active(client):
178
+ raise NotImplementedError("NYI: fetching results during a coalescing block")
179
+ fut = Future(client)
180
+ ident = client.new_node(mutates, dtensors, fut)
181
+ process = mesh._process(shard)
182
+ client.send(
183
+ process,
184
+ messages.SendValue(
185
+ ident,
186
+ None,
187
+ mutates,
188
+ preprocess_message,
189
+ args,
190
+ kwargs,
191
+ stream._active._to_ref(client),
192
+ ),
193
+ )
194
+ # we have to ask for status updates
195
+ # from workers to be sure they have finished
196
+ # enough work to count this future as finished,
197
+ # and all potential errors have been reported
198
+ client._request_status()
199
+ return fut
200
+
201
+
202
+ @remote
203
+ def _propagate(
204
+ function: ResolvableFunction, args: Tuple[Any, ...], kwargs: Dict[str, Any]
205
+ ):
206
+ """
207
+ RF preprocess function
208
+ """
209
+ fn = function.resolve()
210
+
211
+ # XXX - in addition to the functional properties,
212
+ # and info about if any of the input tensors got mutated.
213
+ arg_tensors, _ = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
214
+ input_group = TensorGroup(arg_tensors)
215
+ result = fn(*args, **kwargs)
216
+ result_tensors, unflatten_result = flatten(
217
+ result, lambda x: isinstance(x, torch.Tensor)
218
+ )
219
+
220
+ output_group = TensorGroup(result_tensors, parent=input_group)
221
+
222
+ the_result = unflatten_result([tensor_placeholder for _ in result_tensors])
223
+ return (
224
+ the_result,
225
+ output_group.pattern,
226
+ )
227
+
228
+
229
+ class DummyProcessGroup(dist.ProcessGroup):
230
+ def __init__(self, dims: Dims, world_size: int):
231
+ # pyre-ignore
232
+ super().__init__(0, world_size)
233
+ self.dims = dims
234
+ self.world_size = world_size
235
+
236
+ def allreduce(self, tensor, op=dist.ReduceOp.SUM, async_op=False):
237
+ class DummyWork:
238
+ def wait(self):
239
+ return tensor
240
+
241
+ return DummyWork()
242
+
243
+ def _allgather_base(self, output_tensor, input_tensor, opts):
244
+ class DummyWork:
245
+ def wait(self):
246
+ return output_tensor
247
+
248
+ return DummyWork()
249
+
250
+ def _reduce_scatter_base(self, output_tensor, input_tensor, opts):
251
+ class DummyWork:
252
+ def wait(self):
253
+ return output_tensor
254
+
255
+ return DummyWork()
256
+
257
+ def __getstate__(self):
258
+ return {"dims": self.dims, "world_size": self.world_size}
259
+
260
+ def __setstate__(self, state):
261
+ self.__init__(state["dims"], state["world_size"])
262
+
263
+
264
+ def _mock_pgs(x):
265
+ if isinstance(x, autograd.function.FunctionCtx):
266
+ for attr in dir(x):
267
+ if not attr.startswith("__") and isinstance(attr, RemoteProcessGroup):
268
+ setattr(x, attr, DummyProcessGroup(attr.dims, attr.size()))
269
+ return x
270
+ if isinstance(x, RemoteProcessGroup):
271
+ return DummyProcessGroup(x.dims, x.size())
272
+ return x
273
+
274
+
275
+ # for testing
276
+ _miss = 0
277
+ _hit = 0
278
+
279
+
280
+ def _cached_propagation(_cache, rfunction, args, kwargs):
281
+ tensors, shape_key = hashable_tensor_flatten(args, kwargs)
282
+ inputs_group = TensorGroup([t._fake for t in tensors])
283
+ requires_grads = tuple(t.requires_grad for t in tensors)
284
+ key = (shape_key, inputs_group.pattern, requires_grads)
285
+
286
+ global _miss, _hit
287
+ if key not in _cache:
288
+ _miss += 1
289
+ args_no_pg, kwargs_no_pg = tree_map(_mock_pgs, (args, kwargs))
290
+ result_with_placeholders, output_pattern = _propagate.call_on_shard_and_fetch(
291
+ function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
292
+ ).result()
293
+
294
+ _, unflatten_result = flatten(
295
+ result_with_placeholders, lambda x: isinstance(x, TensorPlaceholder)
296
+ )
297
+ _cache[key] = (unflatten_result, output_pattern)
298
+ else:
299
+ _hit += 1
300
+ # return fresh fake result every time to avoid spurious aliasing
301
+ unflatten_result, output_pattern = _cache[key]
302
+
303
+ output_tensors = fake_call(output_pattern.empty, [inputs_group.tensors])
304
+ return unflatten_result(output_tensors)
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from monarch._rust_bindings.monarch_hyperactor.selection import Selection
8
+
9
+ __all__ = ["Selection"]
@@ -0,0 +1,204 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import itertools
8
+ import operator
9
+ from abc import ABC, abstractmethod
10
+
11
+ from typing import Dict, Generator, Sequence, Tuple
12
+
13
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
14
+
15
+ from typing_extensions import Self
16
+
17
+ NDSlice = Slice
18
+
19
+ Slices = Slice | list[Slice]
20
+
21
+
22
+ def iter_ranks(ranks: Slices) -> Generator[int, None, None]:
23
+ if isinstance(ranks, list):
24
+ seen = set()
25
+ for slice_ in ranks:
26
+ for rank in slice_:
27
+ if rank not in seen:
28
+ seen.add(rank)
29
+ yield rank
30
+ else:
31
+ yield from ranks
32
+
33
+
34
+ class MeshTrait(ABC):
35
+ """
36
+ Mesh interface. Implemented via Shape.
37
+ """
38
+
39
+ @property
40
+ @abstractmethod
41
+ def _ndslice(self) -> NDSlice: ...
42
+
43
+ @property
44
+ @abstractmethod
45
+ def _labels(self) -> Tuple[str, ...]: ...
46
+
47
+ @abstractmethod
48
+ def _new_with_shape(self, shape: Shape) -> Self: ...
49
+
50
+ def slice(self, **kwargs) -> Self:
51
+ """
52
+ mesh.slice(batch=3) or mesh.slice(batch=slice(3, None))
53
+ """
54
+ ndslice = self._ndslice
55
+ labels = self._labels
56
+ offset = ndslice.offset
57
+ names = []
58
+ sizes = []
59
+ strides = []
60
+ for name, size, stride in zip(labels, ndslice.sizes, ndslice.strides):
61
+ if name in kwargs:
62
+ e = kwargs.pop(name)
63
+ if isinstance(e, slice):
64
+ start, stop, slice_stride = e.indices(size)
65
+ offset += start * stride
66
+ names.append(name)
67
+ sizes.append((stop - start) // slice_stride)
68
+ strides.append(slice_stride * stride)
69
+ else:
70
+ if e >= size or e < 0:
71
+ raise IndexError("index out of range")
72
+ offset += e * stride
73
+ else:
74
+ names.append(name)
75
+ sizes.append(size)
76
+ strides.append(stride)
77
+
78
+ if kwargs:
79
+ raise TypeError(
80
+ f"{self} does not have dimension(s) named {tuple(kwargs.keys())}"
81
+ )
82
+
83
+ new_ndslice = NDSlice(offset=offset, sizes=sizes, strides=strides)
84
+ return self._new_with_shape(Shape(names, new_ndslice))
85
+
86
+ def split(self, **kwargs) -> Self:
87
+ """
88
+ Returns a new device mesh with some dimensions of this mesh split.
89
+ For instance, this call splits the host dimension into dp and pp dimensions,
90
+ The size of 'pp' is specified and the dimension size is derived from it:
91
+
92
+ new_mesh = mesh.split(host=('dp', 'pp'), gpu=('tp','cp'), pp=16, cp=2)
93
+
94
+ Dimensions not specified will remain unchanged.
95
+ """
96
+ splits: Dict[str, Sequence[str]] = {}
97
+ size_constraints: Dict[str, int] = {}
98
+ for key, value in kwargs.items():
99
+ if key in self._labels:
100
+ if isinstance(value, str):
101
+ raise ValueError(
102
+ f"expected a sequence of dimensions, but got '{value}'"
103
+ )
104
+ splits[key] = value
105
+ else:
106
+ if not isinstance(value, int):
107
+ raise ValueError(
108
+ f"'{key}' is not an existing dim. Expected an integer size constraint on a new dim."
109
+ )
110
+ size_constraints[key] = value
111
+
112
+ names = []
113
+ sizes = []
114
+ strides = []
115
+ ndslice = self._ndslice
116
+ for name, size, stride in zip(self._labels, ndslice.sizes, ndslice.strides):
117
+ to_names = splits.get(name, (name,))
118
+ total_size = 1
119
+ unknown_size_name = None
120
+ for to_name in to_names:
121
+ if to_name in size_constraints:
122
+ total_size *= size_constraints[to_name]
123
+ elif unknown_size_name is None:
124
+ unknown_size_name = to_name
125
+ else:
126
+ raise ValueError(
127
+ f"Cannot infer size of {to_names} because both {to_name} and {unknown_size_name} have unknown size. Specify at least one as argument, e.g. {to_name}=4"
128
+ )
129
+ if unknown_size_name is not None:
130
+ inferred_size, m = divmod(size, total_size)
131
+ if m != 0:
132
+ to_sizes = tuple(
133
+ (
134
+ size_constraints[to_name]
135
+ if to_name in size_constraints
136
+ else "?"
137
+ )
138
+ for to_name in to_names
139
+ )
140
+ raise ValueError(
141
+ f"Dimension '{name}' of size {size} is not evenly divided by {to_names!r} with sizes {to_sizes!r}"
142
+ )
143
+ size_constraints[unknown_size_name] = inferred_size
144
+ elif total_size != size:
145
+ to_sizes = tuple(size_constraints[to_name] for to_name in to_names)
146
+ raise ValueError(
147
+ f"Dimension '{name}' of size {size} is not evenly divided by {to_names!r} with sizes {to_sizes!r}"
148
+ )
149
+ new_sizes = [size_constraints.pop(to_name) for to_name in to_names]
150
+ new_strides_reversed = tuple(
151
+ itertools.accumulate(reversed(new_sizes), operator.mul, initial=stride)
152
+ )
153
+ sizes.extend(new_sizes)
154
+ strides.extend(reversed(new_strides_reversed[:-1]))
155
+ for name in to_names:
156
+ if name in names:
157
+ raise ValueError(f"Duplicate dimension name '{name}'")
158
+ names.extend(to_names)
159
+ if size_constraints:
160
+ raise ValueError(
161
+ f"unused size constraints: {tuple(size_constraints.keys())}"
162
+ )
163
+ return self._new_with_shape(
164
+ Shape(names, NDSlice(offset=ndslice.offset, sizes=sizes, strides=strides))
165
+ )
166
+
167
+ def flatten(self, name: str) -> Self:
168
+ """
169
+ Returns a new device mesh with all dimensions flattened into a single dimension
170
+ with the given name.
171
+
172
+ Currently this supports only dense meshes: that is, all ranks must be contiguous
173
+ in the mesh.
174
+ """
175
+ ndslice = self._ndslice
176
+ dense_strides = tuple(
177
+ itertools.accumulate(reversed(ndslice.sizes), operator.mul, initial=1)
178
+ )
179
+ dense_strides, total_size = (
180
+ list(reversed(dense_strides[:-1])),
181
+ dense_strides[-1],
182
+ )
183
+ if dense_strides != ndslice.strides:
184
+ raise ValueError(
185
+ "cannot flatten sparse mesh: " f"{ndslice.strides=} != {dense_strides=}"
186
+ )
187
+
188
+ return self._new_with_shape(
189
+ Shape(
190
+ [name], NDSlice(offset=ndslice.offset, sizes=[total_size], strides=[1])
191
+ )
192
+ )
193
+
194
+ def rename(self, **kwargs) -> Self:
195
+ """
196
+ Returns a new device mesh with some of dimensions renamed.
197
+ Dimensions not mentioned are retained:
198
+
199
+ new_mesh = mesh.rename(host='dp', gpu='tp')
200
+ """
201
+ return self.split(**{k: (v,) for k, v in kwargs.items()})
202
+
203
+
204
+ __all__ = ["NDSlice", "Shape", "MeshTrait"]
@@ -0,0 +1,111 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import Callable, List, Tuple, TYPE_CHECKING
9
+ from weakref import ref, WeakKeyDictionary
10
+
11
+ from . import messages
12
+ from .borrows import Borrow
13
+ from .context_manager import activate_first_context_manager
14
+ from .fake import fake_call
15
+ from .reference import Referenceable
16
+
17
+ if TYPE_CHECKING:
18
+ from monarch.common.client import Client # @manual
19
+
20
+ from .tensor import Tensor
21
+
22
+
23
+ class Stream:
24
+ def __init__(self, name: str, _default=False):
25
+ self.name = name
26
+ self.default: bool = _default
27
+ self.clients: WeakKeyDictionary["Client", "StreamRef"] = WeakKeyDictionary()
28
+
29
+ def __repr__(self):
30
+ return f"<Stream({repr(self.name)}) at {hex(id(self))}>"
31
+
32
+ def __str__(self):
33
+ return f"stream {repr(self.name)}"
34
+
35
+ def activate(self):
36
+ return _active_stream(self)
37
+
38
+ def _to_ref(self, client: "Client"):
39
+ if client not in self.clients:
40
+ self.clients[client] = StreamRef(client, self.name, self.default)
41
+ return self.clients[client]
42
+
43
+ def borrow(self, t: "Tensor", mutable: bool = False) -> Tuple["Tensor", "Borrow"]:
44
+ """
45
+ borrowed_tensor, borrow = self.borrow(t)
46
+
47
+ Borrows tensor 't' for use on this stream.
48
+ The memory of t will stay alive until borrow.drop() is called, which will free t and
49
+ and any of its alises on stream `self` and will cause t.stream to wait on self at that point so
50
+ that the memory of t can be reused.
51
+
52
+ If `mutable` then self can write to the storage of `t`, but t.stream cannot read or write `t` until,
53
+ the borrow is returned (becomes free and a wait_for has been issued).
54
+
55
+ If not `mutable` both `self` and `t.stream` can read from t's storage but neither can write to it.
56
+ """
57
+ client = t.mesh.client
58
+ aliases = t._aliases
59
+ r = type(t)(fake_call(t._fake.clone), t.mesh, self)
60
+ client.new_node((r,), (t,))
61
+ borrow = r._aliases.borrow_from(client.new_ref(), t.mesh, aliases, mutable)
62
+ client.new_borrow(borrow)
63
+ assert r.ref is not None
64
+ t.mesh._send(
65
+ messages.BorrowCreate(
66
+ r, borrow._id, t, t.stream._to_ref(client), self._to_ref(client)
67
+ )
68
+ )
69
+ r._on_first_use = lambda t: borrow._use()
70
+
71
+ return r, borrow
72
+
73
+
74
+ class StreamRef(Referenceable):
75
+ def __init__(self, client: "Client", name: str, default: bool):
76
+ self.ref = client.new_ref()
77
+ self.client = ref(client)
78
+ self.name = name
79
+ self.default = default
80
+ client.send(
81
+ client.all_ranks,
82
+ messages.CreateStream(self, self.default),
83
+ )
84
+
85
+ def delete_ref(self, ref):
86
+ client = self.client()
87
+ if client is not None and not client._shutdown:
88
+ client.handle_deletes(client.all_ranks, [ref])
89
+
90
+
91
+ _active = Stream("main", _default=True)
92
+ _on_change: List[Callable] = []
93
+
94
+
95
+ def get_active_stream():
96
+ return _active
97
+
98
+
99
+ @activate_first_context_manager
100
+ def _active_stream(stream: Stream):
101
+ global _active
102
+ for on_change in _on_change:
103
+ on_change(_active, stream)
104
+
105
+ _active, old = stream, _active
106
+ try:
107
+ yield
108
+ finally:
109
+ for on_change in _on_change:
110
+ on_change(_active, old)
111
+ _active = old