torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/actor_mesh.py
ADDED
@@ -0,0 +1,761 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import collections
|
10
|
+
import contextvars
|
11
|
+
import functools
|
12
|
+
import inspect
|
13
|
+
|
14
|
+
import itertools
|
15
|
+
import logging
|
16
|
+
import random
|
17
|
+
import sys
|
18
|
+
import traceback
|
19
|
+
|
20
|
+
from dataclasses import dataclass
|
21
|
+
from traceback import extract_tb, StackSummary
|
22
|
+
from typing import (
|
23
|
+
Any,
|
24
|
+
AsyncGenerator,
|
25
|
+
Awaitable,
|
26
|
+
Callable,
|
27
|
+
cast,
|
28
|
+
Concatenate,
|
29
|
+
Dict,
|
30
|
+
Generic,
|
31
|
+
Iterable,
|
32
|
+
List,
|
33
|
+
Literal,
|
34
|
+
Optional,
|
35
|
+
ParamSpec,
|
36
|
+
Tuple,
|
37
|
+
Type,
|
38
|
+
TYPE_CHECKING,
|
39
|
+
TypeVar,
|
40
|
+
)
|
41
|
+
|
42
|
+
import monarch
|
43
|
+
from monarch import ActorFuture as Future
|
44
|
+
from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span
|
45
|
+
|
46
|
+
from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage
|
47
|
+
from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
|
48
|
+
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
|
49
|
+
Mailbox,
|
50
|
+
OncePortReceiver,
|
51
|
+
OncePortRef,
|
52
|
+
PortReceiver as HyPortReceiver,
|
53
|
+
PortRef,
|
54
|
+
)
|
55
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
56
|
+
from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
|
57
|
+
|
58
|
+
from monarch.common.pickle_flatten import flatten, unflatten
|
59
|
+
from monarch.common.shape import MeshTrait, NDSlice
|
60
|
+
from monarch.pdb_wrapper import remote_breakpointhook
|
61
|
+
|
62
|
+
if TYPE_CHECKING:
|
63
|
+
from monarch.debugger import DebugClient
|
64
|
+
|
65
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
66
|
+
|
67
|
+
Allocator = monarch.ProcessAllocator | monarch.LocalAllocator
|
68
|
+
|
69
|
+
try:
|
70
|
+
from __manifest__ import fbmake # noqa
|
71
|
+
|
72
|
+
IN_PAR = True
|
73
|
+
except ImportError:
|
74
|
+
IN_PAR = False
|
75
|
+
|
76
|
+
T1 = TypeVar("T1")
|
77
|
+
T2 = TypeVar("T2")
|
78
|
+
|
79
|
+
|
80
|
+
class Point(HyPoint, collections.abc.Mapping):
|
81
|
+
pass
|
82
|
+
|
83
|
+
|
84
|
+
@dataclass
|
85
|
+
class MonarchContext:
|
86
|
+
mailbox: Mailbox
|
87
|
+
proc_id: str
|
88
|
+
point: Point
|
89
|
+
|
90
|
+
@staticmethod
|
91
|
+
def get() -> "MonarchContext":
|
92
|
+
return _context.get()
|
93
|
+
|
94
|
+
|
95
|
+
_context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
|
96
|
+
"monarch.actor_mesh._context"
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
T = TypeVar("T")
|
101
|
+
P = ParamSpec("P")
|
102
|
+
R = TypeVar("R")
|
103
|
+
A = TypeVar("A")
|
104
|
+
|
105
|
+
# keep this load balancing deterministic, but
|
106
|
+
# equally distributed.
|
107
|
+
_load_balancing_seed = random.Random(4)
|
108
|
+
|
109
|
+
|
110
|
+
Selection = Literal["all", "choose"] # TODO: replace with real selection objects
|
111
|
+
|
112
|
+
|
113
|
+
# standin class for whatever is the serializable python object we use
|
114
|
+
# to name an actor mesh. Hacked up today because ActorMesh
|
115
|
+
# isn't plumbed to non-clients
|
116
|
+
class _ActorMeshRefImpl:
|
117
|
+
def __init__(
|
118
|
+
self,
|
119
|
+
mailbox: Mailbox,
|
120
|
+
hy_actor_mesh: Optional[PythonActorMesh],
|
121
|
+
shape: Shape,
|
122
|
+
actor_ids: List[ActorId],
|
123
|
+
) -> None:
|
124
|
+
self._mailbox = mailbox
|
125
|
+
self._actor_mesh = hy_actor_mesh
|
126
|
+
self._shape = shape
|
127
|
+
self._please_replace_me_actor_ids = actor_ids
|
128
|
+
|
129
|
+
@staticmethod
|
130
|
+
def from_hyperactor_mesh(
|
131
|
+
mailbox: Mailbox, hy_actor_mesh: PythonActorMesh
|
132
|
+
) -> "_ActorMeshRefImpl":
|
133
|
+
shape: Shape = hy_actor_mesh.shape
|
134
|
+
return _ActorMeshRefImpl(
|
135
|
+
mailbox,
|
136
|
+
hy_actor_mesh,
|
137
|
+
hy_actor_mesh.shape,
|
138
|
+
[cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))],
|
139
|
+
)
|
140
|
+
|
141
|
+
@staticmethod
|
142
|
+
def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl":
|
143
|
+
return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id])
|
144
|
+
|
145
|
+
@staticmethod
|
146
|
+
def from_actor_ref_with_shape(
|
147
|
+
ref: "_ActorMeshRefImpl", shape: Shape
|
148
|
+
) -> "_ActorMeshRefImpl":
|
149
|
+
return _ActorMeshRefImpl(
|
150
|
+
ref._mailbox, None, shape, ref._please_replace_me_actor_ids
|
151
|
+
)
|
152
|
+
|
153
|
+
def __getstate__(
|
154
|
+
self,
|
155
|
+
) -> Tuple[Shape, List[ActorId], Mailbox]:
|
156
|
+
return self._shape, self._please_replace_me_actor_ids, self._mailbox
|
157
|
+
|
158
|
+
def __setstate__(
|
159
|
+
self,
|
160
|
+
state: Tuple[Shape, List[ActorId], Mailbox],
|
161
|
+
) -> None:
|
162
|
+
self._actor_mesh = None
|
163
|
+
self._shape, self._please_replace_me_actor_ids, self._mailbox = state
|
164
|
+
|
165
|
+
def send(self, rank: int, message: PythonMessage) -> None:
|
166
|
+
actor = self._please_replace_me_actor_ids[rank]
|
167
|
+
self._mailbox.post(actor, message)
|
168
|
+
|
169
|
+
def cast(
|
170
|
+
self,
|
171
|
+
message: PythonMessage,
|
172
|
+
selection: Selection,
|
173
|
+
) -> None:
|
174
|
+
# TODO: use the actual actor mesh when available. We cannot currently use it
|
175
|
+
# directly because we risk bifurcating the message delivery paths from the same
|
176
|
+
# client, since slicing the mesh will produce a reference, which calls actors
|
177
|
+
# directly. The reason these paths are bifurcated is that actor meshes will
|
178
|
+
# use multicasting, while direct actor comms do not. Separately we need to decide
|
179
|
+
# whether actor meshes are ordered with actor references.
|
180
|
+
#
|
181
|
+
# The fix is to provide a first-class reference into Python, and always call "cast"
|
182
|
+
# on it, including for load balanced requests.
|
183
|
+
if selection == "choose":
|
184
|
+
idx = _load_balancing_seed.randrange(len(self._shape))
|
185
|
+
actor_rank = self._shape.ndslice[idx]
|
186
|
+
self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message)
|
187
|
+
return
|
188
|
+
elif selection == "all":
|
189
|
+
# replace me with actual remote actor mesh
|
190
|
+
call_shape = Shape(
|
191
|
+
self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes)
|
192
|
+
)
|
193
|
+
for i, rank in enumerate(self._shape.ranks()):
|
194
|
+
self._mailbox.post_cast(
|
195
|
+
self._please_replace_me_actor_ids[rank],
|
196
|
+
i,
|
197
|
+
call_shape,
|
198
|
+
message,
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
raise ValueError(f"invalid selection: {selection}")
|
202
|
+
|
203
|
+
def __len__(self) -> int:
|
204
|
+
return len(self._shape)
|
205
|
+
|
206
|
+
|
207
|
+
class Endpoint(Generic[P, R]):
|
208
|
+
def __init__(
|
209
|
+
self,
|
210
|
+
actor_mesh_ref: _ActorMeshRefImpl,
|
211
|
+
name: str,
|
212
|
+
impl: Callable[Concatenate[Any, P], Awaitable[R]],
|
213
|
+
mailbox: Mailbox,
|
214
|
+
) -> None:
|
215
|
+
self._actor_mesh = actor_mesh_ref
|
216
|
+
self._name = name
|
217
|
+
self._signature: inspect.Signature = inspect.signature(impl)
|
218
|
+
self._mailbox = mailbox
|
219
|
+
|
220
|
+
# the following are all 'adverbs' or different ways to handle the
|
221
|
+
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
|
222
|
+
# of the original call. If we want to add syntax sugar for something that needs additional
|
223
|
+
# arguments, it should be implemented as function indepdendent of endpoint like `send`
|
224
|
+
# and `Accumulator`
|
225
|
+
def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
|
226
|
+
"""
|
227
|
+
Load balanced sends a message to one chosen actor and awaits a result.
|
228
|
+
|
229
|
+
Load balanced RPC-style entrypoint for request/response messaging.
|
230
|
+
"""
|
231
|
+
p: Port[R]
|
232
|
+
r: PortReceiver[R]
|
233
|
+
p, r = port(self, once=True)
|
234
|
+
# pyre-ignore
|
235
|
+
send(self, args, kwargs, port=p, selection="choose")
|
236
|
+
return r.recv()
|
237
|
+
|
238
|
+
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
|
239
|
+
if len(self._actor_mesh) != 1:
|
240
|
+
raise ValueError(
|
241
|
+
f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}"
|
242
|
+
)
|
243
|
+
return self.choose(*args, **kwargs)
|
244
|
+
|
245
|
+
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
|
246
|
+
p: Port[R]
|
247
|
+
r: RankedPortReceiver[R]
|
248
|
+
p, r = ranked_port(self)
|
249
|
+
# pyre-ignore
|
250
|
+
send(self, args, kwargs, port=p)
|
251
|
+
|
252
|
+
async def process() -> ValueMesh[R]:
|
253
|
+
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
|
254
|
+
for _ in range(len(self._actor_mesh)):
|
255
|
+
rank, value = await r.recv()
|
256
|
+
results[rank] = value
|
257
|
+
call_shape = Shape(
|
258
|
+
self._actor_mesh._shape.labels,
|
259
|
+
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
|
260
|
+
)
|
261
|
+
return ValueMesh(call_shape, results)
|
262
|
+
|
263
|
+
def process_blocking() -> ValueMesh[R]:
|
264
|
+
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
|
265
|
+
for _ in range(len(self._actor_mesh)):
|
266
|
+
rank, value = r.recv().get()
|
267
|
+
results[rank] = value
|
268
|
+
call_shape = Shape(
|
269
|
+
self._actor_mesh._shape.labels,
|
270
|
+
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
|
271
|
+
)
|
272
|
+
return ValueMesh(call_shape, results)
|
273
|
+
|
274
|
+
return Future(process, process_blocking)
|
275
|
+
|
276
|
+
async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
|
277
|
+
"""
|
278
|
+
Broadcasts to all actors and yields their responses as a stream / generator.
|
279
|
+
|
280
|
+
This enables processing results from multiple actors incrementally as
|
281
|
+
they become available. Returns an async generator of response values.
|
282
|
+
"""
|
283
|
+
p, r = port(self)
|
284
|
+
# pyre-ignore
|
285
|
+
send(self, args, kwargs, port=p)
|
286
|
+
for _ in range(len(self._actor_mesh)):
|
287
|
+
yield await r.recv()
|
288
|
+
|
289
|
+
def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
290
|
+
"""
|
291
|
+
Fire-and-forget broadcast to all actors without waiting for actors to
|
292
|
+
acknowledge receipt.
|
293
|
+
|
294
|
+
In other words, the return of this method does not guarrantee the
|
295
|
+
delivery of the message.
|
296
|
+
"""
|
297
|
+
# pyre-ignore
|
298
|
+
send(self, args, kwargs)
|
299
|
+
|
300
|
+
|
301
|
+
class Accumulator(Generic[P, R, A]):
|
302
|
+
def __init__(
|
303
|
+
self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
|
304
|
+
) -> None:
|
305
|
+
self._endpoint: Endpoint[P, R] = endpoint
|
306
|
+
self._identity: A = identity
|
307
|
+
self._combine: Callable[[A, R], A] = combine
|
308
|
+
|
309
|
+
def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
|
310
|
+
gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs)
|
311
|
+
|
312
|
+
async def impl() -> A:
|
313
|
+
value = self._identity
|
314
|
+
async for x in gen:
|
315
|
+
value = self._combine(value, x)
|
316
|
+
return value
|
317
|
+
|
318
|
+
return Future(impl)
|
319
|
+
|
320
|
+
|
321
|
+
class ValueMesh(MeshTrait, Generic[R]):
|
322
|
+
"""
|
323
|
+
Container of return values, indexed by rank.
|
324
|
+
"""
|
325
|
+
|
326
|
+
def __init__(self, shape: Shape, values: List[R]) -> None:
|
327
|
+
self._shape = shape
|
328
|
+
self._values = values
|
329
|
+
|
330
|
+
def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
|
331
|
+
return ValueMesh(shape, self._values)
|
332
|
+
|
333
|
+
def item(self, **kwargs) -> R:
|
334
|
+
coordinates = [kwargs.pop(label) for label in self._labels]
|
335
|
+
if kwargs:
|
336
|
+
raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
|
337
|
+
|
338
|
+
return self._values[self._ndslice.nditem(coordinates)]
|
339
|
+
|
340
|
+
def __iter__(self):
|
341
|
+
for rank in self._shape.ranks():
|
342
|
+
yield Point(rank, self._shape), self._values[rank]
|
343
|
+
|
344
|
+
def __len__(self) -> int:
|
345
|
+
return len(self._shape)
|
346
|
+
|
347
|
+
def __repr__(self) -> str:
|
348
|
+
return f"ValueMesh({self._shape})"
|
349
|
+
|
350
|
+
@property
|
351
|
+
def _ndslice(self) -> NDSlice:
|
352
|
+
return self._shape.ndslice
|
353
|
+
|
354
|
+
@property
|
355
|
+
def _labels(self) -> Iterable[str]:
|
356
|
+
return self._shape.labels
|
357
|
+
|
358
|
+
|
359
|
+
def send(
|
360
|
+
endpoint: Endpoint[P, R],
|
361
|
+
args: Tuple[Any, ...],
|
362
|
+
kwargs: Dict[str, Any],
|
363
|
+
port: "Optional[Port]" = None,
|
364
|
+
selection: Selection = "all",
|
365
|
+
) -> None:
|
366
|
+
"""
|
367
|
+
Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
|
368
|
+
|
369
|
+
This sends the message to all actors but does not wait for any result.
|
370
|
+
"""
|
371
|
+
endpoint._signature.bind(None, *args, **kwargs)
|
372
|
+
message = PythonMessage(
|
373
|
+
endpoint._name,
|
374
|
+
_pickle((args, kwargs)),
|
375
|
+
None if port is None else port._port_ref,
|
376
|
+
None,
|
377
|
+
)
|
378
|
+
endpoint._actor_mesh.cast(message, selection)
|
379
|
+
|
380
|
+
|
381
|
+
class EndpointProperty(Generic[P, R]):
|
382
|
+
def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None:
|
383
|
+
self._method = method
|
384
|
+
|
385
|
+
def __get__(self, instance, owner) -> Endpoint[P, R]:
|
386
|
+
# this is a total lie, but we have to actually
|
387
|
+
# recognize this was defined as an endpoint,
|
388
|
+
# and also lookup the method
|
389
|
+
return cast(Endpoint[P, R], self)
|
390
|
+
|
391
|
+
|
392
|
+
def endpoint(
|
393
|
+
method: Callable[Concatenate[Any, P], Awaitable[R]],
|
394
|
+
) -> EndpointProperty[P, R]:
|
395
|
+
return EndpointProperty(method)
|
396
|
+
|
397
|
+
|
398
|
+
class Port(Generic[R]):
|
399
|
+
def __init__(
|
400
|
+
self, port_ref: PortRef | OncePortRef, mailbox: Mailbox, rank: Optional[int]
|
401
|
+
) -> None:
|
402
|
+
self._port_ref = port_ref
|
403
|
+
self._mailbox = mailbox
|
404
|
+
self._rank = rank
|
405
|
+
|
406
|
+
def send(self, method: str, obj: R) -> None:
|
407
|
+
self._port_ref.send(
|
408
|
+
self._mailbox,
|
409
|
+
PythonMessage(method, _pickle(obj), None, self._rank),
|
410
|
+
)
|
411
|
+
|
412
|
+
|
413
|
+
# advance lower-level API for sending messages. This is intentially
|
414
|
+
# not part of the Endpoint API because they way it accepts arguments
|
415
|
+
# and handles concerns is different.
|
416
|
+
def port(
|
417
|
+
endpoint: Endpoint[P, R], once: bool = False
|
418
|
+
) -> Tuple["Port[R]", "PortReceiver[R]"]:
|
419
|
+
handle, receiver = (
|
420
|
+
endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
|
421
|
+
)
|
422
|
+
port_ref: PortRef | OncePortRef = handle.bind()
|
423
|
+
return Port(port_ref, endpoint._mailbox, rank=None), PortReceiver(
|
424
|
+
endpoint._mailbox, receiver
|
425
|
+
)
|
426
|
+
|
427
|
+
|
428
|
+
def ranked_port(
|
429
|
+
endpoint: Endpoint[P, R], once: bool = False
|
430
|
+
) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
|
431
|
+
p, receiver = port(endpoint, once)
|
432
|
+
return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
|
433
|
+
|
434
|
+
|
435
|
+
class PortReceiver(Generic[R]):
|
436
|
+
def __init__(
|
437
|
+
self,
|
438
|
+
mailbox: Mailbox,
|
439
|
+
receiver: HyPortReceiver | OncePortReceiver,
|
440
|
+
) -> None:
|
441
|
+
self._mailbox: Mailbox = mailbox
|
442
|
+
self._receiver: HyPortReceiver | OncePortReceiver = receiver
|
443
|
+
|
444
|
+
async def _recv(self) -> R:
|
445
|
+
return self._process(await self._receiver.recv())
|
446
|
+
|
447
|
+
def _blocking_recv(self) -> R:
|
448
|
+
return self._process(self._receiver.blocking_recv())
|
449
|
+
|
450
|
+
def _process(self, msg: PythonMessage) -> R:
|
451
|
+
# TODO: Try to do something more structured than a cast here
|
452
|
+
payload = cast(R, _unpickle(msg.message, self._mailbox))
|
453
|
+
if msg.method == "result":
|
454
|
+
return payload
|
455
|
+
else:
|
456
|
+
assert msg.method == "exception"
|
457
|
+
# pyre-ignore
|
458
|
+
raise payload
|
459
|
+
|
460
|
+
def recv(self) -> "Future[R]":
|
461
|
+
return Future(lambda: self._recv(), self._blocking_recv)
|
462
|
+
|
463
|
+
|
464
|
+
class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
|
465
|
+
def _process(self, msg: PythonMessage) -> Tuple[int, R]:
|
466
|
+
if msg.rank is None:
|
467
|
+
raise ValueError("RankedPort receiver got a message without a rank")
|
468
|
+
return msg.rank, super()._process(msg)
|
469
|
+
|
470
|
+
|
471
|
+
singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
|
472
|
+
|
473
|
+
|
474
|
+
class _Actor:
|
475
|
+
"""
|
476
|
+
This is the message handling implementation of a Python actor.
|
477
|
+
|
478
|
+
The layering goes:
|
479
|
+
Rust `PythonActor` -> `_Actor` -> user-provided `Actor` instance
|
480
|
+
|
481
|
+
Messages are received from the Rust backend, and forwarded to the `handle`
|
482
|
+
methods on this class.
|
483
|
+
|
484
|
+
This class wraps the actual `Actor` instance provided by the user, and
|
485
|
+
routes messages to it, managing argument serialization/deserialization and
|
486
|
+
error handling.
|
487
|
+
"""
|
488
|
+
|
489
|
+
def __init__(self) -> None:
|
490
|
+
self.instance: object | None = None
|
491
|
+
|
492
|
+
async def handle(
|
493
|
+
self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
|
494
|
+
) -> None:
|
495
|
+
return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag)
|
496
|
+
|
497
|
+
async def handle_cast(
|
498
|
+
self,
|
499
|
+
mailbox: Mailbox,
|
500
|
+
rank: int,
|
501
|
+
shape: Shape,
|
502
|
+
message: PythonMessage,
|
503
|
+
panic_flag: PanicFlag,
|
504
|
+
) -> None:
|
505
|
+
port = (
|
506
|
+
Port(message.response_port, mailbox, rank)
|
507
|
+
if message.response_port
|
508
|
+
else None
|
509
|
+
)
|
510
|
+
try:
|
511
|
+
ctx: MonarchContext = MonarchContext(
|
512
|
+
mailbox, mailbox.actor_id.proc_id, Point(rank, shape)
|
513
|
+
)
|
514
|
+
_context.set(ctx)
|
515
|
+
|
516
|
+
args, kwargs = _unpickle(message.message, mailbox)
|
517
|
+
|
518
|
+
if message.method == "__init__":
|
519
|
+
Class, *args = args
|
520
|
+
self.instance = Class(*args, **kwargs)
|
521
|
+
return None
|
522
|
+
|
523
|
+
if self.instance is None:
|
524
|
+
raise AssertionError(
|
525
|
+
"__init__ failed earlier and no Actor object is available"
|
526
|
+
)
|
527
|
+
the_method = getattr(self.instance, message.method)._method
|
528
|
+
|
529
|
+
if inspect.iscoroutinefunction(the_method):
|
530
|
+
|
531
|
+
async def instrumented():
|
532
|
+
enter_span(
|
533
|
+
the_method.__module__,
|
534
|
+
message.method,
|
535
|
+
str(ctx.mailbox.actor_id),
|
536
|
+
)
|
537
|
+
try:
|
538
|
+
result = await the_method(self.instance, *args, **kwargs)
|
539
|
+
except Exception as e:
|
540
|
+
logging.critical(
|
541
|
+
"Unahndled exception in actor endpoint",
|
542
|
+
exc_info=e,
|
543
|
+
)
|
544
|
+
raise e
|
545
|
+
exit_span()
|
546
|
+
return result
|
547
|
+
|
548
|
+
result = await instrumented()
|
549
|
+
else:
|
550
|
+
enter_span(
|
551
|
+
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
|
552
|
+
)
|
553
|
+
result = the_method(self.instance, *args, **kwargs)
|
554
|
+
exit_span()
|
555
|
+
|
556
|
+
if port is not None:
|
557
|
+
port.send("result", result)
|
558
|
+
except Exception as e:
|
559
|
+
traceback.print_exc()
|
560
|
+
s = ActorError(e)
|
561
|
+
|
562
|
+
# The exception is delivered to exactly one of:
|
563
|
+
# (1) our caller, (2) our supervisor
|
564
|
+
if port is not None:
|
565
|
+
port.send("exception", s)
|
566
|
+
else:
|
567
|
+
raise s from None
|
568
|
+
except BaseException as e:
|
569
|
+
# A BaseException can be thrown in the case of a Rust panic.
|
570
|
+
# In this case, we need a way to signal the panic to the Rust side.
|
571
|
+
# See [Panics in async endpoints]
|
572
|
+
try:
|
573
|
+
panic_flag.signal_panic(e)
|
574
|
+
except Exception:
|
575
|
+
# The channel might be closed if the Rust side has already detected the error
|
576
|
+
pass
|
577
|
+
raise
|
578
|
+
|
579
|
+
|
580
|
+
def _is_mailbox(x: object) -> bool:
|
581
|
+
return isinstance(x, Mailbox)
|
582
|
+
|
583
|
+
|
584
|
+
def _pickle(obj: object) -> bytes:
|
585
|
+
_, msg = flatten(obj, _is_mailbox)
|
586
|
+
return msg
|
587
|
+
|
588
|
+
|
589
|
+
def _unpickle(data: bytes, mailbox: Mailbox) -> Any:
|
590
|
+
# regardless of the mailboxes of the remote objects
|
591
|
+
# they all become the local mailbox.
|
592
|
+
return unflatten(data, itertools.repeat(mailbox))
|
593
|
+
|
594
|
+
|
595
|
+
class Actor(MeshTrait):
|
596
|
+
@functools.cached_property
|
597
|
+
def logger(cls) -> logging.Logger:
|
598
|
+
lgr = logging.getLogger(cls.__class__.__name__)
|
599
|
+
lgr.setLevel(logging.DEBUG)
|
600
|
+
return lgr
|
601
|
+
|
602
|
+
@property
|
603
|
+
def _ndslice(self) -> NDSlice:
|
604
|
+
raise NotImplementedError(
|
605
|
+
"actor implementations are not meshes, but we can't convince the typechecker of it..."
|
606
|
+
)
|
607
|
+
|
608
|
+
@property
|
609
|
+
def _labels(self) -> Tuple[str, ...]:
|
610
|
+
raise NotImplementedError(
|
611
|
+
"actor implementations are not meshes, but we can't convince the typechecker of it..."
|
612
|
+
)
|
613
|
+
|
614
|
+
def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
|
615
|
+
raise NotImplementedError(
|
616
|
+
"actor implementations are not meshes, but we can't convince the typechecker of it..."
|
617
|
+
)
|
618
|
+
|
619
|
+
@endpoint # pyre-ignore
|
620
|
+
def _set_debug_client(self, client: "DebugClient") -> None:
|
621
|
+
point = MonarchContext.get().point
|
622
|
+
# For some reason, using a lambda instead of functools.partial
|
623
|
+
# confuses the pdb wrapper implementation.
|
624
|
+
sys.breakpointhook = functools.partial( # pyre-ignore
|
625
|
+
remote_breakpointhook,
|
626
|
+
point.rank,
|
627
|
+
point.shape.coordinates(point.rank),
|
628
|
+
MonarchContext.get().mailbox.actor_id,
|
629
|
+
client,
|
630
|
+
)
|
631
|
+
|
632
|
+
|
633
|
+
class ActorMeshRef(MeshTrait, Generic[T]):
|
634
|
+
def __init__(
|
635
|
+
self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
|
636
|
+
) -> None:
|
637
|
+
self.__name__: str = Class.__name__
|
638
|
+
self._class: Type[T] = Class
|
639
|
+
self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref
|
640
|
+
self._mailbox: Mailbox = mailbox
|
641
|
+
for attr_name in dir(self._class):
|
642
|
+
attr_value = getattr(self._class, attr_name, None)
|
643
|
+
if isinstance(attr_value, EndpointProperty):
|
644
|
+
setattr(
|
645
|
+
self,
|
646
|
+
attr_name,
|
647
|
+
Endpoint(
|
648
|
+
self._actor_mesh_ref,
|
649
|
+
attr_name,
|
650
|
+
attr_value._method,
|
651
|
+
self._mailbox,
|
652
|
+
),
|
653
|
+
)
|
654
|
+
|
655
|
+
def __getattr__(self, name: str) -> Any:
|
656
|
+
# This method is called when an attribute is not found
|
657
|
+
# For linting purposes, we need to tell the type checker that any attribute
|
658
|
+
# could be an endpoint that's dynamically added at runtime
|
659
|
+
# At runtime, we still want to raise AttributeError for truly missing attributes
|
660
|
+
|
661
|
+
# Check if this is a method on the underlying class
|
662
|
+
if hasattr(self._class, name):
|
663
|
+
attr = getattr(self._class, name)
|
664
|
+
if isinstance(attr, EndpointProperty):
|
665
|
+
# Dynamically create the endpoint
|
666
|
+
endpoint = Endpoint(
|
667
|
+
self._actor_mesh_ref,
|
668
|
+
name,
|
669
|
+
attr._method,
|
670
|
+
self._mailbox,
|
671
|
+
)
|
672
|
+
# Cache it for future use
|
673
|
+
setattr(self, name, endpoint)
|
674
|
+
return endpoint
|
675
|
+
|
676
|
+
# If we get here, it's truly not found
|
677
|
+
raise AttributeError(
|
678
|
+
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
679
|
+
)
|
680
|
+
|
681
|
+
def _create(
|
682
|
+
self,
|
683
|
+
args: Iterable[Any],
|
684
|
+
kwargs: Dict[str, Any],
|
685
|
+
) -> None:
|
686
|
+
async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
|
687
|
+
return None
|
688
|
+
|
689
|
+
ep = Endpoint(
|
690
|
+
self._actor_mesh_ref,
|
691
|
+
"__init__",
|
692
|
+
null_func,
|
693
|
+
self._mailbox,
|
694
|
+
)
|
695
|
+
send(ep, (self._class, *args), kwargs)
|
696
|
+
|
697
|
+
def __reduce_ex__(
|
698
|
+
self, protocol: ...
|
699
|
+
) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]":
|
700
|
+
return ActorMeshRef, (
|
701
|
+
self._class,
|
702
|
+
self._actor_mesh_ref,
|
703
|
+
self._mailbox,
|
704
|
+
)
|
705
|
+
|
706
|
+
@property
|
707
|
+
def _ndslice(self) -> NDSlice:
|
708
|
+
return self._actor_mesh_ref._shape.ndslice
|
709
|
+
|
710
|
+
@property
|
711
|
+
def _labels(self) -> Iterable[str]:
|
712
|
+
return self._actor_mesh_ref._shape.labels
|
713
|
+
|
714
|
+
def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
|
715
|
+
return ActorMeshRef(
|
716
|
+
self._class,
|
717
|
+
_ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape),
|
718
|
+
self._mailbox,
|
719
|
+
)
|
720
|
+
|
721
|
+
def __repr__(self) -> str:
|
722
|
+
return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})"
|
723
|
+
|
724
|
+
|
725
|
+
class ActorError(Exception):
|
726
|
+
"""
|
727
|
+
Deterministic problem with the user's code.
|
728
|
+
For example, an OOM resulting in trying to allocate too much GPU memory, or violating
|
729
|
+
some invariant enforced by the various APIs.
|
730
|
+
"""
|
731
|
+
|
732
|
+
def __init__(
|
733
|
+
self,
|
734
|
+
exception: Exception,
|
735
|
+
message: str = "A remote actor call has failed asynchronously.",
|
736
|
+
) -> None:
|
737
|
+
self.exception = exception
|
738
|
+
self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
|
739
|
+
self.message = message
|
740
|
+
|
741
|
+
def __str__(self) -> str:
|
742
|
+
exe = str(self.exception)
|
743
|
+
actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
|
744
|
+
return (
|
745
|
+
f"{self.message}\n"
|
746
|
+
f"Traceback of where the remote call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
|
747
|
+
)
|
748
|
+
|
749
|
+
|
750
|
+
def current_actor_name() -> str:
|
751
|
+
return str(MonarchContext.get().mailbox.actor_id)
|
752
|
+
|
753
|
+
|
754
|
+
def current_rank() -> Point:
|
755
|
+
ctx = MonarchContext.get()
|
756
|
+
return ctx.point
|
757
|
+
|
758
|
+
|
759
|
+
def current_size() -> Dict[str, int]:
|
760
|
+
ctx = MonarchContext.get()
|
761
|
+
return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes))
|