torchmonarch-nightly 2025.7.1__cp313-cp313-manylinux2014_x86_64.whl → 2025.7.25__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +874 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +270 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +500 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +56 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +51 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +12 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/mesh_controller.py +201 -139
  37. monarch/monarch_controller +0 -0
  38. monarch/opaque_module.py +4 -6
  39. monarch/opaque_object.py +3 -3
  40. monarch/proc_mesh.py +6 -309
  41. monarch/python_local_mesh.py +1 -1
  42. monarch/rust_backend_mesh.py +2 -1
  43. monarch/rust_local_mesh.py +4 -2
  44. monarch/sim_mesh.py +10 -19
  45. monarch/simulator/command_history.py +1 -1
  46. monarch/simulator/interface.py +2 -1
  47. monarch/simulator/mock_controller.py +1 -1
  48. monarch/simulator/simulator.py +1 -1
  49. monarch/tensor_engine/__init__.py +23 -0
  50. monarch/tensor_worker_main.py +3 -1
  51. monarch/tools/cli.py +3 -1
  52. monarch/tools/commands.py +95 -35
  53. monarch/tools/mesh_spec.py +55 -0
  54. monarch/tools/utils.py +38 -0
  55. monarch/worker/worker.py +1 -1
  56. monarch/world_mesh.py +2 -1
  57. monarch_supervisor/python_executable.py +6 -3
  58. tests/error_test_binary.py +48 -10
  59. tests/test_actor_error.py +370 -21
  60. tests/test_alloc.py +1 -1
  61. tests/test_allocator.py +373 -17
  62. tests/test_controller.py +2 -0
  63. tests/test_debugger.py +416 -0
  64. tests/test_env_before_cuda.py +162 -0
  65. tests/test_python_actors.py +184 -333
  66. tests/test_rdma.py +198 -0
  67. tests/test_remote_functions.py +40 -12
  68. tests/test_rust_backend.py +7 -5
  69. tests/test_sim_backend.py +1 -4
  70. tests/test_tensor_engine.py +55 -1
  71. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
  72. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
  73. torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
  74. monarch/_monarch/hyperactor/__init__.py +0 -58
  75. monarch/_monarch/worker/debugger.py +0 -117
  76. monarch/_monarch/worker/logging.py +0 -107
  77. monarch/debugger.py +0 -379
  78. monarch/future.py +0 -76
  79. monarch/rdma.py +0 -162
  80. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  81. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  82. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  83. /monarch/{common → _src/actor}/shape.py +0 -0
  84. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  85. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
  86. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
  87. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,874 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import collections
10
+ import contextvars
11
+ import functools
12
+ import inspect
13
+ import itertools
14
+ import logging
15
+ import random
16
+ import traceback
17
+
18
+ from dataclasses import dataclass
19
+ from traceback import extract_tb, StackSummary
20
+ from typing import (
21
+ Any,
22
+ AsyncGenerator,
23
+ Awaitable,
24
+ Callable,
25
+ cast,
26
+ Concatenate,
27
+ Dict,
28
+ Generic,
29
+ Iterable,
30
+ Iterator,
31
+ List,
32
+ NamedTuple,
33
+ Optional,
34
+ ParamSpec,
35
+ Tuple,
36
+ Type,
37
+ TYPE_CHECKING,
38
+ TypeVar,
39
+ )
40
+
41
+ from monarch._rust_bindings.monarch_hyperactor.actor import (
42
+ PanicFlag,
43
+ PythonMessage,
44
+ PythonMessageKind,
45
+ )
46
+ from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
47
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
48
+ Mailbox,
49
+ OncePortReceiver,
50
+ OncePortRef,
51
+ PortReceiver as HyPortReceiver,
52
+ PortRef,
53
+ )
54
+
55
+ if TYPE_CHECKING:
56
+ from monarch._rust_bindings.monarch_hyperactor.actor import PortProtocol
57
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import PortReceiverBase
58
+
59
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
60
+ from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
61
+ from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError
62
+ from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span
63
+ from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
64
+ from monarch._src.actor.endpoint import (
65
+ Endpoint,
66
+ EndpointProperty,
67
+ Extent,
68
+ Propagator,
69
+ Selection,
70
+ )
71
+ from monarch._src.actor.future import Future
72
+ from monarch._src.actor.pdb_wrapper import PdbWrapper
73
+
74
+ from monarch._src.actor.pickle import flatten, unflatten
75
+
76
+ from monarch._src.actor.shape import MeshTrait, NDSlice
77
+ from monarch._src.actor.sync_state import fake_sync_state
78
+
79
+ from monarch._src.actor.tensor_engine_shim import actor_send
80
+
81
+ if TYPE_CHECKING:
82
+ from monarch._src.actor.proc_mesh import ProcMesh
83
+
84
+ logger: logging.Logger = logging.getLogger(__name__)
85
+
86
+ Allocator = ProcessAllocator | LocalAllocator
87
+
88
+ try:
89
+ from __manifest__ import fbmake # noqa
90
+
91
+ IN_PAR = bool(fbmake.get("par_style"))
92
+ except ImportError:
93
+ IN_PAR = False
94
+
95
+ T1 = TypeVar("T1")
96
+ T2 = TypeVar("T2")
97
+
98
+
99
+ class Point(HyPoint, collections.abc.Mapping):
100
+ pass
101
+
102
+
103
+ @dataclass
104
+ class MonarchContext:
105
+ mailbox: Mailbox
106
+ proc_id: str
107
+ point: Point
108
+
109
+ @staticmethod
110
+ def get() -> "MonarchContext":
111
+ return _context.get()
112
+
113
+
114
+ _context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
115
+ "monarch.actor_mesh._context"
116
+ )
117
+
118
+
119
+ @dataclass
120
+ class DebugContext:
121
+ pdb_wrapper: Optional[PdbWrapper] = None
122
+
123
+ @staticmethod
124
+ def get() -> "DebugContext":
125
+ return _debug_context.get()
126
+
127
+ @staticmethod
128
+ def set(debug_context: "DebugContext") -> None:
129
+ _debug_context.set(debug_context)
130
+
131
+
132
+ _debug_context: contextvars.ContextVar[DebugContext] = contextvars.ContextVar(
133
+ "monarch.actor_mesh._debug_context"
134
+ )
135
+
136
+ T = TypeVar("T")
137
+ P = ParamSpec("P")
138
+ R = TypeVar("R")
139
+ A = TypeVar("A")
140
+
141
+ # keep this load balancing deterministic, but
142
+ # equally distributed.
143
+ _load_balancing_seed = random.Random(4)
144
+
145
+
146
+ # standin class for whatever is the serializable python object we use
147
+ # to name an actor mesh. Hacked up today because ActorMesh
148
+ # isn't plumbed to non-clients
149
+ class _ActorMeshRefImpl:
150
+ def __init__(
151
+ self,
152
+ mailbox: Mailbox,
153
+ hy_actor_mesh: Optional[PythonActorMesh],
154
+ proc_mesh: "Optional[ProcMesh]",
155
+ shape: Shape,
156
+ actor_ids: List[ActorId],
157
+ ) -> None:
158
+ self._mailbox = mailbox
159
+ self._actor_mesh = hy_actor_mesh
160
+ # actor meshes do not have a way to look this up at the moment,
161
+ # so we fake it here
162
+ self._proc_mesh = proc_mesh
163
+ self._shape = shape
164
+ self._please_replace_me_actor_ids = actor_ids
165
+
166
+ @staticmethod
167
+ def from_hyperactor_mesh(
168
+ mailbox: Mailbox, hy_actor_mesh: PythonActorMesh, proc_mesh: "ProcMesh"
169
+ ) -> "_ActorMeshRefImpl":
170
+ shape: Shape = hy_actor_mesh.shape
171
+ return _ActorMeshRefImpl(
172
+ mailbox,
173
+ hy_actor_mesh,
174
+ proc_mesh,
175
+ hy_actor_mesh.shape,
176
+ [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))],
177
+ )
178
+
179
+ @staticmethod
180
+ def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl":
181
+ return _ActorMeshRefImpl(mailbox, None, None, singleton_shape, [actor_id])
182
+
183
+ @staticmethod
184
+ def from_actor_ref_with_shape(
185
+ ref: "_ActorMeshRefImpl", shape: Shape
186
+ ) -> "_ActorMeshRefImpl":
187
+ return _ActorMeshRefImpl(
188
+ ref._mailbox, None, None, shape, ref._please_replace_me_actor_ids
189
+ )
190
+
191
+ def __getstate__(
192
+ self,
193
+ ) -> Tuple[Shape, List[ActorId], Mailbox]:
194
+ return self._shape, self._please_replace_me_actor_ids, self._mailbox
195
+
196
+ def __setstate__(
197
+ self,
198
+ state: Tuple[Shape, List[ActorId], Mailbox],
199
+ ) -> None:
200
+ self._actor_mesh = None
201
+ self._shape, self._please_replace_me_actor_ids, self._mailbox = state
202
+
203
+ def _check_state(self) -> None:
204
+ # This is temporary until we have real cast integration here. We need to actively check
205
+ # supervision error here is because all communication is done through direct mailbox sending
206
+ # and not through comm actor casting.
207
+ # TODO: remove this when casting integration is done.
208
+ if self._actor_mesh is not None:
209
+ if self._actor_mesh.stopped:
210
+ raise SupervisionError(
211
+ "actor mesh is not in a healthy state: `ActorMesh` has been stopped"
212
+ )
213
+
214
+ event = self._actor_mesh.get_supervision_event()
215
+ if event is not None:
216
+ raise SupervisionError(f"actor mesh is not in a healthy state: {event}")
217
+
218
+ def send(self, rank: int, message: PythonMessage) -> None:
219
+ self._check_state()
220
+ actor = self._please_replace_me_actor_ids[rank]
221
+ self._mailbox.post(actor, message)
222
+
223
+ def cast(
224
+ self,
225
+ message: PythonMessage,
226
+ selection: Selection,
227
+ ) -> None:
228
+ self._check_state()
229
+
230
+ # TODO: use the actual actor mesh when available. We cannot currently use it
231
+ # directly because we risk bifurcating the message delivery paths from the same
232
+ # client, since slicing the mesh will produce a reference, which calls actors
233
+ # directly. The reason these paths are bifurcated is that actor meshes will
234
+ # use multicasting, while direct actor comms do not. Separately we need to decide
235
+ # whether actor meshes are ordered with actor references.
236
+ #
237
+ # The fix is to provide a first-class reference into Python, and always call "cast"
238
+ # on it, including for load balanced requests.
239
+ if selection == "choose":
240
+ idx = _load_balancing_seed.randrange(len(self._shape))
241
+ actor_rank = self._shape.ndslice[idx]
242
+ self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message)
243
+ elif selection == "all":
244
+ # replace me with actual remote actor mesh
245
+ call_shape = Shape(
246
+ self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes)
247
+ )
248
+ for i, rank in enumerate(self._shape.ranks()):
249
+ self._mailbox.post_cast(
250
+ self._please_replace_me_actor_ids[rank],
251
+ i,
252
+ call_shape,
253
+ message,
254
+ )
255
+ elif isinstance(selection, int):
256
+ try:
257
+ self._mailbox.post(
258
+ self._please_replace_me_actor_ids[selection], message
259
+ )
260
+ except IndexError:
261
+ raise IndexError(
262
+ f"Tried to send to an out-of-range rank {selection}: "
263
+ f"mesh has {len(self._please_replace_me_actor_ids)} elements."
264
+ )
265
+ else:
266
+ raise ValueError(f"invalid selection: {selection}")
267
+
268
+ def __len__(self) -> int:
269
+ return len(self._shape)
270
+
271
+ @property
272
+ def _name_pid(self):
273
+ actor_id0 = self._please_replace_me_actor_ids[0]
274
+ return actor_id0.actor_name, actor_id0.pid
275
+
276
+ async def stop(self):
277
+ await self._actor_mesh.stop()
278
+
279
+
280
+ class ActorEndpoint(Endpoint[P, R]):
281
+ def __init__(
282
+ self,
283
+ actor_mesh_ref: _ActorMeshRefImpl,
284
+ name: str,
285
+ impl: Callable[Concatenate[Any, P], Awaitable[R]],
286
+ mailbox: Mailbox,
287
+ propagator: Propagator = None,
288
+ ) -> None:
289
+ super().__init__(propagator)
290
+ self._actor_mesh = actor_mesh_ref
291
+ self._name = name
292
+ self._signature: inspect.Signature = inspect.signature(impl)
293
+ self._mailbox = mailbox
294
+
295
+ def _supervise(self, r: HyPortReceiver | OncePortReceiver) -> Any:
296
+ mesh = self._actor_mesh._actor_mesh
297
+ return r if mesh is None else mesh.supervise(r)
298
+
299
+ def _call_name(self) -> Any:
300
+ return self._name
301
+
302
+ def _send(
303
+ self,
304
+ args: Tuple[Any, ...],
305
+ kwargs: Dict[str, Any],
306
+ port: "Optional[Port]" = None,
307
+ selection: Selection = "all",
308
+ ) -> Extent:
309
+ """
310
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
311
+
312
+ This sends the message to all actors but does not wait for any result.
313
+ """
314
+ self._signature.bind(None, *args, **kwargs)
315
+ objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
316
+ refs = [obj for obj in objects if hasattr(obj, "__monarch_ref__")]
317
+ if not refs:
318
+ message = PythonMessage(
319
+ PythonMessageKind.CallMethod(
320
+ self._name, None if port is None else port._port_ref
321
+ ),
322
+ bytes,
323
+ )
324
+ self._actor_mesh.cast(message, selection)
325
+ else:
326
+ actor_send(self, bytes, refs, port, selection)
327
+ shape = self._actor_mesh._shape
328
+ return Extent(shape.labels, shape.ndslice.sizes)
329
+
330
+ def _port(self, once: bool = False) -> "PortTuple[R]":
331
+ p, r = PortTuple.create(self._mailbox, once)
332
+ if TYPE_CHECKING:
333
+ assert isinstance(
334
+ r._receiver, (HyPortReceiver | OncePortReceiver)
335
+ ), "unexpected receiver type"
336
+ return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
337
+
338
+
339
+ class Accumulator(Generic[P, R, A]):
340
+ def __init__(
341
+ self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
342
+ ) -> None:
343
+ self._endpoint: Endpoint[P, R] = endpoint
344
+ self._identity: A = identity
345
+ self._combine: Callable[[A, R], A] = combine
346
+
347
+ def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
348
+ gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs)
349
+
350
+ async def impl() -> A:
351
+ value = self._identity
352
+ async for x in gen:
353
+ value = self._combine(value, x)
354
+ return value
355
+
356
+ return Future(impl=impl)
357
+
358
+
359
+ class ValueMesh(MeshTrait, Generic[R]):
360
+ """
361
+ Container of return values, indexed by rank.
362
+ """
363
+
364
+ def __init__(self, shape: Shape, values: List[R]) -> None:
365
+ self._shape = shape
366
+ self._values = values
367
+
368
+ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
369
+ return ValueMesh(shape, self._values)
370
+
371
+ def item(self, **kwargs) -> R:
372
+ coordinates = [kwargs.pop(label) for label in self._labels]
373
+ if kwargs:
374
+ raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
375
+
376
+ return self._values[self._ndslice.nditem(coordinates)]
377
+
378
+ def items(self) -> Iterable[Tuple[Point, R]]:
379
+ for rank in self._shape.ranks():
380
+ yield Point(rank, self._shape), self._values[rank]
381
+
382
+ def __iter__(self) -> Iterator[Tuple[Point, R]]:
383
+ return iter(self.items())
384
+
385
+ def __len__(self) -> int:
386
+ return len(self._shape)
387
+
388
+ def __repr__(self) -> str:
389
+ return f"ValueMesh({self._shape})"
390
+
391
+ @property
392
+ def _ndslice(self) -> NDSlice:
393
+ return self._shape.ndslice
394
+
395
+ @property
396
+ def _labels(self) -> Iterable[str]:
397
+ return self._shape.labels
398
+
399
+
400
+ def send(
401
+ endpoint: Endpoint[P, R],
402
+ args: Tuple[Any, ...],
403
+ kwargs: Dict[str, Any],
404
+ port: "Optional[Port]" = None,
405
+ selection: Selection = "all",
406
+ ) -> None:
407
+ """
408
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
409
+
410
+ This sends the message to all actors but does not wait for any result.
411
+ """
412
+ endpoint._send(args, kwargs, port, selection)
413
+
414
+
415
+ class Port(Generic[R]):
416
+ def __init__(
417
+ self,
418
+ port_ref: PortRef | OncePortRef,
419
+ mailbox: Mailbox,
420
+ rank: Optional[int],
421
+ ) -> None:
422
+ self._port_ref = port_ref
423
+ self._mailbox = mailbox
424
+ self._rank = rank
425
+
426
+ def send(self, obj: R) -> None:
427
+ self._port_ref.send(
428
+ self._mailbox,
429
+ PythonMessage(PythonMessageKind.Result(self._rank), _pickle(obj)),
430
+ )
431
+
432
+ def exception(self, obj: Exception) -> None:
433
+ # we deliver each error exactly once, so if there is no port to respond to,
434
+ # the error is sent to the current actor as an exception.
435
+ self._port_ref.send(
436
+ self._mailbox,
437
+ PythonMessage(PythonMessageKind.Exception(self._rank), _pickle(obj)),
438
+ )
439
+
440
+
441
+ class DroppingPort:
442
+ """
443
+ Used in place of a real port when the message has no response port.
444
+ Makes sure any exception sent to it causes the actor to report an exception.
445
+ """
446
+
447
+ def __init__(self):
448
+ pass
449
+
450
+ def send(self, obj: Any) -> None:
451
+ pass
452
+
453
+ def exception(self, obj: Exception) -> None:
454
+ # we deliver each error exactly once, so if there is no port to respond to,
455
+ # the error is sent to the current actor as an exception.
456
+ raise obj from None
457
+
458
+
459
+ R = TypeVar("R")
460
+
461
+ T = TypeVar("T")
462
+
463
+ if TYPE_CHECKING:
464
+ # Python <= 3.10 cannot inherit from Generic[R] and NamedTuple at the same time.
465
+ # we only need it for type checking though, so copypasta it until 3.11.
466
+ class PortTuple(NamedTuple, Generic[R]):
467
+ sender: "Port[R]"
468
+ receiver: "PortReceiver[R]"
469
+
470
+ @staticmethod
471
+ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
472
+ handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
473
+ port_ref = handle.bind()
474
+ return PortTuple(
475
+ Port(port_ref, mailbox, rank=None),
476
+ PortReceiver(mailbox, receiver),
477
+ )
478
+ else:
479
+
480
+ class PortTuple(NamedTuple):
481
+ sender: "Port[Any]"
482
+ receiver: "PortReceiver[Any]"
483
+
484
+ @staticmethod
485
+ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
486
+ handle, receiver = mailbox.open_once_port() if once else mailbox.open_port()
487
+ port_ref = handle.bind()
488
+ return PortTuple(
489
+ Port(port_ref, mailbox, rank=None),
490
+ PortReceiver(mailbox, receiver),
491
+ )
492
+
493
+
494
+ # advance lower-level API for sending messages. This is intentially
495
+ # not part of the Endpoint API because they way it accepts arguments
496
+ # and handles concerns is different.
497
+ def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]":
498
+ return endpoint._port(once)
499
+
500
+
501
+ def ranked_port(
502
+ endpoint: Endpoint[P, R], once: bool = False
503
+ ) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
504
+ p, receiver = port(endpoint, once)
505
+ return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
506
+
507
+
508
+ class PortReceiver(Generic[R]):
509
+ def __init__(
510
+ self,
511
+ mailbox: Mailbox,
512
+ receiver: "PortReceiverBase",
513
+ ) -> None:
514
+ self._mailbox: Mailbox = mailbox
515
+ self._receiver = receiver
516
+
517
+ async def _recv(self) -> R:
518
+ return self._process(await self._receiver.recv_task())
519
+
520
+ def _process(self, msg: PythonMessage) -> R:
521
+ # TODO: Try to do something more structured than a cast here
522
+ payload = cast(R, unflatten(msg.message, itertools.repeat(self._mailbox)))
523
+ match msg.kind:
524
+ case PythonMessageKind.Result():
525
+ return payload
526
+ case PythonMessageKind.Exception():
527
+ raise cast(Exception, payload)
528
+ case _:
529
+ raise ValueError(f"Unexpected message kind: {msg.kind}")
530
+
531
+ def recv(self) -> "Future[R]":
532
+ return Future(impl=lambda: self._recv(), requires_loop=False)
533
+
534
+
535
+ class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
536
+ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
537
+ rank = getattr(msg.kind, "rank", None)
538
+ if rank is None:
539
+ raise ValueError(
540
+ f"RankedPort receiver got a message without a rank {msg}",
541
+ )
542
+ return rank, super()._process(msg)
543
+
544
+
545
+ singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
546
+
547
+
548
+ # Currently the synchronous function of actors are run on a python thread that has an active event loop.
549
+ # Technically it is unsafe for them to block at all because they will block the loop of other
550
+ # calls, so all calls to .get() should be failing.
551
+ # But in the meantime, to implement get() by reusing async functions,
552
+ # we need to signal to the consumer of the PythonTask object that the thread really isn't in an async context.
553
+ # We do this by blanking out the running event loop during the call to the synchronous actor function.
554
+
555
+
556
+ class _Actor:
557
+ """
558
+ This is the message handling implementation of a Python actor.
559
+
560
+ The layering goes:
561
+ Rust `PythonActor` -> `_Actor` -> user-provided `Actor` instance
562
+
563
+ Messages are received from the Rust backend, and forwarded to the `handle`
564
+ methods on this class.
565
+
566
+ This class wraps the actual `Actor` instance provided by the user, and
567
+ routes messages to it, managing argument serialization/deserialization and
568
+ error handling.
569
+ """
570
+
571
+ def __init__(self) -> None:
572
+ self.instance: object | None = None
573
+ # TODO: (@pzhang) remove this with T229200522
574
+ self._saved_error: ActorError | None = None
575
+
576
+ async def handle(
577
+ self,
578
+ mailbox: Mailbox,
579
+ rank: int,
580
+ shape: Shape,
581
+ method: str,
582
+ message: bytes,
583
+ panic_flag: PanicFlag,
584
+ local_state: Iterable[Any],
585
+ port: "PortProtocol",
586
+ ) -> None:
587
+ # response_port can be None. If so, then sending to port will drop the response,
588
+ # and raise any exceptions to the caller.
589
+ try:
590
+ ctx: MonarchContext = MonarchContext(
591
+ mailbox, mailbox.actor_id.proc_id, Point(rank, shape)
592
+ )
593
+ _context.set(ctx)
594
+
595
+ DebugContext.set(DebugContext())
596
+
597
+ args, kwargs = unflatten(message, local_state)
598
+
599
+ if method == "__init__":
600
+ Class, *args = args
601
+ try:
602
+ self.instance = Class(*args, **kwargs)
603
+ except Exception as e:
604
+ self._saved_error = ActorError(
605
+ e, f"Remote actor {Class}.__init__ call failed."
606
+ )
607
+ raise e
608
+ port.send(None)
609
+ return None
610
+
611
+ if self.instance is None:
612
+ # This could happen because of the following reasons. Both
613
+ # indicates a possible bug in the framework:
614
+ # 1. the execution of the previous message for "__init__" failed,
615
+ # but that error is not surfaced to the caller.
616
+ # - TODO(T229200522): there is a known bug. fix it.
617
+ # 2. this message is delivered to this actor before the previous
618
+ # message of "__init__" is delivered. Out-of-order delivery
619
+ # should never happen. It indicates either a bug in the
620
+ # message delivery mechanism, or the framework accidentally
621
+ # mixed the usage of cast and direct send.
622
+ error_message = f"Actor object is missing when executing method {method} on actor {mailbox.actor_id}."
623
+ if self._saved_error is not None:
624
+ error_message += (
625
+ f" This is likely due to an earlier error: {self._saved_error}"
626
+ )
627
+ raise AssertionError(error_message)
628
+ the_method = getattr(self.instance, method)._method
629
+
630
+ if inspect.iscoroutinefunction(the_method):
631
+
632
+ async def instrumented():
633
+ enter_span(
634
+ the_method.__module__,
635
+ method,
636
+ str(ctx.mailbox.actor_id),
637
+ )
638
+ try:
639
+ result = await the_method(self.instance, *args, **kwargs)
640
+ self._maybe_exit_debugger()
641
+ except Exception as e:
642
+ logging.critical(
643
+ "Unhandled exception in actor endpoint",
644
+ exc_info=e,
645
+ )
646
+ raise e
647
+ exit_span()
648
+ return result
649
+
650
+ result = await instrumented()
651
+ else:
652
+ enter_span(the_method.__module__, method, str(ctx.mailbox.actor_id))
653
+ with fake_sync_state():
654
+ result = the_method(self.instance, *args, **kwargs)
655
+ self._maybe_exit_debugger()
656
+ exit_span()
657
+
658
+ port.send(result)
659
+ except Exception as e:
660
+ self._post_mortem_debug(e.__traceback__)
661
+ traceback.print_exc()
662
+ port.exception(ActorError(e))
663
+ except BaseException as e:
664
+ self._post_mortem_debug(e.__traceback__)
665
+ # A BaseException can be thrown in the case of a Rust panic.
666
+ # In this case, we need a way to signal the panic to the Rust side.
667
+ # See [Panics in async endpoints]
668
+ try:
669
+ panic_flag.signal_panic(e)
670
+ except Exception:
671
+ # The channel might be closed if the Rust side has already detected the error
672
+ pass
673
+ raise
674
+
675
+ def _maybe_exit_debugger(self, do_continue=True) -> None:
676
+ if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
677
+ if do_continue:
678
+ pdb_wrapper.clear_all_breaks()
679
+ pdb_wrapper.do_continue("")
680
+ pdb_wrapper.end_debug_session()
681
+ DebugContext.set(DebugContext())
682
+
683
+ def _post_mortem_debug(self, exc_tb) -> None:
684
+ from monarch._src.actor.debugger import DebugManager
685
+
686
+ if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
687
+ with fake_sync_state():
688
+ ctx = MonarchContext.get()
689
+ pdb_wrapper = PdbWrapper(
690
+ ctx.point.rank,
691
+ ctx.point.shape.coordinates(ctx.point.rank),
692
+ ctx.mailbox.actor_id,
693
+ DebugManager.ref().get_debug_client.call_one().get(),
694
+ )
695
+ DebugContext.set(DebugContext(pdb_wrapper))
696
+ pdb_wrapper.post_mortem(exc_tb)
697
+ self._maybe_exit_debugger(do_continue=False)
698
+
699
+
700
+ def _is_mailbox(x: object) -> bool:
701
+ if hasattr(x, "__monarch_ref__"):
702
+ raise NotImplementedError(
703
+ "Sending monarch tensor references directly to a port."
704
+ )
705
+ return isinstance(x, Mailbox)
706
+
707
+
708
+ def _is_ref_or_mailbox(x: object) -> bool:
709
+ return hasattr(x, "__monarch_ref__") or isinstance(x, Mailbox)
710
+
711
+
712
+ def _pickle(obj: object) -> bytes:
713
+ _, msg = flatten(obj, _is_mailbox)
714
+ return msg
715
+
716
+
717
+ class Actor(MeshTrait):
718
+ @functools.cached_property
719
+ def logger(cls) -> logging.Logger:
720
+ lgr = logging.getLogger(cls.__class__.__name__)
721
+ lgr.setLevel(logging.DEBUG)
722
+ return lgr
723
+
724
+ @property
725
+ def _ndslice(self) -> NDSlice:
726
+ raise NotImplementedError(
727
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
728
+ )
729
+
730
+ @property
731
+ def _labels(self) -> Tuple[str, ...]:
732
+ raise NotImplementedError(
733
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
734
+ )
735
+
736
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
737
+ raise NotImplementedError(
738
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
739
+ )
740
+
741
+
742
+ class ActorMeshRef(MeshTrait):
743
+ def __init__(
744
+ self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
745
+ ) -> None:
746
+ self.__name__: str = Class.__name__
747
+ self._class: Type[T] = Class
748
+ self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref
749
+ self._mailbox: Mailbox = mailbox
750
+ for attr_name in dir(self._class):
751
+ attr_value = getattr(self._class, attr_name, None)
752
+ if isinstance(attr_value, EndpointProperty):
753
+ setattr(
754
+ self,
755
+ attr_name,
756
+ ActorEndpoint(
757
+ self._actor_mesh_ref,
758
+ attr_name,
759
+ attr_value._method,
760
+ self._mailbox,
761
+ ),
762
+ )
763
+
764
+ def __getattr__(self, name: str) -> Any:
765
+ # This method is called when an attribute is not found
766
+ # For linting purposes, we need to tell the type checker that any attribute
767
+ # could be an endpoint that's dynamically added at runtime
768
+ # At runtime, we still want to raise AttributeError for truly missing attributes
769
+
770
+ # Check if this is a method on the underlying class
771
+ if hasattr(self._class, name):
772
+ attr = getattr(self._class, name)
773
+ if isinstance(attr, EndpointProperty):
774
+ # Dynamically create the endpoint
775
+ endpoint = ActorEndpoint(
776
+ self._actor_mesh_ref,
777
+ name,
778
+ attr._method,
779
+ self._mailbox,
780
+ propagator=attr._propagator,
781
+ )
782
+ # Cache it for future use
783
+ setattr(self, name, endpoint)
784
+ return endpoint
785
+
786
+ # If we get here, it's truly not found
787
+ raise AttributeError(
788
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
789
+ )
790
+
791
+ def _create(
792
+ self,
793
+ args: Iterable[Any],
794
+ kwargs: Dict[str, Any],
795
+ ) -> None:
796
+ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
797
+ return None
798
+
799
+ ep = ActorEndpoint(
800
+ self._actor_mesh_ref,
801
+ "__init__",
802
+ null_func,
803
+ self._mailbox,
804
+ )
805
+ send(ep, (self._class, *args), kwargs)
806
+
807
+ def __reduce_ex__(
808
+ self, protocol: ...
809
+ ) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]":
810
+ return ActorMeshRef, (
811
+ self._class,
812
+ self._actor_mesh_ref,
813
+ self._mailbox,
814
+ )
815
+
816
+ @property
817
+ def _ndslice(self) -> NDSlice:
818
+ return self._actor_mesh_ref._shape.ndslice
819
+
820
+ @property
821
+ def _labels(self) -> Iterable[str]:
822
+ return self._actor_mesh_ref._shape.labels
823
+
824
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
825
+ return ActorMeshRef(
826
+ self._class,
827
+ _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape),
828
+ self._mailbox,
829
+ )
830
+
831
+ def __repr__(self) -> str:
832
+ return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})"
833
+
834
+ async def stop(self):
835
+ await self._actor_mesh_ref.stop()
836
+
837
+
838
+ class ActorError(Exception):
839
+ """
840
+ Deterministic problem with the user's code.
841
+ For example, an OOM resulting in trying to allocate too much GPU memory, or violating
842
+ some invariant enforced by the various APIs.
843
+ """
844
+
845
+ def __init__(
846
+ self,
847
+ exception: Exception,
848
+ message: str = "A remote actor call has failed.",
849
+ ) -> None:
850
+ self.exception = exception
851
+ self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
852
+ self.message = message
853
+
854
+ def __str__(self) -> str:
855
+ exe = str(self.exception)
856
+ actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
857
+ return (
858
+ f"{self.message}\n"
859
+ f"Traceback of where the remote call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
860
+ )
861
+
862
+
863
+ def current_actor_name() -> str:
864
+ return str(MonarchContext.get().mailbox.actor_id)
865
+
866
+
867
+ def current_rank() -> Point:
868
+ ctx = MonarchContext.get()
869
+ return ctx.point
870
+
871
+
872
+ def current_size() -> Dict[str, int]:
873
+ ctx = MonarchContext.get()
874
+ return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes))