torchmonarch-nightly 2025.7.1__cp311-cp311-manylinux2014_x86_64.whl → 2025.7.25__cp311-cp311-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. monarch/__init__.py +13 -9
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
  4. monarch/_src/actor/actor_mesh.py +874 -0
  5. monarch/{allocator.py → _src/actor/allocator.py} +26 -17
  6. monarch/_src/actor/bootstrap_main.py +73 -0
  7. monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
  8. monarch/_src/actor/code_sync/auto_reload.py +223 -0
  9. monarch/_src/actor/debugger.py +565 -0
  10. monarch/_src/actor/endpoint.py +270 -0
  11. monarch/_src/actor/event_loop.py +97 -0
  12. monarch/_src/actor/future.py +100 -0
  13. monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
  14. monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
  15. monarch/_src/actor/proc_mesh.py +500 -0
  16. monarch/_src/actor/sync_state.py +18 -0
  17. monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
  18. monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
  19. monarch/_src/actor/tensor_engine_shim.py +56 -0
  20. monarch/_src/tensor_engine/rdma.py +180 -0
  21. monarch/_testing.py +3 -2
  22. monarch/actor/__init__.py +51 -0
  23. monarch/actor_mesh.py +6 -765
  24. monarch/bootstrap_main.py +8 -47
  25. monarch/common/client.py +1 -1
  26. monarch/common/controller_api.py +2 -1
  27. monarch/common/device_mesh.py +12 -2
  28. monarch/common/messages.py +12 -1
  29. monarch/common/recording.py +4 -3
  30. monarch/common/remote.py +135 -52
  31. monarch/common/tensor.py +2 -1
  32. monarch/controller/backend.py +2 -2
  33. monarch/controller/controller.py +2 -1
  34. monarch/controller/rust_backend/controller.py +2 -1
  35. monarch/fetch.py +3 -5
  36. monarch/mesh_controller.py +201 -139
  37. monarch/monarch_controller +0 -0
  38. monarch/opaque_module.py +4 -6
  39. monarch/opaque_object.py +3 -3
  40. monarch/proc_mesh.py +6 -309
  41. monarch/python_local_mesh.py +1 -1
  42. monarch/rust_backend_mesh.py +2 -1
  43. monarch/rust_local_mesh.py +4 -2
  44. monarch/sim_mesh.py +10 -19
  45. monarch/simulator/command_history.py +1 -1
  46. monarch/simulator/interface.py +2 -1
  47. monarch/simulator/mock_controller.py +1 -1
  48. monarch/simulator/simulator.py +1 -1
  49. monarch/tensor_engine/__init__.py +23 -0
  50. monarch/tensor_worker_main.py +3 -1
  51. monarch/tools/cli.py +3 -1
  52. monarch/tools/commands.py +95 -35
  53. monarch/tools/mesh_spec.py +55 -0
  54. monarch/tools/utils.py +38 -0
  55. monarch/worker/worker.py +1 -1
  56. monarch/world_mesh.py +2 -1
  57. monarch_supervisor/python_executable.py +6 -3
  58. tests/error_test_binary.py +48 -10
  59. tests/test_actor_error.py +370 -21
  60. tests/test_alloc.py +1 -1
  61. tests/test_allocator.py +373 -17
  62. tests/test_controller.py +2 -0
  63. tests/test_debugger.py +416 -0
  64. tests/test_env_before_cuda.py +162 -0
  65. tests/test_python_actors.py +184 -333
  66. tests/test_rdma.py +198 -0
  67. tests/test_remote_functions.py +40 -12
  68. tests/test_rust_backend.py +7 -5
  69. tests/test_sim_backend.py +1 -4
  70. tests/test_tensor_engine.py +55 -1
  71. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
  72. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
  73. torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
  74. monarch/_monarch/hyperactor/__init__.py +0 -58
  75. monarch/_monarch/worker/debugger.py +0 -117
  76. monarch/_monarch/worker/logging.py +0 -107
  77. monarch/debugger.py +0 -379
  78. monarch/future.py +0 -76
  79. monarch/rdma.py +0 -162
  80. torchmonarch_nightly-2025.7.1.dist-info/entry_points.txt +0 -3
  81. /monarch/{_monarch/worker → _src}/__init__.py +0 -0
  82. /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
  83. /monarch/{common → _src/actor}/shape.py +0 -0
  84. /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
  85. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
  86. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
  87. {torchmonarch_nightly-2025.7.1.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+
9
+ import functools
10
+ from abc import ABC, abstractmethod
11
+ from operator import mul
12
+ from typing import (
13
+ Any,
14
+ AsyncGenerator,
15
+ Awaitable,
16
+ Callable,
17
+ cast,
18
+ Concatenate,
19
+ Dict,
20
+ Generic,
21
+ List,
22
+ Literal,
23
+ Optional,
24
+ overload,
25
+ ParamSpec,
26
+ Sequence,
27
+ Tuple,
28
+ TYPE_CHECKING,
29
+ TypeVar,
30
+ )
31
+
32
+ from monarch._src.actor.future import Future
33
+ from monarch._src.actor.tensor_engine_shim import _cached_propagation, fake_call
34
+
35
+ if TYPE_CHECKING:
36
+ from monarch._src.actor.actor_mesh import (
37
+ HyPortReceiver,
38
+ OncePortReceiver,
39
+ Port,
40
+ PortTuple,
41
+ ValueMesh,
42
+ )
43
+
44
+ P = ParamSpec("P")
45
+ R = TypeVar("R")
46
+
47
+ Selection = Literal["all", "choose"] | int
48
+
49
+
50
+ class Extent:
51
+ def __init__(self, labels: Sequence[str], sizes: Sequence[int]) -> None:
52
+ self.labels = labels
53
+ self.sizes = sizes
54
+
55
+ @property
56
+ def nelements(self) -> int:
57
+ return functools.reduce(mul, self.sizes, 1)
58
+
59
+ def __str__(self) -> str:
60
+ return str(dict(zip(self.labels, self.sizes)))
61
+
62
+
63
+ Propagator = Any
64
+
65
+
66
+ class Endpoint(ABC, Generic[P, R]):
67
+ def __init__(self, propagator: Propagator) -> None:
68
+ self._propagator_arg = propagator
69
+ self._cache: Optional[dict] = None
70
+
71
+ @abstractmethod
72
+ def _send(
73
+ self,
74
+ args: Tuple[Any, ...],
75
+ kwargs: Dict[str, Any],
76
+ port: "Optional[Port]" = None,
77
+ selection: Selection = "all",
78
+ ) -> Extent:
79
+ """
80
+ Implements sending a message to the endpoint. The return value of the endpoint will
81
+ be sent to port if provided. If port is not provided, the return will be dropped,
82
+ and any exception will cause the actor to fail.
83
+
84
+ The return value is the (multi-dimension) size of the actors that were sent a message.
85
+ For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
86
+ this will be the size of the currently active proc_mesh.
87
+ """
88
+ pass
89
+
90
+ @abstractmethod
91
+ def _port(self, once: bool = False) -> "PortTuple[R]":
92
+ pass
93
+
94
+ @abstractmethod
95
+ def _call_name(self) -> Any:
96
+ """
97
+ Something to use in InputChecker to represent calling this thingy.
98
+ """
99
+ pass
100
+
101
+ def _supervise(self, r: "HyPortReceiver | OncePortReceiver") -> Any:
102
+ return r
103
+
104
+ # the following are all 'adverbs' or different ways to handle the
105
+ # return values of this endpoint. Adverbs should only ever take *args, **kwargs
106
+ # of the original call. If we want to add syntax sugar for something that needs additional
107
+ # arguments, it should be implemented as function indepdendent of endpoint like `send`
108
+ # and `Accumulator`
109
+ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
110
+ """
111
+ Load balanced sends a message to one chosen actor and awaits a result.
112
+
113
+ Load balanced RPC-style entrypoint for request/response messaging.
114
+ """
115
+ from monarch._src.actor.actor_mesh import port
116
+
117
+ p, r = port(self, once=True)
118
+ # pyre-ignore
119
+ self._send(args, kwargs, port=p, selection="choose")
120
+ return r.recv()
121
+
122
+ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
123
+ from monarch._src.actor.actor_mesh import port
124
+
125
+ p, r = port(self, once=True)
126
+ # pyre-ignore
127
+ extent = self._send(args, kwargs, port=p, selection="choose")
128
+ if extent.nelements != 1:
129
+ raise ValueError(
130
+ f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
131
+ )
132
+ return r.recv()
133
+
134
+ def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
135
+ from monarch._src.actor.actor_mesh import ranked_port, ValueMesh
136
+
137
+ p, r = ranked_port(self)
138
+ # pyre-ignore
139
+ extent = self._send(args, kwargs, port=p)
140
+
141
+ async def process() -> "ValueMesh[R]":
142
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
143
+ from monarch._src.actor.shape import NDSlice
144
+
145
+ results: List[R] = [None] * extent.nelements # pyre-fixme[9]
146
+ for _ in range(extent.nelements):
147
+ rank, value = await r.recv()
148
+ results[rank] = value
149
+ call_shape = Shape(
150
+ extent.labels,
151
+ NDSlice.new_row_major(extent.sizes),
152
+ )
153
+ return ValueMesh(call_shape, results)
154
+
155
+ return Future(impl=process, requires_loop=False)
156
+
157
+ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
158
+ """
159
+ Broadcasts to all actors and yields their responses as a stream / generator.
160
+
161
+ This enables processing results from multiple actors incrementally as
162
+ they become available. Returns an async generator of response values.
163
+ """
164
+ from monarch._src.actor.actor_mesh import port
165
+
166
+ p, r = port(self)
167
+ # pyre-ignore
168
+ extent = self._send(args, kwargs, port=p)
169
+ for _ in range(extent.nelements):
170
+ yield await r.recv()
171
+
172
+ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
173
+ """
174
+ Fire-and-forget broadcast to all actors without waiting for actors to
175
+ acknowledge receipt.
176
+
177
+ In other words, the return of this method does not guarrantee the
178
+ delivery of the message.
179
+ """
180
+ from monarch._src.actor.actor_mesh import send
181
+
182
+ # pyre-ignore
183
+ send(self, args, kwargs)
184
+
185
+ def _propagate(self, args, kwargs, fake_args, fake_kwargs):
186
+ if self._propagator_arg is None or self._propagator_arg == "cached":
187
+ if self._cache is None:
188
+ self._cache = {}
189
+ return _cached_propagation(self._cache, self._resolvable, args, kwargs)
190
+ elif self._propagator_arg == "inspect":
191
+ return None
192
+ elif self._propagator_arg == "mocked":
193
+ raise NotImplementedError("mocked propagation")
194
+ else:
195
+ return fake_call(self._propagator_arg, *fake_args, **fake_kwargs)
196
+
197
+ def _fetch_propagate(self, args, kwargs, fake_args, fake_kwargs):
198
+ if self._propagator_arg is None:
199
+ return # no propgator provided, so we just assume no mutations
200
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
201
+
202
+ def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs):
203
+ if not callable(self._propagator_arg):
204
+ raise ValueError("Must specify explicit callable for pipe")
205
+ return self._propagate(args, kwargs, fake_args, fake_kwargs)
206
+
207
+
208
+ class EndpointProperty(Generic[P, R]):
209
+ @overload
210
+ def __init__(
211
+ self,
212
+ method: Callable[Concatenate[Any, P], Awaitable[R]],
213
+ propagator: Propagator,
214
+ ) -> None: ...
215
+
216
+ @overload
217
+ def __init__(
218
+ self, method: Callable[Concatenate[Any, P], R], propagator: Propagator
219
+ ) -> None: ...
220
+
221
+ def __init__(self, method: Any, propagator: Propagator) -> None:
222
+ self._method = method
223
+ self._propagator = propagator
224
+
225
+ def __get__(self, instance, owner) -> Endpoint[P, R]:
226
+ # this is a total lie, but we have to actually
227
+ # recognize this was defined as an endpoint,
228
+ # and also lookup the method
229
+ return cast(Endpoint[P, R], self)
230
+
231
+
232
+ # This can't just be Callable because otherwise we are not
233
+ # allowed to use type arguments in the return value.
234
+ class EndpointIfy:
235
+ @overload
236
+ def __call__(self, function: Callable[P, Awaitable[R]]) -> Endpoint[P, R]: ...
237
+ @overload
238
+ def __call__(self, function: Callable[P, R]) -> Endpoint[P, R]: ...
239
+
240
+ def __call__(self, function: Any):
241
+ pass
242
+
243
+
244
+ @overload
245
+ def endpoint(
246
+ method: Callable[Concatenate[Any, P], Awaitable[R]],
247
+ *,
248
+ propagate: Propagator = None,
249
+ ) -> EndpointProperty[P, R]: ...
250
+
251
+
252
+ @overload
253
+ def endpoint(
254
+ method: Callable[Concatenate[Any, P], R],
255
+ *,
256
+ propagate: Propagator = None,
257
+ ) -> EndpointProperty[P, R]: ...
258
+
259
+
260
+ @overload
261
+ def endpoint(
262
+ *,
263
+ propagate: Propagator = None,
264
+ ) -> EndpointIfy: ...
265
+
266
+
267
+ def endpoint(method=None, *, propagate=None):
268
+ if method is None:
269
+ return functools.partial(endpoint, propagate=propagate)
270
+ return EndpointProperty(method, propagator=propagate)
@@ -0,0 +1,97 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Module for managing the event loop used by Monarch Python actors.
9
+ This provides a way to create a Python-aware thread from Rust that runs the worker event loop.
10
+ """
11
+
12
+ import asyncio
13
+ import logging
14
+ import threading
15
+ from typing import Optional
16
+
17
+ from libfb.py.pyre import none_throws
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ _event_loop: Optional[asyncio.AbstractEventLoop] = None
22
+ _lock = threading.Lock()
23
+ _ready = threading.Event()
24
+
25
+
26
+ def _initialize_event_loop() -> None:
27
+ """
28
+ Internal function to initialize the event loop.
29
+ This creates a new thread with an event loop that runs forever.
30
+ """
31
+ global _event_loop, _ready
32
+ if _event_loop is not None:
33
+ return
34
+
35
+ # Create a new thread that will host our event loop
36
+ def event_loop_thread():
37
+ """Target function for the event loop thread."""
38
+ global _event_loop, _ready
39
+ try:
40
+ # Create a new event loop
41
+ loop = asyncio.new_event_loop()
42
+ asyncio.set_event_loop(loop)
43
+
44
+ _event_loop = loop
45
+ _ready.set()
46
+
47
+ logger.debug(
48
+ f"Python worker event loop thread started: {threading.current_thread().name}"
49
+ )
50
+ try:
51
+ # Run the event loop forever
52
+ loop.run_forever()
53
+ finally:
54
+ # Clean up when the loop stops
55
+ logger.debug("Python worker event loop stopped, closing...")
56
+ loop.close()
57
+ except Exception as e:
58
+ logger.error(f"Error in event loop thread: {e}")
59
+ _ready.set()
60
+ raise
61
+
62
+ # Create and start the thread
63
+ threading.Thread(
64
+ target=event_loop_thread,
65
+ name="asyncio-event-loop",
66
+ daemon=True, # Make it a daemon thread so it doesn't block program exit
67
+ ).start()
68
+
69
+ _ready.wait() # Wait for the event loop to be ready
70
+
71
+ if _event_loop is None:
72
+ raise RuntimeError("Failed to initialize event loop")
73
+
74
+
75
+ def get_event_loop() -> asyncio.AbstractEventLoop:
76
+ """
77
+ Get the Python worker event loop.
78
+ If no event loop is currently running, this will start a new one.
79
+
80
+ Expected to be called from rust code.
81
+ """
82
+ global _event_loop
83
+ if _event_loop is None:
84
+ with _lock:
85
+ _initialize_event_loop()
86
+ return none_throws(_event_loop)
87
+
88
+
89
+ def stop_event_loop() -> None:
90
+ """
91
+ Stop the event loop gracefully.
92
+ """
93
+ global _event_loop
94
+ if _event_loop is not None:
95
+ logger.debug("Stopping event loop...")
96
+ event_loop = none_throws(_event_loop)
97
+ event_loop.call_soon_threadsafe(event_loop.stop)
@@ -0,0 +1,100 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import traceback
9
+ from functools import partial
10
+ from typing import Generator, Generic, Optional, TypeVar
11
+
12
+ R = TypeVar("R")
13
+
14
+
15
+ async def _aincomplete(impl, self):
16
+ try:
17
+ return self._set_result(await impl())
18
+ except Exception as e:
19
+ self._set_exception(e)
20
+ raise
21
+
22
+
23
+ # Future is our generic mechanism for providing both a synchronous and asynchronous API for
24
+ # Monarch Future objects.
25
+
26
+ # We treat all code as running in one of two contexts: synchronous (asyncio._get_running_loop() is None)
27
+ # or asynchronous.
28
+
29
+ # Inside of asynchronous code, clients of our API must use `await` to wait for monarch Futures to prevent
30
+ # blocking the surrounding event loop.
31
+
32
+ # In synchronous code users must call get() because the call is comming from an non-async function so
33
+ # await is not allowed.
34
+
35
+ # [avoiding async code duplication]
36
+ # Because we allow for two modes, it is tempting as developers of Monarch to start to write two copies of
37
+ # of code for each mode. However, this results in a lot of confusing code duplication.
38
+ # To avoid this, we utilize the fact that synchronous code is allowed to start/complete an asyncio event loop
39
+ # via asyncio.run in order to complete the `get()` operation. So we can just write the async version and use
40
+ # it to implement the synchronoous version.
41
+
42
+ # However, starting and running an event loop is somewhat expensive. For simple messages, using an event loop
43
+ # is about 4x slower than just directly waiting on the tokio result. To avoid this slow down we perform an
44
+ # optimization. For any case where the `impl` coroutine of a future calls `await` only on PythonFuture
45
+ # (a Tokio future returning a Python value) objects, we pass requires_loop=False to the Future. In this mode,
46
+ # the future will just run the coroutine manually, and the PythonFuture object will recognize it is being awaited
47
+ # without an event loop (search [avoiding code duplication]) and simply do a blocking wait. By avoiding the event
48
+ # loop machinery, this gives it the same throughput as if we ran it synchronously.
49
+
50
+
51
+ class Future(Generic[R]):
52
+ def __init__(self, *, impl, requires_loop=True):
53
+ self._aget = partial(_aincomplete, impl)
54
+ self._requires_loop = requires_loop
55
+
56
+ def get(self, timeout: Optional[float] = None) -> R:
57
+ if asyncio._get_running_loop() is not None:
58
+ raise RuntimeError("get() cannot be called from within an async context")
59
+ if timeout is not None:
60
+ return asyncio.run(asyncio.wait_for(self._aget(self), timeout))
61
+ if not self._requires_loop:
62
+ try:
63
+ coro = self._aget(self)
64
+ next(coro.__await__())
65
+ tb_str = "".join(traceback.format_stack(coro.cr_frame))
66
+ raise RuntimeError(
67
+ f"a coroutine paused with a future with requires_loop=False cannot block on a python asyncio.Future. Use requires_loop=True.\n{tb_str}"
68
+ )
69
+ except StopIteration as e:
70
+ return e.value
71
+ return asyncio.run(self._aget(self))
72
+
73
+ def __await__(self) -> Generator[R, None, R]:
74
+ return self._aget(self).__await__()
75
+
76
+ def _set_result(self, result):
77
+ async def af(self):
78
+ return result
79
+
80
+ self._aget = af
81
+ return result
82
+
83
+ def _set_exception(self, e):
84
+ async def af(self):
85
+ raise e
86
+
87
+ self._aget = af
88
+
89
+ # compatibility with old tensor engine Future objects
90
+ # hopefully we do not need done(), add_callback because
91
+ # they are harder to implement right.
92
+ def result(self, timeout: Optional[float] = None) -> R:
93
+ return self.get(timeout)
94
+
95
+ def exception(self, timeout: Optional[float] = None):
96
+ try:
97
+ self.get(timeout)
98
+ return None
99
+ except Exception as e:
100
+ return e
@@ -4,6 +4,7 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
7
8
  import bdb
8
9
  import inspect
9
10
  import io
@@ -15,9 +16,10 @@ from dataclasses import dataclass
15
16
  from typing import Dict, TYPE_CHECKING
16
17
 
17
18
  from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
19
+ from monarch._src.actor.sync_state import fake_sync_state
18
20
 
19
21
  if TYPE_CHECKING:
20
- from monarch.debugger import DebugClient
22
+ from monarch._src.actor.debugger import DebugClient
21
23
 
22
24
 
23
25
  @dataclass
@@ -45,35 +47,38 @@ class PdbWrapper(pdb.Pdb):
45
47
  super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self))
46
48
  self._first = True
47
49
 
48
- def setup(self, *args, **kwargs):
49
- r = super().setup(*args, **kwargs)
50
- if self._first:
51
- self._first = False
52
- # when we enter the debugger, we want to present the user's stack frame
53
- # not the nested one inside session.run. This means that the local
54
- # variables are what gets printed, etc. To do this
55
- # we first execute up 2 to get to that frame.
56
- self.do_up(2)
57
- return r
58
-
59
- def set_continue(self) -> None:
60
- r = super().set_continue()
61
- if not self.breaks:
62
- # no more breakpoints so this debugger will not
63
- # be used again, and we detach from the controller io.
64
- self.client_ref.debugger_session_end.call_one(self.rank).get()
65
- # break cycle with itself before we exit
66
- self.stdin = sys.stdin
67
- self.stdout = sys.stdout
68
- return r
69
-
70
- def set_trace(self):
71
- self.client_ref.debugger_session_start.call_one(
50
+ def set_trace(self, frame):
51
+ self.client_ref.debugger_session_start.broadcast(
72
52
  self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id
73
- ).get()
53
+ )
74
54
  if self.header:
75
55
  self.message(self.header)
76
- super().set_trace()
56
+ super().set_trace(frame)
57
+
58
+ def do_clear(self, arg):
59
+ if not arg:
60
+ # Sending `clear` without any argument specified will
61
+ # request confirmation from the user using the `input` function,
62
+ # which bypasses our ReadWrapper and causes a hang on the client.
63
+ # To avoid this, we just clear all breakpoints instead without
64
+ # confirmation.
65
+ super().clear_all_breaks()
66
+ else:
67
+ super().do_clear(arg)
68
+
69
+ def end_debug_session(self):
70
+ self.client_ref.debugger_session_end.broadcast(self.rank)
71
+ # Once the debug client actor is notified of the session being over,
72
+ # we need to prevent any additional requests being sent for the session
73
+ # by redirecting stdin and stdout.
74
+ self.stdin = sys.stdin
75
+ self.stdout = sys.stdout
76
+
77
+ def post_mortem(self, exc_tb):
78
+ self._first = False
79
+ # See builtin implementation of pdb.post_mortem() for reference.
80
+ self.reset()
81
+ self.interaction(None, exc_tb)
77
82
 
78
83
 
79
84
  class ReadWrapper(io.RawIOBase):
@@ -81,16 +86,19 @@ class ReadWrapper(io.RawIOBase):
81
86
  self.session = session
82
87
 
83
88
  def readinto(self, b):
84
- response = self.session.client_ref.debugger_read.call_one(
85
- self.session.rank, len(b)
86
- ).get()
87
- if response == "detach":
88
- # this gets injected by the worker event loop to
89
- # get the worker thread to exit on an Exit command.
90
- raise bdb.BdbQuit
91
- assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(b)
92
- b[: len(response.payload)] = response.payload
93
- return len(response.payload)
89
+ with fake_sync_state():
90
+ response = self.session.client_ref.debugger_read.call_one(
91
+ self.session.rank, len(b)
92
+ ).get()
93
+ if response == "detach":
94
+ # this gets injected by the worker event loop to
95
+ # get the worker thread to exit on an Exit command.
96
+ raise bdb.BdbQuit
97
+ assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(
98
+ b
99
+ )
100
+ b[: len(response.payload)] = response.payload
101
+ return len(response.payload)
94
102
 
95
103
  def readable(self) -> bool:
96
104
  return True
@@ -115,21 +123,14 @@ class WriteWrapper:
115
123
  function = f"{inspect.getmodulename(self.session.curframe.f_code.co_filename)}.{self.session.curframe.f_code.co_name}"
116
124
  # pyre-ignore
117
125
  lineno = self.session.curframe.f_lineno
118
- self.session.client_ref.debugger_write.call_one(
126
+ self.session.client_ref.debugger_write.broadcast(
119
127
  self.session.rank,
120
128
  DebuggerWrite(
121
129
  s.encode(),
122
130
  function,
123
131
  lineno,
124
132
  ),
125
- ).get()
133
+ )
126
134
 
127
135
  def flush(self):
128
136
  pass
129
-
130
-
131
- def remote_breakpointhook(
132
- rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient"
133
- ):
134
- ds = PdbWrapper(rank, coords, actor_id, client_ref)
135
- ds.set_trace()
@@ -6,10 +6,16 @@
6
6
 
7
7
  import io
8
8
  import pickle
9
+ from contextlib import contextmanager, ExitStack
9
10
  from typing import Any, Callable, Iterable, List, Tuple
10
11
 
11
12
  import cloudpickle
12
13
 
14
+ try:
15
+ import torch # @manual
16
+ except ImportError:
17
+ torch = None
18
+
13
19
 
14
20
  class _Pickler(cloudpickle.Pickler):
15
21
  def __init__(self, filter):
@@ -44,5 +50,23 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]:
44
50
 
45
51
 
46
52
  def unflatten(data: bytes, values: Iterable[Any]) -> Any:
47
- up = _Unpickler(data, values)
48
- return up.load()
53
+ with ExitStack() as stack:
54
+ if torch is not None:
55
+ stack.enter_context(load_tensors_on_cpu())
56
+ stack.enter_context(torch.utils._python_dispatch._disable_current_modes())
57
+ up = _Unpickler(data, values)
58
+ return up.load()
59
+
60
+
61
+ @contextmanager
62
+ def load_tensors_on_cpu():
63
+ # Ensure that any tensors load from CPU via monkeypatching how Storages are
64
+ # loaded.
65
+ old = torch.storage._load_from_bytes
66
+ try:
67
+ torch.storage._load_from_bytes = lambda b: torch.load(
68
+ io.BytesIO(b), map_location="cpu", weights_only=False
69
+ )
70
+ yield
71
+ finally:
72
+ torch.storage._load_from_bytes = old