torchmonarch-nightly 2025.7.1__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.26__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +878 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +303 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +508 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +59 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +53 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +21 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/gradient/_gradient_generator.so +0 -0
  37. monarch/mesh_controller.py +263 -139
  38. monarch/monarch_controller +0 -0
  39. monarch/opaque_module.py +4 -6
  40. monarch/opaque_object.py +3 -3
  41. monarch/proc_mesh.py +6 -309
  42. monarch/python_local_mesh.py +1 -1
  43. monarch/rust_backend_mesh.py +2 -1
  44. monarch/rust_local_mesh.py +4 -2
  45. monarch/sim_mesh.py +10 -19
  46. monarch/simulator/command_history.py +1 -1
  47. monarch/simulator/interface.py +2 -1
  48. monarch/simulator/mock_controller.py +1 -1
  49. monarch/simulator/simulator.py +1 -1
  50. monarch/tensor_engine/__init__.py +23 -0
  51. monarch/tensor_worker_main.py +3 -1
  52. monarch/tools/cli.py +3 -1
  53. monarch/tools/commands.py +129 -47
  54. monarch/tools/components/hyperactor.py +5 -3
  55. monarch/tools/config/__init__.py +18 -1
  56. monarch/tools/config/defaults.py +2 -2
  57. monarch/tools/mesh_spec.py +59 -1
  58. monarch/tools/utils.py +38 -0
  59. monarch/worker/worker.py +1 -1
  60. monarch/world_mesh.py +2 -1
  61. monarch_supervisor/python_executable.py +6 -3
  62. tests/error_test_binary.py +48 -10
  63. tests/test_actor_error.py +370 -21
  64. tests/test_alloc.py +1 -1
  65. tests/test_allocator.py +369 -17
  66. tests/test_controller.py +2 -0
  67. tests/test_debugger.py +416 -0
  68. tests/test_env_before_cuda.py +161 -0
  69. tests/test_python_actors.py +184 -333
  70. tests/test_rdma.py +198 -0
  71. tests/test_remote_functions.py +40 -12
  72. tests/test_rust_backend.py +7 -5
  73. tests/test_sim_backend.py +1 -4
  74. tests/test_tensor_engine.py +81 -1
  75. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/METADATA +39 -1
  76. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/RECORD +84 -72
  77. torchmonarch_nightly-2025.7.26.dist-info/entry_points.txt +3 -0
  78. monarch/_monarch/hyperactor/__init__.py +0 -58
  79. monarch/_monarch/worker/debugger.py +0 -117
  80. monarch/_monarch/worker/logging.py +0 -107
  81. monarch/debugger.py +0 -379
  82. monarch/future.py +0 -76
  83. monarch/rdma.py +0 -162
  84. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  85. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  86. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  87. /monarch/{common → _src/actor}/shape.py +0 -0
  88. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  89. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/WHEEL +0 -0
  90. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/licenses/LICENSE +0 -0
  91. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,303 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import functools
