torchmonarch-nightly 2025.7.1__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.26__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +878 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +303 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +508 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +59 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +53 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +21 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/gradient/_gradient_generator.so +0 -0
  37. monarch/mesh_controller.py +263 -139
  38. monarch/monarch_controller +0 -0
  39. monarch/opaque_module.py +4 -6
  40. monarch/opaque_object.py +3 -3
  41. monarch/proc_mesh.py +6 -309
  42. monarch/python_local_mesh.py +1 -1
  43. monarch/rust_backend_mesh.py +2 -1
  44. monarch/rust_local_mesh.py +4 -2
  45. monarch/sim_mesh.py +10 -19
  46. monarch/simulator/command_history.py +1 -1
  47. monarch/simulator/interface.py +2 -1
  48. monarch/simulator/mock_controller.py +1 -1
  49. monarch/simulator/simulator.py +1 -1
  50. monarch/tensor_engine/__init__.py +23 -0
  51. monarch/tensor_worker_main.py +3 -1
  52. monarch/tools/cli.py +3 -1
  53. monarch/tools/commands.py +129 -47
  54. monarch/tools/components/hyperactor.py +5 -3
  55. monarch/tools/config/__init__.py +18 -1
  56. monarch/tools/config/defaults.py +2 -2
  57. monarch/tools/mesh_spec.py +59 -1
  58. monarch/tools/utils.py +38 -0
  59. monarch/worker/worker.py +1 -1
  60. monarch/world_mesh.py +2 -1
  61. monarch_supervisor/python_executable.py +6 -3
  62. tests/error_test_binary.py +48 -10
  63. tests/test_actor_error.py +370 -21
  64. tests/test_alloc.py +1 -1
  65. tests/test_allocator.py +369 -17
  66. tests/test_controller.py +2 -0
  67. tests/test_debugger.py +416 -0
  68. tests/test_env_before_cuda.py +161 -0
  69. tests/test_python_actors.py +184 -333
  70. tests/test_rdma.py +198 -0
  71. tests/test_remote_functions.py +40 -12
  72. tests/test_rust_backend.py +7 -5
  73. tests/test_sim_backend.py +1 -4
  74. tests/test_tensor_engine.py +81 -1
  75. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/METADATA +39 -1
  76. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/RECORD +84 -72
  77. torchmonarch_nightly-2025.7.26.dist-info/entry_points.txt +3 -0
  78. monarch/_monarch/hyperactor/__init__.py +0 -58
  79. monarch/_monarch/worker/debugger.py +0 -117
  80. monarch/_monarch/worker/logging.py +0 -107
  81. monarch/debugger.py +0 -379
  82. monarch/future.py +0 -76
  83. monarch/rdma.py +0 -162
  84. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  85. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  86. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  87. /monarch/{common → _src/actor}/shape.py +0 -0
  88. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  89. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/WHEEL +0 -0
  90. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/licenses/LICENSE +0 -0
  91. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,878 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import collections
