torchmonarch-nightly 2025.7.1__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.26__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +878 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +303 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +508 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +59 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +53 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +21 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/gradient/_gradient_generator.so +0 -0
  37. monarch/mesh_controller.py +263 -139
  38. monarch/monarch_controller +0 -0
  39. monarch/opaque_module.py +4 -6
  40. monarch/opaque_object.py +3 -3
  41. monarch/proc_mesh.py +6 -309
  42. monarch/python_local_mesh.py +1 -1
  43. monarch/rust_backend_mesh.py +2 -1
  44. monarch/rust_local_mesh.py +4 -2
  45. monarch/sim_mesh.py +10 -19
  46. monarch/simulator/command_history.py +1 -1
  47. monarch/simulator/interface.py +2 -1
  48. monarch/simulator/mock_controller.py +1 -1
  49. monarch/simulator/simulator.py +1 -1
  50. monarch/tensor_engine/__init__.py +23 -0
  51. monarch/tensor_worker_main.py +3 -1
  52. monarch/tools/cli.py +3 -1
  53. monarch/tools/commands.py +129 -47
  54. monarch/tools/components/hyperactor.py +5 -3
  55. monarch/tools/config/__init__.py +18 -1
  56. monarch/tools/config/defaults.py +2 -2
  57. monarch/tools/mesh_spec.py +59 -1
  58. monarch/tools/utils.py +38 -0
  59. monarch/worker/worker.py +1 -1
  60. monarch/world_mesh.py +2 -1
  61. monarch_supervisor/python_executable.py +6 -3
  62. tests/error_test_binary.py +48 -10
  63. tests/test_actor_error.py +370 -21
  64. tests/test_alloc.py +1 -1
  65. tests/test_allocator.py +369 -17
  66. tests/test_controller.py +2 -0
  67. tests/test_debugger.py +416 -0
  68. tests/test_env_before_cuda.py +161 -0
  69. tests/test_python_actors.py +184 -333
  70. tests/test_rdma.py +198 -0
  71. tests/test_remote_functions.py +40 -12
  72. tests/test_rust_backend.py +7 -5
  73. tests/test_sim_backend.py +1 -4
  74. tests/test_tensor_engine.py +81 -1
  75. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/METADATA +39 -1
  76. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/RECORD +84 -72
  77. torchmonarch_nightly-2025.7.26.dist-info/entry_points.txt +3 -0
  78. monarch/_monarch/hyperactor/__init__.py +0 -58
  79. monarch/_monarch/worker/debugger.py +0 -117
  80. monarch/_monarch/worker/logging.py +0 -107
  81. monarch/debugger.py +0 -379
  82. monarch/future.py +0 -76
  83. monarch/rdma.py +0 -162
  84. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  85. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  86. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  87. /monarch/{common → _src/actor}/shape.py +0 -0
  88. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  89. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/WHEEL +0 -0
  90. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/licenses/LICENSE +0 -0
  91. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/top_level.txt +0 -0
monarch/bootstrap_main.py CHANGED
@@ -4,56 +4,17 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
- """
8
- This is the main function for the boostrapping a new process using a ProcessAllocator.
9
- """
7
+ import warnings
10
8
 
11
- import asyncio
12
- import importlib.resources
13
- import logging
14
- import os
15
- import sys
9
+ warnings.warn(
10
+ "monarch.bootstrap_main is deprecated, please use from monarch._src.actor.bootstrap_main instead.",
11
+ DeprecationWarning,
12
+ stacklevel=2,
13
+ )
16
14
 
17
- # Import torch to avoid import-time races if a spawned actor tries to import torch.
18
- import torch # noqa[F401]
19
-
20
-
21
- async def main():
22
- from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main
23
-
24
- await bootstrap_main()
25
-
26
-
27
- def invoke_main():
28
- # if this is invoked with the stdout piped somewhere, then print
29
- # changes its buffering behavior. So we default to the standard
30
- # behavior of std out as if it were a terminal.
31
- sys.stdout.reconfigure(line_buffering=True)
32
- global bootstrap_main
33
-
34
- # TODO: figure out what from worker_main.py we should reproduce here.
35
- from monarch.telemetry import TracingForwarder
36
-
37
- if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
38
- raise RuntimeError("Error during bootstrap for testing")
39
-
40
- # forward logs to rust tracing. Defaults to on.
41
- if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
42
- logging.root.addHandler(TracingForwarder(level=logging.DEBUG))
43
-
44
- try:
45
- with (
46
- importlib.resources.path("monarch", "py-spy") as pyspy,
47
- ):
48
- if pyspy.exists():
49
- os.environ["PYSPY_BIN"] = str(pyspy)
50
- # fallback to using local py-spy
51
- except Exception as e:
52
- logging.warning(f"Failed to set up py-spy: {e}")
53
-
54
- # Start an event loop for PythonActors to use.
55
- asyncio.run(main())
15
+ from monarch._src.actor.bootstrap_main import * # noqa
56
16
 
57
17
 
58
18
  if __name__ == "__main__":
19
+ # noqa
59
20
  invoke_main() # pragma: no cover
monarch/common/client.py CHANGED
@@ -37,6 +37,7 @@ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monar
37
37
  LogLevel,
38
38
  WorldState,
39
39
  )
40
+ from monarch._src.actor.shape import NDSlice
40
41
  from monarch.common import messages
41
42
  from monarch.common.borrows import Borrow, StorageAliases
42
43
  from monarch.common.controller_api import LogMessage, MessageResult, TController
@@ -47,7 +48,6 @@ from monarch.common.invocation import DeviceException, RemoteException, Seq
47
48
  from monarch.common.recording import flatten_messages, Recording
48
49
 
49
50
  from monarch.common.reference import Ref, Referenceable
50
- from monarch.common.shape import NDSlice
51
51
  from monarch.common.stream import StreamRef
52
52
  from monarch.common.tensor import Tensor
53
53
  from monarch.common.tree import tree_map
@@ -13,9 +13,10 @@ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monar
13
13
  WorldState,
14
14
  )
15
15
 
16
+ from monarch._src.actor.shape import NDSlice
17
+
16
18
  from monarch.common.invocation import DeviceException, RemoteException, Seq
17
19
  from monarch.common.reference import Ref
18
- from monarch.common.shape import NDSlice
19
20
  from monarch.common.tensor import Tensor
20
21
 
21
22
 
@@ -28,16 +28,16 @@ from typing import (
28
28
 
29
29
  import monarch.common.messages as messages
30
30
  import torch
31
- from monarch.common.shape import MeshTrait
31
+ from monarch._src.actor.shape import MeshTrait, NDSlice, Shape
32
32
 
33
33
  from torch.utils._python_dispatch import TorchDispatchMode
34
34
  from torch.utils._pytree import tree_map
35
+ from torch.utils.weak import weakref
35
36
 
36
37
  from ._tensor_to_table import tensor_to_table
37
38
  from .context_manager import activate_first_context_manager
38
39
  from .messages import Dims
39
40
  from .reference import Referenceable
40
- from .shape import NDSlice, Shape
41
41
  from .stream import Stream
42
42
  from .tensor import MeshSliceTensor, Tensor
43
43
 
@@ -171,6 +171,7 @@ class DeviceMesh(Referenceable, MeshTrait):
171
171
  self.exit = lambda: None
172
172
  self.ref = None
173
173
  self._active_mesh_context = None
174
+ self._subset_of: Optional[weakref.ReferenceType["DeviceMesh"]] = None
174
175
 
175
176
  def define_remotely(self):
176
177
  if self.ref is None:
@@ -228,8 +229,17 @@ class DeviceMesh(Referenceable, MeshTrait):
228
229
  def _new_with_shape(self, shape: Shape) -> "DeviceMesh":
229
230
  mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels))
230
231
  mesh.exit = self.exit
232
+ mesh._subset_of = weakref.ref(self)
231
233
  return mesh
232
234
 
235
+ def _is_subset_of(self, other: "DeviceMesh") -> bool:
236
+ p = self
237
+ while p is not None:
238
+ if p is other:
239
+ return True
240
+ p = None if p._subset_of is None else p._subset_of()
241
+ return False
242
+
233
243
  def __call__(self, **kwargs) -> "DeviceMesh":
234
244
  """
235
245
  device_mesh(batch=3) or device_mesh(batch=slice(3, None))
@@ -17,18 +17,21 @@ from typing import (
17
17
  NamedTuple,
18
18
  Optional,
19
19
  Protocol,
20
+ Sequence,
20
21
  Tuple,
21
22
  TYPE_CHECKING,
22
23
  )
23
24
 
24
25
  from monarch._rust_bindings.monarch_extension import tensor_worker
26
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
27
+
28
+ from monarch._src.actor.shape import NDSlice
25
29
  from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction
26
30
  from monarch.common.invocation import DeviceException, RemoteException
27
31
  from monarch.common.reference import Referenceable
28
32
  from monarch.common.tree import flattener
29
33
  from pyre_extensions import none_throws
30
34
 
31
- from .shape import NDSlice
32
35
  from .tensor_factory import TensorFactory
33
36
 
34
37
  if TYPE_CHECKING:
@@ -424,6 +427,23 @@ class SendTensor(NamedTuple):
424
427
  )
425
428
 
426
429
 
430
+ class SendResultOfActorCall(NamedTuple):
431
+ seq: int
432
+ broker_id: Tuple[str, int]
433
+ local_state: Sequence[Tensor | tensor_worker.Ref]
434
+ mutates: List[tensor_worker.Ref]
435
+ stream: tensor_worker.StreamRef
436
+
437
+
438
+ class CallActorMethod(NamedTuple):
439
+ seq: int
440
+ result: object
441
+ broker_id: Tuple[str, int]
442
+ local_state: Sequence[Tensor | tensor_worker.Ref]
443
+ mutates: List[tensor_worker.Ref]
444
+ stream: tensor_worker.StreamRef
445
+
446
+
427
447
  class SplitComm(NamedTuple):
428
448
  dims: Dims
429
449
  device_mesh: DeviceMesh
@@ -10,9 +10,9 @@ import traceback
10
10
  from collections import defaultdict
11
11
  from typing import cast, Dict, Generator, List, NamedTuple, Tuple, TYPE_CHECKING, Union
12
12
 
13
- from monarch.common.reference import Ref
13
+ from monarch._src.actor.shape import iter_ranks
14
14
 
15
- from monarch.common.shape import iter_ranks
15
+ from monarch.common.reference import Ref
16
16
 
17
17
  from monarch.common.tensor import InputChecker
18
18
 
@@ -21,8 +21,9 @@ from . import messages
21
21
  if TYPE_CHECKING:
22
22
  from monarch.common.client import Client
23
23
 
24
+ from monarch._src.actor.shape import NDSlice
25
+
24
26
  from .reference import Referenceable
25
- from .shape import NDSlice
26
27
  from .tensor import Tensor
27
28
 
28
29
  logger = logging.getLogger(__name__)
monarch/common/remote.py CHANGED
@@ -8,12 +8,12 @@
8
8
 
9
9
  import functools
10
10
  import logging
11
- import warnings
12
11
 
13
12
  from logging import Logger
14
13
  from typing import (
15
14
  Any,
16
15
  Callable,
16
+ cast,
17
17
  Dict,
18
18
  Generic,
19
19
  Literal,
@@ -28,12 +28,18 @@ from typing import (
28
28
  import monarch.common.messages as messages
29
29
 
30
30
  import torch
31
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
32
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
33
+ from monarch._src.actor.actor_mesh import Port, PortTuple
34
+ from monarch._src.actor.endpoint import Extent, Selection
31
35
 
32
- from monarch.common import _coalescing, device_mesh, messages, stream
36
+ from monarch.common import _coalescing, device_mesh, stream
37
+ from monarch.common.future import Future as OldFuture
33
38
 
34
39
  if TYPE_CHECKING:
35
40
  from monarch.common.client import Client
36
41
 
42
+ from monarch._src.actor.endpoint import Endpoint
37
43
  from monarch.common.device_mesh import RemoteProcessGroup
38
44
  from monarch.common.fake import fake_call
39
45
 
@@ -49,9 +55,9 @@ from monarch.common.function_caching import (
49
55
  TensorGroup,
50
56
  TensorPlaceholder,
51
57
  )
52
- from monarch.common.future import Future
53
58
  from monarch.common.messages import Dims
54
- from monarch.common.tensor import dtensor_check, dtensor_dispatch
59
+
60
+ from monarch.common.tensor import dtensor_check, dtensor_dispatch, InputChecker
55
61
  from monarch.common.tree import flatten, tree_map
56
62
  from torch import autograd, distributed as dist
57
63
  from typing_extensions import ParamSpec
@@ -62,42 +68,96 @@ P = ParamSpec("P")
62
68
  R = TypeVar("R")
63
69
  T = TypeVar("T")
64
70
 
65
- Propagator = Callable | Literal["mocked", "cached", "inspect"] | None
66
-
67
71
 
68
- class Remote(Generic[P, R]):
72
+ class Remote(Generic[P, R], Endpoint[P, R]):
69
73
  def __init__(self, impl: Any, propagator_arg: Propagator):
74
+ super().__init__(propagator_arg)
70
75
  self._remote_impl = impl
71
- self._propagator_arg = propagator_arg
72
- self._cache: Optional[dict] = None
76
+
77
+ def _call_name(self) -> Any:
78
+ return self._remote_impl
79
+
80
+ def _send(
81
+ self,
82
+ args: Tuple[Any, ...],
83
+ kwargs: Dict[str, Any],
84
+ port: "Optional[Port]" = None,
85
+ selection: Selection = "all",
86
+ ) -> Extent:
87
+ ambient_mesh = device_mesh._active
88
+ propagator = self._fetch_propagate
89
+ rfunction = self._maybe_resolvable
90
+ # a None rfunction is an optimization for the identity function (lambda x: x)
91
+ if rfunction is None:
92
+ preprocess_message = None
93
+ rfunction = ResolvableFunctionFromPath("ident")
94
+ else:
95
+ preprocess_message = rfunction
96
+ _, dtensors, mutates, tensor_mesh = dtensor_check(
97
+ propagator, rfunction, args, kwargs, ambient_mesh, stream._active
98
+ )
99
+
100
+ if ambient_mesh is None:
101
+ raise ValueError(
102
+ "Calling a 'remote' monarch function requires an active proc_mesh (`with proc_mesh.activate():`)"
103
+ )
104
+
105
+ if not ambient_mesh._is_subset_of(tensor_mesh):
106
+ raise ValueError(
107
+ f"The current mesh {ambient_mesh} is not a subset of the mesh on which the tensors being used are defined {tensor_mesh}"
108
+ )
109
+
110
+ client: "Client" = ambient_mesh.client
111
+ if _coalescing.is_active(client):
112
+ raise NotImplementedError("NYI: fetching results during a coalescing block")
113
+ stream_ref = stream._active._to_ref(client)
114
+
115
+ fut = (port, ambient_mesh._ndslice)
116
+
117
+ ident = client.new_node(mutates, dtensors, cast("OldFuture", fut))
118
+
119
+ client.send(
120
+ ambient_mesh._ndslice,
121
+ messages.SendValue(
122
+ ident,
123
+ None,
124
+ mutates,
125
+ preprocess_message,
126
+ args,
127
+ kwargs,
128
+ stream_ref,
129
+ ),
130
+ )
131
+ # we have to ask for status updates
132
+ # from workers to be sure they have finished
133
+ # enough work to count this future as finished,
134
+ # and all potential errors have been reported
135
+ client._request_status()
136
+ return Extent(ambient_mesh._labels, ambient_mesh._ndslice.sizes)
137
+
138
+ def _port(self, once: bool = False) -> "PortTuple[R]":
139
+ ambient_mesh = device_mesh._active
140
+ if ambient_mesh is None:
141
+ raise ValueError(
142
+ "FIXME - cannot create a port without an active proc_mesh, because there is not way to create a port without a mailbox"
143
+ )
144
+ mesh_controller = getattr(ambient_mesh.client, "_mesh_controller", None)
145
+ if mesh_controller is None:
146
+ raise ValueError(
147
+ "Cannot create raw port objects with an old-style tensor engine controller."
148
+ )
149
+ mailbox: Mailbox = mesh_controller._mailbox
150
+ return PortTuple.create(mailbox, once)
73
151
 
74
152
  @property
75
153
  def _resolvable(self):
76
154
  return resolvable_function(self._remote_impl)
77
155
 
78
- def _propagate(self, args, kwargs, fake_args, fake_kwargs):
79
- if self._propagator_arg is None or self._propagator_arg == "cached":
80
- if self._cache is None:
81
- self._cache = {}
82
- return _cached_propagation(self._cache, self._resolvable, args, kwargs)
83
- elif self._propagator_arg == "inspect":
84
- return None
85
- elif self._propagator_arg == "mocked":
86
- raise NotImplementedError("mocked propagation")
87
- else:
88
- return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
89
-
90
- def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
91
- if self._propagator_arg is None:
92
- return # no propgator provided, so we just assume no mutations
93
- return self._propagate(args, kwargs, fake_args, fake_kwargs)
94
-
95
- def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
96
- if not callable(self._propagator_arg):
97
- raise ValueError("Must specify explicit callable for pipe")
98
- return self._propagate(args, kwargs, fake_args, fake_kwargs)
156
+ @property
157
+ def _maybe_resolvable(self):
158
+ return None if self._remote_impl is None else self._resolvable
99
159
 
100
- def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
160
+ def _rref(self, args, kwargs):
101
161
  return dtensor_dispatch(
102
162
  self._resolvable,
103
163
  self._propagate,
@@ -107,12 +167,8 @@ class Remote(Generic[P, R]):
107
167
  stream._active,
108
168
  )
109
169
 
110
- def call_on_shard_and_fetch(
111
- self, *args, shard: Dict[str, int] | None = None, **kwargs
112
- ) -> Future[R]:
113
- return _call_on_shard_and_fetch(
114
- self._resolvable, self._fetch_propagate, *args, shard=shard, **kwargs
115
- )
170
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
171
+ return self.rref(*args, **kwargs)
116
172
 
117
173
 
118
174
  # This can't just be Callable because otherwise we are not
@@ -151,14 +207,43 @@ def remote(function: Any = None, *, propagate: Propagator = None) -> Any:
151
207
  return Remote(function, propagate)
152
208
 
153
209
 
154
- def _call_on_shard_and_fetch(
155
- rfunction: ResolvableFunction | None,
156
- propagator: Any,
210
+ remote_identity = Remote(None, lambda x: x)
211
+
212
+
213
+ def call_on_shard_and_fetch(
214
+ remote: Endpoint[P, R], *args, shard: Dict[str, int] | None = None, **kwargs
215
+ ) -> OldFuture[R]:
216
+ # We have to flatten the tensors twice: first to discover
217
+ # which mesh we are working on to shard it, and then again when doing the
218
+ # dtensor_check in send. This complexity is a consequence of doing
219
+ # implicit inference of the mesh from the tensors.
220
+ dtensors, unflatten = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor))
221
+ with InputChecker.from_flat_args(
222
+ remote._call_name(), dtensors, unflatten
223
+ ) as checker:
224
+ checker.check_mesh_stream_local(device_mesh._active, stream._active)
225
+
226
+ if not hasattr(checker.mesh.client, "_mesh_controller"):
227
+ return _old_call_on_shard_and_fetch(
228
+ cast("Remote[P, R]", remote),
229
+ *args,
230
+ shard=shard,
231
+ **kwargs,
232
+ )
233
+
234
+ selected_slice = checker.mesh._process(shard)
235
+ shard_mesh = checker.mesh._new_with_shape(Shape(["_"], selected_slice))
236
+ with shard_mesh.activate():
237
+ return cast("OldFuture[R]", remote.call_one(*args, **kwargs))
238
+
239
+
240
+ def _old_call_on_shard_and_fetch(
241
+ remote_obj: Remote[P, R],
157
242
  /,
158
243
  *args: object,
159
244
  shard: dict[str, int] | None = None,
160
245
  **kwargs: object,
161
- ) -> Future:
246
+ ) -> OldFuture[R]:
162
247
  """
163
248
  Call `function` at the coordinates `shard` of the current device mesh, and retrieve the result as a Future.
164
249
  function - the remote function to call
@@ -166,6 +251,9 @@ def _call_on_shard_and_fetch(
166
251
  shard - a dictionary from mesh dimension name to coordinate of the shard
167
252
  If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
168
253
  """
254
+
255
+ rfunction = remote_obj._maybe_resolvable
256
+ propagator = remote_obj._fetch_propagate
169
257
  ambient_mesh = device_mesh._active
170
258
 
171
259
  if rfunction is None:
@@ -180,15 +268,9 @@ def _call_on_shard_and_fetch(
180
268
  client: "Client" = mesh.client
181
269
  if _coalescing.is_active(client):
182
270
  raise NotImplementedError("NYI: fetching results during a coalescing block")
271
+ stream_ref = stream._active._to_ref(client)
183
272
  return client.fetch(
184
- mesh,
185
- stream._active._to_ref(client),
186
- shard,
187
- preprocess_message,
188
- args,
189
- kwargs,
190
- mutates,
191
- dtensors,
273
+ mesh, stream_ref, shard, preprocess_message, args, kwargs, mutates, dtensors
192
274
  )
193
275
 
194
276
 
@@ -270,8 +352,9 @@ _miss = 0
270
352
  _hit = 0
271
353
 
272
354
 
273
- def _cached_propagation(_cache, rfunction, args, kwargs):
355
+ def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs):
274
356
  tensors, shape_key = hashable_tensor_flatten(args, kwargs)
357
+ # pyre-ignore
275
358
  inputs_group = TensorGroup([t._fake for t in tensors])
276
359
  requires_grads = tuple(t.requires_grad for t in tensors)
277
360
  key = (shape_key, inputs_group.pattern, requires_grads)
@@ -280,8 +363,8 @@ def _cached_propagation(_cache, rfunction, args, kwargs):
280
363
  if key not in _cache:
281
364
  _miss += 1
282
365
  args_no_pg, kwargs_no_pg = tree_map(_mock_pgs, (args, kwargs))
283
- result_with_placeholders, output_pattern = _propagate.call_on_shard_and_fetch(
284
- function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
366
+ result_with_placeholders, output_pattern = call_on_shard_and_fetch(
367
+ _propagate, function=rfunction, args=args_no_pg, kwargs=kwargs_no_pg
285
368
  ).result()
286
369
 
287
370
  _, unflatten_result = flatten(
monarch/common/tensor.py CHANGED
@@ -40,12 +40,13 @@ from .borrows import StorageAliases
40
40
  if TYPE_CHECKING:
41
41
  from monarch.common.device_mesh import DeviceMesh
42
42
 
43
+ from monarch._src.actor.shape import NDSlice
44
+
43
45
  from .fake import fake_call
44
46
  from .function import Propagator, ResolvableFunction
45
47
  from .invocation import Invocation
46
48
  from .messages import Dims
47
49
  from .reference import Referenceable
48
- from .shape import NDSlice
49
50
  from .stream import Stream
50
51
  from .tree import flatten
51
52
 
@@ -13,9 +13,9 @@ import socket
13
13
  from abc import ABC, abstractmethod
14
14
  from typing import List, NamedTuple, Optional, Sequence, Tuple
15
15
 
16
- from monarch.common import messages
16
+ from monarch._src.actor.shape import iter_ranks, Slices as Ranks
17
17
 
18
- from monarch.common.shape import iter_ranks, Slices as Ranks
18
+ from monarch.common import messages
19
19
  from monarch_supervisor import (
20
20
  Context,
21
21
  FunctionCall,
@@ -19,11 +19,12 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarc
19
19
  ActorId,
20
20
  )
21
21
 
22
+ from monarch._src.actor.shape import NDSlice
23
+
22
24
  from monarch.common import messages
23
25
  from monarch.common.controller_api import LogMessage, MessageResult
24
26
  from monarch.common.invocation import DeviceException, Seq
25
27
  from monarch.common.reference import Ref
26
- from monarch.common.shape import NDSlice
27
28
  from monarch.common.tensor import Tensor
28
29
  from monarch.controller import debugger
29
30
 
@@ -29,11 +29,12 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarc
29
29
  )
30
30
 
31
31
  from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
32
+
33
+ from monarch._src.actor.shape import NDSlice
32
34
  from monarch.common.controller_api import LogMessage, MessageResult
33
35
  from monarch.common.device_mesh import no_mesh
34
36
  from monarch.common.invocation import DeviceException, RemoteException
35
37
  from monarch.common.messages import SupportsToRustMessage
36
- from monarch.common.shape import NDSlice
37
38
  from monarch.common.tensor import Tensor
38
39
  from monarch.controller.debugger import read as debugger_read, write as debugger_write
39
40
  from pyre_extensions import none_throws
monarch/fetch.py CHANGED
@@ -9,13 +9,13 @@
9
9
  This is a utility file for fetching a shard of a tensor from remote.
10
10
  """
11
11
 
12
- from typing import TypeVar
12
+ from typing import cast, TypeVar
13
13
 
14
14
  from monarch.common.device_mesh import no_mesh
15
15
 
16
16
  from monarch.common.future import Future
17
17
 
18
- from monarch.common.remote import _call_on_shard_and_fetch
18
+ from monarch.common.remote import call_on_shard_and_fetch, remote_identity
19
19
 
20
20
  T = TypeVar("T")
21
21
 
@@ -37,9 +37,7 @@ def fetch_shard(
37
37
  shard = {}
38
38
  shard.update(kwargs)
39
39
 
40
- return _call_on_shard_and_fetch(
41
- None, lambda *args, **kwargs: None, obj, shard=shard
42
- )
40
+ return cast("Future[T]", call_on_shard_and_fetch(remote_identity, obj, shard=shard))
43
41
 
44
42
 
45
43
  def show(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> object:
Binary file