torchmonarch-nightly 2025.7.25__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.27__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monarch/_rust_bindings.so CHANGED
Binary file
@@ -29,8 +29,10 @@ from typing import (
29
29
  Iterable,
30
30
  Iterator,
31
31
  List,
32
+ Literal,
32
33
  NamedTuple,
33
34
  Optional,
35
+ overload,
34
36
  ParamSpec,
35
37
  Tuple,
36
38
  Type,
@@ -39,6 +41,7 @@ from typing import (
39
41
  )
40
42
 
41
43
  from monarch._rust_bindings.monarch_hyperactor.actor import (
44
+ MethodSpecifier,
42
45
  PanicFlag,
43
46
  PythonMessage,
44
47
  PythonMessageKind,
@@ -65,6 +68,7 @@ from monarch._src.actor.endpoint import (
65
68
  Endpoint,
66
69
  EndpointProperty,
67
70
  Extent,
71
+ NotAnEndpoint,
68
72
  Propagator,
69
73
  Selection,
70
74
  )
@@ -76,7 +80,7 @@ from monarch._src.actor.pickle import flatten, unflatten
76
80
  from monarch._src.actor.shape import MeshTrait, NDSlice
77
81
  from monarch._src.actor.sync_state import fake_sync_state
78
82
 
79
- from monarch._src.actor.tensor_engine_shim import actor_send
83
+ from monarch._src.actor.tensor_engine_shim import actor_rref, actor_send
80
84
 
81
85
  if TYPE_CHECKING:
82
86
  from monarch._src.actor.proc_mesh import ProcMesh
@@ -281,16 +285,18 @@ class ActorEndpoint(Endpoint[P, R]):
281
285
  def __init__(
282
286
  self,
283
287
  actor_mesh_ref: _ActorMeshRefImpl,
284
- name: str,
288
+ name: MethodSpecifier,
285
289
  impl: Callable[Concatenate[Any, P], Awaitable[R]],
286
290
  mailbox: Mailbox,
287
- propagator: Propagator = None,
291
+ propagator: Propagator,
292
+ explicit_response_port: bool,
288
293
  ) -> None:
289
294
  super().__init__(propagator)
290
295
  self._actor_mesh = actor_mesh_ref
291
296
  self._name = name
292
297
  self._signature: inspect.Signature = inspect.signature(impl)
293
298
  self._mailbox = mailbox
299
+ self._explicit_response_port = explicit_response_port
294
300
 
295
301
  def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
296
302
  mesh = self._actor_mesh._actor_mesh
@@ -299,6 +305,12 @@ class ActorEndpoint(Endpoint[P, R]):
299
305
  def _call_name(self) -> Any:
300
306
  return self._name
301
307
 
308
+ def _check_arguments(self, args, kwargs):
309
+ if self._explicit_response_port:
310
+ self._signature.bind(None, None, *args, **kwargs)
311
+ else:
312
+ self._signature.bind(None, *args, **kwargs)
313
+
302
314
  def _send(
303
315
  self,
304
316
  args: Tuple[Any, ...],
@@ -311,10 +323,9 @@ class ActorEndpoint(Endpoint[P, R]):
311
323
 
312
324
  This sends the message to all actors but does not wait for any result.
313
325
  """
314
- self._signature.bind(None, *args, **kwargs)
326
+ self._check_arguments(args, kwargs)
315
327
  objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
316
- refs = [obj for obj in objects if hasattr(obj, "__monarch_ref__")]
317
- if not refs:
328
+ if all(not hasattr(obj, "__monarch_ref__") for obj in objects):
318
329
  message = PythonMessage(
319
330
  PythonMessageKind.CallMethod(
320
331
  self._name, None if port is None else port._port_ref
@@ -323,7 +334,7 @@ class ActorEndpoint(Endpoint[P, R]):
323
334
  )
324
335
  self._actor_mesh.cast(message, selection)
325
336
  else:
326
- actor_send(self, bytes, refs, port, selection)
337
+ actor_send(self, bytes, objects, port, selection)
327
338
  shape = self._actor_mesh._shape
328
339
  return Extent(shape.labels, shape.ndslice.sizes)
329
340
 
@@ -335,6 +346,53 @@ class ActorEndpoint(Endpoint[P, R]):
335
346
  ), "unexpected receiver type"
336
347
  return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
337
348
 
349
+ def _rref(self, args, kwargs):
350
+ self._check_arguments(args, kwargs)
351
+ refs, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
352
+
353
+ return actor_rref(self, bytes, refs)
354
+
355
+
356
+ @overload
357
+ def as_endpoint(
358
+ not_an_endpoint: Callable[P, R],
359
+ *,
360
+ propagate: Propagator = None,
361
+ explicit_response_port: Literal[False] = False,
362
+ ) -> Endpoint[P, R]: ...
363
+
364
+
365
+ @overload
366
+ def as_endpoint(
367
+ not_an_endpoint: Callable[Concatenate["PortProtocol[R]", P], None],
368
+ *,
369
+ propagate: Propagator = None,
370
+ explicit_response_port: Literal[True],
371
+ ) -> Endpoint[P, R]: ...
372
+
373
+
374
+ def as_endpoint(
375
+ not_an_endpoint: Any,
376
+ *,
377
+ propagate: Propagator = None,
378
+ explicit_response_port: bool = False,
379
+ ):
380
+ if not isinstance(not_an_endpoint, NotAnEndpoint):
381
+ raise ValueError("expected an method of a spawned actor")
382
+ kind = (
383
+ MethodSpecifier.ExplicitPort
384
+ if explicit_response_port
385
+ else MethodSpecifier.ReturnsResponse
386
+ )
387
+ return ActorEndpoint(
388
+ not_an_endpoint._ref._actor_mesh_ref,
389
+ kind(not_an_endpoint._name),
390
+ getattr(not_an_endpoint._ref, not_an_endpoint._name),
391
+ not_an_endpoint._ref._mailbox,
392
+ propagate,
393
+ explicit_response_port,
394
+ )
395
+
338
396
 
339
397
  class Accumulator(Generic[P, R, A]):
340
398
  def __init__(
@@ -578,7 +636,7 @@ class _Actor:
578
636
  mailbox: Mailbox,
579
637
  rank: int,
580
638
  shape: Shape,
581
- method: str,
639
+ method_spec: MethodSpecifier,
582
640
  message: bytes,
583
641
  panic_flag: PanicFlag,
584
642
  local_state: Iterable[Any],
@@ -596,17 +654,23 @@ class _Actor:
596
654
 
597
655
  args, kwargs = unflatten(message, local_state)
598
656
 
599
- if method == "__init__":
600
- Class, *args = args
601
- try:
602
- self.instance = Class(*args, **kwargs)
603
- except Exception as e:
604
- self._saved_error = ActorError(
605
- e, f"Remote actor {Class}.__init__ call failed."
606
- )
607
- raise e
608
- port.send(None)
609
- return None
657
+ match method_spec:
658
+ case MethodSpecifier.Init():
659
+ Class, *args = args
660
+ try:
661
+ self.instance = Class(*args, **kwargs)
662
+ except Exception as e:
663
+ self._saved_error = ActorError(
664
+ e, f"Remote actor {Class}.__init__ call failed."
665
+ )
666
+ raise e
667
+ port.send(None)
668
+ return None
669
+ case MethodSpecifier.ReturnsResponse(name=method):
670
+ pass
671
+ case MethodSpecifier.ExplicitPort(name=method):
672
+ args = (port, *args)
673
+ port = DroppingPort()
610
674
 
611
675
  if self.instance is None:
612
676
  # This could happen because of the following reasons. Both
@@ -625,18 +689,23 @@ class _Actor:
625
689
  f" This is likely due to an earlier error: {self._saved_error}"
626
690
  )
627
691
  raise AssertionError(error_message)
628
- the_method = getattr(self.instance, method)._method
692
+ the_method = getattr(self.instance, method)
693
+ if isinstance(the_method, EndpointProperty):
694
+ module = the_method._method.__module__
695
+ the_method = functools.partial(the_method._method, self.instance)
696
+ else:
697
+ module = the_method.__module__
629
698
 
630
699
  if inspect.iscoroutinefunction(the_method):
631
700
 
632
701
  async def instrumented():
633
702
  enter_span(
634
- the_method.__module__,
703
+ module,
635
704
  method,
636
705
  str(ctx.mailbox.actor_id),
637
706
  )
638
707
  try:
639
- result = await the_method(self.instance, *args, **kwargs)
708
+ result = await the_method(*args, **kwargs)
640
709
  self._maybe_exit_debugger()
641
710
  except Exception as e:
642
711
  logging.critical(
@@ -649,9 +718,9 @@ class _Actor:
649
718
 
650
719
  result = await instrumented()
651
720
  else:
652
- enter_span(the_method.__module__, method, str(ctx.mailbox.actor_id))
721
+ enter_span(module, method, str(ctx.mailbox.actor_id))
653
722
  with fake_sync_state():
654
- result = the_method(self.instance, *args, **kwargs)
723
+ result = the_method(*args, **kwargs)
655
724
  self._maybe_exit_debugger()
656
725
  exit_span()
657
726
 
@@ -750,43 +819,29 @@ class ActorMeshRef(MeshTrait):
750
819
  for attr_name in dir(self._class):
751
820
  attr_value = getattr(self._class, attr_name, None)
752
821
  if isinstance(attr_value, EndpointProperty):
822
+ # Convert string method name to appropriate MethodSpecifier
823
+ kind = (
824
+ MethodSpecifier.ExplicitPort
825
+ if attr_value._explicit_response_port
826
+ else MethodSpecifier.ReturnsResponse
827
+ )
753
828
  setattr(
754
829
  self,
755
830
  attr_name,
756
831
  ActorEndpoint(
757
832
  self._actor_mesh_ref,
758
- attr_name,
833
+ kind(attr_name),
759
834
  attr_value._method,
760
835
  self._mailbox,
836
+ attr_value._propagator,
837
+ attr_value._explicit_response_port,
761
838
  ),
762
839
  )
763
840
 
764
- def __getattr__(self, name: str) -> Any:
765
- # This method is called when an attribute is not found
766
- # For linting purposes, we need to tell the type checker that any attribute
767
- # could be an endpoint that's dynamically added at runtime
768
- # At runtime, we still want to raise AttributeError for truly missing attributes
769
-
770
- # Check if this is a method on the underlying class
771
- if hasattr(self._class, name):
772
- attr = getattr(self._class, name)
773
- if isinstance(attr, EndpointProperty):
774
- # Dynamically create the endpoint
775
- endpoint = ActorEndpoint(
776
- self._actor_mesh_ref,
777
- name,
778
- attr._method,
779
- self._mailbox,
780
- propagator=attr._propagator,
781
- )
782
- # Cache it for future use
783
- setattr(self, name, endpoint)
784
- return endpoint
785
-
786
- # If we get here, it's truly not found
787
- raise AttributeError(
788
- f"'{self.__class__.__name__}' object has no attribute '{name}'"
789
- )
841
+ def __getattr__(self, attr: str) -> NotAnEndpoint:
842
+ if attr in dir(self._class):
843
+ return NotAnEndpoint(self, attr)
844
+ raise AttributeError(attr)
790
845
 
791
846
  def _create(
792
847
  self,
@@ -798,9 +853,11 @@ class ActorMeshRef(MeshTrait):
798
853
 
799
854
  ep = ActorEndpoint(
800
855
  self._actor_mesh_ref,
801
- "__init__",
856
+ MethodSpecifier.Init(),
802
857
  null_func,
803
858
  self._mailbox,
859
+ None,
860
+ False,
804
861
  )
805
862
  send(ep, (self._class, *args), kwargs)
806
863
 
@@ -34,6 +34,7 @@ from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
34
34
 
35
35
  if TYPE_CHECKING:
36
36
  from monarch._src.actor.actor_mesh import (
37
+ ActorMeshRef,
37
38
  HyPortReceiver,
38
39
  OncePortReceiver,
39
40
  Port,
@@ -182,11 +183,22 @@ class Endpoint(ABC, Generic[P, R]):
182
183
  # pyre-ignore
183
184
  send(self, args, kwargs)
184
185
 
186
+ @abstractmethod
187
+ def _rref(self, args, kwargs) -> Any: ...
188
+
189
+ def rref(self, *args: P.args, **kwargs: P.kwargs) -> R:
190
+ return self._rref(args, kwargs)
191
+
185
192
  def _propagate(self, args, kwargs, fake_args, fake_kwargs):
186
193
  if self._propagator_arg is None or self._propagator_arg == "cached":
187
194
  if self._cache is None:
188
195
  self._cache = {}
189
- return _cached_propagation(self._cache, self._resolvable, args, kwargs)
196
+ resolvable = getattr(self, "_resolvable", None)
197
+ if resolvable is None:
198
+ raise NotImplementedError(
199
+ "Cached propagation is not implemented for actor endpoints."
200
+ )
201
+ return _cached_propagation(self._cache, resolvable, args, kwargs)
190
202
  elif self._propagator_arg == "inspect":
191
203
  return None
192
204
  elif self._propagator_arg == "mocked":
@@ -211,16 +223,23 @@ class EndpointProperty(Generic[P, R]):
211
223
  self,
212
224
  method: Callable[Concatenate[Any, P], Awaitable[R]],
213
225
  propagator: Propagator,
226
+ explicit_response_port: bool,
214
227
  ) -> None: ...
215
228
 
216
229
  @overload
217
230
  def __init__(
218
- self, method: Callable[Concatenate[Any, P], R], propagator: Propagator
231
+ self,
232
+ method: Callable[Concatenate[Any, P], R],
233
+ propagator: Propagator,
234
+ explicit_response_port: bool,
219
235
  ) -> None: ...
220
236
 
221
- def __init__(self, method: Any, propagator: Propagator) -> None:
237
+ def __init__(
238
+ self, method: Any, propagator: Propagator, explicit_response_port: bool
239
+ ) -> None:
222
240
  self._method = method
223
241
  self._propagator = propagator
242
+ self._explicit_response_port = explicit_response_port
224
243
 
225
244
  def __get__(self, instance, owner) -> Endpoint[P, R]:
226
245
  # this is a total lie, but we have to actually
@@ -229,13 +248,50 @@ class EndpointProperty(Generic[P, R]):
229
248
  return cast(Endpoint[P, R], self)
230
249
 
231
250
 
251
+ class NotAnEndpoint:
252
+ """
253
+ Used as the dynamic value of functions on an ActorMeshRef that were not marked as endpoints.
254
+ This is used both to give a better error message (since we cannot prevent the type system from thinking they are methods),
255
+ and to provide the oppurtunity for someone to do endpoint(x.foo) on something that wasn't marked as an endpoint.
256
+ """
257
+
258
+ def __init__(self, ref: "ActorMeshRef", name: str):
259
+ self._ref = ref
260
+ self._name = name
261
+
262
+ def __call__(self, *args, **kwargs) -> None:
263
+ raise RuntimeError(
264
+ f"Actor {self._ref._class}.{self._name} is not annotated as an endpoint. To call it as one, add a @endpoint decorator to it, or directly wrap it in one as_endpoint(obj.method).call(...)"
265
+ )
266
+
267
+
232
268
  # This can't just be Callable because otherwise we are not
233
269
  # allowed to use type arguments in the return value.
234
270
  class EndpointIfy:
235
271
  @overload
236
- def __call__(self, function: Callable[P, Awaitable[R]]) -> Endpoint[P, R]: ...
272
+ def __call__(
273
+ self, function: Callable[Concatenate[Any, P], Awaitable[R]]
274
+ ) -> Endpoint[P, R]: ...
237
275
  @overload
238
- def __call__(self, function: Callable[P, R]) -> Endpoint[P, R]: ...
276
+ def __call__(
277
+ self, function: Callable[Concatenate[Any, P], R]
278
+ ) -> Endpoint[P, R]: ...
279
+
280
+ def __call__(self, function: Any):
281
+ pass
282
+
283
+
284
+ class PortedEndpointIfy:
285
+ @overload
286
+ def __call__(
287
+ self,
288
+ function: Callable[Concatenate[Any, "Port[R]", P], Awaitable[None]],
289
+ ) -> Endpoint[P, R]: ...
290
+
291
+ @overload
292
+ def __call__(
293
+ self, function: Callable[Concatenate[Any, "Port[R]", P], None]
294
+ ) -> Endpoint[P, R]: ...
239
295
 
240
296
  def __call__(self, function: Any):
241
297
  pass
@@ -246,6 +302,7 @@ def endpoint(
246
302
  method: Callable[Concatenate[Any, P], Awaitable[R]],
247
303
  *,
248
304
  propagate: Propagator = None,
305
+ explicit_response_port: Literal[False] = False,
249
306
  ) -> EndpointProperty[P, R]: ...
250
307
 
251
308
 
@@ -254,6 +311,7 @@ def endpoint(
254
311
  method: Callable[Concatenate[Any, P], R],
255
312
  *,
256
313
  propagate: Propagator = None,
314
+ explicit_response_port: Literal[False] = False,
257
315
  ) -> EndpointProperty[P, R]: ...
258
316
 
259
317
 
@@ -261,10 +319,43 @@ def endpoint(
261
319
  def endpoint(
262
320
  *,
263
321
  propagate: Propagator = None,
322
+ explicit_response_port: Literal[False] = False,
264
323
  ) -> EndpointIfy: ...
265
324
 
266
325
 
267
- def endpoint(method=None, *, propagate=None):
326
+ @overload
327
+ def endpoint(
328
+ method: Callable[Concatenate[Any, "Port[R]", P], Awaitable[None]],
329
+ *,
330
+ propagate: Propagator = None,
331
+ explicit_response_port: Literal[True],
332
+ ) -> EndpointProperty[P, R]: ...
333
+
334
+
335
+ @overload
336
+ def endpoint(
337
+ method: Callable[Concatenate[Any, "Port[R]", P], None],
338
+ *,
339
+ propagate: Propagator = None,
340
+ explicit_response_port: Literal[True],
341
+ ) -> EndpointProperty[P, R]: ...
342
+
343
+
344
+ @overload
345
+ def endpoint(
346
+ *,
347
+ propagate: Propagator = None,
348
+ explicit_response_port: Literal[True],
349
+ ) -> PortedEndpointIfy: ...
350
+
351
+
352
+ def endpoint(method=None, *, propagate=None, explicit_response_port: bool = False):
268
353
  if method is None:
269
- return functools.partial(endpoint, propagate=propagate)
270
- return EndpointProperty(method, propagator=propagate)
354
+ return functools.partial(
355
+ endpoint,
356
+ propagate=propagate,
357
+ explicit_response_port=explicit_response_port,
358
+ )
359
+ return EndpointProperty(
360
+ method, propagator=propagate, explicit_response_port=explicit_response_port
361
+ )
@@ -14,7 +14,7 @@ import logging
14
14
  import threading
15
15
  from typing import Optional
16
16
 
17
- from libfb.py.pyre import none_throws
17
+ from pyre_extensions import none_throws
18
18
 
19
19
  logger = logging.getLogger(__name__)
20
20
 
@@ -43,7 +43,6 @@ from monarch._src.actor.actor_mesh import (
43
43
  Actor,
44
44
  ActorMeshRef,
45
45
  fake_sync_state,
46
- MonarchContext,
47
46
  )
48
47
 
49
48
  from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator
@@ -89,7 +88,7 @@ class SetupActor(Actor):
89
88
  Typically used to setup the environment variables.
90
89
  """
91
90
 
92
- def __init__(self, env: Callable[[MonarchContext], None]) -> None:
91
+ def __init__(self, env: Callable[[], None]) -> None:
93
92
  """
94
93
  Initialize the setup actor with the user defined setup method.
95
94
  """
@@ -100,8 +99,7 @@ class SetupActor(Actor):
100
99
  """
101
100
  Call the user defined setup method with the monarch context.
102
101
  """
103
- ctx = MonarchContext.get()
104
- self._setup_method(ctx)
102
+ self._setup_method()
105
103
 
106
104
 
107
105
  T = TypeVar("T")
@@ -114,7 +112,7 @@ except ImportError:
114
112
 
115
113
 
116
114
  async def _allocate_nonblocking(
117
- alloc: Alloc, setup: Callable[[MonarchContext], None] | None = None
115
+ alloc: Alloc, setup: Callable[[], None] | None = None
118
116
  ) -> "ProcMesh":
119
117
  _proc_mesh = await HyProcMesh.allocate_nonblocking(alloc)
120
118
  if setup is None:
@@ -211,7 +209,7 @@ class ProcMesh(MeshTrait):
211
209
 
212
210
  @classmethod
213
211
  def from_alloc(
214
- self, alloc: Alloc, setup: Callable[[MonarchContext], None] | None = None
212
+ self, alloc: Alloc, setup: Callable[[], None] | None = None
215
213
  ) -> Future["ProcMesh"]:
216
214
  """
217
215
  Allocate a process mesh according to the provided alloc.
@@ -219,7 +217,17 @@ class ProcMesh(MeshTrait):
219
217
 
220
218
  Arguments:
221
219
  - `alloc`: The alloc to allocate according to.
222
- - `setup`: A lambda taking MonarchContext as param, can be used to setup env vars on the allocated mesh
220
+ - `setup`: An optional lambda function to configure environment variables on the allocated mesh.
221
+ Use the `current_rank()` method within the lambda to obtain the rank.
222
+
223
+ Example of a setup method to initialize torch distributed environment variables:
224
+ ```
225
+ def setup():
226
+ rank = current_rank()
227
+ os.environ["RANK"] = str(rank)
228
+ os.environ["WORLD_SIZE"] = str(len(rank.shape))
229
+ os.environ["LOCAL_RANK"] = str(rank["gpus"])
230
+ ```
223
231
  """
224
232
  return Future(
225
233
  impl=lambda: _allocate_nonblocking(alloc, setup),
@@ -428,7 +436,7 @@ async def proc_mesh_nonblocking(
428
436
  gpus: Optional[int] = None,
429
437
  hosts: int = 1,
430
438
  env: dict[str, str] | None = None,
431
- setup: Callable[[MonarchContext], None] | None = None,
439
+ setup: Callable[[], None] | None = None,
432
440
  ) -> ProcMesh:
433
441
  if gpus is None:
434
442
  gpus = _local_device_count()
@@ -457,7 +465,7 @@ def proc_mesh(
457
465
  gpus: Optional[int] = None,
458
466
  hosts: int = 1,
459
467
  env: dict[str, str] | None = None,
460
- setup: Callable[[MonarchContext], None] | None = None,
468
+ setup: Callable[[], None] | None = None,
461
469
  ) -> Future[ProcMesh]:
462
470
  return Future(
463
471
  impl=lambda: proc_mesh_nonblocking(
@@ -19,7 +19,6 @@ time it is used.
19
19
 
20
20
  if TYPE_CHECKING:
21
21
  from monarch._src.actor.actor_mesh import ActorEndpoint, Port, Selection
22
- from monarch._src.actor.endpoint import Endpoint
23
22
 
24
23
 
25
24
  def shim(fn=None, *, module=None):
@@ -48,8 +47,12 @@ def actor_send(
48
47
  ) -> None: ...
49
48
 
50
49
 
50
+ @shim(module="monarch.mesh_controller")
51
+ def actor_rref(endpoint, args_kwargs_tuple: bytes, refs: Sequence[Any]): ...
52
+
53
+
51
54
  @shim(module="monarch.common.remote")
52
- def _cached_propagation(_cache, rfunction: "Endpoint", args, kwargs) -> Any: ...
55
+ def _cached_propagation(_cache, rfunction, args, kwargs) -> Any: ...
53
56
 
54
57
 
55
58
  @shim(module="monarch.common.fake")
monarch/actor/__init__.py CHANGED
@@ -12,6 +12,7 @@ from monarch._src.actor.actor_mesh import (
12
12
  Accumulator,
13
13
  Actor,
14
14
  ActorError,
15
+ as_endpoint,
15
16
  current_actor_name,
16
17
  current_rank,
17
18
  current_size,
@@ -35,6 +36,7 @@ __all__ = [
35
36
  "Actor",
36
37
  "ActorError",
37
38
  "current_actor_name",
39
+ "as_endpoint",
38
40
  "current_rank",
39
41
  "current_size",
40
42
  "endpoint",
@@ -435,6 +435,15 @@ class SendResultOfActorCall(NamedTuple):
435
435
  stream: tensor_worker.StreamRef
436
436
 
437
437
 
438
+ class CallActorMethod(NamedTuple):
439
+ seq: int
440
+ result: object
441
+ broker_id: Tuple[str, int]
442
+ local_state: Sequence[Tensor | tensor_worker.Ref]
443
+ mutates: List[tensor_worker.Ref]
444
+ stream: tensor_worker.StreamRef
445
+
446
+
438
447
  class SplitComm(NamedTuple):
439
448
  dims: Dims
440
449
  device_mesh: DeviceMesh
monarch/common/remote.py CHANGED
@@ -157,7 +157,7 @@ class Remote(Generic[P, R], Endpoint[P, R]):
157
157
  def _maybe_resolvable(self):
158
158
  return None if self._remote_impl is None else self._resolvable
159
159
 
160
- def rref(self, *args: P.args, **kwargs: P.kwargs) -> R:
160
+ def _rref(self, args, kwargs):
161
161
  return dtensor_dispatch(
162
162
  self._resolvable,
163
163
  self._propagate,
@@ -352,7 +352,7 @@ _miss = 0
352
352
  _hit = 0
353
353
 
354
354
 
355
- def _cached_propagation(_cache, rfunction: Endpoint, args, kwargs):
355
+ def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs):
356
356
  tensors, shape_key = hashable_tensor_flatten(args, kwargs)
357
357
  # pyre-ignore
358
358
  inputs_group = TensorGroup([t._fake for t in tensors])
Binary file