10
+ import contextvars
11
+ import functools
12
+ import inspect
13
+ import itertools
14
+ import logging
15
+ import random
16
+ import traceback
17
+
18
+ from dataclasses import dataclass
19
+ from traceback import extract_tb, StackSummary
20
+ from typing import (
21
+ Any,
22
+ AsyncGenerator,
23
+ Awaitable,
24
+ Callable,
25
+ cast,
26
+ Concatenate,
27
+ Dict,
28
+ Generic,
29
+ Iterable,
30
+ Iterator,
31
+ List,
32
+ NamedTuple,
33
+ Optional,
34
+ ParamSpec,
35
+ Tuple,
36
+ Type,
37
+ TYPE_CHECKING,
38
+ TypeVar,
39
+ )
40
+
41
+ from monarch._rust_bindings.monarch_hyperactor.actor import (
42
+ PanicFlag,
43
+ PythonMessage,
44
+ PythonMessageKind,
45
+ )
46
+ from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
47
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
48
+ Mailbox,
49
+ OncePortReceiver,
50
+ OncePortRef,
51
+ PortReceiver as HyPortReceiver,
52
+ PortRef,
53
+ )
54
+
55
+ if TYPE_CHECKING:
56
+ from monarch._rust_bindings.monarch_hyperactor.actor import PortProtocol
57
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import PortReceiverBase
58
+
59
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
60
+ from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
61
+ from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError
62
+ from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span
63
+ from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
64
+ from monarch._src.actor.endpoint import (
65
+ Endpoint,
66
+ EndpointProperty,
67
+ Extent,
68
+ NotAnEndpoint,
69
+ Propagator,
70
+ Selection,
71
+ )
72
+ from monarch._src.actor.future import Future
73
+ from monarch._src.actor.pdb_wrapper import PdbWrapper
74
+
75
+ from monarch._src.actor.pickle import flatten, unflatten
76
+
77
+ from monarch._src.actor.shape import MeshTrait, NDSlice
78
+ from monarch._src.actor.sync_state import fake_sync_state
79
+
80
+ from monarch._src.actor.tensor_engine_shim import actor_rref, actor_send
81
+
82
+ if TYPE_CHECKING:
83
+ from monarch._src.actor.proc_mesh import ProcMesh
84
+
85
+ logger: logging.Logger = logging.getLogger(__name__)
86
+
87
+ Allocator = ProcessAllocator | LocalAllocator
88
+
89
+ try:
90
+ from __manifest__ import fbmake # noqa
91
+
92
+ IN_PAR = bool(fbmake.get("par_style"))
93
+ except ImportError:
94
+ IN_PAR = False
95
+
96
+ T1 = TypeVar("T1")
97
+ T2 = TypeVar("T2")
98
+
99
+
100
+ class Point(HyPoint, collections.abc.Mapping):
101
+ pass
102
+
103
+
104
+ @dataclass
105
+ class MonarchContext:
106
+ mailbox: Mailbox
107
+ proc_id: str
108
+ point: Point
109
+
110
+ @staticmethod
111
+ def get() -> "MonarchContext":
112
+ return _context.get()
113
+
114
+
115
+ _context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
116
+ "monarch.actor_mesh._context"
117
+ )
118
+
119
+
120
+ @dataclass
121
+ class DebugContext:
122
+ pdb_wrapper: Optional[PdbWrapper] = None
123
+
124
+ @staticmethod
125
+ def get() -> "DebugContext":
126
+ return _debug_context.get()
127
+
128
+ @staticmethod
129
+ def set(debug_context: "DebugContext") -> None:
130
+ _debug_context.set(debug_context)
131
+
132
+
133
+ _debug_context: contextvars.ContextVar[DebugContext] = contextvars.ContextVar(
134
+ "monarch.actor_mesh._debug_context"
135
+ )
136
+
137
+ T = TypeVar("T")
138
+ P = ParamSpec("P")
139
+ R = TypeVar("R")
140
+ A = TypeVar("A")
141
+
142
+ # keep this load balancing deterministic, but
143
+ # equally distributed.
144
+ _load_balancing_seed = random.Random(4)
145
+
146
+
147
+ # standin class for whatever is the serializable python object we use
148
+ # to name an actor mesh. Hacked up today because ActorMesh
149
+ # isn't plumbed to non-clients
150
+ class _ActorMeshRefImpl:
151
+ def __init__(
152
+ self,
153
+ mailbox: Mailbox,
154
+ hy_actor_mesh: Optional[PythonActorMesh],
155
+ proc_mesh: "Optional[ProcMesh]",
156
+ shape: Shape,
157
+ actor_ids: List[ActorId],
158
+ ) -> None:
159
+ self._mailbox = mailbox
160
+ self._actor_mesh = hy_actor_mesh
161
+ # actor meshes do not have a way to look this up at the moment,
162
+ # so we fake it here
163
+ self._proc_mesh = proc_mesh
164
+ self._shape = shape
165
+ self._please_replace_me_actor_ids = actor_ids
166
+
167
+ @staticmethod
168
+ def from_hyperactor_mesh(
169
+ mailbox: Mailbox, hy_actor_mesh: PythonActorMesh, proc_mesh: "ProcMesh"
170
+ ) -> "_ActorMeshRefImpl":
171
+ shape: Shape = hy_actor_mesh.shape
172
+ return _ActorMeshRefImpl(
173
+ mailbox,
174
+ hy_actor_mesh,
175
+ proc_mesh,
176
+ hy_actor_mesh.shape,
177
+ [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))],
178
+ )
179
+
180
+ @staticmethod
181
+ def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl":
182
+ return _ActorMeshRefImpl(mailbox, None, None, singleton_shape, [actor_id])
183
+
184
+ @staticmethod
185
+ def from_actor_ref_with_shape(
186
+ ref: "_ActorMeshRefImpl", shape: Shape
187
+ ) -> "_ActorMeshRefImpl":
188
+ return _ActorMeshRefImpl(
189
+ ref._mailbox, None, None, shape, ref._please_replace_me_actor_ids
190
+ )
191
+
192
+ def __getstate__(
193
+ self,
194
+ ) -> Tuple[Shape, List[ActorId], Mailbox]:
195
+ return self._shape, self._please_replace_me_actor_ids, self._mailbox
196
+
197
+ def __setstate__(
198
+ self,
199
+ state: Tuple[Shape, List[ActorId], Mailbox],
200
+ ) -> None:
201
+ self._actor_mesh = None
202
+ self._shape, self._please_replace_me_actor_ids, self._mailbox = state
203
+
204
+ def _check_state(self) -> None:
205
+ # This is temporary until we have real cast integration here. We need to actively check
206
+ # supervision error here is because all communication is done through direct mailbox sending
207
+ # and not through comm actor casting.
208
+ # TODO: remove this when casting integration is done.
209
+ if self._actor_mesh is not None:
210
+ if self._actor_mesh.stopped:
211
+ raise SupervisionError(
212
+ "actor mesh is not in a healthy state: `ActorMesh` has been stopped"
213
+ )
214
+
215
+ event = self._actor_mesh.get_supervision_event()
216
+ if event is not None:
217
+ raise SupervisionError(f"actor mesh is not in a healthy state: {event}")
218
+
219
+ def send(self, rank: int, message: PythonMessage) -> None:
220
+ self._check_state()
221
+ actor = self._please_replace_me_actor_ids[rank]
222
+ self._mailbox.post(actor, message)
223
+
224
+ def cast(
225
+ self,
226
+ message: PythonMessage,
227
+ selection: Selection,
228
+ ) -> None:
229
+ self._check_state()
230
+
231
+ # TODO: use the actual actor mesh when available. We cannot currently use it
232
+ # directly because we risk bifurcating the message delivery paths from the same
233
+ # client, since slicing the mesh will produce a reference, which calls actors
234
+ # directly. The reason these paths are bifurcated is that actor meshes will
235
+ # use multicasting, while direct actor comms do not. Separately we need to decide
236
+ # whether actor meshes are ordered with actor references.
237
+ #
238
+ # The fix is to provide a first-class reference into Python, and always call "cast"
239
+ # on it, including for load balanced requests.
240
+ if selection == "choose":
241
+ idx = _load_balancing_seed.randrange(len(self._shape))
242
+ actor_rank = self._shape.ndslice[idx]
243
+ self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message)
244
+ elif selection == "all":
245
+ # replace me with actual remote actor mesh
246
+ call_shape = Shape(
247
+ self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes)
248
+ )
249
+ for i, rank in enumerate(self._shape.ranks()):
250
+ self._mailbox.post_cast(
251
+ self._please_replace_me_actor_ids[rank],
252
+ i,
253
+ call_shape,
254
+ message,
255
+ )
256
+ elif isinstance(selection, int):
257
+ try:
258
+ self._mailbox.post(
259
+ self._please_replace_me_actor_ids[selection], message
260
+ )
261
+ except IndexError:
262
+ raise IndexError(
263
+ f"Tried to send to an out-of-range rank {selection}: "
264
+ f"mesh has {len(self._please_replace_me_actor_ids)} elements."
265
+ )
266
+ else:
267
+ raise ValueError(f"invalid selection: {selection}")
268
+
269
+ def __len__(self) -> int:
270
+ return len(self._shape)
271
+
272
+ @property
273
+ def _name_pid(self):
274
+ actor_id0 = self._please_replace_me_actor_ids[0]
275
+ return actor_id0.actor_name, actor_id0.pid
276
+
277
+ async def stop(self):
278
+ await self._actor_mesh.stop()
279
+
280
+
281
+ class ActorEndpoint(Endpoint[P, R]):
282
+ def __init__(
283
+ self,
284
+ actor_mesh_ref: _ActorMeshRefImpl,
285
+ name: str,
286
+ impl: Callable[Concatenate[Any, P], Awaitable[R]],
287
+ mailbox: Mailbox,
288
+ propagator: Propagator = None,
289
+ ) -> None:
290
+ super().__init__(propagator)
291
+ self._actor_mesh = actor_mesh_ref
292
+ self._name = name
293
+ self._signature: inspect.Signature = inspect.signature(impl)
294
+ self._mailbox = mailbox
295
+
296
+ def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
297
+ mesh = self._actor_mesh._actor_mesh
298
+ return r if mesh is None else mesh.supervise(r)
299
+
300
+ def _call_name(self) -> Any:
301
+ return self._name
302
+
303
+ def _send(
304
+ self,
305
+ args: Tuple[Any, ...],
306
+ kwargs: Dict[str, Any],
307
+ port: "Optional[Port]" = None,
308
+ selection: Selection = "all",
309
+ ) -> Extent:
310
+ """
311
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
312
+
313
+ This sends the message to all actors but does not wait for any result.
314
+ """
315
+ self._signature.bind(None, *args, **kwargs)
316
+ objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
317
+ if all(not hasattr(obj, "__monarch_ref__") for obj in objects):
318
+ message = PythonMessage(
319
+ PythonMessageKind.CallMethod(
320
+ self._name, None if port is None else port._port_ref
321
+ ),
322
+ bytes,
323
+ )
324
+ self._actor_mesh.cast(message, selection)
325
+ else:
326
+ actor_send(self, bytes, objects, port, selection)
327
+ shape = self._actor_mesh._shape
328
+ return Extent(shape.labels, shape.ndslice.sizes)
329
+
330
+ def _port(self, once: bool = False) -> "PortTuple[R]":
331
+ p, r = PortTuple.create(self._mailbox, once)
332
+ if TYPE_CHECKING:
333
+ assert isinstance(
334
+ r._receiver, (HyPortReceiver | OncePortReceiver)
335
+ ), "unexpected receiver type"
336
+ return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
337
+
338
+ def _rref(self, args, kwargs):
339
+ self._signature.bind(None, *args, **kwargs)
340
+ refs, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
341
+
342
+ return actor_rref(self, bytes, refs)
343
+
344
+
345
+ def as_endpoint(
346
+ not_an_endpoint: Callable[P, R], *, propagate: Propagator = None
347
+ ) -> Endpoint[P, R]:
348
+ if not isinstance(not_an_endpoint, NotAnEndpoint):
349
+ raise ValueError("expected an method of a spawned actor")
350
+ return ActorEndpoint(
351
+ not_an_endpoint._ref._actor_mesh_ref,
352
+ not_an_endpoint._name,
353
+ getattr(not_an_endpoint._ref, not_an_endpoint._name),
354
+ not_an_endpoint._ref._mailbox,
355
+ propagate,
356
+ )
357
+
358
+
359
+ class Accumulator(Generic[P, R, A]):
360
+ def __init__(
361
+ self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
362
+ ) -> None:
363
+ self._endpoint: Endpoint[P, R] = endpoint
364
+ self._identity: A = identity
365
+ self._combine: Callable[[A, R], A] = combine
366
+
367
+ def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
368
+ gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs)
369
+
370
+ async def impl() -> A:
371
+ value = self._identity
372
+ async for x in gen:
373
+ value = self._combine(value, x)
374
+ return value
375
+
376
+ return Future(impl=impl)
377
+
378
+
379
+ class ValueMesh(MeshTrait, Generic[R]):
380
+ """
381
+ Container of return values, indexed by rank.
382
+ """
383
+
384
+ def __init__(self, shape: Shape, values: List[R]) -> None:
385
+ self._shape = shape
386
+ self._values = values
387
+
388
+ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
389
+ return ValueMesh(shape, self._values)
390
+
391
+ def item(self, **kwargs) -> R:
392
+ coordinates = [kwargs.pop(label) for label in self._labels]
393
+ if kwargs:
394
+ raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
395
+
396
+ return self._values[self._ndslice.nditem(coordinates)]
397
+
398
+ def items(self) -> Iterable[Tuple[Point, R]]:
399
+ for rank in self._shape.ranks():
400
+ yield Point(rank, self._shape), self._values[rank]
401
+
402
+ def __iter__(self) -> Iterator[Tuple[Point, R]]:
403
+ return iter(self.items())
404
+
405
+ def __len__(self) -> int:
406
+ return len(self._shape)
407
+
408
+ def __repr__(self) -> str:
409
+ return f"ValueMesh({self._shape})"
410
+
411
+ @property
412
+ def _ndslice(self) -> NDSlice:
413
+ return self._shape.ndslice
414
+
415
+ @property
416
+ def _labels(self) -> Iterable[str]:
417
+ return self._shape.labels
418
+
419
+
420
+ def send(
421
+ endpoint: Endpoint[P, R],
422
+ args: Tuple[Any, ...],
423
+ kwargs: Dict[str, Any],
424
+ port: "Optional[Port]" = None,
425
+ selection: Selection = "all",
426
+ ) -> None:
427
+ """
428
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
429
+
430
+ This sends the message to all actors but does not wait for any result.
431
+ """
432
+ endpoint._send(args, kwargs, port, selection)
433
+
434
+
435
+ class Port(Generic[R]):
436
+ def __init__(
437
+ self,
438
+ port_ref: PortRef | OncePortRef,
439
+ mailbox: Mailbox,
440
+ rank: Optional[int],
441
+ ) -> None:
442
+ self._port_ref = port_ref
443
+ self._mailbox = mailbox
444
+ self._rank = rank
445
+
446
+ def send(self, obj: R) -> None:
447
+ self._port_ref.send(
448
+ self._mailbox,
449
+ PythonMessage(PythonMessageKind.Result(self._rank), _pickle(obj)),
450
+ )
451
+
452
+ def exception(self, obj: Exception) -> None:
453
+ # we deliver each error exactly once, so if there is no port to respond to,
454
+ # the error is sent to the current actor as an exception.
455
+ self._port_ref.send(
456
+ self._mailbox,
457
+ PythonMessage(PythonMessageKind.Exception(self._rank), _pickle(obj)),
458
+ )
459
+
460
+
461
+ class DroppingPort:
462
+ """
463
+ Used in place of a real port when the message has no response port.
464
+ Makes sure any exception sent to it causes the actor to report an exception.
465
+ """
466
+
467
+ def __init__(self):
468
+ pass
469
+
470
+ def send(self, obj: Any) -> None:
471
+ pass
472
+
473
+ def exception(self, obj: Exception) -> None:
474
+ # we deliver each error exactly once, so if there is no port to respond to,
475
+ # the error is sent to the current actor as an exception.
476
+ raise obj from None
477
+
478
+
479
+ R = TypeVar("R")
480
+
481
+ T = TypeVar("T")
482
+
483
+ if TYPE_CHECKING:
484
+ # Python <= 3.10 cannot inherit from Generic[R] and NamedTuple at the same time.
485
+ # we only need it for type checking though, so copypasta it until 3.11.
486
+ class PortTuple(NamedTuple, Generic[R]):
487
+ sender: "Port[R]"
488
+ receiver: "PortReceiver[R]"
489
+
490
+ @staticmethod
491
+ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
492
+ handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
493
+ port_ref = handle.bind()
494
+ return PortTuple(
495
+ Port(port_ref, mailbox, rank=None),
496
+ PortReceiver(mailbox, receiver),
497
+ )
498
+ else:
499
+
500
+ class PortTuple(NamedTuple):
501
+ sender: "Port[Any]"
502
+ receiver: "PortReceiver[Any]"
503
+
504
+ @staticmethod
505
+ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
506
+ handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
507
+ port_ref = handle.bind()
508
+ return PortTuple(
509
+ Port(port_ref, mailbox, rank=None),
510
+ PortReceiver(mailbox, receiver),
511
+ )
512
+
513
+
514
+ # advance lower-level API for sending messages. This is intentially
515
+ # not part of the Endpoint API because they way it accepts arguments
516
+ # and handles concerns is different.
517
+ def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]":
518
+ return endpoint._port(once)
519
+
520
+
521
+ def ranked_port(
522
+ endpoint: Endpoint[P, R], once: bool = False
523
+ ) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
524
+ p, receiver = port(endpoint, once)
525
+ return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
526
+
527
+
528
+ class PortReceiver(Generic[R]):
529
+ def __init__(
530
+ self,
531
+ mailbox: Mailbox,
532
+ receiver: "PortReceiverBase",
533
+ ) -> None:
534
+ self._mailbox: Mailbox = mailbox
535
+ self._receiver = receiver
536
+
537
+ async def _recv(self) -> R:
538
+ return self._process(await self._receiver.recv_task())
539
+
540
+ def _process(self, msg: PythonMessage) -> R:
541
+ # TODO: Try to do something more structured than a cast here
542
+ payload = cast(R, unflatten(msg.message, itertools.repeat(self._mailbox)))
543
+ match msg.kind:
544
+ case PythonMessageKind.Result():
545
+ return payload
546
+ case PythonMessageKind.Exception():
547
+ raise cast(Exception, payload)
548
+ case _:
549
+ raise ValueError(f"Unexpected message kind: {msg.kind}")
550
+
551
+ def recv(self) -> "Future[R]":
552
+ return Future(impl=lambda: self._recv(), requires_loop=False)
553
+
554
+
555
+ class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
556
+ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
557
+ rank = getattr(msg.kind, "rank", None)
558
+ if rank is None:
559
+ raise ValueError(
560
+ f"RankedPort receiver got a message without a rank {msg}",
561
+ )
562
+ return rank, super()._process(msg)
563
+
564
+
565
+ singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
566
+
567
+
568
+ # Currently the synchronous function of actors are run on a python thread that has an active event loop.
569
+ # Technically it is unsafe for them to block at all because they will block the loop of other
570
+ # calls, so all calls to .get() should be failing.
571
+ # But in the meantime, to implement get() by reusing async functions,
572
+ # we need to signal to the consumer of the PythonTask object that the thread really isn't in an async context.
573
+ # We do this by blanking out the running event loop during the call to the synchronous actor function.
574
+
575
+
576
+ class _Actor:
577
+ """
578
+ This is the message handling implementation of a Python actor.
579
+
580
+ The layering goes:
581
+ Rust `PythonActor` -> `_Actor` -> user-provided `Actor` instance
582
+
583
+ Messages are received from the Rust backend, and forwarded to the `handle`
584
+ methods on this class.
585
+
586
+ This class wraps the actual `Actor` instance provided by the user, and
587
+ routes messages to it, managing argument serialization/deserialization and
588
+ error handling.
589
+ """
590
+
591
+ def __init__(self) -> None:
592
+ self.instance: object | None = None
593
+ # TODO: (@pzhang) remove this with T229200522
594
+ self._saved_error: ActorError | None = None
595
+
596
+ async def handle(
597
+ self,
598
+ mailbox: Mailbox,
599
+ rank: int,
600
+ shape: Shape,
601
+ method: str,
602
+ message: bytes,
603
+ panic_flag: PanicFlag,
604
+ local_state: Iterable[Any],
605
+ port: "PortProtocol",
606
+ ) -> None:
607
+ # response_port can be None. If so, then sending to port will drop the response,
608
+ # and raise any exceptions to the caller.
609
+ try:
610
+ ctx: MonarchContext = MonarchContext(
611
+ mailbox, mailbox.actor_id.proc_id, Point(rank, shape)
612
+ )
613
+ _context.set(ctx)
614
+
615
+ DebugContext.set(DebugContext())
616
+
617
+ args, kwargs = unflatten(message, local_state)
618
+
619
+ if method == "__init__":
620
+ Class, *args = args
621
+ try:
622
+ self.instance = Class(*args, **kwargs)
623
+ except Exception as e:
624
+ self._saved_error = ActorError(
625
+ e, f"Remote actor {Class}.__init__ call failed."
626
+ )
627
+ raise e
628
+ port.send(None)
629
+ return None
630
+
631
+ if self.instance is None:
632
+ # This could happen because of the following reasons. Both
633
+ # indicates a possible bug in the framework:
634
+ # 1. the execution of the previous message for "__init__" failed,
635
+ # but that error is not surfaced to the caller.
636
+ # - TODO(T229200522): there is a known bug. fix it.
637
+ # 2. this message is delivered to this actor before the previous
638
+ # message of "__init__" is delivered. Out-of-order delivery
639
+ # should never happen. It indicates either a bug in the
640
+ # message delivery mechanism, or the framework accidentally
641
+ # mixed the usage of cast and direct send.
642
+ error_message = f"Actor object is missing when executing method {method} on actor {mailbox.actor_id}."
643
+ if self._saved_error is not None:
644
+ error_message += (
645
+ f" This is likely due to an earlier error: {self._saved_error}"
646
+ )
647
+ raise AssertionError(error_message)
648
+ the_method = getattr(self.instance, method)
649
+ if isinstance(the_method, EndpointProperty):
650
+ module = the_method._method.__module__
651
+ the_method = functools.partial(the_method._method, self.instance)
652
+ else:
653
+ module = the_method.__module__
654
+
655
+ if inspect.iscoroutinefunction(the_method):
656
+
657
+ async def instrumented():
658
+ enter_span(
659
+ module,
660
+ method,
661
+ str(ctx.mailbox.actor_id),
662
+ )
663
+ try:
664
+ result = await the_method(*args, **kwargs)
665
+ self._maybe_exit_debugger()
666
+ except Exception as e:
667
+ logging.critical(
668
+ "Unhandled exception in actor endpoint",
669
+ exc_info=e,
670
+ )
671
+ raise e
672
+ exit_span()
673
+ return result
674
+
675
+ result = await instrumented()
676
+ else:
677
+ enter_span(module, method, str(ctx.mailbox.actor_id))
678
+ with fake_sync_state():
679
+ result = the_method(*args, **kwargs)
680
+ self._maybe_exit_debugger()
681
+ exit_span()
682
+
683
+ port.send(result)
684
+ except Exception as e:
685
+ self._post_mortem_debug(e.__traceback__)
686
+ traceback.print_exc()
687
+ port.exception(ActorError(e))
688
+ except BaseException as e:
689
+ self._post_mortem_debug(e.__traceback__)
690
+ # A BaseException can be thrown in the case of a Rust panic.
691
+ # In this case, we need a way to signal the panic to the Rust side.
692
+ # See [Panics in async endpoints]
693
+ try:
694
+ panic_flag.signal_panic(e)
695
+ except Exception:
696
+ # The channel might be closed if the Rust side has already detected the error
697
+ pass
698
+ raise
699
+
700
+ def _maybe_exit_debugger(self, do_continue=True) -> None:
701
+ if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
702
+ if do_continue:
703
+ pdb_wrapper.clear_all_breaks()
704
+ pdb_wrapper.do_continue("")
705
+ pdb_wrapper.end_debug_session()
706
+ DebugContext.set(DebugContext())
707
+
708
+ def _post_mortem_debug(self, exc_tb) -> None:
709
+ from monarch._src.actor.debugger import DebugManager
710
+
711
+ if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
712
+ with fake_sync_state():
713
+ ctx = MonarchContext.get()
714
+ pdb_wrapper = PdbWrapper(
715
+ ctx.point.rank,
716
+ ctx.point.shape.coordinates(ctx.point.rank),
717
+ ctx.mailbox.actor_id,
718
+ DebugManager.ref().get_debug_client.call_one().get(),
719
+ )
720
+ DebugContext.set(DebugContext(pdb_wrapper))
721
+ pdb_wrapper.post_mortem(exc_tb)
722
+ self._maybe_exit_debugger(do_continue=False)
723
+
724
+
725
+ def _is_mailbox(x: object) -> bool:
726
+ if hasattr(x, "__monarch_ref__"):
727
+ raise NotImplementedError(
728
+ "Sending monarch tensor references directly to a port."
729
+ )
730
+ return isinstance(x, Mailbox)
731
+
732
+
733
+ def _is_ref_or_mailbox(x: object) -> bool:
734
+ return hasattr(x, "__monarch_ref__") or isinstance(x, Mailbox)
735
+
736
+
737
+ def _pickle(obj: object) -> bytes:
738
+ _, msg = flatten(obj, _is_mailbox)
739
+ return msg
740
+
741
+
742
+ class Actor(MeshTrait):
743
+ @functools.cached_property
744
+ def logger(cls) -> logging.Logger:
745
+ lgr = logging.getLogger(cls.__class__.__name__)
746
+ lgr.setLevel(logging.DEBUG)
747
+ return lgr
748
+
749
+ @property
750
+ def _ndslice(self) -> NDSlice:
751
+ raise NotImplementedError(
752
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
753
+ )
754
+
755
+ @property
756
+ def _labels(self) -> Tuple[str, ...]:
757
+ raise NotImplementedError(
758
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
759
+ )
760
+
761
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
762
+ raise NotImplementedError(
763
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
764
+ )
765
+
766
+
767
+ class ActorMeshRef(MeshTrait):
768
+ def __init__(
769
+ self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
770
+ ) -> None:
771
+ self.__name__: str = Class.__name__
772
+ self._class: Type[T] = Class
773
+ self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref
774
+ self._mailbox: Mailbox = mailbox
775
+ for attr_name in dir(self._class):
776
+ attr_value = getattr(self._class, attr_name, None)
777
+ if isinstance(attr_value, EndpointProperty):
778
+ setattr(
779
+ self,
780
+ attr_name,
781
+ ActorEndpoint(
782
+ self._actor_mesh_ref,
783
+ attr_name,
784
+ attr_value._method,
785
+ self._mailbox,
786
+ attr_value._propagator,
787
+ ),
788
+ )
789
+
790
+ def __getattr__(self, attr: str) -> NotAnEndpoint:
791
+ if attr in dir(self._class):
792
+ return NotAnEndpoint(self, attr)
793
+ raise AttributeError(attr)
794
+
795
+ def _create(
796
+ self,
797
+ args: Iterable[Any],
798
+ kwargs: Dict[str, Any],
799
+ ) -> None:
800
+ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
801
+ return None
802
+
803
+ ep = ActorEndpoint(
804
+ self._actor_mesh_ref,
805
+ "__init__",
806
+ null_func,
807
+ self._mailbox,
808
+ )
809
+ send(ep, (self._class, *args), kwargs)
810
+
811
+ def __reduce_ex__(
812
+ self, protocol: ...
813
+ ) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]":
814
+ return ActorMeshRef, (
815
+ self._class,
816
+ self._actor_mesh_ref,
817
+ self._mailbox,
818
+ )
819
+
820
+ @property
821
+ def _ndslice(self) -> NDSlice:
822
+ return self._actor_mesh_ref._shape.ndslice
823
+
824
+ @property
825
+ def _labels(self) -> Iterable[str]:
826
+ return self._actor_mesh_ref._shape.labels
827
+
828
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
829
+ return ActorMeshRef(
830
+ self._class,
831
+ _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape),
832
+ self._mailbox,
833
+ )
834
+
835
+ def __repr__(self) -> str:
836
+ return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})"
837
+
838
+ async def stop(self):
839
+ await self._actor_mesh_ref.stop()
840
+
841
+
842
+ class ActorError(Exception):
843
+ """
844
+ Deterministic problem with the user's code.
845
+ For example, an OOM resulting in trying to allocate too much GPU memory, or violating
846
+ some invariant enforced by the various APIs.
847
+ """
848
+
849
+ def __init__(
850
+ self,
851
+ exception: Exception,
852
+ message: str = "A remote actor call has failed.",
853
+ ) -> None:
854
+ self.exception = exception
855
+ self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
856
+ self.message = message
857
+
858
+ def __str__(self) -> str:
859
+ exe = str(self.exception)
860
+ actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
861
+ return (
862
+ f"{self.message}\n"
863
+ f"Traceback of where the remote call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
864
+ )
865
+
866
+
867
+ def current_actor_name() -> str:
868
+ return str(MonarchContext.get().mailbox.actor_id)
869
+
870
+
871
+ def current_rank() -> Point:
872
+ ctx = MonarchContext.get()
873
+ return ctx.point
874
+
875
+
876
+ def current_size() -> Dict[str, int]:
877
+ ctx = MonarchContext.get()
878
+ return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes))