torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/actor_mesh.py ADDED
@@ -0,0 +1,761 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import collections
10
+ import contextvars
11
+ import functools
12
+ import inspect
13
+
14
+ import itertools
15
+ import logging
16
+ import random
17
+ import sys
18
+ import traceback
19
+
20
+ from dataclasses import dataclass
21
+ from traceback import extract_tb, StackSummary
22
+ from typing import (
23
+ Any,
24
+ AsyncGenerator,
25
+ Awaitable,
26
+ Callable,
27
+ cast,
28
+ Concatenate,
29
+ Dict,
30
+ Generic,
31
+ Iterable,
32
+ List,
33
+ Literal,
34
+ Optional,
35
+ ParamSpec,
36
+ Tuple,
37
+ Type,
38
+ TYPE_CHECKING,
39
+ TypeVar,
40
+ )
41
+
42
+ import monarch
43
+ from monarch import ActorFuture as Future
44
+ from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span
45
+
46
+ from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage
47
+ from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
48
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
49
+ Mailbox,
50
+ OncePortReceiver,
51
+ OncePortRef,
52
+ PortReceiver as HyPortReceiver,
53
+ PortRef,
54
+ )
55
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
56
+ from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
57
+
58
+ from monarch.common.pickle_flatten import flatten, unflatten
59
+ from monarch.common.shape import MeshTrait, NDSlice
60
+ from monarch.pdb_wrapper import remote_breakpointhook
61
+
62
+ if TYPE_CHECKING:
63
+ from monarch.debugger import DebugClient
64
+
65
+ logger: logging.Logger = logging.getLogger(__name__)
66
+
67
+ Allocator = monarch.ProcessAllocator | monarch.LocalAllocator
68
+
69
+ try:
70
+ from __manifest__ import fbmake # noqa
71
+
72
+ IN_PAR = True
73
+ except ImportError:
74
+ IN_PAR = False
75
+
76
+ T1 = TypeVar("T1")
77
+ T2 = TypeVar("T2")
78
+
79
+
80
+ class Point(HyPoint, collections.abc.Mapping):
81
+ pass
82
+
83
+
84
+ @dataclass
85
+ class MonarchContext:
86
+ mailbox: Mailbox
87
+ proc_id: str
88
+ point: Point
89
+
90
+ @staticmethod
91
+ def get() -> "MonarchContext":
92
+ return _context.get()
93
+
94
+
95
+ _context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
96
+ "monarch.actor_mesh._context"
97
+ )
98
+
99
+
100
+ T = TypeVar("T")
101
+ P = ParamSpec("P")
102
+ R = TypeVar("R")
103
+ A = TypeVar("A")
104
+
105
+ # keep this load balancing deterministic, but
106
+ # equally distributed.
107
+ _load_balancing_seed = random.Random(4)
108
+
109
+
110
+ Selection = Literal["all", "choose"] # TODO: replace with real selection objects
111
+
112
+
113
+ # standin class for whatever is the serializable python object we use
114
+ # to name an actor mesh. Hacked up today because ActorMesh
115
+ # isn't plumbed to non-clients
116
+ class _ActorMeshRefImpl:
117
+ def __init__(
118
+ self,
119
+ mailbox: Mailbox,
120
+ hy_actor_mesh: Optional[PythonActorMesh],
121
+ shape: Shape,
122
+ actor_ids: List[ActorId],
123
+ ) -> None:
124
+ self._mailbox = mailbox
125
+ self._actor_mesh = hy_actor_mesh
126
+ self._shape = shape
127
+ self._please_replace_me_actor_ids = actor_ids
128
+
129
+ @staticmethod
130
+ def from_hyperactor_mesh(
131
+ mailbox: Mailbox, hy_actor_mesh: PythonActorMesh
132
+ ) -> "_ActorMeshRefImpl":
133
+ shape: Shape = hy_actor_mesh.shape
134
+ return _ActorMeshRefImpl(
135
+ mailbox,
136
+ hy_actor_mesh,
137
+ hy_actor_mesh.shape,
138
+ [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))],
139
+ )
140
+
141
+ @staticmethod
142
+ def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl":
143
+ return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id])
144
+
145
+ @staticmethod
146
+ def from_actor_ref_with_shape(
147
+ ref: "_ActorMeshRefImpl", shape: Shape
148
+ ) -> "_ActorMeshRefImpl":
149
+ return _ActorMeshRefImpl(
150
+ ref._mailbox, None, shape, ref._please_replace_me_actor_ids
151
+ )
152
+
153
+ def __getstate__(
154
+ self,
155
+ ) -> Tuple[Shape, List[ActorId], Mailbox]:
156
+ return self._shape, self._please_replace_me_actor_ids, self._mailbox
157
+
158
+ def __setstate__(
159
+ self,
160
+ state: Tuple[Shape, List[ActorId], Mailbox],
161
+ ) -> None:
162
+ self._actor_mesh = None
163
+ self._shape, self._please_replace_me_actor_ids, self._mailbox = state
164
+
165
+ def send(self, rank: int, message: PythonMessage) -> None:
166
+ actor = self._please_replace_me_actor_ids[rank]
167
+ self._mailbox.post(actor, message)
168
+
169
+ def cast(
170
+ self,
171
+ message: PythonMessage,
172
+ selection: Selection,
173
+ ) -> None:
174
+ # TODO: use the actual actor mesh when available. We cannot currently use it
175
+ # directly because we risk bifurcating the message delivery paths from the same
176
+ # client, since slicing the mesh will produce a reference, which calls actors
177
+ # directly. The reason these paths are bifurcated is that actor meshes will
178
+ # use multicasting, while direct actor comms do not. Separately we need to decide
179
+ # whether actor meshes are ordered with actor references.
180
+ #
181
+ # The fix is to provide a first-class reference into Python, and always call "cast"
182
+ # on it, including for load balanced requests.
183
+ if selection == "choose":
184
+ idx = _load_balancing_seed.randrange(len(self._shape))
185
+ actor_rank = self._shape.ndslice[idx]
186
+ self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message)
187
+ return
188
+ elif selection == "all":
189
+ # replace me with actual remote actor mesh
190
+ call_shape = Shape(
191
+ self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes)
192
+ )
193
+ for i, rank in enumerate(self._shape.ranks()):
194
+ self._mailbox.post_cast(
195
+ self._please_replace_me_actor_ids[rank],
196
+ i,
197
+ call_shape,
198
+ message,
199
+ )
200
+ else:
201
+ raise ValueError(f"invalid selection: {selection}")
202
+
203
+ def __len__(self) -> int:
204
+ return len(self._shape)
205
+
206
+
207
+ class Endpoint(Generic[P, R]):
208
+ def __init__(
209
+ self,
210
+ actor_mesh_ref: _ActorMeshRefImpl,
211
+ name: str,
212
+ impl: Callable[Concatenate[Any, P], Awaitable[R]],
213
+ mailbox: Mailbox,
214
+ ) -> None:
215
+ self._actor_mesh = actor_mesh_ref
216
+ self._name = name
217
+ self._signature: inspect.Signature = inspect.signature(impl)
218
+ self._mailbox = mailbox
219
+
220
+ # the following are all 'adverbs' or different ways to handle the
221
+ # return values of this endpoint. Adverbs should only ever take *args, **kwargs
222
+ # of the original call. If we want to add syntax sugar for something that needs additional
223
+ # arguments, it should be implemented as function indepdendent of endpoint like `send`
224
+ # and `Accumulator`
225
+ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
226
+ """
227
+ Load balanced sends a message to one chosen actor and awaits a result.
228
+
229
+ Load balanced RPC-style entrypoint for request/response messaging.
230
+ """
231
+ p: Port[R]
232
+ r: PortReceiver[R]
233
+ p, r = port(self, once=True)
234
+ # pyre-ignore
235
+ send(self, args, kwargs, port=p, selection="choose")
236
+ return r.recv()
237
+
238
+ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
239
+ if len(self._actor_mesh) != 1:
240
+ raise ValueError(
241
+ f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}"
242
+ )
243
+ return self.choose(*args, **kwargs)
244
+
245
+ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
246
+ p: Port[R]
247
+ r: RankedPortReceiver[R]
248
+ p, r = ranked_port(self)
249
+ # pyre-ignore
250
+ send(self, args, kwargs, port=p)
251
+
252
+ async def process() -> ValueMesh[R]:
253
+ results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
254
+ for _ in range(len(self._actor_mesh)):
255
+ rank, value = await r.recv()
256
+ results[rank] = value
257
+ call_shape = Shape(
258
+ self._actor_mesh._shape.labels,
259
+ NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
260
+ )
261
+ return ValueMesh(call_shape, results)
262
+
263
+ def process_blocking() -> ValueMesh[R]:
264
+ results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
265
+ for _ in range(len(self._actor_mesh)):
266
+ rank, value = r.recv().get()
267
+ results[rank] = value
268
+ call_shape = Shape(
269
+ self._actor_mesh._shape.labels,
270
+ NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
271
+ )
272
+ return ValueMesh(call_shape, results)
273
+
274
+ return Future(process, process_blocking)
275
+
276
+ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
277
+ """
278
+ Broadcasts to all actors and yields their responses as a stream / generator.
279
+
280
+ This enables processing results from multiple actors incrementally as
281
+ they become available. Returns an async generator of response values.
282
+ """
283
+ p, r = port(self)
284
+ # pyre-ignore
285
+ send(self, args, kwargs, port=p)
286
+ for _ in range(len(self._actor_mesh)):
287
+ yield await r.recv()
288
+
289
+ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
290
+ """
291
+ Fire-and-forget broadcast to all actors without waiting for actors to
292
+ acknowledge receipt.
293
+
294
+ In other words, the return of this method does not guarrantee the
295
+ delivery of the message.
296
+ """
297
+ # pyre-ignore
298
+ send(self, args, kwargs)
299
+
300
+
301
+ class Accumulator(Generic[P, R, A]):
302
+ def __init__(
303
+ self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
304
+ ) -> None:
305
+ self._endpoint: Endpoint[P, R] = endpoint
306
+ self._identity: A = identity
307
+ self._combine: Callable[[A, R], A] = combine
308
+
309
+ def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
310
+ gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs)
311
+
312
+ async def impl() -> A:
313
+ value = self._identity
314
+ async for x in gen:
315
+ value = self._combine(value, x)
316
+ return value
317
+
318
+ return Future(impl)
319
+
320
+
321
+ class ValueMesh(MeshTrait, Generic[R]):
322
+ """
323
+ Container of return values, indexed by rank.
324
+ """
325
+
326
+ def __init__(self, shape: Shape, values: List[R]) -> None:
327
+ self._shape = shape
328
+ self._values = values
329
+
330
+ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
331
+ return ValueMesh(shape, self._values)
332
+
333
+ def item(self, **kwargs) -> R:
334
+ coordinates = [kwargs.pop(label) for label in self._labels]
335
+ if kwargs:
336
+ raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
337
+
338
+ return self._values[self._ndslice.nditem(coordinates)]
339
+
340
+ def __iter__(self):
341
+ for rank in self._shape.ranks():
342
+ yield Point(rank, self._shape), self._values[rank]
343
+
344
+ def __len__(self) -> int:
345
+ return len(self._shape)
346
+
347
+ def __repr__(self) -> str:
348
+ return f"ValueMesh({self._shape})"
349
+
350
+ @property
351
+ def _ndslice(self) -> NDSlice:
352
+ return self._shape.ndslice
353
+
354
+ @property
355
+ def _labels(self) -> Iterable[str]:
356
+ return self._shape.labels
357
+
358
+
359
+ def send(
360
+ endpoint: Endpoint[P, R],
361
+ args: Tuple[Any, ...],
362
+ kwargs: Dict[str, Any],
363
+ port: "Optional[Port]" = None,
364
+ selection: Selection = "all",
365
+ ) -> None:
366
+ """
367
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
368
+
369
+ This sends the message to all actors but does not wait for any result.
370
+ """
371
+ endpoint._signature.bind(None, *args, **kwargs)
372
+ message = PythonMessage(
373
+ endpoint._name,
374
+ _pickle((args, kwargs)),
375
+ None if port is None else port._port_ref,
376
+ None,
377
+ )
378
+ endpoint._actor_mesh.cast(message, selection)
379
+
380
+
381
+ class EndpointProperty(Generic[P, R]):
382
+ def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None:
383
+ self._method = method
384
+
385
+ def __get__(self, instance, owner) -> Endpoint[P, R]:
386
+ # this is a total lie, but we have to actually
387
+ # recognize this was defined as an endpoint,
388
+ # and also lookup the method
389
+ return cast(Endpoint[P, R], self)
390
+
391
+
392
+ def endpoint(
393
+ method: Callable[Concatenate[Any, P], Awaitable[R]],
394
+ ) -> EndpointProperty[P, R]:
395
+ return EndpointProperty(method)
396
+
397
+
398
+ class Port(Generic[R]):
399
+ def __init__(
400
+ self, port_ref: PortRef | OncePortRef, mailbox: Mailbox, rank: Optional[int]
401
+ ) -> None:
402
+ self._port_ref = port_ref
403
+ self._mailbox = mailbox
404
+ self._rank = rank
405
+
406
+ def send(self, method: str, obj: R) -> None:
407
+ self._port_ref.send(
408
+ self._mailbox,
409
+ PythonMessage(method, _pickle(obj), None, self._rank),
410
+ )
411
+
412
+
413
+ # advance lower-level API for sending messages. This is intentially
414
+ # not part of the Endpoint API because they way it accepts arguments
415
+ # and handles concerns is different.
416
+ def port(
417
+ endpoint: Endpoint[P, R], once: bool = False
418
+ ) -> Tuple["Port[R]", "PortReceiver[R]"]:
419
+ handle, receiver = (
420
+ endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
421
+ )
422
+ port_ref: PortRef | OncePortRef = handle.bind()
423
+ return Port(port_ref, endpoint._mailbox, rank=None), PortReceiver(
424
+ endpoint._mailbox, receiver
425
+ )
426
+
427
+
428
+ def ranked_port(
429
+ endpoint: Endpoint[P, R], once: bool = False
430
+ ) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
431
+ p, receiver = port(endpoint, once)
432
+ return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
433
+
434
+
435
+ class PortReceiver(Generic[R]):
436
+ def __init__(
437
+ self,
438
+ mailbox: Mailbox,
439
+ receiver: HyPortReceiver | OncePortReceiver,
440
+ ) -> None:
441
+ self._mailbox: Mailbox = mailbox
442
+ self._receiver: HyPortReceiver | OncePortReceiver = receiver
443
+
444
+ async def _recv(self) -> R:
445
+ return self._process(await self._receiver.recv())
446
+
447
+ def _blocking_recv(self) -> R:
448
+ return self._process(self._receiver.blocking_recv())
449
+
450
+ def _process(self, msg: PythonMessage) -> R:
451
+ # TODO: Try to do something more structured than a cast here
452
+ payload = cast(R, _unpickle(msg.message, self._mailbox))
453
+ if msg.method == "result":
454
+ return payload
455
+ else:
456
+ assert msg.method == "exception"
457
+ # pyre-ignore
458
+ raise payload
459
+
460
+ def recv(self) -> "Future[R]":
461
+ return Future(lambda: self._recv(), self._blocking_recv)
462
+
463
+
464
+ class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
465
+ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
466
+ if msg.rank is None:
467
+ raise ValueError("RankedPort receiver got a message without a rank")
468
+ return msg.rank, super()._process(msg)
469
+
470
+
471
+ singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
472
+
473
+
474
+ class _Actor:
475
+ """
476
+ This is the message handling implementation of a Python actor.
477
+
478
+ The layering goes:
479
+ Rust `PythonActor` -> `_Actor` -> user-provided `Actor` instance
480
+
481
+ Messages are received from the Rust backend, and forwarded to the `handle`
482
+ methods on this class.
483
+
484
+ This class wraps the actual `Actor` instance provided by the user, and
485
+ routes messages to it, managing argument serialization/deserialization and
486
+ error handling.
487
+ """
488
+
489
+ def __init__(self) -> None:
490
+ self.instance: object | None = None
491
+
492
+ async def handle(
493
+ self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
494
+ ) -> None:
495
+ return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag)
496
+
497
+ async def handle_cast(
498
+ self,
499
+ mailbox: Mailbox,
500
+ rank: int,
501
+ shape: Shape,
502
+ message: PythonMessage,
503
+ panic_flag: PanicFlag,
504
+ ) -> None:
505
+ port = (
506
+ Port(message.response_port, mailbox, rank)
507
+ if message.response_port
508
+ else None
509
+ )
510
+ try:
511
+ ctx: MonarchContext = MonarchContext(
512
+ mailbox, mailbox.actor_id.proc_id, Point(rank, shape)
513
+ )
514
+ _context.set(ctx)
515
+
516
+ args, kwargs = _unpickle(message.message, mailbox)
517
+
518
+ if message.method == "__init__":
519
+ Class, *args = args
520
+ self.instance = Class(*args, **kwargs)
521
+ return None
522
+
523
+ if self.instance is None:
524
+ raise AssertionError(
525
+ "__init__ failed earlier and no Actor object is available"
526
+ )
527
+ the_method = getattr(self.instance, message.method)._method
528
+
529
+ if inspect.iscoroutinefunction(the_method):
530
+
531
+ async def instrumented():
532
+ enter_span(
533
+ the_method.__module__,
534
+ message.method,
535
+ str(ctx.mailbox.actor_id),
536
+ )
537
+ try:
538
+ result = await the_method(self.instance, *args, **kwargs)
539
+ except Exception as e:
540
+ logging.critical(
541
+ "Unahndled exception in actor endpoint",
542
+ exc_info=e,
543
+ )
544
+ raise e
545
+ exit_span()
546
+ return result
547
+
548
+ result = await instrumented()
549
+ else:
550
+ enter_span(
551
+ the_method.__module__, message.method, str(ctx.mailbox.actor_id)
552
+ )
553
+ result = the_method(self.instance, *args, **kwargs)
554
+ exit_span()
555
+
556
+ if port is not None:
557
+ port.send("result", result)
558
+ except Exception as e:
559
+ traceback.print_exc()
560
+ s = ActorError(e)
561
+
562
+ # The exception is delivered to exactly one of:
563
+ # (1) our caller, (2) our supervisor
564
+ if port is not None:
565
+ port.send("exception", s)
566
+ else:
567
+ raise s from None
568
+ except BaseException as e:
569
+ # A BaseException can be thrown in the case of a Rust panic.
570
+ # In this case, we need a way to signal the panic to the Rust side.
571
+ # See [Panics in async endpoints]
572
+ try:
573
+ panic_flag.signal_panic(e)
574
+ except Exception:
575
+ # The channel might be closed if the Rust side has already detected the error
576
+ pass
577
+ raise
578
+
579
+
580
+ def _is_mailbox(x: object) -> bool:
581
+ return isinstance(x, Mailbox)
582
+
583
+
584
+ def _pickle(obj: object) -> bytes:
585
+ _, msg = flatten(obj, _is_mailbox)
586
+ return msg
587
+
588
+
589
+ def _unpickle(data: bytes, mailbox: Mailbox) -> Any:
590
+ # regardless of the mailboxes of the remote objects
591
+ # they all become the local mailbox.
592
+ return unflatten(data, itertools.repeat(mailbox))
593
+
594
+
595
+ class Actor(MeshTrait):
596
+ @functools.cached_property
597
+ def logger(cls) -> logging.Logger:
598
+ lgr = logging.getLogger(cls.__class__.__name__)
599
+ lgr.setLevel(logging.DEBUG)
600
+ return lgr
601
+
602
+ @property
603
+ def _ndslice(self) -> NDSlice:
604
+ raise NotImplementedError(
605
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
606
+ )
607
+
608
+ @property
609
+ def _labels(self) -> Tuple[str, ...]:
610
+ raise NotImplementedError(
611
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
612
+ )
613
+
614
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
615
+ raise NotImplementedError(
616
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
617
+ )
618
+
619
+ @endpoint # pyre-ignore
620
+ def _set_debug_client(self, client: "DebugClient") -> None:
621
+ point = MonarchContext.get().point
622
+ # For some reason, using a lambda instead of functools.partial
623
+ # confuses the pdb wrapper implementation.
624
+ sys.breakpointhook = functools.partial( # pyre-ignore
625
+ remote_breakpointhook,
626
+ point.rank,
627
+ point.shape.coordinates(point.rank),
628
+ MonarchContext.get().mailbox.actor_id,
629
+ client,
630
+ )
631
+
632
+
633
+ class ActorMeshRef(MeshTrait, Generic[T]):
634
+ def __init__(
635
+ self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
636
+ ) -> None:
637
+ self.__name__: str = Class.__name__
638
+ self._class: Type[T] = Class
639
+ self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref
640
+ self._mailbox: Mailbox = mailbox
641
+ for attr_name in dir(self._class):
642
+ attr_value = getattr(self._class, attr_name, None)
643
+ if isinstance(attr_value, EndpointProperty):
644
+ setattr(
645
+ self,
646
+ attr_name,
647
+ Endpoint(
648
+ self._actor_mesh_ref,
649
+ attr_name,
650
+ attr_value._method,
651
+ self._mailbox,
652
+ ),
653
+ )
654
+
655
+ def __getattr__(self, name: str) -> Any:
656
+ # This method is called when an attribute is not found
657
+ # For linting purposes, we need to tell the type checker that any attribute
658
+ # could be an endpoint that's dynamically added at runtime
659
+ # At runtime, we still want to raise AttributeError for truly missing attributes
660
+
661
+ # Check if this is a method on the underlying class
662
+ if hasattr(self._class, name):
663
+ attr = getattr(self._class, name)
664
+ if isinstance(attr, EndpointProperty):
665
+ # Dynamically create the endpoint
666
+ endpoint = Endpoint(
667
+ self._actor_mesh_ref,
668
+ name,
669
+ attr._method,
670
+ self._mailbox,
671
+ )
672
+ # Cache it for future use
673
+ setattr(self, name, endpoint)
674
+ return endpoint
675
+
676
+ # If we get here, it's truly not found
677
+ raise AttributeError(
678
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
679
+ )
680
+
681
+ def _create(
682
+ self,
683
+ args: Iterable[Any],
684
+ kwargs: Dict[str, Any],
685
+ ) -> None:
686
+ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
687
+ return None
688
+
689
+ ep = Endpoint(
690
+ self._actor_mesh_ref,
691
+ "__init__",
692
+ null_func,
693
+ self._mailbox,
694
+ )
695
+ send(ep, (self._class, *args), kwargs)
696
+
697
+ def __reduce_ex__(
698
+ self, protocol: ...
699
+ ) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]":
700
+ return ActorMeshRef, (
701
+ self._class,
702
+ self._actor_mesh_ref,
703
+ self._mailbox,
704
+ )
705
+
706
+ @property
707
+ def _ndslice(self) -> NDSlice:
708
+ return self._actor_mesh_ref._shape.ndslice
709
+
710
+ @property
711
+ def _labels(self) -> Iterable[str]:
712
+ return self._actor_mesh_ref._shape.labels
713
+
714
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
715
+ return ActorMeshRef(
716
+ self._class,
717
+ _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape),
718
+ self._mailbox,
719
+ )
720
+
721
+ def __repr__(self) -> str:
722
+ return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})"
723
+
724
+
725
+ class ActorError(Exception):
726
+ """
727
+ Deterministic problem with the user's code.
728
+ For example, an OOM resulting in trying to allocate too much GPU memory, or violating
729
+ some invariant enforced by the various APIs.
730
+ """
731
+
732
+ def __init__(
733
+ self,
734
+ exception: Exception,
735
+ message: str = "A remote actor call has failed asynchronously.",
736
+ ) -> None:
737
+ self.exception = exception
738
+ self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
739
+ self.message = message
740
+
741
+ def __str__(self) -> str:
742
+ exe = str(self.exception)
743
+ actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
744
+ return (
745
+ f"{self.message}\n"
746
+ f"Traceback of where the remote call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
747
+ )
748
+
749
+
750
+ def current_actor_name() -> str:
751
+ return str(MonarchContext.get().mailbox.actor_id)
752
+
753
+
754
+ def current_rank() -> Point:
755
+ ctx = MonarchContext.get()
756
+ return ctx.point
757
+
758
+
759
+ def current_size() -> Dict[str, int]:
760
+ ctx = MonarchContext.get()
761
+ return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes))