torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
monarch/actor_mesh.py
ADDED
@@ -0,0 +1,692 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import collections
|
9
|
+
import contextvars
|
10
|
+
import inspect
|
11
|
+
|
12
|
+
import itertools
|
13
|
+
import logging
|
14
|
+
import random
|
15
|
+
import traceback
|
16
|
+
|
17
|
+
from dataclasses import dataclass
|
18
|
+
from traceback import extract_tb, StackSummary
|
19
|
+
from typing import (
|
20
|
+
Any,
|
21
|
+
AsyncGenerator,
|
22
|
+
Callable,
|
23
|
+
cast,
|
24
|
+
Concatenate,
|
25
|
+
Coroutine,
|
26
|
+
Dict,
|
27
|
+
Generator,
|
28
|
+
Generic,
|
29
|
+
Iterable,
|
30
|
+
List,
|
31
|
+
Literal,
|
32
|
+
Optional,
|
33
|
+
ParamSpec,
|
34
|
+
Tuple,
|
35
|
+
Type,
|
36
|
+
TypeVar,
|
37
|
+
)
|
38
|
+
|
39
|
+
import monarch
|
40
|
+
from monarch import ActorFuture as Future
|
41
|
+
|
42
|
+
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
|
43
|
+
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
|
44
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
|
45
|
+
Mailbox,
|
46
|
+
OncePortReceiver,
|
47
|
+
PortId,
|
48
|
+
PortReceiver as HyPortReceiver,
|
49
|
+
)
|
50
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
51
|
+
from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
|
52
|
+
from monarch.common.pickle_flatten import flatten, unflatten
|
53
|
+
from monarch.common.shape import MeshTrait, NDSlice, Shape
|
54
|
+
|
55
|
+
logger = logging.getLogger(__name__)
|
56
|
+
|
57
|
+
Allocator = monarch.ProcessAllocator | monarch.LocalAllocator
|
58
|
+
|
59
|
+
try:
|
60
|
+
from __manifest__ import fbmake # noqa
|
61
|
+
|
62
|
+
IN_PAR = True
|
63
|
+
except ImportError:
|
64
|
+
IN_PAR = False
|
65
|
+
|
66
|
+
T1 = TypeVar("T1")
|
67
|
+
T2 = TypeVar("T2")
|
68
|
+
|
69
|
+
|
70
|
+
class Point(HyPoint, collections.abc.Mapping):
|
71
|
+
pass
|
72
|
+
|
73
|
+
|
74
|
+
@dataclass
|
75
|
+
class MonarchContext:
|
76
|
+
mailbox: Mailbox
|
77
|
+
proc_id: str
|
78
|
+
point: Point
|
79
|
+
|
80
|
+
@staticmethod
|
81
|
+
def get() -> "MonarchContext":
|
82
|
+
return _context.get()
|
83
|
+
|
84
|
+
|
85
|
+
_context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
|
86
|
+
"monarch.service._context"
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
# this was implemented in python 3.12 as an argument to task
|
91
|
+
# but I have to backport to 3.10/3.11.
|
92
|
+
def create_eager_task(coro: Coroutine[Any, None, Any]) -> asyncio.Future:
|
93
|
+
iter = coro.__await__()
|
94
|
+
try:
|
95
|
+
first_yield = next(iter)
|
96
|
+
return asyncio.create_task(RestOfCoroutine(first_yield, iter).run())
|
97
|
+
except StopIteration as e:
|
98
|
+
t = asyncio.Future()
|
99
|
+
t.set_result(e.value)
|
100
|
+
return t
|
101
|
+
|
102
|
+
|
103
|
+
class RestOfCoroutine(Generic[T1, T2]):
|
104
|
+
def __init__(self, first_yield: T1, iter: Generator[T2, None, T2]) -> None:
|
105
|
+
self.first_yield: T1 | None = first_yield
|
106
|
+
self.iter: Generator[T2, None, T2] = iter
|
107
|
+
|
108
|
+
def __await__(self) -> Generator[T1, None, T1] | Generator[T2, None, T2]:
|
109
|
+
first_yield = self.first_yield
|
110
|
+
assert first_yield is not None
|
111
|
+
yield first_yield
|
112
|
+
self.first_yield = None
|
113
|
+
while True:
|
114
|
+
try:
|
115
|
+
yield next(self.iter)
|
116
|
+
except StopIteration as e:
|
117
|
+
return e.value
|
118
|
+
|
119
|
+
async def run(self) -> T1 | T2:
|
120
|
+
return await self
|
121
|
+
|
122
|
+
|
123
|
+
T = TypeVar("T")
|
124
|
+
P = ParamSpec("P")
|
125
|
+
R = TypeVar("R")
|
126
|
+
A = TypeVar("A")
|
127
|
+
|
128
|
+
# keep this load balancing deterministic, but
|
129
|
+
# equally distributed.
|
130
|
+
_load_balancing_seed = random.Random(4)
|
131
|
+
|
132
|
+
|
133
|
+
Selection = Literal["all", "choose"] # TODO: replace with real selection objects
|
134
|
+
|
135
|
+
|
136
|
+
# standin class for whatever is the serializable python object we use
|
137
|
+
# to name an actor mesh. Hacked up today because ActorMesh
|
138
|
+
# isn't plumbed to non-clients
|
139
|
+
class _ActorMeshRefImpl:
|
140
|
+
def __init__(
|
141
|
+
self,
|
142
|
+
mailbox: Mailbox,
|
143
|
+
hy_actor_mesh: Optional[PythonActorMesh],
|
144
|
+
shape: Shape,
|
145
|
+
actor_ids: List[ActorId],
|
146
|
+
) -> None:
|
147
|
+
self._mailbox = mailbox
|
148
|
+
self._actor_mesh = hy_actor_mesh
|
149
|
+
self._shape = shape
|
150
|
+
self._please_replace_me_actor_ids = actor_ids
|
151
|
+
|
152
|
+
@staticmethod
|
153
|
+
def from_hyperactor_mesh(
|
154
|
+
mailbox: Mailbox, hy_actor_mesh: PythonActorMesh
|
155
|
+
) -> "_ActorMeshRefImpl":
|
156
|
+
shape: Shape = hy_actor_mesh.shape
|
157
|
+
return _ActorMeshRefImpl(
|
158
|
+
mailbox,
|
159
|
+
hy_actor_mesh,
|
160
|
+
hy_actor_mesh.shape,
|
161
|
+
[cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape.ndslice))],
|
162
|
+
)
|
163
|
+
|
164
|
+
@staticmethod
|
165
|
+
def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl":
|
166
|
+
return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id])
|
167
|
+
|
168
|
+
@staticmethod
|
169
|
+
def from_actor_ref_with_shape(
|
170
|
+
ref: "_ActorMeshRefImpl", shape: Shape
|
171
|
+
) -> "_ActorMeshRefImpl":
|
172
|
+
return _ActorMeshRefImpl(
|
173
|
+
ref._mailbox, None, shape, ref._please_replace_me_actor_ids
|
174
|
+
)
|
175
|
+
|
176
|
+
def __getstate__(
|
177
|
+
self,
|
178
|
+
) -> Tuple[Shape, List[ActorId], Mailbox]:
|
179
|
+
return self._shape, self._please_replace_me_actor_ids, self._mailbox
|
180
|
+
|
181
|
+
def __setstate__(
|
182
|
+
self,
|
183
|
+
state: Tuple[Shape, List[ActorId], Mailbox],
|
184
|
+
) -> None:
|
185
|
+
self._actor_mesh = None
|
186
|
+
self._shape, self._please_replace_me_actor_ids, self._mailbox = state
|
187
|
+
|
188
|
+
def send(self, rank: int, message: PythonMessage) -> None:
|
189
|
+
actor = self._please_replace_me_actor_ids[rank]
|
190
|
+
self._mailbox.post(actor, message)
|
191
|
+
|
192
|
+
def cast(
|
193
|
+
self,
|
194
|
+
message: PythonMessage,
|
195
|
+
selection: Selection,
|
196
|
+
) -> None:
|
197
|
+
# TODO: use the actual actor mesh when available. We cannot currently use it
|
198
|
+
# directly because we risk bifurcating the message delivery paths from the same
|
199
|
+
# client, since slicing the mesh will produce a reference, which calls actors
|
200
|
+
# directly. The reason these paths are bifurcated is that actor meshes will
|
201
|
+
# use multicasting, while direct actor comms do not. Separately we need to decide
|
202
|
+
# whether actor meshes are ordered with actor references.
|
203
|
+
#
|
204
|
+
# The fix is to provide a first-class reference into Python, and always call "cast"
|
205
|
+
# on it, including for load balanced requests.
|
206
|
+
if selection == "choose":
|
207
|
+
idx = _load_balancing_seed.randrange(len(self._shape.ndslice))
|
208
|
+
actor_rank = self._shape.ndslice[idx]
|
209
|
+
self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message)
|
210
|
+
return
|
211
|
+
elif selection == "all":
|
212
|
+
# replace me with actual remote actor mesh
|
213
|
+
call_shape = Shape(
|
214
|
+
self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes)
|
215
|
+
)
|
216
|
+
for i, rank in enumerate(self._shape.ranks()):
|
217
|
+
self._mailbox.post_cast(
|
218
|
+
self._please_replace_me_actor_ids[rank],
|
219
|
+
i,
|
220
|
+
call_shape,
|
221
|
+
message,
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
raise ValueError(f"invalid selection: {selection}")
|
225
|
+
|
226
|
+
@property
|
227
|
+
def len(self) -> int:
|
228
|
+
return len(self._shape.ndslice)
|
229
|
+
|
230
|
+
|
231
|
+
class Endpoint(Generic[P, R]):
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
actor_mesh_ref: _ActorMeshRefImpl,
|
235
|
+
name: str,
|
236
|
+
impl: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]],
|
237
|
+
mailbox: Mailbox,
|
238
|
+
) -> None:
|
239
|
+
self._actor_mesh = actor_mesh_ref
|
240
|
+
self._name = name
|
241
|
+
self._signature: inspect.Signature = inspect.signature(impl)
|
242
|
+
self._mailbox = mailbox
|
243
|
+
|
244
|
+
# the following are all 'adverbs' or different ways to handle the
|
245
|
+
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
|
246
|
+
# of the original call. If we want to add syntax sugar for something that needs additional
|
247
|
+
# arguments, it should be implemented as function indepdendent of endpoint like `send`
|
248
|
+
# and `Accumulator`
|
249
|
+
def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
|
250
|
+
"""
|
251
|
+
Load balanced sends a message to one chosen actor and awaits a result.
|
252
|
+
|
253
|
+
Load balanced RPC-style entrypoint for request/response messaging.
|
254
|
+
"""
|
255
|
+
p, r = port(self, once=True)
|
256
|
+
# pyre-ignore
|
257
|
+
send(self, args, kwargs, port=p, selection="choose")
|
258
|
+
return r.recv()
|
259
|
+
|
260
|
+
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
|
261
|
+
if self._actor_mesh.len != 1:
|
262
|
+
raise ValueError(
|
263
|
+
f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}"
|
264
|
+
)
|
265
|
+
return self.choose(*args, **kwargs)
|
266
|
+
|
267
|
+
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
|
268
|
+
p, r = port(self, kind=RankedPort)
|
269
|
+
# pyre-ignore
|
270
|
+
send(self, args, kwargs, port=p)
|
271
|
+
|
272
|
+
async def process():
|
273
|
+
results = [None] * self._actor_mesh.len
|
274
|
+
for _ in range(self._actor_mesh.len):
|
275
|
+
rank, value = await r.recv()
|
276
|
+
results[rank] = value
|
277
|
+
call_shape = Shape(
|
278
|
+
self._actor_mesh._shape.labels,
|
279
|
+
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
|
280
|
+
)
|
281
|
+
return ValueMesh(call_shape, results)
|
282
|
+
|
283
|
+
return Future(process)
|
284
|
+
|
285
|
+
async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
|
286
|
+
"""
|
287
|
+
Broadcasts to all actors and yields their responses as a stream / generator.
|
288
|
+
|
289
|
+
This enables processing results from multiple actors incrementally as
|
290
|
+
they become available. Returns an async generator of response values.
|
291
|
+
"""
|
292
|
+
p, r = port(self)
|
293
|
+
# pyre-ignore
|
294
|
+
send(self, args, kwargs, port=p)
|
295
|
+
for _ in range(self._actor_mesh.len):
|
296
|
+
yield await r.recv()
|
297
|
+
|
298
|
+
def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
299
|
+
"""
|
300
|
+
Broadcast to all actors and wait for each to acknowledge receipt.
|
301
|
+
|
302
|
+
This behaves like `cast`, but ensures that each actor has received and
|
303
|
+
processed the message by awaiting a response from each one. Does not
|
304
|
+
return any results.
|
305
|
+
"""
|
306
|
+
# pyre-ignore
|
307
|
+
send(self, args, kwargs)
|
308
|
+
|
309
|
+
|
310
|
+
class Accumulator(Generic[P, R, A]):
|
311
|
+
def __init__(
|
312
|
+
self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
|
313
|
+
):
|
314
|
+
self._endpoint = endpoint
|
315
|
+
self._identity = identity
|
316
|
+
self._combine = combine
|
317
|
+
|
318
|
+
def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
|
319
|
+
gen = self._endpoint.stream(*args, **kwargs)
|
320
|
+
|
321
|
+
async def impl():
|
322
|
+
value = self._identity
|
323
|
+
async for x in gen:
|
324
|
+
value = self._combine(value, x)
|
325
|
+
return value
|
326
|
+
|
327
|
+
return Future(impl)
|
328
|
+
|
329
|
+
|
330
|
+
class ValueMesh(MeshTrait, Generic[R]):
|
331
|
+
def __init__(self, shape: Shape, values: List[R]) -> None:
|
332
|
+
self._shape = shape
|
333
|
+
self._values = values
|
334
|
+
|
335
|
+
def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
|
336
|
+
return ValueMesh(shape, self._values)
|
337
|
+
|
338
|
+
def item(self, **kwargs):
|
339
|
+
coordinates = [kwargs.pop(label) for label in self._labels]
|
340
|
+
if kwargs:
|
341
|
+
raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
|
342
|
+
|
343
|
+
return self._values[self._ndslice.nditem(coordinates)]
|
344
|
+
|
345
|
+
def __iter__(self):
|
346
|
+
for rank in self._shape.ranks():
|
347
|
+
yield Point(rank, self._shape), self._values[rank]
|
348
|
+
|
349
|
+
@property
|
350
|
+
def _ndslice(self) -> NDSlice:
|
351
|
+
return self._shape.ndslice
|
352
|
+
|
353
|
+
@property
|
354
|
+
def _labels(self) -> Iterable[str]:
|
355
|
+
return self._shape.labels
|
356
|
+
|
357
|
+
|
358
|
+
def send(
|
359
|
+
endpoint: Endpoint[P, R],
|
360
|
+
args: Tuple[Any, ...],
|
361
|
+
kwargs: Dict[str, Any],
|
362
|
+
port: "Optional[Port]" = None,
|
363
|
+
selection: Selection = "all",
|
364
|
+
) -> None:
|
365
|
+
"""
|
366
|
+
Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
|
367
|
+
|
368
|
+
This sends the message to all actors but does not wait for any result.
|
369
|
+
"""
|
370
|
+
endpoint._signature.bind(None, *args, **kwargs)
|
371
|
+
message = PythonMessage(endpoint._name, _pickle((args, kwargs, port)))
|
372
|
+
endpoint._actor_mesh.cast(message, selection)
|
373
|
+
|
374
|
+
|
375
|
+
class EndpointProperty(Generic[P, R]):
|
376
|
+
def __init__(self, method: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]]):
|
377
|
+
self._method = method
|
378
|
+
|
379
|
+
def __get__(self, instance, owner) -> Endpoint[P, R]:
|
380
|
+
# this is a total lie, but we have to actually
|
381
|
+
# recognize this was defined as an endpoint,
|
382
|
+
# and also lookup the method
|
383
|
+
return cast(Endpoint[P, R], self)
|
384
|
+
|
385
|
+
|
386
|
+
def endpoint(
|
387
|
+
method: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]],
|
388
|
+
) -> EndpointProperty[P, R]:
|
389
|
+
return EndpointProperty(method)
|
390
|
+
|
391
|
+
|
392
|
+
class Port:
|
393
|
+
def __init__(self, port: PortId, mailbox: Mailbox) -> None:
|
394
|
+
self._port = port
|
395
|
+
self._mailbox = mailbox
|
396
|
+
|
397
|
+
def send(self, method: str, obj: object) -> None:
|
398
|
+
self._mailbox.post(
|
399
|
+
self._port,
|
400
|
+
PythonMessage(method, _pickle(obj)),
|
401
|
+
)
|
402
|
+
|
403
|
+
|
404
|
+
# advance lower-level API for sending messages. This is intentially
|
405
|
+
# not part of the Endpoint API because they way it accepts arguments
|
406
|
+
# and handles concerns is different.
|
407
|
+
def port(
|
408
|
+
endpoint: Endpoint[P, R], once=False, kind=Port
|
409
|
+
) -> Tuple["Port", "PortReceiver[R]"]:
|
410
|
+
handle, receiver = (
|
411
|
+
endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
|
412
|
+
)
|
413
|
+
port_id: PortId = handle.bind()
|
414
|
+
return kind(port_id, endpoint._mailbox), PortReceiver(endpoint._mailbox, receiver)
|
415
|
+
|
416
|
+
|
417
|
+
class PortReceiver(Generic[R]):
|
418
|
+
def __init__(
|
419
|
+
self,
|
420
|
+
mailbox: Mailbox,
|
421
|
+
receiver: HyPortReceiver | OncePortReceiver,
|
422
|
+
):
|
423
|
+
self._mailbox = mailbox
|
424
|
+
self._receiver = receiver
|
425
|
+
|
426
|
+
async def _recv(self) -> R:
|
427
|
+
return self._process(await self._receiver.recv())
|
428
|
+
|
429
|
+
def _blocking_recv(self) -> R:
|
430
|
+
return self._process(self._receiver.blocking_recv())
|
431
|
+
|
432
|
+
def _process(self, msg: PythonMessage):
|
433
|
+
# TODO: Try to do something more structured than a cast here
|
434
|
+
payload = cast(R, _unpickle(msg.message, self._mailbox))
|
435
|
+
if msg.method == "result":
|
436
|
+
return payload
|
437
|
+
else:
|
438
|
+
assert msg.method == "exception"
|
439
|
+
if isinstance(payload, tuple):
|
440
|
+
# If we're receiving on a RankedPort, raise the exception and ignore the rank.
|
441
|
+
# pyre-ignore do something more structured here
|
442
|
+
raise payload[1]
|
443
|
+
else:
|
444
|
+
# pyre-ignore
|
445
|
+
raise payload
|
446
|
+
|
447
|
+
def recv(self) -> "Future[R]":
|
448
|
+
return Future(lambda: self._recv(), self._blocking_recv)
|
449
|
+
|
450
|
+
|
451
|
+
singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
|
452
|
+
|
453
|
+
|
454
|
+
class RankedPort(Port):
|
455
|
+
def send(self, method: str, obj: object) -> None:
|
456
|
+
super().send(method, (MonarchContext.get().point.rank, obj))
|
457
|
+
|
458
|
+
|
459
|
+
class _Actor:
|
460
|
+
def __init__(self) -> None:
|
461
|
+
self.instance: object | None = None
|
462
|
+
self.active_requests: asyncio.Queue[asyncio.Future[object]] = asyncio.Queue()
|
463
|
+
self.complete_task: object | None = None
|
464
|
+
|
465
|
+
def handle(
|
466
|
+
self, mailbox: Mailbox, message: PythonMessage
|
467
|
+
) -> Optional[Coroutine[Any, Any, Any]]:
|
468
|
+
return self.handle_cast(mailbox, 0, singleton_shape, message)
|
469
|
+
|
470
|
+
def handle_cast(
|
471
|
+
self,
|
472
|
+
mailbox: Mailbox,
|
473
|
+
rank: int,
|
474
|
+
shape: Shape,
|
475
|
+
message: PythonMessage,
|
476
|
+
) -> Optional[Coroutine[Any, Any, Any]]:
|
477
|
+
port = None
|
478
|
+
try:
|
479
|
+
args, kwargs, port = _unpickle(message.message, mailbox)
|
480
|
+
|
481
|
+
ctx = MonarchContext(mailbox, mailbox.actor_id.proc_id, Point(rank, shape))
|
482
|
+
_context.set(ctx)
|
483
|
+
|
484
|
+
if message.method == "__init__":
|
485
|
+
Class, *args = args
|
486
|
+
self.instance = Class(*args, **kwargs)
|
487
|
+
return None
|
488
|
+
else:
|
489
|
+
the_method = getattr(self.instance, message.method)._method
|
490
|
+
result = the_method(self.instance, *args, **kwargs)
|
491
|
+
if not inspect.iscoroutinefunction(the_method):
|
492
|
+
if port is not None:
|
493
|
+
port.send("result", result)
|
494
|
+
return None
|
495
|
+
|
496
|
+
return self.run_async(ctx, self.run_task(port, result))
|
497
|
+
except Exception as e:
|
498
|
+
traceback.print_exc()
|
499
|
+
s = ActorMeshRefCallFailedException(e)
|
500
|
+
|
501
|
+
# The exception is delivered to exactly one of:
|
502
|
+
# (1) our caller, (2) our supervisor
|
503
|
+
if port is not None:
|
504
|
+
port.send("exception", s)
|
505
|
+
else:
|
506
|
+
raise s from None
|
507
|
+
|
508
|
+
async def run_async(self, ctx, coroutine):
|
509
|
+
_context.set(ctx)
|
510
|
+
if self.complete_task is None:
|
511
|
+
asyncio.create_task(self._complete())
|
512
|
+
await self.active_requests.put(create_eager_task(coroutine))
|
513
|
+
|
514
|
+
async def run_task(self, port, coroutine):
|
515
|
+
try:
|
516
|
+
result = await coroutine
|
517
|
+
if port is not None:
|
518
|
+
port.send("result", result)
|
519
|
+
except Exception as e:
|
520
|
+
traceback.print_exc()
|
521
|
+
s = ActorMeshRefCallFailedException(e)
|
522
|
+
|
523
|
+
# The exception is delivered to exactly one of:
|
524
|
+
# (1) our caller, (2) our supervisor
|
525
|
+
if port is not None:
|
526
|
+
port.send("exception", s)
|
527
|
+
else:
|
528
|
+
raise s from None
|
529
|
+
|
530
|
+
async def _complete(self) -> None:
|
531
|
+
while True:
|
532
|
+
task = await self.active_requests.get()
|
533
|
+
await task
|
534
|
+
|
535
|
+
|
536
|
+
def _is_mailbox(x: object) -> bool:
|
537
|
+
return isinstance(x, Mailbox)
|
538
|
+
|
539
|
+
|
540
|
+
def _pickle(obj: object) -> bytes:
|
541
|
+
_, msg = flatten(obj, _is_mailbox)
|
542
|
+
return msg
|
543
|
+
|
544
|
+
|
545
|
+
def _unpickle(data: bytes, mailbox: Mailbox) -> Any:
|
546
|
+
# regardless of the mailboxes of the remote objects
|
547
|
+
# they all become the local mailbox.
|
548
|
+
return unflatten(data, itertools.repeat(mailbox))
|
549
|
+
|
550
|
+
|
551
|
+
class Actor(MeshTrait):
|
552
|
+
@property
|
553
|
+
def _ndslice(self) -> NDSlice:
|
554
|
+
raise NotImplementedError(
|
555
|
+
"actor implementations are not meshes, but we can't convince the typechecker of it..."
|
556
|
+
)
|
557
|
+
|
558
|
+
@property
|
559
|
+
def _labels(self) -> Tuple[str, ...]:
|
560
|
+
raise NotImplementedError(
|
561
|
+
"actor implementations are not meshes, but we can't convince the typechecker of it..."
|
562
|
+
)
|
563
|
+
|
564
|
+
def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
|
565
|
+
raise NotImplementedError(
|
566
|
+
"actor implementations are not meshes, but we can't convince the typechecker of it..."
|
567
|
+
)
|
568
|
+
|
569
|
+
|
570
|
+
class ActorMeshRef(MeshTrait):
|
571
|
+
def __init__(
|
572
|
+
self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
|
573
|
+
) -> None:
|
574
|
+
self.__name__ = Class.__name__
|
575
|
+
self._class = Class
|
576
|
+
self._actor_mesh_ref = actor_mesh_ref
|
577
|
+
self._mailbox = mailbox
|
578
|
+
for attr_name in dir(self._class):
|
579
|
+
attr_value = getattr(self._class, attr_name, None)
|
580
|
+
if isinstance(attr_value, EndpointProperty):
|
581
|
+
setattr(
|
582
|
+
self,
|
583
|
+
attr_name,
|
584
|
+
Endpoint(
|
585
|
+
self._actor_mesh_ref,
|
586
|
+
attr_name,
|
587
|
+
attr_value._method,
|
588
|
+
self._mailbox,
|
589
|
+
),
|
590
|
+
)
|
591
|
+
|
592
|
+
def __getattr__(self, name: str) -> Any:
|
593
|
+
# This method is called when an attribute is not found
|
594
|
+
# For linting purposes, we need to tell the type checker that any attribute
|
595
|
+
# could be an endpoint that's dynamically added at runtime
|
596
|
+
# At runtime, we still want to raise AttributeError for truly missing attributes
|
597
|
+
|
598
|
+
# Check if this is a method on the underlying class
|
599
|
+
if hasattr(self._class, name):
|
600
|
+
attr = getattr(self._class, name)
|
601
|
+
if isinstance(attr, EndpointProperty):
|
602
|
+
# Dynamically create the endpoint
|
603
|
+
endpoint = Endpoint(
|
604
|
+
self._actor_mesh_ref,
|
605
|
+
name,
|
606
|
+
attr._method,
|
607
|
+
self._mailbox,
|
608
|
+
)
|
609
|
+
# Cache it for future use
|
610
|
+
setattr(self, name, endpoint)
|
611
|
+
return endpoint
|
612
|
+
|
613
|
+
# If we get here, it's truly not found
|
614
|
+
raise AttributeError(
|
615
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
616
|
+
)
|
617
|
+
|
618
|
+
def _create(self, args: Iterable[Any], kwargs: Dict[str, Any]) -> None:
|
619
|
+
async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
|
620
|
+
return None
|
621
|
+
|
622
|
+
ep = Endpoint(
|
623
|
+
self._actor_mesh_ref,
|
624
|
+
"__init__",
|
625
|
+
null_func,
|
626
|
+
self._mailbox,
|
627
|
+
)
|
628
|
+
# pyre-ignore
|
629
|
+
send(ep, (self._class, *args), kwargs)
|
630
|
+
|
631
|
+
def __reduce_ex__(
|
632
|
+
self, protocol: ...
|
633
|
+
) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]":
|
634
|
+
return ActorMeshRef, (
|
635
|
+
self._class,
|
636
|
+
self._actor_mesh_ref,
|
637
|
+
self._mailbox,
|
638
|
+
)
|
639
|
+
|
640
|
+
@property
|
641
|
+
def _ndslice(self) -> NDSlice:
|
642
|
+
return self._actor_mesh_ref._shape.ndslice
|
643
|
+
|
644
|
+
@property
|
645
|
+
def _labels(self) -> Iterable[str]:
|
646
|
+
return self._actor_mesh_ref._shape.labels
|
647
|
+
|
648
|
+
def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
|
649
|
+
return ActorMeshRef(
|
650
|
+
self._class,
|
651
|
+
_ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape),
|
652
|
+
self._mailbox,
|
653
|
+
)
|
654
|
+
|
655
|
+
|
656
|
+
class ActorMeshRefCallFailedException(Exception):
|
657
|
+
"""
|
658
|
+
Deterministic problem with the user's code.
|
659
|
+
For example, an OOM resulting in trying to allocate too much GPU memory, or violating
|
660
|
+
some invariant enforced by the various APIs.
|
661
|
+
"""
|
662
|
+
|
663
|
+
def __init__(
|
664
|
+
self,
|
665
|
+
exception: Exception,
|
666
|
+
message: str = "A remote service call has failed asynchronously.",
|
667
|
+
) -> None:
|
668
|
+
self.exception = exception
|
669
|
+
self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
|
670
|
+
self.message = message
|
671
|
+
|
672
|
+
def __str__(self) -> str:
|
673
|
+
exe = str(self.exception)
|
674
|
+
actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
|
675
|
+
return (
|
676
|
+
f"{self.message}\n"
|
677
|
+
f"Traceback of where the service call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
|
678
|
+
)
|
679
|
+
|
680
|
+
|
681
|
+
def current_actor_name() -> str:
|
682
|
+
return str(MonarchContext.get().mailbox.actor_id)
|
683
|
+
|
684
|
+
|
685
|
+
def current_rank() -> Point:
|
686
|
+
ctx = MonarchContext.get()
|
687
|
+
return ctx.point
|
688
|
+
|
689
|
+
|
690
|
+
def current_size() -> Dict[str, int]:
|
691
|
+
ctx = MonarchContext.get()
|
692
|
+
return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes))
|