torchmonarch-nightly 2025.7.1__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.26__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +878 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +303 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +508 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +59 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +53 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +21 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/gradient/_gradient_generator.so +0 -0
  37. monarch/mesh_controller.py +263 -139
  38. monarch/monarch_controller +0 -0
  39. monarch/opaque_module.py +4 -6
  40. monarch/opaque_object.py +3 -3
  41. monarch/proc_mesh.py +6 -309
  42. monarch/python_local_mesh.py +1 -1
  43. monarch/rust_backend_mesh.py +2 -1
  44. monarch/rust_local_mesh.py +4 -2
  45. monarch/sim_mesh.py +10 -19
  46. monarch/simulator/command_history.py +1 -1
  47. monarch/simulator/interface.py +2 -1
  48. monarch/simulator/mock_controller.py +1 -1
  49. monarch/simulator/simulator.py +1 -1
  50. monarch/tensor_engine/__init__.py +23 -0
  51. monarch/tensor_worker_main.py +3 -1
  52. monarch/tools/cli.py +3 -1
  53. monarch/tools/commands.py +129 -47
  54. monarch/tools/components/hyperactor.py +5 -3
  55. monarch/tools/config/__init__.py +18 -1
  56. monarch/tools/config/defaults.py +2 -2
  57. monarch/tools/mesh_spec.py +59 -1
  58. monarch/tools/utils.py +38 -0
  59. monarch/worker/worker.py +1 -1
  60. monarch/world_mesh.py +2 -1
  61. monarch_supervisor/python_executable.py +6 -3
  62. tests/error_test_binary.py +48 -10
  63. tests/test_actor_error.py +370 -21
  64. tests/test_alloc.py +1 -1
  65. tests/test_allocator.py +369 -17
  66. tests/test_controller.py +2 -0
  67. tests/test_debugger.py +416 -0
  68. tests/test_env_before_cuda.py +161 -0
  69. tests/test_python_actors.py +184 -333
  70. tests/test_rdma.py +198 -0
  71. tests/test_remote_functions.py +40 -12
  72. tests/test_rust_backend.py +7 -5
  73. tests/test_sim_backend.py +1 -4
  74. tests/test_tensor_engine.py +81 -1
  75. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/METADATA +39 -1
  76. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/RECORD +84 -72
  77. torchmonarch_nightly-2025.7.26.dist-info/entry_points.txt +3 -0
  78. monarch/_monarch/hyperactor/__init__.py +0 -58
  79. monarch/_monarch/worker/debugger.py +0 -117
  80. monarch/_monarch/worker/logging.py +0 -107
  81. monarch/debugger.py +0 -379
  82. monarch/future.py +0 -76
  83. monarch/rdma.py +0 -162
  84. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  85. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  86. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  87. /monarch/{common → _src/actor}/shape.py +0 -0
  88. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  89. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/WHEEL +0 -0
  90. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/licenses/LICENSE +0 -0
  91. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/top_level.txt +0 -0
@@ -7,40 +7,66 @@
7
7
  import atexit
8
8
  import logging
9
9
  import os
10
- import time
10
+
11
+ import pdb # noqa
11
12
  import traceback
12
13
  from collections import deque
13
14
  from logging import Logger
14
- from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union
15
+ from typing import (
16
+ Any,
17
+ cast,
18
+ List,
19
+ NamedTuple,
20
+ Optional,
21
+ Sequence,
22
+ Tuple,
23
+ TYPE_CHECKING,
24
+ Union,
25
+ )
15
26
 
16
27
  import torch.utils._python_dispatch
17
-
18
- from monarch import NDSlice
19
- from monarch._rust_bindings.monarch_extension import client, debugger
28
+ from monarch._rust_bindings.monarch_extension import client
20
29
  from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
21
30
  WorldState,
22
31
  )
23
32
  from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller
33
+ from monarch._rust_bindings.monarch_extension.tensor_worker import Ref
34
+ from monarch._rust_bindings.monarch_hyperactor.actor import (
35
+ PythonMessage,
36
+ PythonMessageKind,
37
+ UnflattenArg,
38
+ )
39
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
24
40
  from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
25
41
  ActorId,
26
42
  )
