torchmonarch-nightly 2025.7.1__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.25__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +874 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +270 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +500 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +56 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +51 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +12 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/mesh_controller.py +201 -139
  37. monarch/monarch_controller +0 -0
  38. monarch/opaque_module.py +4 -6
  39. monarch/opaque_object.py +3 -3
  40. monarch/proc_mesh.py +6 -309
  41. monarch/python_local_mesh.py +1 -1
  42. monarch/rust_backend_mesh.py +2 -1
  43. monarch/rust_local_mesh.py +4 -2
  44. monarch/sim_mesh.py +10 -19
  45. monarch/simulator/command_history.py +1 -1
  46. monarch/simulator/interface.py +2 -1
  47. monarch/simulator/mock_controller.py +1 -1
  48. monarch/simulator/simulator.py +1 -1
  49. monarch/tensor_engine/__init__.py +23 -0
  50. monarch/tensor_worker_main.py +3 -1
  51. monarch/tools/cli.py +3 -1
  52. monarch/tools/commands.py +95 -35
  53. monarch/tools/mesh_spec.py +55 -0
  54. monarch/tools/utils.py +38 -0
  55. monarch/worker/worker.py +1 -1
  56. monarch/world_mesh.py +2 -1
  57. monarch_supervisor/python_executable.py +6 -3
  58. tests/error_test_binary.py +48 -10
  59. tests/test_actor_error.py +370 -21
  60. tests/test_alloc.py +1 -1
  61. tests/test_allocator.py +373 -17
  62. tests/test_controller.py +2 -0
  63. tests/test_debugger.py +416 -0
  64. tests/test_env_before_cuda.py +162 -0
  65. tests/test_python_actors.py +184 -333
  66. tests/test_rdma.py +198 -0
  67. tests/test_remote_functions.py +40 -12
  68. tests/test_rust_backend.py +7 -5
  69. tests/test_sim_backend.py +1 -4
  70. tests/test_tensor_engine.py +55 -1
  71. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
  72. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
  73. torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
  74. monarch/_monarch/hyperactor/__init__.py +0 -58
  75. monarch/_monarch/worker/debugger.py +0 -117
  76. monarch/_monarch/worker/logging.py +0 -107
  77. monarch/debugger.py +0 -379
  78. monarch/future.py +0 -76
  79. monarch/rdma.py +0 -162
  80. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  81. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  82. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  83. /monarch/{common → _src/actor}/shape.py +0 -0
  84. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  85. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
  86. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
  87. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
@@ -7,40 +7,63 @@
7
7
  import atexit
8
8
  import logging
9
9
  import os
10
- import time
10
+
11
+ import pdb # noqa
11
12
  import traceback
12
13
  from collections import deque
13
14
  from logging import Logger
14
- from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union
15
+ from typing import (
16
+ Any,
17
+ cast,
18
+ List,
19
+ NamedTuple,
20
+ Optional,
21
+ Sequence,
22
+ Tuple,
23
+ TYPE_CHECKING,
24
+ Union,
25
+ )
15
26
 
16
27
  import torch.utils._python_dispatch
17
-
18
- from monarch import NDSlice
19
- from monarch._rust_bindings.monarch_extension import client, debugger
28
+ from monarch._rust_bindings.monarch_extension import client
20
29
  from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
21
30
  WorldState,
22
31
  )
23
32
  from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller
33
+ from monarch._rust_bindings.monarch_hyperactor.actor import (
34
+ PythonMessage,
35
+ PythonMessageKind,
36
+ UnflattenArg,
37
+ )
38
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
24
39
  from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
25
40
  ActorId,
26
41
  )
42
+ from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple
43
+ from monarch._src.actor.endpoint import Selection
44
+ from monarch._src.actor.shape import NDSlice
45
+ from monarch.common import device_mesh, messages, stream
46
+ from monarch.common.controller_api import TController
47
+ from monarch.common.invocation import Seq
48
+ from monarch.common.messages import Referenceable, SendResultOfActorCall
49
+ from monarch.common.stream import StreamRef
50
+ from monarch.common.tensor import InputChecker, Tensor
51
+ from monarch.tensor_worker_main import _set_trace
27
52
 
28
53
  if TYPE_CHECKING:
29
54
  from monarch._rust_bindings.monarch_hyperactor.proc_mesh import (
30
55
  ProcMesh as HyProcMesh,
31
56
  )
32
- from monarch.proc_mesh import ProcMesh
57
+ from monarch.actor import ProcMesh
33
58
 
