torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
monarch/actor_mesh.py ADDED
@@ -0,0 +1,692 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import collections
9
+ import contextvars
10
+ import inspect
11
+
12
+ import itertools
13
+ import logging
14
+ import random
15
+ import traceback
16
+
17
+ from dataclasses import dataclass
18
+ from traceback import extract_tb, StackSummary
19
+ from typing import (
20
+ Any,
21
+ AsyncGenerator,
22
+ Callable,
23
+ cast,
24
+ Concatenate,
25
+ Coroutine,
26
+ Dict,
27
+ Generator,
28
+ Generic,
29
+ Iterable,
30
+ List,
31
+ Literal,
32
+ Optional,
33
+ ParamSpec,
34
+ Tuple,
35
+ Type,
36
+ TypeVar,
37
+ )
38
+
39
+ import monarch
40
+ from monarch import ActorFuture as Future
41
+
42
+ from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
43
+ from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
44
+ from monarch._rust_bindings.monarch_hyperactor.mailbox import (
45
+ Mailbox,
46
+ OncePortReceiver,
47
+ PortId,
48
+ PortReceiver as HyPortReceiver,
49
+ )
50
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
51
+ from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
52
+ from monarch.common.pickle_flatten import flatten, unflatten
53
+ from monarch.common.shape import MeshTrait, NDSlice, Shape
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+ Allocator = monarch.ProcessAllocator | monarch.LocalAllocator
58
+
59
+ try:
60
+ from __manifest__ import fbmake # noqa
61
+
62
+ IN_PAR = True
63
+ except ImportError:
64
+ IN_PAR = False
65
+
66
+ T1 = TypeVar("T1")
67
+ T2 = TypeVar("T2")
68
+
69
+
70
+ class Point(HyPoint, collections.abc.Mapping):
71
+ pass
72
+
73
+
74
+ @dataclass
75
+ class MonarchContext:
76
+ mailbox: Mailbox
77
+ proc_id: str
78
+ point: Point
79
+
80
+ @staticmethod
81
+ def get() -> "MonarchContext":
82
+ return _context.get()
83
+
84
+
85
+ _context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
86
+ "monarch.service._context"
87
+ )
88
+
89
+
90
+ # this was implemented in python 3.12 as an argument to task
91
+ # but I have to backport to 3.10/3.11.
92
+ def create_eager_task(coro: Coroutine[Any, None, Any]) -> asyncio.Future:
93
+ iter = coro.__await__()
94
+ try:
95
+ first_yield = next(iter)
96
+ return asyncio.create_task(RestOfCoroutine(first_yield, iter).run())
97
+ except StopIteration as e:
98
+ t = asyncio.Future()
99
+ t.set_result(e.value)
100
+ return t
101
+
102
+
103
+ class RestOfCoroutine(Generic[T1, T2]):
104
+ def __init__(self, first_yield: T1, iter: Generator[T2, None, T2]) -> None:
105
+ self.first_yield: T1 | None = first_yield
106
+ self.iter: Generator[T2, None, T2] = iter
107
+
108
+ def __await__(self) -> Generator[T1, None, T1] | Generator[T2, None, T2]:
109
+ first_yield = self.first_yield
110
+ assert first_yield is not None
111
+ yield first_yield
112
+ self.first_yield = None
113
+ while True:
114
+ try:
115
+ yield next(self.iter)
116
+ except StopIteration as e:
117
+ return e.value
118
+
119
+ async def run(self) -> T1 | T2:
120
+ return await self
121
+
122
+
123
+ T = TypeVar("T")
124
+ P = ParamSpec("P")
125
+ R = TypeVar("R")
126
+ A = TypeVar("A")
127
+
128
+ # keep this load balancing deterministic, but
129
+ # equally distributed.
130
+ _load_balancing_seed = random.Random(4)
131
+
132
+
133
+ Selection = Literal["all", "choose"] # TODO: replace with real selection objects
134
+
135
+
136
+ # standin class for whatever is the serializable python object we use
137
+ # to name an actor mesh. Hacked up today because ActorMesh
138
+ # isn't plumbed to non-clients
139
+ class _ActorMeshRefImpl:
140
+ def __init__(
141
+ self,
142
+ mailbox: Mailbox,
143
+ hy_actor_mesh: Optional[PythonActorMesh],
144
+ shape: Shape,
145
+ actor_ids: List[ActorId],
146
+ ) -> None:
147
+ self._mailbox = mailbox
148
+ self._actor_mesh = hy_actor_mesh
149
+ self._shape = shape
150
+ self._please_replace_me_actor_ids = actor_ids
151
+
152
+ @staticmethod
153
+ def from_hyperactor_mesh(
154
+ mailbox: Mailbox, hy_actor_mesh: PythonActorMesh
155
+ ) -> "_ActorMeshRefImpl":
156
+ shape: Shape = hy_actor_mesh.shape
157
+ return _ActorMeshRefImpl(
158
+ mailbox,
159
+ hy_actor_mesh,
160
+ hy_actor_mesh.shape,
161
+ [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape.ndslice))],
162
+ )
163
+
164
+ @staticmethod
165
+ def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl":
166
+ return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id])
167
+
168
+ @staticmethod
169
+ def from_actor_ref_with_shape(
170
+ ref: "_ActorMeshRefImpl", shape: Shape
171
+ ) -> "_ActorMeshRefImpl":
172
+ return _ActorMeshRefImpl(
173
+ ref._mailbox, None, shape, ref._please_replace_me_actor_ids
174
+ )
175
+
176
+ def __getstate__(
177
+ self,
178
+ ) -> Tuple[Shape, List[ActorId], Mailbox]:
179
+ return self._shape, self._please_replace_me_actor_ids, self._mailbox
180
+
181
+ def __setstate__(
182
+ self,
183
+ state: Tuple[Shape, List[ActorId], Mailbox],
184
+ ) -> None:
185
+ self._actor_mesh = None
186
+ self._shape, self._please_replace_me_actor_ids, self._mailbox = state
187
+
188
+ def send(self, rank: int, message: PythonMessage) -> None:
189
+ actor = self._please_replace_me_actor_ids[rank]
190
+ self._mailbox.post(actor, message)
191
+
192
+ def cast(
193
+ self,
194
+ message: PythonMessage,
195
+ selection: Selection,
196
+ ) -> None:
197
+ # TODO: use the actual actor mesh when available. We cannot currently use it
198
+ # directly because we risk bifurcating the message delivery paths from the same
199
+ # client, since slicing the mesh will produce a reference, which calls actors
200
+ # directly. The reason these paths are bifurcated is that actor meshes will
201
+ # use multicasting, while direct actor comms do not. Separately we need to decide
202
+ # whether actor meshes are ordered with actor references.
203
+ #
204
+ # The fix is to provide a first-class reference into Python, and always call "cast"
205
+ # on it, including for load balanced requests.
206
+ if selection == "choose":
207
+ idx = _load_balancing_seed.randrange(len(self._shape.ndslice))
208
+ actor_rank = self._shape.ndslice[idx]
209
+ self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message)
210
+ return
211
+ elif selection == "all":
212
+ # replace me with actual remote actor mesh
213
+ call_shape = Shape(
214
+ self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes)
215
+ )
216
+ for i, rank in enumerate(self._shape.ranks()):
217
+ self._mailbox.post_cast(
218
+ self._please_replace_me_actor_ids[rank],
219
+ i,
220
+ call_shape,
221
+ message,
222
+ )
223
+ else:
224
+ raise ValueError(f"invalid selection: {selection}")
225
+
226
+ @property
227
+ def len(self) -> int:
228
+ return len(self._shape.ndslice)
229
+
230
+
231
+ class Endpoint(Generic[P, R]):
232
+ def __init__(
233
+ self,
234
+ actor_mesh_ref: _ActorMeshRefImpl,
235
+ name: str,
236
+ impl: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]],
237
+ mailbox: Mailbox,
238
+ ) -> None:
239
+ self._actor_mesh = actor_mesh_ref
240
+ self._name = name
241
+ self._signature: inspect.Signature = inspect.signature(impl)
242
+ self._mailbox = mailbox
243
+
244
+ # the following are all 'adverbs' or different ways to handle the
245
+ # return values of this endpoint. Adverbs should only ever take *args, **kwargs
246
+ # of the original call. If we want to add syntax sugar for something that needs additional
247
+ # arguments, it should be implemented as function indepdendent of endpoint like `send`
248
+ # and `Accumulator`
249
+ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
250
+ """
251
+ Load balanced sends a message to one chosen actor and awaits a result.
252
+
253
+ Load balanced RPC-style entrypoint for request/response messaging.
254
+ """
255
+ p, r = port(self, once=True)
256
+ # pyre-ignore
257
+ send(self, args, kwargs, port=p, selection="choose")
258
+ return r.recv()
259
+
260
+ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
261
+ if self._actor_mesh.len != 1:
262
+ raise ValueError(
263
+ f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}"
264
+ )
265
+ return self.choose(*args, **kwargs)
266
+
267
+ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
268
+ p, r = port(self, kind=RankedPort)
269
+ # pyre-ignore
270
+ send(self, args, kwargs, port=p)
271
+
272
+ async def process():
273
+ results = [None] * self._actor_mesh.len
274
+ for _ in range(self._actor_mesh.len):
275
+ rank, value = await r.recv()
276
+ results[rank] = value
277
+ call_shape = Shape(
278
+ self._actor_mesh._shape.labels,
279
+ NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
280
+ )
281
+ return ValueMesh(call_shape, results)
282
+
283
+ return Future(process)
284
+
285
+ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
286
+ """
287
+ Broadcasts to all actors and yields their responses as a stream / generator.
288
+
289
+ This enables processing results from multiple actors incrementally as
290
+ they become available. Returns an async generator of response values.
291
+ """
292
+ p, r = port(self)
293
+ # pyre-ignore
294
+ send(self, args, kwargs, port=p)
295
+ for _ in range(self._actor_mesh.len):
296
+ yield await r.recv()
297
+
298
+ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
299
+ """
300
+ Broadcast to all actors and wait for each to acknowledge receipt.
301
+
302
+ This behaves like `cast`, but ensures that each actor has received and
303
+ processed the message by awaiting a response from each one. Does not
304
+ return any results.
305
+ """
306
+ # pyre-ignore
307
+ send(self, args, kwargs)
308
+
309
+
310
+ class Accumulator(Generic[P, R, A]):
311
+ def __init__(
312
+ self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
313
+ ):
314
+ self._endpoint = endpoint
315
+ self._identity = identity
316
+ self._combine = combine
317
+
318
+ def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
319
+ gen = self._endpoint.stream(*args, **kwargs)
320
+
321
+ async def impl():
322
+ value = self._identity
323
+ async for x in gen:
324
+ value = self._combine(value, x)
325
+ return value
326
+
327
+ return Future(impl)
328
+
329
+
330
+ class ValueMesh(MeshTrait, Generic[R]):
331
+ def __init__(self, shape: Shape, values: List[R]) -> None:
332
+ self._shape = shape
333
+ self._values = values
334
+
335
+ def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
336
+ return ValueMesh(shape, self._values)
337
+
338
+ def item(self, **kwargs):
339
+ coordinates = [kwargs.pop(label) for label in self._labels]
340
+ if kwargs:
341
+ raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
342
+
343
+ return self._values[self._ndslice.nditem(coordinates)]
344
+
345
+ def __iter__(self):
346
+ for rank in self._shape.ranks():
347
+ yield Point(rank, self._shape), self._values[rank]
348
+
349
+ @property
350
+ def _ndslice(self) -> NDSlice:
351
+ return self._shape.ndslice
352
+
353
+ @property
354
+ def _labels(self) -> Iterable[str]:
355
+ return self._shape.labels
356
+
357
+
358
+ def send(
359
+ endpoint: Endpoint[P, R],
360
+ args: Tuple[Any, ...],
361
+ kwargs: Dict[str, Any],
362
+ port: "Optional[Port]" = None,
363
+ selection: Selection = "all",
364
+ ) -> None:
365
+ """
366
+ Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
367
+
368
+ This sends the message to all actors but does not wait for any result.
369
+ """
370
+ endpoint._signature.bind(None, *args, **kwargs)
371
+ message = PythonMessage(endpoint._name, _pickle((args, kwargs, port)))
372
+ endpoint._actor_mesh.cast(message, selection)
373
+
374
+
375
+ class EndpointProperty(Generic[P, R]):
376
+ def __init__(self, method: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]]):
377
+ self._method = method
378
+
379
+ def __get__(self, instance, owner) -> Endpoint[P, R]:
380
+ # this is a total lie, but we have to actually
381
+ # recognize this was defined as an endpoint,
382
+ # and also lookup the method
383
+ return cast(Endpoint[P, R], self)
384
+
385
+
386
+ def endpoint(
387
+ method: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]],
388
+ ) -> EndpointProperty[P, R]:
389
+ return EndpointProperty(method)
390
+
391
+
392
+ class Port:
393
+ def __init__(self, port: PortId, mailbox: Mailbox) -> None:
394
+ self._port = port
395
+ self._mailbox = mailbox
396
+
397
+ def send(self, method: str, obj: object) -> None:
398
+ self._mailbox.post(
399
+ self._port,
400
+ PythonMessage(method, _pickle(obj)),
401
+ )
402
+
403
+
404
+ # advance lower-level API for sending messages. This is intentially
405
+ # not part of the Endpoint API because they way it accepts arguments
406
+ # and handles concerns is different.
407
+ def port(
408
+ endpoint: Endpoint[P, R], once=False, kind=Port
409
+ ) -> Tuple["Port", "PortReceiver[R]"]:
410
+ handle, receiver = (
411
+ endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
412
+ )
413
+ port_id: PortId = handle.bind()
414
+ return kind(port_id, endpoint._mailbox), PortReceiver(endpoint._mailbox, receiver)
415
+
416
+
417
+ class PortReceiver(Generic[R]):
418
+ def __init__(
419
+ self,
420
+ mailbox: Mailbox,
421
+ receiver: HyPortReceiver | OncePortReceiver,
422
+ ):
423
+ self._mailbox = mailbox
424
+ self._receiver = receiver
425
+
426
+ async def _recv(self) -> R:
427
+ return self._process(await self._receiver.recv())
428
+
429
+ def _blocking_recv(self) -> R:
430
+ return self._process(self._receiver.blocking_recv())
431
+
432
+ def _process(self, msg: PythonMessage):
433
+ # TODO: Try to do something more structured than a cast here
434
+ payload = cast(R, _unpickle(msg.message, self._mailbox))
435
+ if msg.method == "result":
436
+ return payload
437
+ else:
438
+ assert msg.method == "exception"
439
+ if isinstance(payload, tuple):
440
+ # If we're receiving on a RankedPort, raise the exception and ignore the rank.
441
+ # pyre-ignore do something more structured here
442
+ raise payload[1]
443
+ else:
444
+ # pyre-ignore
445
+ raise payload
446
+
447
+ def recv(self) -> "Future[R]":
448
+ return Future(lambda: self._recv(), self._blocking_recv)
449
+
450
+
451
+ singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
452
+
453
+
454
+ class RankedPort(Port):
455
+ def send(self, method: str, obj: object) -> None:
456
+ super().send(method, (MonarchContext.get().point.rank, obj))
457
+
458
+
459
+ class _Actor:
460
+ def __init__(self) -> None:
461
+ self.instance: object | None = None
462
+ self.active_requests: asyncio.Queue[asyncio.Future[object]] = asyncio.Queue()
463
+ self.complete_task: object | None = None
464
+
465
+ def handle(
466
+ self, mailbox: Mailbox, message: PythonMessage
467
+ ) -> Optional[Coroutine[Any, Any, Any]]:
468
+ return self.handle_cast(mailbox, 0, singleton_shape, message)
469
+
470
+ def handle_cast(
471
+ self,
472
+ mailbox: Mailbox,
473
+ rank: int,
474
+ shape: Shape,
475
+ message: PythonMessage,
476
+ ) -> Optional[Coroutine[Any, Any, Any]]:
477
+ port = None
478
+ try:
479
+ args, kwargs, port = _unpickle(message.message, mailbox)
480
+
481
+ ctx = MonarchContext(mailbox, mailbox.actor_id.proc_id, Point(rank, shape))
482
+ _context.set(ctx)
483
+
484
+ if message.method == "__init__":
485
+ Class, *args = args
486
+ self.instance = Class(*args, **kwargs)
487
+ return None
488
+ else:
489
+ the_method = getattr(self.instance, message.method)._method
490
+ result = the_method(self.instance, *args, **kwargs)
491
+ if not inspect.iscoroutinefunction(the_method):
492
+ if port is not None:
493
+ port.send("result", result)
494
+ return None
495
+
496
+ return self.run_async(ctx, self.run_task(port, result))
497
+ except Exception as e:
498
+ traceback.print_exc()
499
+ s = ActorMeshRefCallFailedException(e)
500
+
501
+ # The exception is delivered to exactly one of:
502
+ # (1) our caller, (2) our supervisor
503
+ if port is not None:
504
+ port.send("exception", s)
505
+ else:
506
+ raise s from None
507
+
508
+ async def run_async(self, ctx, coroutine):
509
+ _context.set(ctx)
510
+ if self.complete_task is None:
511
+ asyncio.create_task(self._complete())
512
+ await self.active_requests.put(create_eager_task(coroutine))
513
+
514
+ async def run_task(self, port, coroutine):
515
+ try:
516
+ result = await coroutine
517
+ if port is not None:
518
+ port.send("result", result)
519
+ except Exception as e:
520
+ traceback.print_exc()
521
+ s = ActorMeshRefCallFailedException(e)
522
+
523
+ # The exception is delivered to exactly one of:
524
+ # (1) our caller, (2) our supervisor
525
+ if port is not None:
526
+ port.send("exception", s)
527
+ else:
528
+ raise s from None
529
+
530
+ async def _complete(self) -> None:
531
+ while True:
532
+ task = await self.active_requests.get()
533
+ await task
534
+
535
+
536
+ def _is_mailbox(x: object) -> bool:
537
+ return isinstance(x, Mailbox)
538
+
539
+
540
+ def _pickle(obj: object) -> bytes:
541
+ _, msg = flatten(obj, _is_mailbox)
542
+ return msg
543
+
544
+
545
+ def _unpickle(data: bytes, mailbox: Mailbox) -> Any:
546
+ # regardless of the mailboxes of the remote objects
547
+ # they all become the local mailbox.
548
+ return unflatten(data, itertools.repeat(mailbox))
549
+
550
+
551
+ class Actor(MeshTrait):
552
+ @property
553
+ def _ndslice(self) -> NDSlice:
554
+ raise NotImplementedError(
555
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
556
+ )
557
+
558
+ @property
559
+ def _labels(self) -> Tuple[str, ...]:
560
+ raise NotImplementedError(
561
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
562
+ )
563
+
564
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
565
+ raise NotImplementedError(
566
+ "actor implementations are not meshes, but we can't convince the typechecker of it..."
567
+ )
568
+
569
+
570
+ class ActorMeshRef(MeshTrait):
571
+ def __init__(
572
+ self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
573
+ ) -> None:
574
+ self.__name__ = Class.__name__
575
+ self._class = Class
576
+ self._actor_mesh_ref = actor_mesh_ref
577
+ self._mailbox = mailbox
578
+ for attr_name in dir(self._class):
579
+ attr_value = getattr(self._class, attr_name, None)
580
+ if isinstance(attr_value, EndpointProperty):
581
+ setattr(
582
+ self,
583
+ attr_name,
584
+ Endpoint(
585
+ self._actor_mesh_ref,
586
+ attr_name,
587
+ attr_value._method,
588
+ self._mailbox,
589
+ ),
590
+ )
591
+
592
+ def __getattr__(self, name: str) -> Any:
593
+ # This method is called when an attribute is not found
594
+ # For linting purposes, we need to tell the type checker that any attribute
595
+ # could be an endpoint that's dynamically added at runtime
596
+ # At runtime, we still want to raise AttributeError for truly missing attributes
597
+
598
+ # Check if this is a method on the underlying class
599
+ if hasattr(self._class, name):
600
+ attr = getattr(self._class, name)
601
+ if isinstance(attr, EndpointProperty):
602
+ # Dynamically create the endpoint
603
+ endpoint = Endpoint(
604
+ self._actor_mesh_ref,
605
+ name,
606
+ attr._method,
607
+ self._mailbox,
608
+ )
609
+ # Cache it for future use
610
+ setattr(self, name, endpoint)
611
+ return endpoint
612
+
613
+ # If we get here, it's truly not found
614
+ raise AttributeError(
615
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
616
+ )
617
+
618
+ def _create(self, args: Iterable[Any], kwargs: Dict[str, Any]) -> None:
619
+ async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
620
+ return None
621
+
622
+ ep = Endpoint(
623
+ self._actor_mesh_ref,
624
+ "__init__",
625
+ null_func,
626
+ self._mailbox,
627
+ )
628
+ # pyre-ignore
629
+ send(ep, (self._class, *args), kwargs)
630
+
631
+ def __reduce_ex__(
632
+ self, protocol: ...
633
+ ) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]":
634
+ return ActorMeshRef, (
635
+ self._class,
636
+ self._actor_mesh_ref,
637
+ self._mailbox,
638
+ )
639
+
640
+ @property
641
+ def _ndslice(self) -> NDSlice:
642
+ return self._actor_mesh_ref._shape.ndslice
643
+
644
+ @property
645
+ def _labels(self) -> Iterable[str]:
646
+ return self._actor_mesh_ref._shape.labels
647
+
648
+ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
649
+ return ActorMeshRef(
650
+ self._class,
651
+ _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape),
652
+ self._mailbox,
653
+ )
654
+
655
+
656
+ class ActorMeshRefCallFailedException(Exception):
657
+ """
658
+ Deterministic problem with the user's code.
659
+ For example, an OOM resulting in trying to allocate too much GPU memory, or violating
660
+ some invariant enforced by the various APIs.
661
+ """
662
+
663
+ def __init__(
664
+ self,
665
+ exception: Exception,
666
+ message: str = "A remote service call has failed asynchronously.",
667
+ ) -> None:
668
+ self.exception = exception
669
+ self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
670
+ self.message = message
671
+
672
+ def __str__(self) -> str:
673
+ exe = str(self.exception)
674
+ actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
675
+ return (
676
+ f"{self.message}\n"
677
+ f"Traceback of where the service call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
678
+ )
679
+
680
+
681
+ def current_actor_name() -> str:
682
+ return str(MonarchContext.get().mailbox.actor_id)
683
+
684
+
685
+ def current_rank() -> Point:
686
+ ctx = MonarchContext.get()
687
+ return ctx.point
688
+
689
+
690
+ def current_size() -> Dict[str, int]:
691
+ ctx = MonarchContext.get()
692
+ return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes))