torchmonarch-nightly 2025.6.30__cp310-cp310-manylinux2014_x86_64.whl → 2025.7.25__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +13 -9
- monarch/_rust_bindings.so +0 -0
- monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
- monarch/_src/actor/actor_mesh.py +874 -0
- monarch/{allocator.py → _src/actor/allocator.py} +26 -17
- monarch/_src/actor/bootstrap_main.py +73 -0
- monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
- monarch/_src/actor/code_sync/auto_reload.py +223 -0
- monarch/_src/actor/debugger.py +565 -0
- monarch/_src/actor/endpoint.py +270 -0
- monarch/_src/actor/event_loop.py +97 -0
- monarch/_src/actor/future.py +100 -0
- monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
- monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
- monarch/_src/actor/proc_mesh.py +500 -0
- monarch/_src/actor/sync_state.py +18 -0
- monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
- monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
- monarch/_src/actor/tensor_engine_shim.py +56 -0
- monarch/_src/tensor_engine/rdma.py +180 -0
- monarch/_testing.py +3 -2
- monarch/actor/__init__.py +51 -0
- monarch/actor_mesh.py +6 -752
- monarch/bootstrap_main.py +8 -47
- monarch/common/client.py +1 -1
- monarch/common/controller_api.py +2 -1
- monarch/common/device_mesh.py +12 -2
- monarch/common/messages.py +12 -1
- monarch/common/recording.py +4 -3
- monarch/common/remote.py +135 -52
- monarch/common/tensor.py +2 -1
- monarch/controller/backend.py +2 -2
- monarch/controller/controller.py +2 -1
- monarch/controller/rust_backend/controller.py +2 -1
- monarch/fetch.py +3 -5
- monarch/mesh_controller.py +201 -139
- monarch/monarch_controller +0 -0
- monarch/opaque_module.py +4 -6
- monarch/opaque_object.py +3 -3
- monarch/proc_mesh.py +6 -309
- monarch/python_local_mesh.py +1 -1
- monarch/rust_backend_mesh.py +2 -1
- monarch/rust_local_mesh.py +4 -2
- monarch/sim_mesh.py +10 -19
- monarch/simulator/command_history.py +1 -1
- monarch/simulator/interface.py +2 -1
- monarch/simulator/mock_controller.py +1 -1
- monarch/simulator/simulator.py +1 -1
- monarch/tensor_engine/__init__.py +23 -0
- monarch/tensor_worker_main.py +3 -1
- monarch/tools/cli.py +3 -1
- monarch/tools/commands.py +95 -35
- monarch/tools/mesh_spec.py +55 -0
- monarch/tools/utils.py +38 -0
- monarch/worker/worker.py +1 -1
- monarch/world_mesh.py +2 -1
- monarch_supervisor/python_executable.py +6 -3
- tests/error_test_binary.py +75 -9
- tests/test_actor_error.py +370 -21
- tests/test_alloc.py +1 -1
- tests/test_allocator.py +373 -17
- tests/test_controller.py +2 -0
- tests/test_debugger.py +416 -0
- tests/test_env_before_cuda.py +162 -0
- tests/test_python_actors.py +184 -332
- tests/test_rdma.py +198 -0
- tests/test_remote_functions.py +40 -12
- tests/test_rust_backend.py +7 -5
- tests/test_sim_backend.py +1 -4
- tests/test_tensor_engine.py +55 -1
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
- torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
- monarch/_monarch/hyperactor/__init__.py +0 -58
- monarch/_monarch/worker/debugger.py +0 -117
- monarch/_monarch/worker/logging.py +0 -107
- monarch/debugger.py +0 -379
- monarch/future.py +0 -76
- monarch/rdma.py +0 -162
- torchmonarch_nightly-2025.6.30.dist-info/entry_points.txt +0 -3
- /monarch/{_monarch/worker → _src}/__init__.py +0 -0
- /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
- /monarch/{common → _src/actor}/shape.py +0 -0
- /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
|
9
|
+
import functools
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
from operator import mul
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
AsyncGenerator,
|
15
|
+
Awaitable,
|
16
|
+
Callable,
|
17
|
+
cast,
|
18
|
+
Concatenate,
|
19
|
+
Dict,
|
20
|
+
Generic,
|
21
|
+
List,
|
22
|
+
Literal,
|
23
|
+
Optional,
|
24
|
+
overload,
|
25
|
+
ParamSpec,
|
26
|
+
Sequence,
|
27
|
+
Tuple,
|
28
|
+
TYPE_CHECKING,
|
29
|
+
TypeVar,
|
30
|
+
)
|
31
|
+
|
32
|
+
from monarch._src.actor.future import Future
|
33
|
+
from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
|
34
|
+
|
35
|
+
if TYPE_CHECKING:
|
36
|
+
from monarch._src.actor.actor_mesh import (
|
37
|
+
HyPortReceiver,
|
38
|
+
OncePortReceiver,
|
39
|
+
Port,
|
40
|
+
PortTuple,
|
41
|
+
ValueMesh,
|
42
|
+
)
|
43
|
+
|
44
|
+
P = ParamSpec("P")
|
45
|
+
R = TypeVar("R")
|
46
|
+
|
47
|
+
Selection = Literal["all", "choose"] | int
|
48
|
+
|
49
|
+
|
50
|
+
class Extent:
|
51
|
+
def __init__(self, labels: Sequence[str], sizes: Sequence[int]) -> None:
|
52
|
+
self.labels = labels
|
53
|
+
self.sizes = sizes
|
54
|
+
|
55
|
+
@property
|
56
|
+
def nelements(self) -> int:
|
57
|
+
return functools.reduce(mul, self.sizes, 1)
|
58
|
+
|
59
|
+
def __str__(self) -> str:
|
60
|
+
return str(dict(zip(self.labels, self.sizes)))
|
61
|
+
|
62
|
+
|
63
|
+
Propagator = Any
|
64
|
+
|
65
|
+
|
66
|
+
class Endpoint(ABC, Generic[P, R]):
|
67
|
+
def __init__(self, propagator: Propagator) -> None:
|
68
|
+
self._propagator_arg = propagator
|
69
|
+
self._cache: Optional[dict] = None
|
70
|
+
|
71
|
+
@abstractmethod
|
72
|
+
def _send(
|
73
|
+
self,
|
74
|
+
args: Tuple[Any, ...],
|
75
|
+
kwargs: Dict[str, Any],
|
76
|
+
port: "Optional[Port]" = None,
|
77
|
+
selection: Selection = "all",
|
78
|
+
) -> Extent:
|
79
|
+
"""
|
80
|
+
Implements sending a message to the endpoint. The return value of the endpoint will
|
81
|
+
be sent to port if provided. If port is not provided, the return will be dropped,
|
82
|
+
and any exception will cause the actor to fail.
|
83
|
+
|
84
|
+
The return value is the (multi-dimension) size of the actors that were sent a message.
|
85
|
+
For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
|
86
|
+
this will be the size of the currently active proc_mesh.
|
87
|
+
"""
|
88
|
+
pass
|
89
|
+
|
90
|
+
@abstractmethod
|
91
|
+
def _port(self, once: bool = False) -> "PortTuple[R]":
|
92
|
+
pass
|
93
|
+
|
94
|
+
@abstractmethod
|
95
|
+
def _call_name(self) -> Any:
|
96
|
+
"""
|
97
|
+
Something to use in InputChecker to represent calling this thingy.
|
98
|
+
"""
|
99
|
+
pass
|
100
|
+
|
101
|
+
def _supervise(self, r: "HyPortReceiver | OncePortReceiver") -> Any:
|
102
|
+
return r
|
103
|
+
|
104
|
+
# the following are all 'adverbs' or different ways to handle the
|
105
|
+
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
|
106
|
+
# of the original call. If we want to add syntax sugar for something that needs additional
|
107
|
+
# arguments, it should be implemented as function indepdendent of endpoint like `send`
|
108
|
+
# and `Accumulator`
|
109
|
+
def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
|
110
|
+
"""
|
111
|
+
Load balanced sends a message to one chosen actor and awaits a result.
|
112
|
+
|
113
|
+
Load balanced RPC-style entrypoint for request/response messaging.
|
114
|
+
"""
|
115
|
+
from monarch._src.actor.actor_mesh import port
|
116
|
+
|
117
|
+
p, r = port(self, once=True)
|
118
|
+
# pyre-ignore
|
119
|
+
self._send(args, kwargs, port=p, selection="choose")
|
120
|
+
return r.recv()
|
121
|
+
|
122
|
+
def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
|
123
|
+
from monarch._src.actor.actor_mesh import port
|
124
|
+
|
125
|
+
p, r = port(self, once=True)
|
126
|
+
# pyre-ignore
|
127
|
+
extent = self._send(args, kwargs, port=p, selection="choose")
|
128
|
+
if extent.nelements != 1:
|
129
|
+
raise ValueError(
|
130
|
+
f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
|
131
|
+
)
|
132
|
+
return r.recv()
|
133
|
+
|
134
|
+
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
|
135
|
+
from monarch._src.actor.actor_mesh import ranked_port, ValueMesh
|
136
|
+
|
137
|
+
p, r = ranked_port(self)
|
138
|
+
# pyre-ignore
|
139
|
+
extent = self._send(args, kwargs, port=p)
|
140
|
+
|
141
|
+
async def process() -> "ValueMesh[R]":
|
142
|
+
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
|
143
|
+
from monarch._src.actor.shape import NDSlice
|
144
|
+
|
145
|
+
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
|
146
|
+
for _ in range(extent.nelements):
|
147
|
+
rank, value = await r.recv()
|
148
|
+
results[rank] = value
|
149
|
+
call_shape = Shape(
|
150
|
+
extent.labels,
|
151
|
+
NDSlice.new_row_major(extent.sizes),
|
152
|
+
)
|
153
|
+
return ValueMesh(call_shape, results)
|
154
|
+
|
155
|
+
return Future(impl=process, requires_loop=False)
|
156
|
+
|
157
|
+
async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
|
158
|
+
"""
|
159
|
+
Broadcasts to all actors and yields their responses as a stream / generator.
|
160
|
+
|
161
|
+
This enables processing results from multiple actors incrementally as
|
162
|
+
they become available. Returns an async generator of response values.
|
163
|
+
"""
|
164
|
+
from monarch._src.actor.actor_mesh import port
|
165
|
+
|
166
|
+
p, r = port(self)
|
167
|
+
# pyre-ignore
|
168
|
+
extent = self._send(args, kwargs, port=p)
|
169
|
+
for _ in range(extent.nelements):
|
170
|
+
yield await r.recv()
|
171
|
+
|
172
|
+
def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
173
|
+
"""
|
174
|
+
Fire-and-forget broadcast to all actors without waiting for actors to
|
175
|
+
acknowledge receipt.
|
176
|
+
|
177
|
+
In other words, the return of this method does not guarrantee the
|
178
|
+
delivery of the message.
|
179
|
+
"""
|
180
|
+
from monarch._src.actor.actor_mesh import send
|
181
|
+
|
182
|
+
# pyre-ignore
|
183
|
+
send(self, args, kwargs)
|
184
|
+
|
185
|
+
def _propagate(self, args, kwargs, fake_args, fake_kwargs):
|
186
|
+
if self._propagator_arg is None or self._propagator_arg == "cached":
|
187
|
+
if self._cache is None:
|
188
|
+
self._cache = {}
|
189
|
+
return _cached_propagation(self._cache, self._resolvable, args, kwargs)
|
190
|
+
elif self._propagator_arg == "inspect":
|
191
|
+
return None
|
192
|
+
elif self._propagator_arg == "mocked":
|
193
|
+
raise NotImplementedError("mocked propagation")
|
194
|
+
else:
|
195
|
+
return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
|
196
|
+
|
197
|
+
def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
|
198
|
+
if self._propagator_arg is None:
|
199
|
+
return # no propgator provided, so we just assume no mutations
|
200
|
+
return self._propagate(args, kwargs, fake_args, fake_kwargs)
|
201
|
+
|
202
|
+
def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
|
203
|
+
if not callable(self._propagator_arg):
|
204
|
+
raise ValueError("Must specify explicit callable for pipe")
|
205
|
+
return self._propagate(args, kwargs, fake_args, fake_kwargs)
|
206
|
+
|
207
|
+
|
208
|
+
class EndpointProperty(Generic[P, R]):
|
209
|
+
@overload
|
210
|
+
def __init__(
|
211
|
+
self,
|
212
|
+
method: Callable[Concatenate[Any, P], Awaitable[R]],
|
213
|
+
propagator: Propagator,
|
214
|
+
) -> None: ...
|
215
|
+
|
216
|
+
@overload
|
217
|
+
def __init__(
|
218
|
+
self, method: Callable[Concatenate[Any, P], R], propagator: Propagator
|
219
|
+
) -> None: ...
|
220
|
+
|
221
|
+
def __init__(self, method: Any, propagator: Propagator) -> None:
|
222
|
+
self._method = method
|
223
|
+
self._propagator = propagator
|
224
|
+
|
225
|
+
def __get__(self, instance, owner) -> Endpoint[P, R]:
|
226
|
+
# this is a total lie, but we have to actually
|
227
|
+
# recognize this was defined as an endpoint,
|
228
|
+
# and also lookup the method
|
229
|
+
return cast(Endpoint[P, R], self)
|
230
|
+
|
231
|
+
|
232
|
+
# This can't just be Callable because otherwise we are not
|
233
|
+
# allowed to use type arguments in the return value.
|
234
|
+
class EndpointIfy:
|
235
|
+
@overload
|
236
|
+
def __call__(self, function: Callable[P, Awaitable[R]]) -> Endpoint[P, R]: ...
|
237
|
+
@overload
|
238
|
+
def __call__(self, function: Callable[P, R]) -> Endpoint[P, R]: ...
|
239
|
+
|
240
|
+
def __call__(self, function: Any):
|
241
|
+
pass
|
242
|
+
|
243
|
+
|
244
|
+
@overload
|
245
|
+
def endpoint(
|
246
|
+
method: Callable[Concatenate[Any, P], Awaitable[R]],
|
247
|
+
*,
|
248
|
+
propagate: Propagator = None,
|
249
|
+
) -> EndpointProperty[P, R]: ...
|
250
|
+
|
251
|
+
|
252
|
+
@overload
|
253
|
+
def endpoint(
|
254
|
+
method: Callable[Concatenate[Any, P], R],
|
255
|
+
*,
|
256
|
+
propagate: Propagator = None,
|
257
|
+
) -> EndpointProperty[P, R]: ...
|
258
|
+
|
259
|
+
|
260
|
+
@overload
|
261
|
+
def endpoint(
|
262
|
+
*,
|
263
|
+
propagate: Propagator = None,
|
264
|
+
) -> EndpointIfy: ...
|
265
|
+
|
266
|
+
|
267
|
+
def endpoint(method=None, *, propagate=None):
|
268
|
+
if method is None:
|
269
|
+
return functools.partial(endpoint, propagate=propagate)
|
270
|
+
return EndpointProperty(method, propagator=propagate)
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
"""
|
8
|
+
Module for managing the event loop used by Monarch Python actors.
|
9
|
+
This provides a way to create a Python-aware thread from Rust that runs the worker event loop.
|
10
|
+
"""
|
11
|
+
|
12
|
+
import asyncio
|
13
|
+
import logging
|
14
|
+
import threading
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from libfb.py.pyre import none_throws
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
_event_loop: Optional[asyncio.AbstractEventLoop] = None
|
22
|
+
_lock = threading.Lock()
|
23
|
+
_ready = threading.Event()
|
24
|
+
|
25
|
+
|
26
|
+
def _initialize_event_loop() -> None:
|
27
|
+
"""
|
28
|
+
Internal function to initialize the event loop.
|
29
|
+
This creates a new thread with an event loop that runs forever.
|
30
|
+
"""
|
31
|
+
global _event_loop, _ready
|
32
|
+
if _event_loop is not None:
|
33
|
+
return
|
34
|
+
|
35
|
+
# Create a new thread that will host our event loop
|
36
|
+
def event_loop_thread():
|
37
|
+
"""Target function for the event loop thread."""
|
38
|
+
global _event_loop, _ready
|
39
|
+
try:
|
40
|
+
# Create a new event loop
|
41
|
+
loop = asyncio.new_event_loop()
|
42
|
+
asyncio.set_event_loop(loop)
|
43
|
+
|
44
|
+
_event_loop = loop
|
45
|
+
_ready.set()
|
46
|
+
|
47
|
+
logger.debug(
|
48
|
+
f"Python worker event loop thread started: {threading.current_thread().name}"
|
49
|
+
)
|
50
|
+
try:
|
51
|
+
# Run the event loop forever
|
52
|
+
loop.run_forever()
|
53
|
+
finally:
|
54
|
+
# Clean up when the loop stops
|
55
|
+
logger.debug("Python worker event loop stopped, closing...")
|
56
|
+
loop.close()
|
57
|
+
except Exception as e:
|
58
|
+
logger.error(f"Error in event loop thread: {e}")
|
59
|
+
_ready.set()
|
60
|
+
raise
|
61
|
+
|
62
|
+
# Create and start the thread
|
63
|
+
threading.Thread(
|
64
|
+
target=event_loop_thread,
|
65
|
+
name="asyncio-event-loop",
|
66
|
+
daemon=True, # Make it a daemon thread so it doesn't block program exit
|
67
|
+
).start()
|
68
|
+
|
69
|
+
_ready.wait() # Wait for the event loop to be ready
|
70
|
+
|
71
|
+
if _event_loop is None:
|
72
|
+
raise RuntimeError("Failed to initialize event loop")
|
73
|
+
|
74
|
+
|
75
|
+
def get_event_loop() -> asyncio.AbstractEventLoop:
|
76
|
+
"""
|
77
|
+
Get the Python worker event loop.
|
78
|
+
If no event loop is currently running, this will start a new one.
|
79
|
+
|
80
|
+
Expected to be called from rust code.
|
81
|
+
"""
|
82
|
+
global _event_loop
|
83
|
+
if _event_loop is None:
|
84
|
+
with _lock:
|
85
|
+
_initialize_event_loop()
|
86
|
+
return none_throws(_event_loop)
|
87
|
+
|
88
|
+
|
89
|
+
def stop_event_loop() -> None:
|
90
|
+
"""
|
91
|
+
Stop the event loop gracefully.
|
92
|
+
"""
|
93
|
+
global _event_loop
|
94
|
+
if _event_loop is not None:
|
95
|
+
logger.debug("Stopping event loop...")
|
96
|
+
event_loop = none_throws(_event_loop)
|
97
|
+
event_loop.call_soon_threadsafe(event_loop.stop)
|
@@ -0,0 +1,100 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import traceback
|
9
|
+
from functools import partial
|
10
|
+
from typing import Generator, Generic, Optional, TypeVar
|
11
|
+
|
12
|
+
R = TypeVar("R")
|
13
|
+
|
14
|
+
|
15
|
+
async def _aincomplete(impl, self):
|
16
|
+
try:
|
17
|
+
return self._set_result(await impl())
|
18
|
+
except Exception as e:
|
19
|
+
self._set_exception(e)
|
20
|
+
raise
|
21
|
+
|
22
|
+
|
23
|
+
# Future is our generic mechanism for providing both a synchronous and asynchronous API for
|
24
|
+
# Monarch Future objects.
|
25
|
+
|
26
|
+
# We treat all code as running in one of two contexts: synchronous (asyncio._get_running_loop() is None)
|
27
|
+
# or asynchronous.
|
28
|
+
|
29
|
+
# Inside of asynchronous code, clients of our API must use `await` to wait for monarch Futures to prevent
|
30
|
+
# blocking the surrounding event loop.
|
31
|
+
|
32
|
+
# In synchronous code users must call get() because the call is comming from an non-async function so
|
33
|
+
# await is not allowed.
|
34
|
+
|
35
|
+
# [avoiding async code duplication]
|
36
|
+
# Because we allow for two modes, it is tempting as developers of Monarch to start to write two copies of
|
37
|
+
# of code for each mode. However, this results in a lot of confusing code duplication.
|
38
|
+
# To avoid this, we utilize the fact that synchronous code is allowed to start/complete an asyncio event loop
|
39
|
+
# via asyncio.run in order to complete the `get()` operation. So we can just write the async version and use
|
40
|
+
# it to implement the synchronoous version.
|
41
|
+
|
42
|
+
# However, starting and running an event loop is somewhat expensive. For simple messages, using an event loop
|
43
|
+
# is about 4x slower than just directly waiting on the tokio result. To avoid this slow down we perform an
|
44
|
+
# optimization. For any case where the `impl` coroutine of a future calls `await` only on PythonFuture
|
45
|
+
# (a Tokio future returning a Python value) objects, we pass requires_loop=False to the Future. In this mode,
|
46
|
+
# the future will just run the coroutine manually, and the PythonFuture object will recognize it is being awaited
|
47
|
+
# without an event loop (search [avoiding code duplication]) and simply do a blocking wait. By avoiding the event
|
48
|
+
# loop machinery, this gives it the same throughput as if we ran it synchronously.
|
49
|
+
|
50
|
+
|
51
|
+
class Future(Generic[R]):
|
52
|
+
def __init__(self, *, impl, requires_loop=True):
|
53
|
+
self._aget = partial(_aincomplete, impl)
|
54
|
+
self._requires_loop = requires_loop
|
55
|
+
|
56
|
+
def get(self, timeout: Optional[float] = None) -> R:
|
57
|
+
if asyncio._get_running_loop() is not None:
|
58
|
+
raise RuntimeError("get() cannot be called from within an async context")
|
59
|
+
if timeout is not None:
|
60
|
+
return asyncio.run(asyncio.wait_for(self._aget(self), timeout))
|
61
|
+
if not self._requires_loop:
|
62
|
+
try:
|
63
|
+
coro = self._aget(self)
|
64
|
+
next(coro.__await__())
|
65
|
+
tb_str = "".join(traceback.format_stack(coro.cr_frame))
|
66
|
+
raise RuntimeError(
|
67
|
+
f"a coroutine paused with a future with requires_loop=False cannot block on a python asyncio.Future. Use requires_loop=True.\n{tb_str}"
|
68
|
+
)
|
69
|
+
except StopIteration as e:
|
70
|
+
return e.value
|
71
|
+
return asyncio.run(self._aget(self))
|
72
|
+
|
73
|
+
def __await__(self) -> Generator[R, None, R]:
|
74
|
+
return self._aget(self).__await__()
|
75
|
+
|
76
|
+
def _set_result(self, result):
|
77
|
+
async def af(self):
|
78
|
+
return result
|
79
|
+
|
80
|
+
self._aget = af
|
81
|
+
return result
|
82
|
+
|
83
|
+
def _set_exception(self, e):
|
84
|
+
async def af(self):
|
85
|
+
raise e
|
86
|
+
|
87
|
+
self._aget = af
|
88
|
+
|
89
|
+
# compatibility with old tensor engine Future objects
|
90
|
+
# hopefully we do not need done(), add_callback because
|
91
|
+
# they are harder to implement right.
|
92
|
+
def result(self, timeout: Optional[float] = None) -> R:
|
93
|
+
return self.get(timeout)
|
94
|
+
|
95
|
+
def exception(self, timeout: Optional[float] = None):
|
96
|
+
try:
|
97
|
+
self.get(timeout)
|
98
|
+
return None
|
99
|
+
except Exception as e:
|
100
|
+
return e
|
@@ -4,6 +4,7 @@
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
6
6
|
|
7
|
+
# pyre-unsafe
|
7
8
|
import bdb
|
8
9
|
import inspect
|
9
10
|
import io
|
@@ -15,9 +16,10 @@ from dataclasses import dataclass
|
|
15
16
|
from typing import Dict, TYPE_CHECKING
|
16
17
|
|
17
18
|
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
19
|
+
from monarch._src.actor.sync_state import fake_sync_state
|
18
20
|
|
19
21
|
if TYPE_CHECKING:
|
20
|
-
from monarch.debugger import DebugClient
|
22
|
+
from monarch._src.actor.debugger import DebugClient
|
21
23
|
|
22
24
|
|
23
25
|
@dataclass
|
@@ -45,35 +47,38 @@ class PdbWrapper(pdb.Pdb):
|
|
45
47
|
super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self))
|
46
48
|
self._first = True
|
47
49
|
|
48
|
-
def
|
49
|
-
|
50
|
-
if self._first:
|
51
|
-
self._first = False
|
52
|
-
# when we enter the debugger, we want to present the user's stack frame
|
53
|
-
# not the nested one inside session.run. This means that the local
|
54
|
-
# variables are what gets printed, etc. To do this
|
55
|
-
# we first execute up 2 to get to that frame.
|
56
|
-
self.do_up(2)
|
57
|
-
return r
|
58
|
-
|
59
|
-
def set_continue(self) -> None:
|
60
|
-
r = super().set_continue()
|
61
|
-
if not self.breaks:
|
62
|
-
# no more breakpoints so this debugger will not
|
63
|
-
# be used again, and we detach from the controller io.
|
64
|
-
self.client_ref.debugger_session_end.call_one(self.rank).get()
|
65
|
-
# break cycle with itself before we exit
|
66
|
-
self.stdin = sys.stdin
|
67
|
-
self.stdout = sys.stdout
|
68
|
-
return r
|
69
|
-
|
70
|
-
def set_trace(self):
|
71
|
-
self.client_ref.debugger_session_start.call_one(
|
50
|
+
def set_trace(self, frame):
|
51
|
+
self.client_ref.debugger_session_start.broadcast(
|
72
52
|
self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id
|
73
|
-
)
|
53
|
+
)
|
74
54
|
if self.header:
|
75
55
|
self.message(self.header)
|
76
|
-
super().set_trace()
|
56
|
+
super().set_trace(frame)
|
57
|
+
|
58
|
+
def do_clear(self, arg):
|
59
|
+
if not arg:
|
60
|
+
# Sending `clear` without any argument specified will
|
61
|
+
# request confirmation from the user using the `input` function,
|
62
|
+
# which bypasses our ReadWrapper and causes a hang on the client.
|
63
|
+
# To avoid this, we just clear all breakpoints instead without
|
64
|
+
# confirmation.
|
65
|
+
super().clear_all_breaks()
|
66
|
+
else:
|
67
|
+
super().do_clear(arg)
|
68
|
+
|
69
|
+
def end_debug_session(self):
|
70
|
+
self.client_ref.debugger_session_end.broadcast(self.rank)
|
71
|
+
# Once the debug client actor is notified of the session being over,
|
72
|
+
# we need to prevent any additional requests being sent for the session
|
73
|
+
# by redirecting stdin and stdout.
|
74
|
+
self.stdin = sys.stdin
|
75
|
+
self.stdout = sys.stdout
|
76
|
+
|
77
|
+
def post_mortem(self, exc_tb):
|
78
|
+
self._first = False
|
79
|
+
# See builtin implementation of pdb.post_mortem() for reference.
|
80
|
+
self.reset()
|
81
|
+
self.interaction(None, exc_tb)
|
77
82
|
|
78
83
|
|
79
84
|
class ReadWrapper(io.RawIOBase):
|
@@ -81,16 +86,19 @@ class ReadWrapper(io.RawIOBase):
|
|
81
86
|
self.session = session
|
82
87
|
|
83
88
|
def readinto(self, b):
|
84
|
-
|
85
|
-
self.session.
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
89
|
+
with fake_sync_state():
|
90
|
+
response = self.session.client_ref.debugger_read.call_one(
|
91
|
+
self.session.rank, len(b)
|
92
|
+
).get()
|
93
|
+
if response == "detach":
|
94
|
+
# this gets injected by the worker event loop to
|
95
|
+
# get the worker thread to exit on an Exit command.
|
96
|
+
raise bdb.BdbQuit
|
97
|
+
assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(
|
98
|
+
b
|
99
|
+
)
|
100
|
+
b[: len(response.payload)] = response.payload
|
101
|
+
return len(response.payload)
|
94
102
|
|
95
103
|
def readable(self) -> bool:
|
96
104
|
return True
|
@@ -115,21 +123,14 @@ class WriteWrapper:
|
|
115
123
|
function = f"{inspect.getmodulename(self.session.curframe.f_code.co_filename)}.{self.session.curframe.f_code.co_name}"
|
116
124
|
# pyre-ignore
|
117
125
|
lineno = self.session.curframe.f_lineno
|
118
|
-
self.session.client_ref.debugger_write.
|
126
|
+
self.session.client_ref.debugger_write.broadcast(
|
119
127
|
self.session.rank,
|
120
128
|
DebuggerWrite(
|
121
129
|
s.encode(),
|
122
130
|
function,
|
123
131
|
lineno,
|
124
132
|
),
|
125
|
-
)
|
133
|
+
)
|
126
134
|
|
127
135
|
def flush(self):
|
128
136
|
pass
|
129
|
-
|
130
|
-
|
131
|
-
def remote_breakpointhook(
|
132
|
-
rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient"
|
133
|
-
):
|
134
|
-
ds = PdbWrapper(rank, coords, actor_id, client_ref)
|
135
|
-
ds.set_trace()
|
@@ -6,10 +6,16 @@
|
|
6
6
|
|
7
7
|
import io
|
8
8
|
import pickle
|
9
|
+
from contextlib import contextmanager, ExitStack
|
9
10
|
from typing import Any, Callable, Iterable, List, Tuple
|
10
11
|
|
11
12
|
import cloudpickle
|
12
13
|
|
14
|
+
try:
|
15
|
+
import torch # @manual
|
16
|
+
except ImportError:
|
17
|
+
torch = None
|
18
|
+
|
13
19
|
|
14
20
|
class _Pickler(cloudpickle.Pickler):
|
15
21
|
def __init__(self, filter):
|
@@ -44,5 +50,23 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]:
|
|
44
50
|
|
45
51
|
|
46
52
|
def unflatten(data: bytes, values: Iterable[Any]) -> Any:
|
47
|
-
|
48
|
-
|
53
|
+
with ExitStack() as stack:
|
54
|
+
if torch is not None:
|
55
|
+
stack.enter_context(load_tensors_on_cpu())
|
56
|
+
stack.enter_context(torch.utils._python_dispatch._disable_current_modes())
|
57
|
+
up = _Unpickler(data, values)
|
58
|
+
return up.load()
|
59
|
+
|
60
|
+
|
61
|
+
@contextmanager
|
62
|
+
def load_tensors_on_cpu():
|
63
|
+
# Ensure that any tensors load from CPU via monkeypatching how Storages are
|
64
|
+
# loaded.
|
65
|
+
old = torch.storage._load_from_bytes
|
66
|
+
try:
|
67
|
+
torch.storage._load_from_bytes = lambda b: torch.load(
|
68
|
+
io.BytesIO(b), map_location="cpu", weights_only=False
|
69
|
+
)
|
70
|
+
yield
|
71
|
+
finally:
|
72
|
+
torch.storage._load_from_bytes = old
|