43
+ from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple
44
+ from monarch._src.actor.endpoint import Selection
45
+ from monarch._src.actor.shape import NDSlice
46
+ from monarch.common import device_mesh, messages, stream
47
+ from monarch.common.controller_api import TController
48
+ from monarch.common.function import ResolvableFunction
49
+ from monarch.common.invocation import Seq
50
+ from monarch.common.messages import Referenceable, SendResultOfActorCall
51
+ from monarch.common.stream import StreamRef
52
+ from monarch.common.tensor import dtensor_check, InputChecker, Tensor
53
+ from monarch.common.tree import flatten
54
+ from monarch.tensor_worker_main import _set_trace
27
55
 
28
56
  if TYPE_CHECKING:
29
57
  from monarch._rust_bindings.monarch_hyperactor.proc_mesh import (
30
58
  ProcMesh as HyProcMesh,
31
59
  )
32
- from monarch.proc_mesh import ProcMesh
60
+ from monarch.actor import ProcMesh
33
61
 
34
62
  from monarch._rust_bindings.monarch_hyperactor.shape import Point
35
63
 
36
- from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
37
64
  from monarch.common.client import Client
38
65
  from monarch.common.controller_api import LogMessage, MessageResult
39
- from monarch.common.device_mesh import DeviceMesh, no_mesh
66
+ from monarch.common.device_mesh import DeviceMesh
67
+ from monarch.common.future import Future as OldFuture
40
68
  from monarch.common.invocation import DeviceException, RemoteException
41
- from monarch.controller.debugger import read as debugger_read, write as debugger_write
42
69
  from monarch.rust_local_mesh import _get_worker_exec_info
43
- from pyre_extensions import none_throws
44
70
 
45
71
  logger: Logger = logging.getLogger(__name__)
46
72
 
@@ -48,6 +74,7 @@ logger: Logger = logging.getLogger(__name__)
48
74
  class Controller(_Controller):
49
75
  def __init__(self, workers: "HyProcMesh") -> None:
50
76
  super().__init__()
77
+ self._mailbox: Mailbox = workers.client
51
78
  # Buffer for messages unrelated to debugging that are received while a
52
79
  # debugger session is active.
53
80
  self._non_debugger_pending_messages: deque[
@@ -58,19 +85,9 @@ class Controller(_Controller):
58
85
  def next_message(
59
86
  self, timeout: Optional[float]
60
87
  ) -> Optional[LogMessage | MessageResult]:
61
- if self._non_debugger_pending_messages:
62
- msg = self._non_debugger_pending_messages.popleft()
63
- else:
64
- msg = self._get_next_message(timeout_msec=int((timeout or 0.0) * 1000.0))
65
- if msg is None:
66
- return None
67
-
68
- if isinstance(msg, client.WorkerResponse):
69
- return _worker_response_to_result(msg)
70
- elif isinstance(msg, client.LogMessage):
71
- return LogMessage(msg.level, msg.message)
72
- elif isinstance(msg, client.DebuggerMessage):
73
- self._run_debugger_loop(msg)
88
+ raise RuntimeError(
89
+ "internal error: tensor engine does not produce futures that call next_message"
90
+ )
74
91
 
75
92
  def send(
76
93
  self,
@@ -86,56 +103,6 @@ class Controller(_Controller):
86
103
  self._drain_and_stop()
87
104
  return []
88
105
 
89
- def _run_debugger_loop(self, message: client.DebuggerMessage) -> None:
90
- if not isinstance(message.action, DebuggerAction.Paused):
91
- raise RuntimeError(
92
- f"Unexpected debugger message {message} when no debugger session is running"
93
- )
94
-
95
- self._pending_debugger_sessions.append(message.debugger_actor_id)
96
- while self._pending_debugger_sessions:
97
- debugger_actor_id = self._pending_debugger_sessions.popleft()
98
- rank = debugger_actor_id.rank
99
- proc_id = debugger_actor_id.proc_id
100
- debugger_write(
101
- f"pdb attached to proc {proc_id} with rank {rank}, debugger actor {debugger_actor_id} \n"
102
- )
103
-
104
- self._debugger_attach(debugger_actor_id)
105
- while True:
106
- # TODO: Add appropriate timeout.
107
- msg = self._get_next_message(timeout_msec=None)
108
-
109
- if not isinstance(msg, client.DebuggerMessage):
110
- self._non_debugger_pending_messages.append(msg)
111
- continue
112
-
113
- if msg.debugger_actor_id != debugger_actor_id:
114
- if isinstance(msg.action, DebuggerAction.Paused):
115
- self._pending_debugger_sessions.append(msg.debugger_actor_id)
116
- continue
117
- else:
118
- raise RuntimeError(
119
- f"unexpected debugger message {msg} from rank {msg.debugger_actor_id.rank} "
120
- f"when debugging rank {debugger_actor_id.rank}"
121
- )
122
-
123
- action = msg.action
124
- if isinstance(action, DebuggerAction.Detach):
125
- break
126
- elif isinstance(action, DebuggerAction.Read):
127
- self._debugger_write(
128
- debugger_actor_id, debugger_read(action.requested_size)
129
- )
130
- elif isinstance(action, DebuggerAction.Write):
131
- debugger_write(
132
- debugger.get_bytes_from_write_action(action).decode()
133
- )
134
- else:
135
- raise RuntimeError(
136
- f"unexpected debugger message {msg} when debugging rank {debugger_actor_id.rank}"
137
- )
138
-
139
106
  def worker_world_state(self) -> WorldState:
140
107
  raise NotImplementedError("worker world state")
141
108
 
@@ -145,54 +112,6 @@ class Controller(_Controller):
145
112
  pass
146
113
 
147
114
 
148
- # TODO: Handling conversion of the response can move to a separate module over time
149
- # especially as we have structured error messages.
150
- def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
151
- if not result.is_exception():
152
- # The result of the message needs to be unwrapped on a real device.
153
- # Staying as a fake tensor will fail the tensor deserialization.
154
- with no_mesh.activate():
155
- return MessageResult(result.seq, result.result(), None)
156
- exc = none_throws(result.exception())
157
- if isinstance(exc, client.Error):
158
- worker_frames = [
159
- traceback.FrameSummary("<unknown>", None, frame)
160
- for frame in exc.backtrace.split("\\n")
161
- ]
162
- return MessageResult(
163
- seq=result.seq,
164
- result=None,
165
- error=RemoteException(
166
- seq=exc.caused_by_seq,
167
- exception=RuntimeError(exc.backtrace),
168
- controller_frame_index=0, # TODO: T225205291 fix this once we have recording support in rust
169
- controller_frames=None,
170
- worker_frames=worker_frames,
171
- source_actor_id=exc.actor_id,
172
- message=f"Remote function in {exc.actor_id} errored.",
173
- ),
174
- )
175
- elif isinstance(exc, client.Failure):
176
- frames = [
177
- traceback.FrameSummary("<unknown>", None, frame)
178
- for frame in exc.backtrace.split("\n")
179
- ]
180
- reason = f"Actor {exc.actor_id} crashed on {exc.address}, check the host log for details"
181
- logger.error(reason)
182
- return MessageResult(
183
- seq=0, # seq is not consumed for DeviceException; it will be directly thrown by the client
184
- result=None,
185
- error=DeviceException(
186
- exception=RuntimeError(reason),
187
- frames=frames,
188
- source_actor_id=exc.actor_id,
189
- message=reason,
190
- ),
191
- )
192
- else:
193
- raise RuntimeError(f"Unknown exception type: {type(exc)}")
194
-
195
-
196
115
  def _initialize_env(worker_point: Point, proc_id: str) -> None:
197
116
  worker_rank = worker_point.rank
198
117
  try:
@@ -213,12 +132,50 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None:
213
132
  "LOCAL_WORLD_SIZE": str(gpus_per_host),
214
133
  }
215
134
  os.environ.update(process_env)
135
+ pdb.set_trace = _set_trace
136
+ # workaround for set_manual_seed somehow not working if cuda is not initialized\
137
+ if torch.cuda.is_available():
138
+ torch.cuda.init()
216
139
  except Exception:
217
140
  traceback.print_exc()
218
141
  raise
219
142
 
220
143
 
221
144
  class MeshClient(Client):
145
+ def fetch(
146
+ self,
147
+ mesh: "DeviceMesh",
148
+ stream: "StreamRef",
149
+ shard,
150
+ preprocess_message,
151
+ args,
152
+ kwargs,
153
+ defs: Tuple["Tensor", ...],
154
+ uses: Tuple["Tensor", ...],
155
+ ) -> "OldFuture": # the OldFuture is a lie
156
+ sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
157
+
158
+ ident = self.new_node(defs, uses, cast("OldFuture", sender))
159
+ process = mesh._process(shard)
160
+ self.send(
161
+ process,
162
+ messages.SendValue(
163
+ ident,
164
+ None,
165
+ defs,
166
+ preprocess_message,
167
+ args,
168
+ kwargs,
169
+ stream,
170
+ ),
171
+ )
172
+ # we have to ask for status updates
173
+ # from workers to be sure they have finished
174
+ # enough work to count this future as finished,
175
+ # and all potential errors have been reported
176
+ self._request_status()
177
+ return cast("OldFuture", receiver.recv())
178
+
222
179
  def shutdown(
223
180
  self,
224
181
  destroy_pg: bool = True,
@@ -232,27 +189,43 @@ class MeshClient(Client):
232
189
  atexit.unregister(self._atexit)
233
190
  self._shutdown = True
234
191
 
235
- # ensure all pending work is finished.
236
- # all errors must be messaged back at this point
237
- self.new_node_nocoalesce([], [], None, [])
238
- self._request_status()
239
-
240
- ttl = 60
241
- start_time = time.time()
242
- end_time = start_time + ttl
243
- while ttl > 0 and self.last_assigned_seq > self.last_processed_seq:
244
- ttl = end_time - time.time()
245
- self.handle_next_message(ttl)
246
- if self._pending_shutdown_error:
247
- raise self._pending_shutdown_error
248
-
249
- if ttl <= 0:
250
- raise RuntimeError("shutdown timed out")
251
-
192
+ sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
193
+ assert sender._port_ref is not None
194
+ self._mesh_controller.sync_at_exit(sender._port_ref.port_id)
195
+ receiver.recv().get(timeout=60)
252
196
  # we are not expecting anything more now, because we already
253
197
  # waited for the responses
254
198
  self.inner.drain_and_stop()
255
199
 
200
+ @property
201
+ def _mesh_controller(self) -> Controller:
202
+ return cast(Controller, self.inner)
203
+
204
+ def new_node_nocoalesce(
205
+ self,
206
+ defs: Sequence["Tensor"],
207
+ uses: Sequence["Tensor"],
208
+ future: Optional["OldFuture"],
209
+ tracebacks: List[List[traceback.FrameSummary]],
210
+ ) -> Seq:
211
+ seq = self._next_seq()
212
+ for d in defs:
213
+ d._seq = seq
214
+ response_port = None
215
+ if future is not None:
216
+ # method annotation is a lie to make Client happy
217
+ port, slice = cast("Tuple[Port[Any], NDSlice]", future)
218
+ assert port._port_ref is not None
219
+ response_port = (port._port_ref.port_id, slice)
220
+ self._mesh_controller.node(seq, defs, uses, response_port, tracebacks)
221
+ return seq
222
+
223
+ def handle_next_message(self, timeout: Optional[float]) -> bool:
224
+ """
225
+ Mesh controller message loop is handled by the tokio event loop.
226
+ """
227
+ return False
228
+
256
229
 
257
230
  def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
258
231
  # This argument to Controller
@@ -260,7 +233,7 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
260
233
  # report the proc ID instead of the rank it currently does.
261
234
  gpus = proc_mesh.sizes.get("gpus", 1)
262
235
  backend_ctrl = Controller(proc_mesh._proc_mesh)
263
- client = MeshClient(backend_ctrl, proc_mesh.size(), gpus)
236
+ client = MeshClient(cast("TController", backend_ctrl), proc_mesh.size(), gpus)
264
237
  dm = DeviceMesh(
265
238
  client,
266
239
  NDSlice.new_row_major(list(proc_mesh.sizes.values())),
@@ -268,3 +241,154 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
268
241
  )
269
242
  dm.exit = lambda: client.shutdown()
270
243
  return dm
244
+
245
+
246
+ class RemoteException(Exception):
247
+ def __init__(
248
+ self,
249
+ worker_error_string: str, # this should really be an exception + stacktrace but
250
+ # worker code needs major refactor to make this possible
251
+ controller_frames: List[traceback.FrameSummary],
252
+ rank: int,
253
+ ):
254
+ self.worker_error_string = worker_error_string
255
+ self.controller_frames = controller_frames
256
+ self.rank = rank
257
+
258
+ def __str__(self):
259
+ try:
260
+ controller_tb = "".join(traceback.format_list(self.controller_frames))
261
+ return (
262
+ f"A remote function has failed asynchronously on rank {self.rank}.\n"
263
+ f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}"
264
+ f"Error as reported from worker:\n{self.worker_error_string}"
265
+ )
266
+ except Exception:
267
+ traceback.print_exc()
268
+ return "<exception formatting RemoteException>"
269
+
270
+
271
+ def _cast_call_method_indirect(
272
+ endpoint: ActorEndpoint,
273
+ selection: Selection,
274
+ client: MeshClient,
275
+ seq: Seq,
276
+ args_kwargs_tuple: bytes,
277
+ refs: Sequence[Any],
278
+ ) -> Tuple[str, int]:
279
+ unflatten_args = [
280
+ UnflattenArg.PyObject if isinstance(ref, Tensor) else UnflattenArg.Mailbox
281
+ for ref in refs
282
+ ]
283
+ broker_id: Tuple[str, int] = client._mesh_controller.broker_id
284
+ actor_msg = PythonMessage(
285
+ PythonMessageKind.CallMethodIndirect(
286
+ endpoint._name, broker_id, seq, unflatten_args
287
+ ),
288
+ args_kwargs_tuple,
289
+ )
290
+ endpoint._actor_mesh.cast(actor_msg, selection)
291
+ return broker_id
292
+
293
+
294
+ def actor_send(
295
+ endpoint: ActorEndpoint,
296
+ args_kwargs_tuple: bytes,
297
+ refs: Sequence[Any],
298
+ port: Optional[Port[Any]],
299
+ selection: Selection,
300
+ ):
301
+ tensors = [ref for ref in refs if isinstance(ref, Tensor)]
302
+ # we have some monarch references, we need to ensure their
303
+ # proc_mesh matches that of the tensors we sent to it
304
+ chosen_stream = stream._active
305
+ for t in tensors:
306
+ if hasattr(t, "stream"):
307
+ chosen_stream = t.stream
308
+ break
309
+ with InputChecker(tensors, lambda x: f"actor_call({x})") as checker:
310
+ checker.check_mesh_stream_local(device_mesh._active, chosen_stream)
311
+ # TODO: move propagators into Endpoint abstraction and run the propagator to get the
312
+ # mutates
313
+ checker.check_permission(())
314
+ selected_device_mesh = (
315
+ endpoint._actor_mesh._proc_mesh and endpoint._actor_mesh._proc_mesh._device_mesh
316
+ )
317
+ if selected_device_mesh is not checker.mesh:
318
+ raise ValueError(
319
+ f"monarch Tensors sent to an actor must be located on the same process as the actor. However {checker.mesh} is not {selected_device_mesh}."
320
+ "NYI: better serialization of mesh names to make the mismatch more clear."
321
+ )
322
+
323
+ client = cast(MeshClient, checker.mesh.client)
324
+
325
+ stream_ref = chosen_stream._to_ref(client)
326
+
327
+ fut = (port, checker.mesh._ndslice) if port is not None else None
328
+
329
+ ident = client.new_node([], tensors, cast("OldFuture", fut))
330
+
331
+ # To ensure that both the actor and the stream execute in order, we send a message
332
+ # to each at this point. The message to the worker will be handled on the stream actor where
333
+ # it will send the 'tensor's to the broker actor locally, along with a response port with the
334
+ # computed value.
335
+
336
+ # The message to the generic actor tells it to first wait on the broker to get the local arguments
337
+ # from the stream, then it will run the actor method, and send the result to response port.
338
+
339
+ broker_id = _cast_call_method_indirect(
340
+ endpoint, selection, client, ident, args_kwargs_tuple, refs
341
+ )
342
+ worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref)
343
+ client.send(checker.mesh._ndslice, worker_msg)
344
+ # we have to ask for status updates
345
+ # from workers to be sure they have finished
346
+ # enough work to count this future as finished,
347
+ # and all potential errors have been reported
348
+ client._request_status()
349
+
350
+
351
+ def actor_rref(endpoint, args_kwargs_tuple: bytes, refs: Sequence[Any]):
352
+ chosen_stream = stream._active
353
+ fake_result, dtensors, mutates, mesh = dtensor_check(
354
+ endpoint._propagate,
355
+ cast(ResolvableFunction, endpoint._name),
356
+ refs,
357
+ {},
358
+ device_mesh._active,
359
+ chosen_stream,
360
+ )
361
+ assert mesh is not None
362
+
363
+ fake_result_dtensors, unflatten_result = flatten(
364
+ fake_result, lambda x: isinstance(x, torch.Tensor)
365
+ )
366
+ result_dtensors = tuple(
367
+ Tensor(fake, mesh, chosen_stream) for fake in fake_result_dtensors
368
+ )
369
+ seq = mesh.client.new_node(result_dtensors + mutates, dtensors)
370
+ assert all(t.ref is not None for t in result_dtensors)
371
+ assert all(t.ref is not None for t in mutates)
372
+ result = result_msg = unflatten_result(result_dtensors)
373
+ if len(result_dtensors) == 0:
374
+ result_msg = None
375
+
376
+ broker_id = _cast_call_method_indirect(
377
+ endpoint, "all", mesh.client, seq, args_kwargs_tuple, refs
378
+ )
379
+ # note the device mesh has to be defined regardles so the remote functions
380
+ # can invoke mesh.rank("...")
381
+
382
+ mesh.define_remotely()
383
+
384
+ mesh._send(
385
+ messages.CallActorMethod(
386
+ seq,
387
+ result_msg,
388
+ broker_id,
389
+ refs,
390
+ cast("List[Ref]", mutates),
391
+ stream._active._to_ref(mesh.client),
392
+ )
393
+ )
394
+ return result
Binary file
monarch/opaque_module.py CHANGED
@@ -9,7 +9,7 @@ from typing import List
9
9
  import torch
