torchmonarch-nightly 2025.7.25__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/_rust_bindings.so +0 -0
- monarch/_src/actor/actor_mesh.py +109 -52
- monarch/_src/actor/endpoint.py +99 -8
- monarch/_src/actor/event_loop.py +1 -1
- monarch/_src/actor/proc_mesh.py +17 -9
- monarch/_src/actor/tensor_engine_shim.py +5 -2
- monarch/actor/__init__.py +2 -0
- monarch/common/messages.py +9 -0
- monarch/common/remote.py +2 -2
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/mesh_controller.py +76 -14
- monarch/monarch_controller +0 -0
- monarch/tools/cli.py +2 -2
- monarch/tools/commands.py +49 -27
- monarch/tools/components/hyperactor.py +5 -3
- monarch/tools/config/__init__.py +18 -1
- monarch/tools/config/defaults.py +2 -2
- monarch/tools/mesh_spec.py +4 -1
- tests/test_allocator.py +11 -15
- tests/test_env_before_cuda.py +2 -3
- tests/test_python_actors.py +12 -0
- tests/test_tensor_engine.py +27 -1
- {torchmonarch_nightly-2025.7.25.dist-info → torchmonarch_nightly-2025.7.27.dist-info}/METADATA +34 -1
- {torchmonarch_nightly-2025.7.25.dist-info → torchmonarch_nightly-2025.7.27.dist-info}/RECORD +28 -28
- {torchmonarch_nightly-2025.7.25.dist-info → torchmonarch_nightly-2025.7.27.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.7.25.dist-info → torchmonarch_nightly-2025.7.27.dist-info}/entry_points.txt +0 -0
- {torchmonarch_nightly-2025.7.25.dist-info → torchmonarch_nightly-2025.7.27.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.7.25.dist-info → torchmonarch_nightly-2025.7.27.dist-info}/top_level.txt +0 -0
monarch/_rust_bindings.so
CHANGED
Binary file
|
monarch/_src/actor/actor_mesh.py
CHANGED
@@ -29,8 +29,10 @@ from typing import (
|
|
29
29
|
Iterable,
|
30
30
|
Iterator,
|
31
31
|
List,
|
32
|
+
Literal,
|
32
33
|
NamedTuple,
|
33
34
|
Optional,
|
35
|
+
overload,
|
34
36
|
ParamSpec,
|
35
37
|
Tuple,
|
36
38
|
Type,
|
@@ -39,6 +41,7 @@ from typing import (
|
|
39
41
|
)
|
40
42
|
|
41
43
|
from monarch._rust_bindings.monarch_hyperactor.actor import (
|
44
|
+
MethodSpecifier,
|
42
45
|
PanicFlag,
|
43
46
|
PythonMessage,
|
44
47
|
PythonMessageKind,
|
@@ -65,6 +68,7 @@ from monarch._src.actor.endpoint import (
|
|
65
68
|
Endpoint,
|
66
69
|
EndpointProperty,
|
67
70
|
Extent,
|
71
|
+
NotAnEndpoint,
|
68
72
|
Propagator,
|
69
73
|
Selection,
|
70
74
|
)
|
@@ -76,7 +80,7 @@ from monarch._src.actor.pickle import flatten, unflatten
|
|
76
80
|
from monarch._src.actor.shape import MeshTrait, NDSlice
|
77
81
|
from monarch._src.actor.sync_state import fake_sync_state
|
78
82
|
|
79
|
-
from monarch._src.actor.tensor_engine_shim import actor_send
|
83
|
+
from monarch._src.actor.tensor_engine_shim import actor_rref, actor_send
|
80
84
|
|
81
85
|
if TYPE_CHECKING:
|
82
86
|
from monarch._src.actor.proc_mesh import ProcMesh
|
@@ -281,16 +285,18 @@ class ActorEndpoint(Endpoint[P, R]):
|
|
281
285
|
def __init__(
|
282
286
|
self,
|
283
287
|
actor_mesh_ref: _ActorMeshRefImpl,
|
284
|
-
name:
|
288
|
+
name: MethodSpecifier,
|
285
289
|
impl: Callable[Concatenate[Any, P], Awaitable[R]],
|
286
290
|
mailbox: Mailbox,
|
287
|
-
propagator: Propagator
|
291
|
+
propagator: Propagator,
|
292
|
+
explicit_response_port: bool,
|
288
293
|
) -> None:
|
289
294
|
super().__init__(propagator)
|
290
295
|
self._actor_mesh = actor_mesh_ref
|
291
296
|
self._name = name
|
292
297
|
self._signature: inspect.Signature = inspect.signature(impl)
|
293
298
|
self._mailbox = mailbox
|
299
|
+
self._explicit_response_port = explicit_response_port
|
294
300
|
|
295
301
|
def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
|
296
302
|
mesh = self._actor_mesh._actor_mesh
|
@@ -299,6 +305,12 @@ class ActorEndpoint(Endpoint[P, R]):
|
|
299
305
|
def _call_name(self) -> Any:
|
300
306
|
return self._name
|
301
307
|
|
308
|
+
def _check_arguments(self, args, kwargs):
|
309
|
+
if self._explicit_response_port:
|
310
|
+
self._signature.bind(None, None, *args, **kwargs)
|
311
|
+
else:
|
312
|
+
self._signature.bind(None, *args, **kwargs)
|
313
|
+
|
302
314
|
def _send(
|
303
315
|
self,
|
304
316
|
args: Tuple[Any, ...],
|
@@ -311,10 +323,9 @@ class ActorEndpoint(Endpoint[P, R]):
|
|
311
323
|
|
312
324
|
This sends the message to all actors but does not wait for any result.
|
313
325
|
"""
|
314
|
-
self.
|
326
|
+
self._check_arguments(args, kwargs)
|
315
327
|
objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
|
316
|
-
|
317
|
-
if not refs:
|
328
|
+
if all(not hasattr(obj, "__monarch_ref__") for obj in objects):
|
318
329
|
message = PythonMessage(
|
319
330
|
PythonMessageKind.CallMethod(
|
320
331
|
self._name, None if port is None else port._port_ref
|
@@ -323,7 +334,7 @@ class ActorEndpoint(Endpoint[P, R]):
|
|
323
334
|
)
|
324
335
|
self._actor_mesh.cast(message, selection)
|
325
336
|
else:
|
326
|
-
actor_send(self, bytes,
|
337
|
+
actor_send(self, bytes, objects, port, selection)
|
327
338
|
shape = self._actor_mesh._shape
|
328
339
|
return Extent(shape.labels, shape.ndslice.sizes)
|
329
340
|
|
@@ -335,6 +346,53 @@ class ActorEndpoint(Endpoint[P, R]):
|
|
335
346
|
), "unexpected receiver type"
|
336
347
|
return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
|
337
348
|
|
349
|
+
def _rref(self, args, kwargs):
|
350
|
+
self._check_arguments(args, kwargs)
|
351
|
+
refs, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
|
352
|
+
|
353
|
+
return actor_rref(self, bytes, refs)
|
354
|
+
|
355
|
+
|
356
|
+
@overload
|
357
|
+
def as_endpoint(
|
358
|
+
not_an_endpoint: Callable[P, R],
|
359
|
+
*,
|
360
|
+
propagate: Propagator = None,
|
361
|
+
explicit_response_port: Literal[False] = False,
|
362
|
+
) -> Endpoint[P, R]: ...
|
363
|
+
|
364
|
+
|
365
|
+
@overload
|
366
|
+
def as_endpoint(
|
367
|
+
not_an_endpoint: Callable[Concatenate["PortProtocol[R]", P], None],
|
368
|
+
*,
|
369
|
+
propagate: Propagator = None,
|
370
|
+
explicit_response_port: Literal[True],
|
371
|
+
) -> Endpoint[P, R]: ...
|
372
|
+
|
373
|
+
|
374
|
+
def as_endpoint(
|
375
|
+
not_an_endpoint: Any,
|
376
|
+
*,
|
377
|
+
propagate: Propagator = None,
|
378
|
+
explicit_response_port: bool = False,
|
379
|
+
):
|
380
|
+
if not isinstance(not_an_endpoint, NotAnEndpoint):
|
381
|
+
raise ValueError("expected an method of a spawned actor")
|
382
|
+
kind = (
|
383
|
+
MethodSpecifier.ExplicitPort
|
384
|
+
if explicit_response_port
|
385
|
+
else MethodSpecifier.ReturnsResponse
|
386
|
+
)
|
387
|
+
return ActorEndpoint(
|
388
|
+
not_an_endpoint._ref._actor_mesh_ref,
|
389
|
+
kind(not_an_endpoint._name),
|
390
|
+
getattr(not_an_endpoint._ref, not_an_endpoint._name),
|
391
|
+
not_an_endpoint._ref._mailbox,
|
392
|
+
propagate,
|
393
|
+
explicit_response_port,
|
394
|
+
)
|
395
|
+
|
338
396
|
|
339
397
|
class Accumulator(Generic[P, R, A]):
|
340
398
|
def __init__(
|
@@ -578,7 +636,7 @@ class _Actor:
|
|
578
636
|
mailbox: Mailbox,
|
579
637
|
rank: int,
|
580
638
|
shape: Shape,
|
581
|
-
|
639
|
+
method_spec: MethodSpecifier,
|
582
640
|
message: bytes,
|
583
641
|
panic_flag: PanicFlag,
|
584
642
|
local_state: Iterable[Any],
|
@@ -596,17 +654,23 @@ class _Actor:
|
|
596
654
|
|
597
655
|
args, kwargs = unflatten(message, local_state)
|
598
656
|
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
657
|
+
match method_spec:
|
658
|
+
case MethodSpecifier.Init():
|
659
|
+
Class, *args = args
|
660
|
+
try:
|
661
|
+
self.instance = Class(*args, **kwargs)
|
662
|
+
except Exception as e:
|
663
|
+
self._saved_error = ActorError(
|
664
|
+
e, f"Remote actor {Class}.__init__ call failed."
|
665
|
+
)
|
666
|
+
raise e
|
667
|
+
port.send(None)
|
668
|
+
return None
|
669
|
+
case MethodSpecifier.ReturnsResponse(name=method):
|
670
|
+
pass
|
671
|
+
case MethodSpecifier.ExplicitPort(name=method):
|
672
|
+
args = (port, *args)
|
673
|
+
port = DroppingPort()
|
610
674
|
|
611
675
|
if self.instance is None:
|
612
676
|
# This could happen because of the following reasons. Both
|
@@ -625,18 +689,23 @@ class _Actor:
|
|
625
689
|
f" This is likely due to an earlier error: {self._saved_error}"
|
626
690
|
)
|
627
691
|
raise AssertionError(error_message)
|
628
|
-
the_method = getattr(self.instance, method)
|
692
|
+
the_method = getattr(self.instance, method)
|
693
|
+
if isinstance(the_method, EndpointProperty):
|
694
|
+
module = the_method._method.__module__
|
695
|
+
the_method = functools.partial(the_method._method, self.instance)
|
696
|
+
else:
|
697
|
+
module = the_method.__module__
|
629
698
|
|
630
699
|
if inspect.iscoroutinefunction(the_method):
|
631
700
|
|
632
701
|
async def instrumented():
|
633
702
|
enter_span(
|
634
|
-
|
703
|
+
module,
|
635
704
|
method,
|
636
705
|
str(ctx.mailbox.actor_id),
|
637
706
|
)
|
638
707
|
try:
|
639
|
-
result = await the_method(
|
708
|
+
result = await the_method(*args, **kwargs)
|
640
709
|
self._maybe_exit_debugger()
|
641
710
|
except Exception as e:
|
642
711
|
logging.critical(
|
@@ -649,9 +718,9 @@ class _Actor:
|
|
649
718
|
|
650
719
|
result = await instrumented()
|
651
720
|
else:
|
652
|
-
enter_span(
|
721
|
+
enter_span(module, method, str(ctx.mailbox.actor_id))
|
653
722
|
with fake_sync_state():
|
654
|
-
result = the_method(
|
723
|
+
result = the_method(*args, **kwargs)
|
655
724
|
self._maybe_exit_debugger()
|
656
725
|
exit_span()
|
657
726
|
|
@@ -750,43 +819,29 @@ class ActorMeshRef(MeshTrait):
|
|
750
819
|
for attr_name in dir(self._class):
|
751
820
|
attr_value = getattr(self._class, attr_name, None)
|
752
821
|
if isinstance(attr_value, EndpointProperty):
|
822
|
+
# Convert string method name to appropriate MethodSpecifier
|
823
|
+
kind = (
|
824
|
+
MethodSpecifier.ExplicitPort
|
825
|
+
if attr_value._explicit_response_port
|
826
|
+
else MethodSpecifier.ReturnsResponse
|
827
|
+
)
|
753
828
|
setattr(
|
754
829
|
self,
|
755
830
|
attr_name,
|
756
831
|
ActorEndpoint(
|
757
832
|
self._actor_mesh_ref,
|
758
|
-
attr_name,
|
833
|
+
kind(attr_name),
|
759
834
|
attr_value._method,
|
760
835
|
self._mailbox,
|
836
|
+
attr_value._propagator,
|
837
|
+
attr_value._explicit_response_port,
|
761
838
|
),
|
762
839
|
)
|
763
840
|
|
764
|
-
def __getattr__(self,
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
# At runtime, we still want to raise AttributeError for truly missing attributes
|
769
|
-
|
770
|
-
# Check if this is a method on the underlying class
|
771
|
-
if hasattr(self._class, name):
|
772
|
-
attr = getattr(self._class, name)
|
773
|
-
if isinstance(attr, EndpointProperty):
|
774
|
-
# Dynamically create the endpoint
|
775
|
-
endpoint = ActorEndpoint(
|
776
|
-
self._actor_mesh_ref,
|
777
|
-
name,
|
778
|
-
attr._method,
|
779
|
-
self._mailbox,
|
780
|
-
propagator=attr._propagator,
|
781
|
-
)
|
782
|
-
# Cache it for future use
|
783
|
-
setattr(self, name, endpoint)
|
784
|
-
return endpoint
|
785
|
-
|
786
|
-
# If we get here, it's truly not found
|
787
|
-
raise AttributeError(
|
788
|
-
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
789
|
-
)
|
841
|
+
def __getattr__(self, attr: str) -> NotAnEndpoint:
|
842
|
+
if attr in dir(self._class):
|
843
|
+
return NotAnEndpoint(self, attr)
|
844
|
+
raise AttributeError(attr)
|
790
845
|
|
791
846
|
def _create(
|
792
847
|
self,
|
@@ -798,9 +853,11 @@ class ActorMeshRef(MeshTrait):
|
|
798
853
|
|
799
854
|
ep = ActorEndpoint(
|
800
855
|
self._actor_mesh_ref,
|
801
|
-
|
856
|
+
MethodSpecifier.Init(),
|
802
857
|
null_func,
|
803
858
|
self._mailbox,
|
859
|
+
None,
|
860
|
+
False,
|
804
861
|
)
|
805
862
|
send(ep, (self._class, *args), kwargs)
|
806
863
|
|
monarch/_src/actor/endpoint.py
CHANGED
@@ -34,6 +34,7 @@ from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
|
|
34
34
|
|
35
35
|
if TYPE_CHECKING:
|
36
36
|
from monarch._src.actor.actor_mesh import (
|
37
|
+
ActorMeshRef,
|
37
38
|
HyPortReceiver,
|
38
39
|
OncePortReceiver,
|
39
40
|
Port,
|
@@ -182,11 +183,22 @@ class Endpoint(ABC, Generic[P, R]):
|
|
182
183
|
# pyre-ignore
|
183
184
|
send(self, args, kwargs)
|
184
185
|
|
186
|
+
@abstractmethod
|
187
|
+
def _rref(self, args, kwargs) -> Any: ...
|
188
|
+
|
189
|
+
def rref(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
190
|
+
return self._rref(args, kwargs)
|
191
|
+
|
185
192
|
def _propagate(self, args, kwargs, fake_args, fake_kwargs):
|
186
193
|
if self._propagator_arg is None or self._propagator_arg == "cached":
|
187
194
|
if self._cache is None:
|
188
195
|
self._cache = {}
|
189
|
-
|
196
|
+
resolvable = getattr(self, "_resolvable", None)
|
197
|
+
if resolvable is None:
|
198
|
+
raise NotImplementedError(
|
199
|
+
"Cached propagation is not implemented for actor endpoints."
|
200
|
+
)
|
201
|
+
return _cached_propagation(self._cache, resolvable, args, kwargs)
|
190
202
|
elif self._propagator_arg == "inspect":
|
191
203
|
return None
|
192
204
|
elif self._propagator_arg == "mocked":
|
@@ -211,16 +223,23 @@ class EndpointProperty(Generic[P, R]):
|
|
211
223
|
self,
|
212
224
|
method: Callable[Concatenate[Any, P], Awaitable[R]],
|
213
225
|
propagator: Propagator,
|
226
|
+
explicit_response_port: bool,
|
214
227
|
) -> None: ...
|
215
228
|
|
216
229
|
@overload
|
217
230
|
def __init__(
|
218
|
-
self,
|
231
|
+
self,
|
232
|
+
method: Callable[Concatenate[Any, P], R],
|
233
|
+
propagator: Propagator,
|
234
|
+
explicit_response_port: bool,
|
219
235
|
) -> None: ...
|
220
236
|
|
221
|
-
def __init__(
|
237
|
+
def __init__(
|
238
|
+
self, method: Any, propagator: Propagator, explicit_response_port: bool
|
239
|
+
) -> None:
|
222
240
|
self._method = method
|
223
241
|
self._propagator = propagator
|
242
|
+
self._explicit_response_port = explicit_response_port
|
224
243
|
|
225
244
|
def __get__(self, instance, owner) -> Endpoint[P, R]:
|
226
245
|
# this is a total lie, but we have to actually
|
@@ -229,13 +248,50 @@ class EndpointProperty(Generic[P, R]):
|
|
229
248
|
return cast(Endpoint[P, R], self)
|
230
249
|
|
231
250
|
|
251
|
+
class NotAnEndpoint:
|
252
|
+
"""
|
253
|
+
Used as the dynamic value of functions on an ActorMeshRef that were not marked as endpoints.
|
254
|
+
This is used both to give a better error message (since we cannot prevent the type system from thinking they are methods),
|
255
|
+
and to provide the oppurtunity for someone to do endpoint(x.foo) on something that wasn't marked as an endpoint.
|
256
|
+
"""
|
257
|
+
|
258
|
+
def __init__(self, ref: "ActorMeshRef", name: str):
|
259
|
+
self._ref = ref
|
260
|
+
self._name = name
|
261
|
+
|
262
|
+
def __call__(self, *args, **kwargs) -> None:
|
263
|
+
raise RuntimeError(
|
264
|
+
f"Actor {self._ref._class}.{self._name} is not annotated as an endpoint. To call it as one, add a @endpoint decorator to it, or directly wrap it in one as_endpoint(obj.method).call(...)"
|
265
|
+
)
|
266
|
+
|
267
|
+
|
232
268
|
# This can't just be Callable because otherwise we are not
|
233
269
|
# allowed to use type arguments in the return value.
|
234
270
|
class EndpointIfy:
|
235
271
|
@overload
|
236
|
-
def __call__(
|
272
|
+
def __call__(
|
273
|
+
self, function: Callable[Concatenate[Any, P], Awaitable[R]]
|
274
|
+
) -> Endpoint[P, R]: ...
|
237
275
|
@overload
|
238
|
-
def __call__(
|
276
|
+
def __call__(
|
277
|
+
self, function: Callable[Concatenate[Any, P], R]
|
278
|
+
) -> Endpoint[P, R]: ...
|
279
|
+
|
280
|
+
def __call__(self, function: Any):
|
281
|
+
pass
|
282
|
+
|
283
|
+
|
284
|
+
class PortedEndpointIfy:
|
285
|
+
@overload
|
286
|
+
def __call__(
|
287
|
+
self,
|
288
|
+
function: Callable[Concatenate[Any, "Port[R]", P], Awaitable[None]],
|
289
|
+
) -> Endpoint[P, R]: ...
|
290
|
+
|
291
|
+
@overload
|
292
|
+
def __call__(
|
293
|
+
self, function: Callable[Concatenate[Any, "Port[R]", P], None]
|
294
|
+
) -> Endpoint[P, R]: ...
|
239
295
|
|
240
296
|
def __call__(self, function: Any):
|
241
297
|
pass
|
@@ -246,6 +302,7 @@ def endpoint(
|
|
246
302
|
method: Callable[Concatenate[Any, P], Awaitable[R]],
|
247
303
|
*,
|
248
304
|
propagate: Propagator = None,
|
305
|
+
explicit_response_port: Literal[False] = False,
|
249
306
|
) -> EndpointProperty[P, R]: ...
|
250
307
|
|
251
308
|
|
@@ -254,6 +311,7 @@ def endpoint(
|
|
254
311
|
method: Callable[Concatenate[Any, P], R],
|
255
312
|
*,
|
256
313
|
propagate: Propagator = None,
|
314
|
+
explicit_response_port: Literal[False] = False,
|
257
315
|
) -> EndpointProperty[P, R]: ...
|
258
316
|
|
259
317
|
|
@@ -261,10 +319,43 @@ def endpoint(
|
|
261
319
|
def endpoint(
|
262
320
|
*,
|
263
321
|
propagate: Propagator = None,
|
322
|
+
explicit_response_port: Literal[False] = False,
|
264
323
|
) -> EndpointIfy: ...
|
265
324
|
|
266
325
|
|
267
|
-
|
326
|
+
@overload
|
327
|
+
def endpoint(
|
328
|
+
method: Callable[Concatenate[Any, "Port[R]", P], Awaitable[None]],
|
329
|
+
*,
|
330
|
+
propagate: Propagator = None,
|
331
|
+
explicit_response_port: Literal[True],
|
332
|
+
) -> EndpointProperty[P, R]: ...
|
333
|
+
|
334
|
+
|
335
|
+
@overload
|
336
|
+
def endpoint(
|
337
|
+
method: Callable[Concatenate[Any, "Port[R]", P], None],
|
338
|
+
*,
|
339
|
+
propagate: Propagator = None,
|
340
|
+
explicit_response_port: Literal[True],
|
341
|
+
) -> EndpointProperty[P, R]: ...
|
342
|
+
|
343
|
+
|
344
|
+
@overload
|
345
|
+
def endpoint(
|
346
|
+
*,
|
347
|
+
propagate: Propagator = None,
|
348
|
+
explicit_response_port: Literal[True],
|
349
|
+
) -> PortedEndpointIfy: ...
|
350
|
+
|
351
|
+
|
352
|
+
def endpoint(method=None, *, propagate=None, explicit_response_port: bool = False):
|
268
353
|
if method is None:
|
269
|
-
return functools.partial(
|
270
|
-
|
354
|
+
return functools.partial(
|
355
|
+
endpoint,
|
356
|
+
propagate=propagate,
|
357
|
+
explicit_response_port=explicit_response_port,
|
358
|
+
)
|
359
|
+
return EndpointProperty(
|
360
|
+
method, propagator=propagate, explicit_response_port=explicit_response_port
|
361
|
+
)
|
monarch/_src/actor/event_loop.py
CHANGED
monarch/_src/actor/proc_mesh.py
CHANGED
@@ -43,7 +43,6 @@ from monarch._src.actor.actor_mesh import (
|
|
43
43
|
Actor,
|
44
44
|
ActorMeshRef,
|
45
45
|
fake_sync_state,
|
46
|
-
MonarchContext,
|
47
46
|
)
|
48
47
|
|
49
48
|
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator
|
@@ -89,7 +88,7 @@ class SetupActor(Actor):
|
|
89
88
|
Typically used to setup the environment variables.
|
90
89
|
"""
|
91
90
|
|
92
|
-
def __init__(self, env: Callable[[
|
91
|
+
def __init__(self, env: Callable[[], None]) -> None:
|
93
92
|
"""
|
94
93
|
Initialize the setup actor with the user defined setup method.
|
95
94
|
"""
|
@@ -100,8 +99,7 @@ class SetupActor(Actor):
|
|
100
99
|
"""
|
101
100
|
Call the user defined setup method with the monarch context.
|
102
101
|
"""
|
103
|
-
|
104
|
-
self._setup_method(ctx)
|
102
|
+
self._setup_method()
|
105
103
|
|
106
104
|
|
107
105
|
T = TypeVar("T")
|
@@ -114,7 +112,7 @@ except ImportError:
|
|
114
112
|
|
115
113
|
|
116
114
|
async def _allocate_nonblocking(
|
117
|
-
alloc: Alloc, setup: Callable[[
|
115
|
+
alloc: Alloc, setup: Callable[[], None] | None = None
|
118
116
|
) -> "ProcMesh":
|
119
117
|
_proc_mesh = await HyProcMesh.allocate_nonblocking(alloc)
|
120
118
|
if setup is None:
|
@@ -211,7 +209,7 @@ class ProcMesh(MeshTrait):
|
|
211
209
|
|
212
210
|
@classmethod
|
213
211
|
def from_alloc(
|
214
|
-
self, alloc: Alloc, setup: Callable[[
|
212
|
+
self, alloc: Alloc, setup: Callable[[], None] | None = None
|
215
213
|
) -> Future["ProcMesh"]:
|
216
214
|
"""
|
217
215
|
Allocate a process mesh according to the provided alloc.
|
@@ -219,7 +217,17 @@ class ProcMesh(MeshTrait):
|
|
219
217
|
|
220
218
|
Arguments:
|
221
219
|
- `alloc`: The alloc to allocate according to.
|
222
|
-
- `setup`:
|
220
|
+
- `setup`: An optional lambda function to configure environment variables on the allocated mesh.
|
221
|
+
Use the `current_rank()` method within the lambda to obtain the rank.
|
222
|
+
|
223
|
+
Example of a setup method to initialize torch distributed environment variables:
|
224
|
+
```
|
225
|
+
def setup():
|
226
|
+
rank = current_rank()
|
227
|
+
os.environ["RANK"] = str(rank)
|
228
|
+
os.environ["WORLD_SIZE"] = str(len(rank.shape))
|
229
|
+
os.environ["LOCAL_RANK"] = str(rank["gpus"])
|
230
|
+
```
|
223
231
|
"""
|
224
232
|
return Future(
|
225
233
|
impl=lambda: _allocate_nonblocking(alloc, setup),
|
@@ -428,7 +436,7 @@ async def proc_mesh_nonblocking(
|
|
428
436
|
gpus: Optional[int] = None,
|
429
437
|
hosts: int = 1,
|
430
438
|
env: dict[str, str] | None = None,
|
431
|
-
setup: Callable[[
|
439
|
+
setup: Callable[[], None] | None = None,
|
432
440
|
) -> ProcMesh:
|
433
441
|
if gpus is None:
|
434
442
|
gpus = _local_device_count()
|
@@ -457,7 +465,7 @@ def proc_mesh(
|
|
457
465
|
gpus: Optional[int] = None,
|
458
466
|
hosts: int = 1,
|
459
467
|
env: dict[str, str] | None = None,
|
460
|
-
setup: Callable[[
|
468
|
+
setup: Callable[[], None] | None = None,
|
461
469
|
) -> Future[ProcMesh]:
|
462
470
|
return Future(
|
463
471
|
impl=lambda: proc_mesh_nonblocking(
|
@@ -19,7 +19,6 @@ time it is used.
|
|
19
19
|
|
20
20
|
if TYPE_CHECKING:
|
21
21
|
from monarch._src.actor.actor_mesh import ActorEndpoint, Port, Selection
|
22
|
-
from monarch._src.actor.endpoint import Endpoint
|
23
22
|
|
24
23
|
|
25
24
|
def shim(fn=None, *, module=None):
|
@@ -48,8 +47,12 @@ def actor_send(
|
|
48
47
|
) -> None: ...
|
49
48
|
|
50
49
|
|
50
|
+
@shim(module="monarch.mesh_controller")
|
51
|
+
def actor_rref(endpoint, args_kwargs_tuple: bytes, refs: Sequence[Any]): ...
|
52
|
+
|
53
|
+
|
51
54
|
@shim(module="monarch.common.remote")
|
52
|
-
def _cached_propagation(_cache, rfunction
|
55
|
+
def _cached_propagation(_cache, rfunction, args, kwargs) -> Any: ...
|
53
56
|
|
54
57
|
|
55
58
|
@shim(module="monarch.common.fake")
|
monarch/actor/__init__.py
CHANGED
@@ -12,6 +12,7 @@ from monarch._src.actor.actor_mesh import (
|
|
12
12
|
Accumulator,
|
13
13
|
Actor,
|
14
14
|
ActorError,
|
15
|
+
as_endpoint,
|
15
16
|
current_actor_name,
|
16
17
|
current_rank,
|
17
18
|
current_size,
|
@@ -35,6 +36,7 @@ __all__ = [
|
|
35
36
|
"Actor",
|
36
37
|
"ActorError",
|
37
38
|
"current_actor_name",
|
39
|
+
"as_endpoint",
|
38
40
|
"current_rank",
|
39
41
|
"current_size",
|
40
42
|
"endpoint",
|
monarch/common/messages.py
CHANGED
@@ -435,6 +435,15 @@ class SendResultOfActorCall(NamedTuple):
|
|
435
435
|
stream: tensor_worker.StreamRef
|
436
436
|
|
437
437
|
|
438
|
+
class CallActorMethod(NamedTuple):
|
439
|
+
seq: int
|
440
|
+
result: object
|
441
|
+
broker_id: Tuple[str, int]
|
442
|
+
local_state: Sequence[Tensor | tensor_worker.Ref]
|
443
|
+
mutates: List[tensor_worker.Ref]
|
444
|
+
stream: tensor_worker.StreamRef
|
445
|
+
|
446
|
+
|
438
447
|
class SplitComm(NamedTuple):
|
439
448
|
dims: Dims
|
440
449
|
device_mesh: DeviceMesh
|
monarch/common/remote.py
CHANGED
@@ -157,7 +157,7 @@ class Remote(Generic[P, R], Endpoint[P, R]):
|
|
157
157
|
def _maybe_resolvable(self):
|
158
158
|
return None if self._remote_impl is None else self._resolvable
|
159
159
|
|
160
|
-
def
|
160
|
+
def _rref(self, args, kwargs):
|
161
161
|
return dtensor_dispatch(
|
162
162
|
self._resolvable,
|
163
163
|
self._propagate,
|
@@ -352,7 +352,7 @@ _miss = 0
|
|
352
352
|
_hit = 0
|
353
353
|
|
354
354
|
|
355
|
-
def _cached_propagation(_cache, rfunction:
|
355
|
+
def _cached_propagation(_cache, rfunction: ResolvableFunction, args, kwargs):
|
356
356
|
tensors, shape_key = hashable_tensor_flatten(args, kwargs)
|
357
357
|
# pyre-ignore
|
358
358
|
inputs_group = TensorGroup([t._fake for t in tensors])
|
Binary file
|