34
59
  from monarch._rust_bindings.monarch_hyperactor.shape import Point
35
60
 
36
- from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
37
61
  from monarch.common.client import Client
38
62
  from monarch.common.controller_api import LogMessage, MessageResult
39
- from monarch.common.device_mesh import DeviceMesh, no_mesh
63
+ from monarch.common.device_mesh import DeviceMesh
64
+ from monarch.common.future import Future as OldFuture
40
65
  from monarch.common.invocation import DeviceException, RemoteException
41
- from monarch.controller.debugger import read as debugger_read, write as debugger_write
42
66
  from monarch.rust_local_mesh import _get_worker_exec_info
43
- from pyre_extensions import none_throws
44
67
 
45
68
  logger: Logger = logging.getLogger(__name__)
46
69
 
@@ -48,6 +71,7 @@ logger: Logger = logging.getLogger(__name__)
48
71
  class Controller(_Controller):
49
72
  def __init__(self, workers: "HyProcMesh") -> None:
50
73
  super().__init__()
74
+ self._mailbox: Mailbox = workers.client
51
75
  # Buffer for messages unrelated to debugging that are received while a
52
76
  # debugger session is active.
53
77
  self._non_debugger_pending_messages: deque[
@@ -58,19 +82,9 @@ class Controller(_Controller):
58
82
  def next_message(
59
83
  self, timeout: Optional[float]
60
84
  ) -> Optional[LogMessage | MessageResult]:
61
- if self._non_debugger_pending_messages:
62
- msg = self._non_debugger_pending_messages.popleft()
63
- else:
64
- msg = self._get_next_message(timeout_msec=int((timeout or 0.0) * 1000.0))
65
- if msg is None:
66
- return None
67
-
68
- if isinstance(msg, client.WorkerResponse):
69
- return _worker_response_to_result(msg)
70
- elif isinstance(msg, client.LogMessage):
71
- return LogMessage(msg.level, msg.message)
72
- elif isinstance(msg, client.DebuggerMessage):
73
- self._run_debugger_loop(msg)
85
+ raise RuntimeError(
86
+ "internal error: tensor engine does not produce futures that call next_message"
87
+ )
74
88
 
75
89
  def send(
76
90
  self,
@@ -86,56 +100,6 @@ class Controller(_Controller):
86
100
  self._drain_and_stop()
87
101
  return []
88
102
 
89
- def _run_debugger_loop(self, message: client.DebuggerMessage) -> None:
90
- if not isinstance(message.action, DebuggerAction.Paused):
91
- raise RuntimeError(
92
- f"Unexpected debugger message {message} when no debugger session is running"
93
- )
94
-
95
- self._pending_debugger_sessions.append(message.debugger_actor_id)
96
- while self._pending_debugger_sessions:
97
- debugger_actor_id = self._pending_debugger_sessions.popleft()
98
- rank = debugger_actor_id.rank
99
- proc_id = debugger_actor_id.proc_id
100
- debugger_write(
101
- f"pdb attached to proc {proc_id} with rank {rank}, debugger actor {debugger_actor_id} \n"
102
- )
103
-
104
- self._debugger_attach(debugger_actor_id)
105
- while True:
106
- # TODO: Add appropriate timeout.
107
- msg = self._get_next_message(timeout_msec=None)
108
-
109
- if not isinstance(msg, client.DebuggerMessage):
110
- self._non_debugger_pending_messages.append(msg)
111
- continue
112
-
113
- if msg.debugger_actor_id != debugger_actor_id:
114
- if isinstance(msg.action, DebuggerAction.Paused):
115
- self._pending_debugger_sessions.append(msg.debugger_actor_id)
116
- continue
117
- else:
118
- raise RuntimeError(
119
- f"unexpected debugger message {msg} from rank {msg.debugger_actor_id.rank} "
120
- f"when debugging rank {debugger_actor_id.rank}"
121
- )
122
-
123
- action = msg.action
124
- if isinstance(action, DebuggerAction.Detach):
125
- break
126
- elif isinstance(action, DebuggerAction.Read):
127
- self._debugger_write(
128
- debugger_actor_id, debugger_read(action.requested_size)
129
- )
130
- elif isinstance(action, DebuggerAction.Write):
131
- debugger_write(
132
- debugger.get_bytes_from_write_action(action).decode()
133
- )
134
- else:
135
- raise RuntimeError(
136
- f"unexpected debugger message {msg} when debugging rank {debugger_actor_id.rank}"
137
- )
138
-
139
103
  def worker_world_state(self) -> WorldState:
140
104
  raise NotImplementedError("worker world state")
141
105
 
@@ -145,54 +109,6 @@ class Controller(_Controller):
145
109
  pass
146
110
 
147
111
 
148
- # TODO: Handling conversion of the response can move to a separate module over time
149
- # especially as we have structured error messages.
150
- def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
151
- if not result.is_exception():
152
- # The result of the message needs to be unwrapped on a real device.
153
- # Staying as a fake tensor will fail the tensor deserialization.
154
- with no_mesh.activate():
155
- return MessageResult(result.seq, result.result(), None)
156
- exc = none_throws(result.exception())
157
- if isinstance(exc, client.Error):
158
- worker_frames = [
159
- traceback.FrameSummary("<unknown>", None, frame)
160
- for frame in exc.backtrace.split("\\n")
161
- ]
162
- return MessageResult(
163
- seq=result.seq,
164
- result=None,
165
- error=RemoteException(
166
- seq=exc.caused_by_seq,
167
- exception=RuntimeError(exc.backtrace),
168
- controller_frame_index=0, # TODO: T225205291 fix this once we have recording support in rust
169
- controller_frames=None,
170
- worker_frames=worker_frames,
171
- source_actor_id=exc.actor_id,
172
- message=f"Remote function in {exc.actor_id} errored.",
173
- ),
174
- )
175
- elif isinstance(exc, client.Failure):
176
- frames = [
177
- traceback.FrameSummary("<unknown>", None, frame)
178
- for frame in exc.backtrace.split("\n")
179
- ]
180
- reason = f"Actor {exc.actor_id} crashed on {exc.address}, check the host log for details"
181
- logger.error(reason)
182
- return MessageResult(
183
- seq=0, # seq is not consumed for DeviceException; it will be directly thrown by the client
184
- result=None,
185
- error=DeviceException(
186
- exception=RuntimeError(reason),
187
- frames=frames,
188
- source_actor_id=exc.actor_id,
189
- message=reason,
190
- ),
191
- )
192
- else:
193
- raise RuntimeError(f"Unknown exception type: {type(exc)}")
194
-
195
-
196
112
  def _initialize_env(worker_point: Point, proc_id: str) -> None:
197
113
  worker_rank = worker_point.rank
198
114
  try:
@@ -213,12 +129,50 @@ def _initialize_env(worker_point: Point, proc_id: str) -> None:
213
129
  "LOCAL_WORLD_SIZE": str(gpus_per_host),
214
130
  }
215
131
  os.environ.update(process_env)
132
+ pdb.set_trace = _set_trace
133
+ # workaround for set_manual_seed somehow not working if cuda is not initialized\
134
+ if torch.cuda.is_available():
135
+ torch.cuda.init()
216
136
  except Exception:
217
137
  traceback.print_exc()
218
138
  raise
219
139
 
220
140
 
221
141
  class MeshClient(Client):
142
+ def fetch(
143
+ self,
144
+ mesh: "DeviceMesh",
145
+ stream: "StreamRef",
146
+ shard,
147
+ preprocess_message,
148
+ args,
149
+ kwargs,
150
+ defs: Tuple["Tensor", ...],
151
+ uses: Tuple["Tensor", ...],
152
+ ) -> "OldFuture": # the OldFuture is a lie
153
+ sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
154
+
155
+ ident = self.new_node(defs, uses, cast("OldFuture", sender))
156
+ process = mesh._process(shard)
157
+ self.send(
158
+ process,
159
+ messages.SendValue(
160
+ ident,
161
+ None,
162
+ defs,
163
+ preprocess_message,
164
+ args,
165
+ kwargs,
166
+ stream,
167
+ ),
168
+ )
169
+ # we have to ask for status updates
170
+ # from workers to be sure they have finished
171
+ # enough work to count this future as finished,
172
+ # and all potential errors have been reported
173
+ self._request_status()
174
+ return cast("OldFuture", receiver.recv())
175
+
222
176
  def shutdown(
223
177
  self,
224
178
  destroy_pg: bool = True,
@@ -232,27 +186,43 @@ class MeshClient(Client):
232
186
  atexit.unregister(self._atexit)
233
187
  self._shutdown = True
234
188
 
235
- # ensure all pending work is finished.
236
- # all errors must be messaged back at this point
237
- self.new_node_nocoalesce([], [], None, [])
238
- self._request_status()
239
-
240
- ttl = 60
241
- start_time = time.time()
242
- end_time = start_time + ttl
243
- while ttl > 0 and self.last_assigned_seq > self.last_processed_seq:
244
- ttl = end_time - time.time()
245
- self.handle_next_message(ttl)
246
- if self._pending_shutdown_error:
247
- raise self._pending_shutdown_error
248
-
249
- if ttl <= 0:
250
- raise RuntimeError("shutdown timed out")
251
-
189
+ sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True)
190
+ assert sender._port_ref is not None
191
+ self._mesh_controller.sync_at_exit(sender._port_ref.port_id)
192
+ receiver.recv().get(timeout=60)
252
193
  # we are not expecting anything more now, because we already
253
194
  # waited for the responses
254
195
  self.inner.drain_and_stop()
255
196
 
197
+ @property
198
+ def _mesh_controller(self) -> Controller:
199
+ return cast(Controller, self.inner)
200
+
201
+ def new_node_nocoalesce(
202
+ self,
203
+ defs: Sequence["Tensor"],
204
+ uses: Sequence["Tensor"],
205
+ future: Optional["OldFuture"],
206
+ tracebacks: List[List[traceback.FrameSummary]],
207
+ ) -> Seq:
208
+ seq = self._next_seq()
209
+ for d in defs:
210
+ d._seq = seq
211
+ response_port = None
212
+ if future is not None:
213
+ # method annotation is a lie to make Client happy
214
+ port, slice = cast("Tuple[Port[Any], NDSlice]", future)
215
+ assert port._port_ref is not None
216
+ response_port = (port._port_ref.port_id, slice)
217
+ self._mesh_controller.node(seq, defs, uses, response_port, tracebacks)
218
+ return seq
219
+
220
+ def handle_next_message(self, timeout: Optional[float]) -> bool:
221
+ """
222
+ Mesh controller message loop is handled by the tokio event loop.
223
+ """
224
+ return False
225
+
256
226
 
257
227
  def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
258
228
  # This argument to Controller
@@ -260,7 +230,7 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
260
230
  # report the proc ID instead of the rank it currently does.
261
231
  gpus = proc_mesh.sizes.get("gpus", 1)
262
232
  backend_ctrl = Controller(proc_mesh._proc_mesh)
263
- client = MeshClient(backend_ctrl, proc_mesh.size(), gpus)
233
+ client = MeshClient(cast("TController", backend_ctrl), proc_mesh.size(), gpus)
264
234
  dm = DeviceMesh(
265
235
  client,
266
236
  NDSlice.new_row_major(list(proc_mesh.sizes.values())),
@@ -268,3 +238,95 @@ def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
268
238
  )
269
239
  dm.exit = lambda: client.shutdown()
270
240
  return dm
241
+
242
+
243
+ class RemoteException(Exception):
244
+ def __init__(
245
+ self,
246
+ worker_error_string: str, # this should really be an exception + stacktrace but
247
+ # worker code needs major refactor to make this possible
248
+ controller_frames: List[traceback.FrameSummary],
249
+ rank: int,
250
+ ):
251
+ self.worker_error_string = worker_error_string
252
+ self.controller_frames = controller_frames
253
+ self.rank = rank
254
+
255
+ def __str__(self):
256
+ try:
257
+ controller_tb = "".join(traceback.format_list(self.controller_frames))
258
+ return (
259
+ f"A remote function has failed asynchronously on rank {self.rank}.\n"
260
+ f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}"
261
+ f"Error as reported from worker:\n{self.worker_error_string}"
262
+ )
263
+ except Exception:
264
+ traceback.print_exc()
265
+ return "<exception formatting RemoteException>"
266
+
267
+
268
+ def actor_send(
269
+ endpoint: ActorEndpoint,
270
+ args_kwargs_tuple: bytes,
271
+ refs: Sequence[Any],
272
+ port: Optional[Port[Any]],
273
+ selection: Selection,
274
+ ):
275
+ unflatten_args = [
276
+ UnflattenArg.PyObject if isinstance(ref, Tensor) else UnflattenArg.Mailbox
277
+ for ref in refs
278
+ ]
279
+ tensors = [ref for ref in refs if isinstance(ref, Tensor)]
280
+ # we have some monarch references, we need to ensure their
281
+ # proc_mesh matches that of the tensors we sent to it
282
+ chosen_stream = stream._active
283
+ for t in tensors:
284
+ if hasattr(t, "stream"):
285
+ chosen_stream = t.stream
286
+ break
287
+ with InputChecker(refs, lambda x: f"actor_call({x})") as checker:
288
+ checker.check_mesh_stream_local(device_mesh._active, chosen_stream)
289
+ # TODO: move propagators into Endpoint abstraction and run the propagator to get the
290
+ # mutates
291
+ checker.check_permission(())
292
+ selected_device_mesh = (
293
+ endpoint._actor_mesh._proc_mesh and endpoint._actor_mesh._proc_mesh._device_mesh
294
+ )
295
+ if selected_device_mesh is not checker.mesh:
296
+ raise ValueError(
297
+ f"monarch Tensors sent to an actor must be located on the same process as the actor. However {checker.mesh} is not {selected_device_mesh}."
298
+ "NYI: better serialization of mesh names to make the mismatch more clear."
299
+ )
300
+
301
+ client = cast(MeshClient, checker.mesh.client)
302
+
303
+ broker_id: Tuple[str, int] = client._mesh_controller.broker_id
304
+
305
+ stream_ref = chosen_stream._to_ref(client)
306
+
307
+ fut = (port, checker.mesh._ndslice) if port is not None else None
308
+
309
+ ident = client.new_node([], tensors, cast("OldFuture", fut))
310
+
311
+ # To ensure that both the actor and the stream execute in order, we send a message
312
+ # to each at this point. The message to the worker will be handled on the stream actor where
313
+ # it will send the 'tensor's to the broker actor locally, along with a response port with the
314
+ # computed value.
315
+
316
+ # The message to the generic actor tells it to first wait on the broker to get the local arguments
317
+ # from the stream, then it will run the actor method, and send the result to response port.
318
+
319
+ actor_msg = PythonMessage(
320
+ PythonMessageKind.CallMethodIndirect(
321
+ endpoint._name, broker_id, ident, unflatten_args
322
+ ),
323
+ args_kwargs_tuple,
324
+ )
325
+ endpoint._actor_mesh.cast(actor_msg, selection)
326
+ worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref)
327
+ client.send(checker.mesh._ndslice, worker_msg)
328
+ # we have to ask for status updates
329
+ # from workers to be sure they have finished
330
+ # enough work to count this future as finished,
331
+ # and all potential errors have been reported
332
+ client._request_status()
Binary file
monarch/opaque_module.py CHANGED
@@ -9,7 +9,7 @@ from typing import List
9
9
  import torch