10
10
  from monarch.common.function_caching import TensorGroup, TensorGroupPattern
11
11
  from monarch.common.opaque_ref import OpaqueRef
12
- from monarch.common.remote import remote
12
+ from monarch.common.remote import call_on_shard_and_fetch, remote
13
13
  from monarch.common.tensor_factory import TensorFactory
14
14
  from monarch.common.tree import flatten
15
15
  from monarch.opaque_object import _fresh_opaque_ref, OpaqueObject
@@ -144,11 +144,9 @@ class OpaqueModule:
144
144
 
145
145
  def parameters(self):
146
146
  if self._parameters is None:
147
- tensor_group_pattern = (
148
- remote(_get_parameters_shape)
149
- .call_on_shard_and_fetch(self._object)
150
- .result()
151
- )
147
+ tensor_group_pattern = call_on_shard_and_fetch(
148
+ remote(_get_parameters_shape), self._object
149
+ ).result()
152
150
  self._parameters = [
153
151
  p.requires_grad_(True)
154
152
  for p in remote(
monarch/opaque_object.py CHANGED
@@ -14,7 +14,7 @@ from monarch.common.function import (
14
14
  )
15
15
 
16
16
  from monarch.common.opaque_ref import OpaqueRef
17
- from monarch.common.remote import remote
17
+ from monarch.common.remote import call_on_shard_and_fetch, remote
18
18
 
19
19
 
20
20
  def _invoke_method(obj: OpaqueRef, method_name: str, *args, **kwargs):
@@ -83,6 +83,6 @@ class OpaqueObject(OpaqueRef):
83
83
  return endpoint(self, method_name, *args, **kwargs)
84
84
 
85
85
  def call_method_on_shard_and_fetch(self, method_name, *args, **kwargs):
86
- return remote(_invoke_method).call_on_shard_and_fetch(
87
- self, method_name, *args, **kwargs
86
+ return call_on_shard_and_fetch(
87
+ remote(_invoke_method), self, method_name, *args, **kwargs
88
88
  )