10
+ from abc import ABC, abstractmethod
11
+ from operator import mul
12
+ from typing import (
13
+ Any,
14
+ AsyncGenerator,
15
+ Awaitable,
16
+ Callable,
17
+ cast,
18
+ Concatenate,
19
+ Dict,
20
+ Generic,
21
+ List,
22
+ Literal,
23
+ Optional,
24
+ overload,
25
+ ParamSpec,
26
+ Sequence,
27
+ Tuple,
28
+ TYPE_CHECKING,
29
+ TypeVar,
30
+ )
31
+
32
+ from monarch._src.actor.future import Future
33
+ from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
34
+
35
+ if TYPE_CHECKING:
36
+ from monarch._src.actor.actor_mesh import (
37
+ ActorMeshRef,
38
+ HyPortReceiver,
39
+ OncePortReceiver,
40
+ Port,
41
+ PortTuple,
42
+ ValueMesh,
43
+ )
44
+
45
+ P = ParamSpec("P")
46
+ R = TypeVar("R")
47
+
48
+ Selection = Literal["all", "choose"] | int
49
+
50
+
51
+ class Extent:
52
+ def __init__(self, labels: Sequence[str], sizes: Sequence[int]) -> None:
53
+ self.labels = labels
54
+ self.sizes = sizes
55
+
56
+ @property
57
+ def nelements(self) -> int:
58
+ return functools.reduce(mul, self.sizes, 1)
59
+
60
+ def __str__(self) -> str:
61
+ return str(dict(zip(self.labels, self.sizes)))
62
+
63
+
64
+ Propagator = Any
65
+
66
+
67
+ class Endpoint(ABC, Generic[P, R]):
68
+ def __init__(self, propagator: Propagator) -> None:
69
+ self._propagator_arg = propagator
70
+ self._cache: Optional[dict] = None
71
+
72
+ @abstractmethod
73
+ def _send(
74
+ self,
75
+ args: Tuple[Any, ...],
76
+ kwargs: Dict[str, Any],
77
+ port: "Optional[Port]" = None,
78
+ selection: Selection = "all",
79
+ ) -> Extent:
80
+ """
81
+ Implements sending a message to the endpoint. The return value of the endpoint will
82
+ be sent to port if provided. If port is not provided, the return will be dropped,
83
+ and any exception will cause the actor to fail.
84
+
85
+ The return value is the (multi-dimension) size of the actors that were sent a message.
86
+ For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
87
+ this will be the size of the currently active proc_mesh.
88
+ """
89
+ pass
90
+
91
+ @abstractmethod
92
+ def _port(self, once: bool = False) -> "PortTuple[R]":
93
+ pass
94
+
95
+ @abstractmethod
96
+ def _call_name(self) -> Any:
97
+ """
98
+ Something to use in InputChecker to represent calling this thingy.
99
+ """
100
+ pass
101
+
102
+ def _supervise(self, r: "HyPortReceiver | OncePortReceiver") -> Any:
103
+ return r
104
+
105
+ # the following are all 'adverbs' or different ways to handle the
106
+ # return values of this endpoint. Adverbs should only ever take *args, **kwargs
107
+ # of the original call. If we want to add syntax sugar for something that needs additional
108
+ # arguments, it should be implemented as function indepdendent of endpoint like `send`
109
+ # and `Accumulator`
110
+ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
111
+ """
112
+ Load balanced sends a message to one chosen actor and awaits a result.
113
+
114
+ Load balanced RPC-style entrypoint for request/response messaging.
115
+ """
116
+ from monarch._src.actor.actor_mesh import port
117
+
118
+ p, r = port(self, once=True)
119
+ # pyre-ignore
120
+ self._send(args, kwargs, port=p, selection="choose")
121
+ return r.recv()
122
+
123
+ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
124
+ from monarch._src.actor.actor_mesh import port
125
+
126
+ p, r = port(self, once=True)
127
+ # pyre-ignore
128
+ extent = self._send(args, kwargs, port=p, selection="choose")
129
+ if extent.nelements != 1:
130
+ raise ValueError(
131
+ f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
132
+ )
133
+ return r.recv()
134
+
135
+ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
136
+ from monarch._src.actor.actor_mesh import ranked_port, ValueMesh
137
+
138
+ p, r = ranked_port(self)
139
+ # pyre-ignore
140
+ extent = self._send(args, kwargs, port=p)
141
+
142
+ async def process() -> "ValueMesh[R]":
143
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
144
+ from monarch._src.actor.shape import NDSlice
145
+
146
+ results: List[R] = [None] * extent.nelements # pyre-fixme[9]
147
+ for _ in range(extent.nelements):
148
+ rank, value = await r.recv()
149
+ results[rank] = value
150
+ call_shape = Shape(
151
+ extent.labels,
152
+ NDSlice.new_row_major(extent.sizes),
153
+ )
154
+ return ValueMesh(call_shape, results)
155
+
156
+ return Future(impl=process, requires_loop=False)
157
+
158
+ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
159
+ """
160
+ Broadcasts to all actors and yields their responses as a stream / generator.
161
+
162
+ This enables processing results from multiple actors incrementally as
163
+ they become available. Returns an async generator of response values.
164
+ """
165
+ from monarch._src.actor.actor_mesh import port
166
+
167
+ p, r = port(self)
168
+ # pyre-ignore
169
+ extent = self._send(args, kwargs, port=p)
170
+ for _ in range(extent.nelements):
171
+ yield await r.recv()
172
+
173
+ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
174
+ """
175
+ Fire-and-forget broadcast to all actors without waiting for actors to
176
+ acknowledge receipt.
177
+
178
+ In other words, the return of this method does not guarrantee the
179
+ delivery of the message.
180
+ """
181
+ from monarch._src.actor.actor_mesh import send
182
+
183
+ # pyre-ignore
184
+ send(self, args, kwargs)
185
+
186
+ @abstractmethod
187
+ def _rref(self, args, kwargs) -> Any: ...
188
+
189
+ def rref(self, *args: P.args, **kwargs: P.kwargs) -> R:
190
+ return self._rref(args, kwargs)
191
+
192
+ def _propagate(self, args, kwargs, fake_args, fake_kwargs):
193
+ if self._propagator_arg is None or self._propagator_arg == "cached":
194
+ if self._cache is None:
195
+ self._cache = {}
196
+ resolvable = getattr(self, "_resolvable", None)
197
+ if resolvable is None:
198
+ raise NotImplementedError(
199
+ "Cached propagation is not implemented for actor endpoints."
200
+ )
201
+ return _cached_propagation(self._cache, resolvable, args, kwargs)
202
+ elif self._propagator_arg == "inspect":
203
+ return None
204
+ elif self._propagator_arg == "mocked":
205
+ raise NotImplementedError("mocked propagation")
206
+ else:
207
+ return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
208
+
209
+ def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
210
+ if self._propagator_arg is None:
211
+ return # no propgator provided, so we just assume no mutations
212
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
213
+
214
+ def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
215
+ if not callable(self._propagator_arg):
216
+ raise ValueError("Must specify explicit callable for pipe")
217
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
218
+
219
+
220
+ class EndpointProperty(Generic[P, R]):
221
+ @overload
222
+ def __init__(
223
+ self,
224
+ method: Callable[Concatenate[Any, P], Awaitable[R]],
225
+ propagator: Propagator,
226
+ ) -> None: ...
227
+
228
+ @overload
229
+ def __init__(
230
+ self, method: Callable[Concatenate[Any, P], R], propagator: Propagator
231
+ ) -> None: ...
232
+
233
+ def __init__(self, method: Any, propagator: Propagator) -> None:
234
+ self._method = method
235
+ self._propagator = propagator
236
+
237
+ def __get__(self, instance, owner) -> Endpoint[P, R]:
238
+ # this is a total lie, but we have to actually
239
+ # recognize this was defined as an endpoint,
240
+ # and also lookup the method
241
+ return cast(Endpoint[P, R], self)
242
+
243
+
244
+ class NotAnEndpoint:
245
+ """
246
+ Used as the dynamic value of functions on an ActorMeshRef that were not marked as endpoints.
247
+ This is used both to give a better error message (since we cannot prevent the type system from thinking they are methods),
248
+ and to provide the oppurtunity for someone to do endpoint(x.foo) on something that wasn't marked as an endpoint.
249
+ """
250
+
251
+ def __init__(self, ref: "ActorMeshRef", name: str):
252
+ self._ref = ref
253
+ self._name = name
254
+
255
+ def __call__(self, *args, **kwargs) -> None:
256
+ raise RuntimeError(
257
+ f"Actor {self._ref._class}.{self._name} is not annotated as an endpoint. To call it as one, add a @endpoint decorator to it, or directly wrap it in one as_endpoint(obj.method).call(...)"
258
+ )
259
+
260
+
261
+ # This can't just be Callable because otherwise we are not
262
+ # allowed to use type arguments in the return value.
263
+ class EndpointIfy:
264
+ @overload
265
+ def __call__(
266
+ self, function: Callable[Concatenate[Any, P], Awaitable[R]]
267
+ ) -> Endpoint[P, R]: ...
268
+ @overload
269
+ def __call__(
270
+ self, function: Callable[Concatenate[Any, P], R]
271
+ ) -> Endpoint[P, R]: ...
272
+
273
+ def __call__(self, function: Any):
274
+ pass
275
+
276
+
277
+ @overload
278
+ def endpoint(
279
+ method: Callable[Concatenate[Any, P], Awaitable[R]],
280
+ *,
281
+ propagate: Propagator = None,
282
+ ) -> EndpointProperty[P, R]: ...
283
+
284
+
285
+ @overload
286
+ def endpoint(
287
+ method: Callable[Concatenate[Any, P], R],
288
+ *,
289
+ propagate: Propagator = None,
290
+ ) -> EndpointProperty[P, R]: ...
291
+
292
+
293
+ @overload
294
+ def endpoint(
295
+ *,
296
+ propagate: Propagator = None,
297
+ ) -> EndpointIfy: ...
298
+
299
+
300
+ def endpoint(method=None, *, propagate=None):
301
+ if method is None:
302
+ return functools.partial(endpoint, propagate=propagate)
303
+ return EndpointProperty(method, propagator=propagate)
@@ -0,0 +1,97 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Module for managing the event loop used by Monarch Python actors.
9
+ This provides a way to create a Python-aware thread from Rust that runs the worker event loop.
10
+ """
11
+
12
+ import asyncio
13
+ import logging
14
+ import threading
15
+ from typing import Optional
16
+
17
+ from pyre_extensions import none_throws
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _event_loop: Optional[asyncio.AbstractEventLoop] = None
22
+ _lock = threading.Lock()
23
+ _ready = threading.Event()
24
+
25
+
26
+ def _initialize_event_loop() -> None:
27
+ """
28
+ Internal function to initialize the event loop.
29
+ This creates a new thread with an event loop that runs forever.
30
+ """
31
+ global _event_loop, _ready
32
+ if _event_loop is not None:
33
+ return
34
+
35
+ # Create a new thread that will host our event loop
36
+ def event_loop_thread():
37
+ """Target function for the event loop thread."""
38
+ global _event_loop, _ready
39
+ try:
40
+ # Create a new event loop
41
+ loop = asyncio.new_event_loop()
42
+ asyncio.set_event_loop(loop)
43
+
44
+ _event_loop = loop
45
+ _ready.set()
46
+
47
+ logger.debug(
48
+ f"Python worker event loop thread started: {threading.current_thread().name}"
49
+ )
50
+ try:
51
+ # Run the event loop forever
52
+ loop.run_forever()
53
+ finally:
54
+ # Clean up when the loop stops
55
+ logger.debug("Python worker event loop stopped, closing...")
56
+ loop.close()
57
+ except Exception as e:
58
+ logger.error(f"Error in event loop thread: {e}")
59
+ _ready.set()
60
+ raise
61
+
62
+ # Create and start the thread
63
+ threading.Thread(
64
+ target=event_loop_thread,
65
+ name="asyncio-event-loop",
66
+ daemon=True, # Make it a daemon thread so it doesn't block program exit
67
+ ).start()
68
+
69
+ _ready.wait() # Wait for the event loop to be ready
70
+
71
+ if _event_loop is None:
72
+ raise RuntimeError("Failed to initialize event loop")
73
+
74
+
75
+ def get_event_loop() -> asyncio.AbstractEventLoop:
76
+ """
77
+ Get the Python worker event loop.
78
+ If no event loop is currently running, this will start a new one.
79
+
80
+ Expected to be called from rust code.
81
+ """
82
+ global _event_loop
83
+ if _event_loop is None:
84
+ with _lock:
85
+ _initialize_event_loop()
86
+ return none_throws(_event_loop)
87
+
88
+
89
+ def stop_event_loop() -> None:
90
+ """
91
+ Stop the event loop gracefully.
92
+ """
93
+ global _event_loop
94
+ if _event_loop is not None:
95
+ logger.debug("Stopping event loop...")
96
+ event_loop = none_throws(_event_loop)
97
+ event_loop.call_soon_threadsafe(event_loop.stop)
@@ -0,0 +1,100 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import traceback
9
+ from functools import partial
10
+ from typing import Generator, Generic, Optional, TypeVar
11
+
12
+ R = TypeVar("R")
13
+
14
+
15
+ async def _aincomplete(impl, self):
16
+ try:
17
+ return self._set_result(await impl())
18
+ except Exception as e:
19
+ self._set_exception(e)
20
+ raise
21
+
22
+
23
+ # Future is our generic mechanism for providing both a synchronous and asynchronous API for
24
+ # Monarch Future objects.
25
+
26
+ # We treat all code as running in one of two contexts: synchronous (asyncio._get_running_loop() is None)
27
+ # or asynchronous.
28
+
29
+ # Inside of asynchronous code, clients of our API must use `await` to wait for monarch Futures to prevent
30
+ # blocking the surrounding event loop.
31
+
32
+ # In synchronous code users must call get() because the call is comming from an non-async function so
33
+ # await is not allowed.
34
+
35
+ # [avoiding async code duplication]
36
+ # Because we allow for two modes, it is tempting as developers of Monarch to start to write two copies of
37
+ # of code for each mode. However, this results in a lot of confusing code duplication.
38
+ # To avoid this, we utilize the fact that synchronous code is allowed to start/complete an asyncio event loop
39
+ # via asyncio.run in order to complete the `get()` operation. So we can just write the async version and use
40
+ # it to implement the synchronoous version.
41
+
42
+ # However, starting and running an event loop is somewhat expensive. For simple messages, using an event loop
43
+ # is about 4x slower than just directly waiting on the tokio result. To avoid this slow down we perform an
44
+ # optimization. For any case where the `impl` coroutine of a future calls `await` only on PythonFuture
45
+ # (a Tokio future returning a Python value) objects, we pass requires_loop=False to the Future. In this mode,
46
+ # the future will just run the coroutine manually, and the PythonFuture object will recognize it is being awaited
47
+ # without an event loop (search [avoiding code duplication]) and simply do a blocking wait. By avoiding the event
48
+ # loop machinery, this gives it the same throughput as if we ran it synchronously.
49
+
50
+
51
+ class Future(Generic[R]):
52
+ def __init__(self, *, impl, requires_loop=True):
53
+ self._aget = partial(_aincomplete, impl)
54
+ self._requires_loop = requires_loop
55
+
56
+ def get(self, timeout: Optional[float] = None) -> R:
57
+ if asyncio._get_running_loop() is not None:
58
+ raise RuntimeError("get() cannot be called from within an async context")
59
+ if timeout is not None:
60
+ return asyncio.run(asyncio.wait_for(self._aget(self), timeout))
61
+ if not self._requires_loop:
62
+ try:
63
+ coro = self._aget(self)
64
+ next(coro.__await__())
65
+ tb_str = "".join(traceback.format_stack(coro.cr_frame))
66
+ raise RuntimeError(
67
+ f"a coroutine paused with a future with requires_loop=False cannot block on a python asyncio.Future. Use requires_loop=True.\n{tb_str}"
68
+ )
69
+ except StopIteration as e:
70
+ return e.value
71
+ return asyncio.run(self._aget(self))
72
+
73
+ def __await__(self) -> Generator[R, None, R]:
74
+ return self._aget(self).__await__()
75
+
76
+ def _set_result(self, result):
77
+ async def af(self):
78
+ return result
79
+
80
+ self._aget = af
81
+ return result
82
+
83
+ def _set_exception(self, e):
84
+ async def af(self):
85
+ raise e
86
+
87
+ self._aget = af
88
+
89
+ # compatibility with old tensor engine Future objects
90
+ # hopefully we do not need done(), add_callback because
91
+ # they are harder to implement right.
92
+ def result(self, timeout: Optional[float] = None) -> R:
93
+ return self.get(timeout)
94
+
95
+ def exception(self, timeout: Optional[float] = None):
96
+ try:
97
+ self.get(timeout)
98
+ return None
99
+ except Exception as e:
100
+ return e
@@ -4,6 +4,7 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
7
8
  import bdb
8
9
  import inspect
9
10
  import io
@@ -15,9 +16,10 @@ from dataclasses import dataclass
15
16
  from typing import Dict, TYPE_CHECKING
16
17
 
17
18
  from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
19
+ from monarch._src.actor.sync_state import fake_sync_state
18
20
 
19
21
  if TYPE_CHECKING:
20
- from monarch.debugger import DebugClient
22
+ from monarch._src.actor.debugger import DebugClient
21
23
 
22
24
 
23
25
  @dataclass
@@ -45,35 +47,38 @@ class PdbWrapper(pdb.Pdb):
45
47
  super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self))
46
48
  self._first = True
47
49
 
48
- def setup(self, *args, **kwargs):
49
- r = super().setup(*args, **kwargs)
50
- if self._first:
51
- self._first = False
52
- # when we enter the debugger, we want to present the user's stack frame
53
- # not the nested one inside session.run. This means that the local
54
- # variables are what gets printed, etc. To do this
55
- # we first execute up 2 to get to that frame.
56
- self.do_up(2)
57
- return r
58
-
59
- def set_continue(self) -> None:
60
- r = super().set_continue()
61
- if not self.breaks:
62
- # no more breakpoints so this debugger will not
63
- # be used again, and we detach from the controller io.
64
- self.client_ref.debugger_session_end.call_one(self.rank).get()
65
- # break cycle with itself before we exit
66
- self.stdin = sys.stdin
67
- self.stdout = sys.stdout
68
- return r
69
-
70
- def set_trace(self):
71
- self.client_ref.debugger_session_start.call_one(
50
+ def set_trace(self, frame):
51
+ self.client_ref.debugger_session_start.broadcast(
72
52
  self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id
73
- ).get()
53
+ )
74
54
  if self.header:
75
55
  self.message(self.header)
76
- super().set_trace()
56
+ super().set_trace(frame)
57
+
58
+ def do_clear(self, arg):
59
+ if not arg:
60
+ # Sending `clear` without any argument specified will
61
+ # request confirmation from the user using the `input` function,
62
+ # which bypasses our ReadWrapper and causes a hang on the client.
63
+ # To avoid this, we just clear all breakpoints instead without
64
+ # confirmation.
65
+ super().clear_all_breaks()
66
+ else:
67
+ super().do_clear(arg)
68
+
69
+ def end_debug_session(self):
70
+ self.client_ref.debugger_session_end.broadcast(self.rank)
71
+ # Once the debug client actor is notified of the session being over,
72
+ # we need to prevent any additional requests being sent for the session
73
+ # by redirecting stdin and stdout.
74
+ self.stdin = sys.stdin
75
+ self.stdout = sys.stdout
76
+
77
+ def post_mortem(self, exc_tb):
78
+ self._first = False
79
+ # See builtin implementation of pdb.post_mortem() for reference.
80
+ self.reset()
81
+ self.interaction(None, exc_tb)
77
82
 
78
83
 
79
84
  class ReadWrapper(io.RawIOBase):
@@ -81,16 +86,19 @@ class ReadWrapper(io.RawIOBase):
81
86
  self.session = session
82
87
 
83
88
  def readinto(self, b):
84
- response = self.session.client_ref.debugger_read.call_one(
85
- self.session.rank, len(b)
86
- ).get()
87
- if response == "detach":
88
- # this gets injected by the worker event loop to
89
- # get the worker thread to exit on an Exit command.
90
- raise bdb.BdbQuit
91
- assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(b)
92
- b[: len(response.payload)] = response.payload
93
- return len(response.payload)
89
+ with fake_sync_state():
90
+ response = self.session.client_ref.debugger_read.call_one(
91
+ self.session.rank, len(b)
92
+ ).get()
93
+ if response == "detach":
94
+ # this gets injected by the worker event loop to
95
+ # get the worker thread to exit on an Exit command.
96
+ raise bdb.BdbQuit
97
+ assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(
98
+ b
99
+ )
100
+ b[: len(response.payload)] = response.payload
101
+ return len(response.payload)
94
102
 
95
103
  def readable(self) -> bool:
96
104
  return True
@@ -115,21 +123,14 @@ class WriteWrapper:
115
123
  function = f"{inspect.getmodulename(self.session.curframe.f_code.co_filename)}.{self.session.curframe.f_code.co_name}"
116
124
  # pyre-ignore
117
125
  lineno = self.session.curframe.f_lineno
118
- self.session.client_ref.debugger_write.call_one(
126
+ self.session.client_ref.debugger_write.broadcast(
119
127
  self.session.rank,
120
128
  DebuggerWrite(
121
129
  s.encode(),
122
130
  function,
123
131
  lineno,
124
132
  ),
125
- ).get()
133
+ )
126
134
 
127
135
  def flush(self):
128
136
  pass
129
-
130
-
131
- def remote_breakpointhook(
132
- rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient"
133
- ):
134
- ds = PdbWrapper(rank, coords, actor_id, client_ref)
135
- ds.set_trace()
@@ -6,10 +6,16 @@
6
6
 
7
7
  import io
8
8
  import pickle
9
+ from contextlib import contextmanager, ExitStack
9
10
  from typing import Any, Callable, Iterable, List, Tuple
10
11
 
11
12
  import cloudpickle
12
13
 
14
+ try:
15
+ import torch # @manual
16
+ except ImportError:
17
+ torch = None
18
+
13
19
 
14
20
  class _Pickler(cloudpickle.Pickler):
15
21
  def __init__(self, filter):
@@ -44,5 +50,23 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]:
44
50
 
45
51
 
46
52
  def unflatten(data: bytes, values: Iterable[Any]) -> Any:
47
- up = _Unpickler(data, values)
48
- return up.load()
53
+ with ExitStack() as stack:
54
+ if torch is not None:
55
+ stack.enter_context(load_tensors_on_cpu())
56
+ stack.enter_context(torch.utils._python_dispatch._disable_current_modes())
57
+ up = _Unpickler(data, values)
58
+ return up.load()
59
+
60
+
61
+ @contextmanager
62
+ def load_tensors_on_cpu():
63
+ # Ensure that any tensors load from CPU via monkeypatching how Storages are
64
+ # loaded.
65
+ old = torch.storage._load_from_bytes
66
+ try:
67
+ torch.storage._load_from_bytes = lambda b: torch.load(
68
+ io.BytesIO(b), map_location="cpu", weights_only=False
69
+ )
70
+ yield
71
+ finally:
72
+ torch.storage._load_from_bytes = old