10
10
  from monarch.common.function_caching import TensorGroup, TensorGroupPattern
11
11
  from monarch.common.opaque_ref import OpaqueRef
12
- from monarch.common.remote import remote
12
+ from monarch.common.remote import call_on_shard_and_fetch, remote
13
13
  from monarch.common.tensor_factory import TensorFactory
14
14
  from monarch.common.tree import flatten
15
15
  from monarch.opaque_object import _fresh_opaque_ref, OpaqueObject
@@ -144,11 +144,9 @@ class OpaqueModule:
144
144
 
145
145
  def parameters(self):
146
146
  if self._parameters is None:
147
- tensor_group_pattern = (
148
- remote(_get_parameters_shape)
149
- .call_on_shard_and_fetch(self._object)
150
- .result()
151
- )
147
+ tensor_group_pattern = call_on_shard_and_fetch(
148
+ remote(_get_parameters_shape), self._object
149
+ ).result()
152
150
  self._parameters = [
153
151
  p.requires_grad_(True)
154
152
  for p in remote(
monarch/opaque_object.py CHANGED
@@ -14,7 +14,7 @@ from monarch.common.function import (
14
14
  )
15
15
 
16
16
  from monarch.common.opaque_ref import OpaqueRef
17
- from monarch.common.remote import remote
17
+ from monarch.common.remote import call_on_shard_and_fetch, remote
18
18
 
19
19
 
20
20
  def _invoke_method(obj: OpaqueRef, method_name: str, *args, **kwargs):
@@ -83,6 +83,6 @@ class OpaqueObject(OpaqueRef):
83
83
  return endpoint(self, method_name, *args, **kwargs)
84
84
 
85
85
  def call_method_on_shard_and_fetch(self, method_name, *args, **kwargs):
86
- return remote(_invoke_method).call_on_shard_and_fetch(
87
- self, method_name, *args, **kwargs
86
+ return call_on_shard_and_fetch(
87
+ remote(_invoke_method), self, method_name, *args, **kwargs
88
88
  )