torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,297 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import functools
10
+ import logging
11
+ import warnings
12
+
13
+ from logging import Logger
14
+ from typing import (
15
+ Any,
16
+ Callable,
17
+ Dict,
18
+ Generic,
19
+ Literal,
20
+ Optional,
21
+ overload,
22
+ Protocol,
23
+ Tuple,
24
+ TYPE_CHECKING,
25
+ TypeVar,
26
+ )
27
+
28
+ import monarch.common.messages as messages
29
+
30
+ import torch
31
+
32
+ from monarch.common import _coalescing, device_mesh, messages, stream
33
+
34
+ if TYPE_CHECKING:
35
+ from monarch.common.client import Client
36
+
37
+ from monarch.common.device_mesh import RemoteProcessGroup
38
+ from monarch.common.fake import fake_call
39
+
40
+ from monarch.common.function import (
41
+ Propagator,
42
+ resolvable_function,
43
+ ResolvableFunction,
44
+ ResolvableFunctionFromPath,
45
+ )
46
+ from monarch.common.function_caching import (
47
+ hashable_tensor_flatten,
48
+ tensor_placeholder,
49
+ TensorGroup,
50
+ TensorPlaceholder,
51
+ )
52
+ from monarch.common.future import Future
53
+ from monarch.common.messages import Dims
54
+ from monarch.common.tensor import dtensor_check, dtensor_dispatch
55
+ from monarch.common.tree import flatten, tree_map
56
+ from torch import autograd, distributed as dist
57
+ from typing_extensions import ParamSpec
58
+
59
+ logger: Logger = logging.getLogger(__name__)
60
+
61
+ P = ParamSpec("P")
62
+ R = TypeVar("R")
63
+ T = TypeVar("T")
64
+
65
+ Propagator = Callable | Literal["mocked", "cached", "inspect"] | None
66
+
67
+
68
+ class Remote(Generic[P, R]):
69
+ def __init__(self, impl: Any, propagator_arg: Propagator):
70
+ self._remote_impl = impl
71
+ self._propagator_arg = propagator_arg
72
+ self._cache: Optional[dict] = None
73
+
74
+ @property
75
+ def _resolvable(self):
76
+ return resolvable_function(self._remote_impl)
77
+
78
+ def _propagate(self, args, kwargs, fake_args, fake_kwargs):
79
+ if self._propagator_arg is None or self._propagator_arg == "cached":
80
+ if self._cache is None:
81
+ self._cache = {}
82
+ return _cached_propagation(self._cache, self._resolvable, args, kwargs)
83
+ elif self._propagator_arg == "inspect":
84
+ return None
85
+ elif self._propagator_arg == "mocked":
86
+ raise NotImplementedError("mocked propagation")
87
+ else:
88
+ return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
89
+
90
+ def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
91
+ if self._propagator_arg is None:
92
+ return # no propgator provided, so we just assume no mutations
93
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
94
+
95
+ def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
96
+ if not callable(self._propagator_arg):
97
+ raise ValueError("Must specify explicit callable for pipe")
98
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
99
+
100
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
101
+ return dtensor_dispatch(
102
+ self._resolvable,
103
+ self._propagate,
104
+ args,
105
+ kwargs,
106
+ device_mesh._active,
107
+ stream._active,
108
+ )
109
+
110
+ def call_on_shard_and_fetch(
111
+ self, *args, shard: Dict[str, int] | None = None, **kwargs
112
+ ) -> Future[R]:
113
+ return _call_on_shard_and_fetch(
114
+ self._resolvable, self._fetch_propagate, *args, shard=shard, **kwargs
115
+ )
116
+
117
+
118
+ # This can't just be Callable because otherwise we are not
119
+ # allowed to use type arguments in the return value.
120
+ class RemoteIfy(Protocol):
121
+ def __call__(self, function: Callable[P, R]) -> Remote[P, R]: ...
122
+
123
+
124
+ @overload
125
+ def remote(
126
+ function: Callable[P, R], *, propagate: Propagator = None
127
+ ) -> "Remote[P, R]": ...
128
+
129
+
130
+ @overload
131
+ def remote(
132
+ function: str, *, propagate: Literal["mocked", "cached", "inspect"] | None = None
133
+ ) -> "Remote": ...
134
+
135
+
136
+ @overload
137
+ def remote(function: str, *, propagate: Callable[P, R]) -> Remote[P, R]: ...
138
+
139
+
140
+ @overload
141
+ def remote(*, propagate: Propagator = None) -> RemoteIfy: ... # type: ignore
142
+
143
+
144
+ # ignore because otherwise it claims that the actual implementation doesn't
145
+ # accept the above list of arguments
146
+
147
+
148
+ def remote(function: Any = None, *, propagate: Propagator = None) -> Any:
149
+ if function is None:
150
+ return functools.partial(remote, propagate=propagate)
151
+ return Remote(function, propagate)
152
+
153
+
154
+ def _call_on_shard_and_fetch(
155
+ rfunction: ResolvableFunction | None,
156
+ propagator: Any,
157
+ /,
158
+ *args: object,
159
+ shard: dict[str, int] | None = None,
160
+ **kwargs: object,
161
+ ) -> Future:
162
+ """
163
+ Call `function` at the coordinates `shard` of the current device mesh, and retrieve the result as a Future.
164
+ function - the remote function to call
165
+ *args/**kwargs - arguments to the function
166
+ shard - a dictionary from mesh dimension name to coordinate of the shard
167
+ If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
168
+ """
169
+ ambient_mesh = device_mesh._active
170
+
171
+ if rfunction is None:
172
+ preprocess_message = None
173
+ rfunction = ResolvableFunctionFromPath("ident")
174
+ else:
175
+ preprocess_message = rfunction
176
+ _, dtensors, mutates, mesh = dtensor_check(
177
+ propagator, rfunction, args, kwargs, ambient_mesh, stream._active
178
+ )
179
+
180
+ client: "Client" = mesh.client
181
+ if _coalescing.is_active(client):
182
+ raise NotImplementedError("NYI: fetching results during a coalescing block")
183
+ return client.fetch(
184
+ mesh,
185
+ stream._active._to_ref(client),
186
+ shard,
187
+ preprocess_message,
188
+ args,
189
+ kwargs,
190
+ mutates,
191
+ dtensors,
192
+ )
193
+
194
+
195
+ @remote
196
+ def _propagate(
197
+ function: ResolvableFunction, args: Tuple[Any, ...], kwargs: Dict[str, Any]
198
+ ):
199
+ """
200
+ RF preprocess function
201
+ """
202
+ fn = function.resolve()
203
+
204
+ # XXX - in addition to the functional properties,
205
+ # and info about if any of the input tensors got mutated.
206
+ arg_tensors, _ = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
207
+ input_group = TensorGroup(arg_tensors)
208
+ result = fn(*args, **kwargs)
209
+ result_tensors, unflatten_result = flatten(
210
+ result, lambda x: isinstance(x, torch.Tensor)
211
+ )
212
+
213
+ output_group = TensorGroup(result_tensors, parent=input_group)
214
+
215
+ the_result = unflatten_result([tensor_placeholder for _ in result_tensors])
216
+ return (
217
+ the_result,
218
+ output_group.pattern,
219
+ )
220
+
221
+
222
+ class DummyProcessGroup(dist.ProcessGroup):
223
+ def __init__(self, dims: Dims, world_size: int):
224
+ # pyre-ignore
225
+ super().__init__(0, world_size)
226
+ self.dims = dims
227
+ self.world_size = world_size
228
+
229
+ def allreduce(self, tensor, op=dist.ReduceOp.SUM, async_op=False):
230
+ class DummyWork:
231
+ def wait(self):
232
+ return tensor
233
+
234
+ return DummyWork()
235
+
236
+ def _allgather_base(self, output_tensor, input_tensor, opts):
237
+ class DummyWork:
238
+ def wait(self):
239
+ return output_tensor
240
+
241
+ return DummyWork()
242
+
243
+ def _reduce_scatter_base(self, output_tensor, input_tensor, opts):
244
+ class DummyWork:
245
+ def wait(self):
246
+ return output_tensor
247
+
248
+ return DummyWork()
249
+
250
+ def __getstate__(self):
251
+ return {"dims": self.dims, "world_size": self.world_size}
252
+
253
+ def __setstate__(self, state):
254
+ self.__init__(state["dims"], state["world_size"])
255
+
256
+
257
+ def _mock_pgs(x):
258
+ if isinstance(x, autograd.function.FunctionCtx):
259
+ for attr in dir(x):
260
+ if not attr.startswith("__") and isinstance(attr, RemoteProcessGroup):
261
+ setattr(x, attr, DummyProcessGroup(attr.dims, attr.size()))
262
+ return x
263
+ if isinstance(x, RemoteProcessGroup):
264
+ return DummyProcessGroup(x.dims, x.size())
265
+ return x
266
+
267
+
268
+ # for testing
269
+ _miss = 0
270
+ _hit = 0
271
+
272
+
273
+ def _cached_propagation(_cache, rfunction, args, kwargs):
274
+ tensors, shape_key = hashable_tensor_flatten(args, kwargs)
275
+ inputs_group = TensorGroup([t._fake for t in tensors])
276
+ requires_grads = tuple(t.requires_grad for t in tensors)
277
+ key = (shape_key, inputs_group.pattern, requires_grads)
278
+
279
+ global _miss, _hit
280
+ if key not in _cache:
281
+ _miss += 1
282
+ args_no_pg, kwargs_no_pg = tree_map(_mock_pgs, (args, kwargs))
283
+ result_with_placeholders, output_pattern = _propagate.call_on_shard_and_fetch(
284
+ function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
285
+ ).result()
286
+
287
+ _, unflatten_result = flatten(
288
+ result_with_placeholders, lambda x: isinstance(x, TensorPlaceholder)
289
+ )
290
+ _cache[key] = (unflatten_result, output_pattern)
291
+ else:
292
+ _hit += 1
293
+ # return fresh fake result every time to avoid spurious aliasing
294
+ unflatten_result, output_pattern = _cache[key]
295
+
296
+ output_tensors = fake_call(output_pattern.empty, [inputs_group.tensors])
297
+ return unflatten_result(output_tensors)
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from monarch._rust_bindings.monarch_hyperactor.selection import Selection
8
+
9
+ __all__ = ["Selection"]
@@ -0,0 +1,229 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import itertools
8
+ import operator
9
+ from abc import ABC, abstractmethod
10
+
11
+ from typing import Dict, Generator, Sequence, Tuple, Union
12
+
13
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
14
+
15
+ from typing_extensions import Self
16
+
17
+ NDSlice = Slice
18
+
19
+ Slices = Slice | list[Slice]
20
+
21
+
22
+ def iter_ranks(ranks: Slices) -> Generator[int, None, None]:
23
+ if isinstance(ranks, list):
24
+ seen = set()
25
+ for slice_ in ranks:
26
+ for rank in slice_:
27
+ if rank not in seen:
28
+ seen.add(rank)
29
+ yield rank
30
+ else:
31
+ yield from ranks
32
+
33
+
34
+ class MeshTrait(ABC):
35
+ """
36
+ Mesh interface. Implemented via Shape.
37
+ """
38
+
39
+ @property
40
+ @abstractmethod
41
+ def _ndslice(self) -> NDSlice: ...
42
+
43
+ @property
44
+ @abstractmethod
45
+ def _labels(self) -> Tuple[str, ...]: ...
46
+
47
+ # mesh trait guarentees that its own calls to _new_with_shape
48
+ # will only ever select a shape that is a subspace of the
49
+ # current _ndslice.
50
+ @abstractmethod
51
+ def _new_with_shape(self, shape: Shape) -> Self: ...
52
+
53
+ def slice(self, **kwargs) -> Self:
54
+ """
55
+ mesh.slice(batch=3) or mesh.slice(batch=slice(3, None))
56
+ """
57
+ ndslice = self._ndslice
58
+ labels = self._labels
59
+ offset = ndslice.offset
60
+ names = []
61
+ sizes = []
62
+ strides = []
63
+ for name, size, stride in zip(labels, ndslice.sizes, ndslice.strides):
64
+ if name in kwargs:
65
+ e = kwargs.pop(name)
66
+ if isinstance(e, slice):
67
+ start, stop, slice_stride = e.indices(size)
68
+ offset += start * stride
69
+ names.append(name)
70
+ sizes.append((stop - start) // slice_stride)
71
+ strides.append(slice_stride * stride)
72
+ else:
73
+ if e >= size or e < 0:
74
+ raise IndexError("index out of range")
75
+ offset += e * stride
76
+ else:
77
+ names.append(name)
78
+ sizes.append(size)
79
+ strides.append(stride)
80
+
81
+ if kwargs:
82
+ raise TypeError(
83
+ f"{self} does not have dimension(s) named {tuple(kwargs.keys())}"
84
+ )
85
+
86
+ new_ndslice = NDSlice(offset=offset, sizes=sizes, strides=strides)
87
+ return self._new_with_shape(Shape(names, new_ndslice))
88
+
89
+ def split(self, **kwargs) -> Self:
90
+ """
91
+ Returns a new device mesh with some dimensions of this mesh split.
92
+ For instance, this call splits the host dimension into dp and pp dimensions,
93
+ The size of 'pp' is specified and the dimension size is derived from it:
94
+
95
+ new_mesh = mesh.split(host=('dp', 'pp'), gpu=('tp','cp'), pp=16, cp=2)
96
+
97
+ Dimensions not specified will remain unchanged.
98
+ """
99
+ splits: Dict[str, Sequence[str]] = {}
100
+ size_constraints: Dict[str, int] = {}
101
+ for key, value in kwargs.items():
102
+ if key in self._labels:
103
+ if isinstance(value, str):
104
+ raise ValueError(
105
+ f"expected a sequence of dimensions, but got '{value}'"
106
+ )
107
+ splits[key] = value
108
+ else:
109
+ if not isinstance(value, int):
110
+ raise ValueError(
111
+ f"'{key}' is not an existing dim. Expected an integer size constraint on a new dim."
112
+ )
113
+ size_constraints[key] = value
114
+
115
+ names = []
116
+ sizes = []
117
+ strides = []
118
+ ndslice = self._ndslice
119
+ for name, size, stride in zip(self._labels, ndslice.sizes, ndslice.strides):
120
+ to_names = splits.get(name, (name,))
121
+ total_size = 1
122
+ unknown_size_name = None
123
+ for to_name in to_names:
124
+ if to_name in size_constraints:
125
+ total_size *= size_constraints[to_name]
126
+ elif unknown_size_name is None:
127
+ unknown_size_name = to_name
128
+ else:
129
+ raise ValueError(
130
+ f"Cannot infer size of {to_names} because both {to_name} and {unknown_size_name} have unknown size. Specify at least one as argument, e.g. {to_name}=4"
131
+ )
132
+ if unknown_size_name is not None:
133
+ inferred_size, m = divmod(size, total_size)
134
+ if m != 0:
135
+ to_sizes = tuple(
136
+ (
137
+ size_constraints[to_name]
138
+ if to_name in size_constraints
139
+ else "?"
140
+ )
141
+ for to_name in to_names
142
+ )
143
+ raise ValueError(
144
+ f"Dimension '{name}' of size {size} is not evenly divided by {to_names!r} with sizes {to_sizes!r}"
145
+ )
146
+ size_constraints[unknown_size_name] = inferred_size
147
+ elif total_size != size:
148
+ to_sizes = tuple(size_constraints[to_name] for to_name in to_names)
149
+ raise ValueError(
150
+ f"Dimension '{name}' of size {size} is not evenly divided by {to_names!r} with sizes {to_sizes!r}"
151
+ )
152
+ new_sizes = [size_constraints.pop(to_name) for to_name in to_names]
153
+ new_strides_reversed = tuple(
154
+ itertools.accumulate(reversed(new_sizes), operator.mul, initial=stride)
155
+ )
156
+ sizes.extend(new_sizes)
157
+ strides.extend(reversed(new_strides_reversed[:-1]))
158
+ for name in to_names:
159
+ if name in names:
160
+ raise ValueError(f"Duplicate dimension name '{name}'")
161
+ names.extend(to_names)
162
+ if size_constraints:
163
+ raise ValueError(
164
+ f"unused size constraints: {tuple(size_constraints.keys())}"
165
+ )
166
+ return self._new_with_shape(
167
+ Shape(names, NDSlice(offset=ndslice.offset, sizes=sizes, strides=strides))
168
+ )
169
+
170
+ def flatten(self, name: str) -> Self:
171
+ """
172
+ Returns a new device mesh with all dimensions flattened into a single dimension
173
+ with the given name.
174
+
175
+ Currently this supports only dense meshes: that is, all ranks must be contiguous
176
+ in the mesh.
177
+ """
178
+ ndslice = self._ndslice
179
+ dense_strides = tuple(
180
+ itertools.accumulate(reversed(ndslice.sizes), operator.mul, initial=1)
181
+ )
182
+ dense_strides, total_size = (
183
+ list(reversed(dense_strides[:-1])),
184
+ dense_strides[-1],
185
+ )
186
+ if dense_strides != ndslice.strides:
187
+ raise ValueError(
188
+ "cannot flatten sparse mesh: " f"{ndslice.strides=} != {dense_strides=}"
189
+ )
190
+
191
+ return self._new_with_shape(
192
+ Shape(
193
+ [name], NDSlice(offset=ndslice.offset, sizes=[total_size], strides=[1])
194
+ )
195
+ )
196
+
197
+ def rename(self, **kwargs) -> Self:
198
+ """
199
+ Returns a new device mesh with some of dimensions renamed.
200
+ Dimensions not mentioned are retained:
201
+
202
+ new_mesh = mesh.rename(host='dp', gpu='tp')
203
+ """
204
+ return self.split(**{k: (v,) for k, v in kwargs.items()})
205
+
206
+ def size(self, dim: Union[None, str, Sequence[str]] = None) -> int:
207
+ """
208
+ Returns the number of elements (total) of the subset of mesh asked for.
209
+ If dims is None, returns the total number of devices in the mesh.
210
+ """
211
+
212
+ if dim is None:
213
+ dim = self._labels
214
+ if isinstance(dim, str):
215
+ if dim not in self._labels:
216
+ raise KeyError(f"{self} does not have dimension {repr(dim)}")
217
+ return self._ndslice.sizes[self._labels.index(dim)]
218
+ else:
219
+ p = 1
220
+ for d in dim:
221
+ p *= self.size(d)
222
+ return p
223
+
224
+ @property
225
+ def sizes(self) -> dict[str, int]:
226
+ return dict(zip(self._labels, self._ndslice.sizes))
227
+
228
+
229
+ __all__ = ["NDSlice", "Shape", "MeshTrait"]
@@ -0,0 +1,114 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import Callable, List, Tuple, TYPE_CHECKING
9
+ from weakref import ref, WeakKeyDictionary
10
+
11
+ from . import messages
12
+ from .borrows import Borrow
13
+ from .context_manager import activate_first_context_manager
14
+ from .fake import fake_call
15
+ from .reference import Referenceable
16
+
17
+ if TYPE_CHECKING:
18
+ from monarch.common.client import Client # @manual
19
+
20
+ from .tensor import Tensor
21
+
22
+
23
+ class Stream:
24
+ def __init__(self, name: str, _default=False):
25
+ self.name = name
26
+ self.default: bool = _default
27
+ self.clients: WeakKeyDictionary["Client", "StreamRef"] = WeakKeyDictionary()
28
+
29
+ def __repr__(self):
30
+ return f"<Stream({repr(self.name)}) at {hex(id(self))}>"
31
+
32
+ def __str__(self):
33
+ return f"stream {repr(self.name)}"
34
+
35
+ def activate(self):
36
+ return _active_stream(self)
37
+
38
+ def _to_ref(self, client: "Client"):
39
+ if client not in self.clients:
40
+ self.clients[client] = StreamRef(client, self.name, self.default)
41
+ return self.clients[client]
42
+
43
+ def borrow(self, t: "Tensor", mutable: bool = False) -> Tuple["Tensor", "Borrow"]:
44
+ """
45
+ borrowed_tensor, borrow = self.borrow(t)
46
+
47
+ Borrows tensor 't' for use on this stream.
48
+ The memory of t will stay alive until borrow.drop() is called, which will free t and
49
+ and any of its alises on stream `self` and will cause t.stream to wait on self at that point so
50
+ that the memory of t can be reused.
51
+
52
+ If `mutable` then self can write to the storage of `t`, but t.stream cannot read or write `t` until,
53
+ the borrow is returned (becomes free and a wait_for has been issued).
54
+
55
+ If not `mutable` both `self` and `t.stream` can read from t's storage but neither can write to it.
56
+ """
57
+ client = t.mesh.client
58
+ aliases = t._aliases
59
+ r = type(t)(fake_call(t._fake.clone), t.mesh, self)
60
+ client.new_node((r,), (t,))
61
+ borrow = r._aliases.borrow_from(client.new_ref(), t.mesh, aliases, mutable)
62
+ client.new_borrow(borrow)
63
+ assert r.ref is not None
64
+ t.mesh._send(
65
+ messages.BorrowCreate(
66
+ r, borrow._id, t, t.stream._to_ref(client), self._to_ref(client)
67
+ )
68
+ )
69
+ r._on_first_use = lambda t: borrow._use()
70
+
71
+ return r, borrow
72
+
73
+
74
+ class StreamRef(Referenceable):
75
+ def __init__(self, client: "Client", name: str, default: bool):
76
+ self.ref = client.new_ref()
77
+ self.client = ref(client)
78
+ self.name = name
79
+ self.default = default
80
+ client.send(
81
+ client.all_ranks,
82
+ messages.CreateStream(self, self.default),
83
+ )
84
+
85
+ def __repr__(self):
86
+ return f"<StreamRef {repr(self.name)} {self.ref}>"
87
+
88
+ def delete_ref(self, ref):
89
+ client = self.client()
90
+ if client is not None and not client._shutdown:
91
+ client.handle_deletes(client.all_ranks, [ref])
92
+
93
+
94
+ _active = Stream("main", _default=True)
95
+ _on_change: List[Callable] = []
96
+
97
+
98
+ def get_active_stream():
99
+ return _active
100
+
101
+
102
+ @activate_first_context_manager
103
+ def _active_stream(stream: Stream):
104
+ global _active
105
+ for on_change in _on_change:
106
+ on_change(_active, stream)
107
+
108
+ _active, old = stream, _active
109
+ try:
110
+ yield
111
+ finally:
112
+ for on_change in _on_change:
113
+ on_change(_active, old)
114
+ _active = old