torchmonarch-nightly 2025.6.30__cp313-cp313-manylinux2014_x86_64.whl → 2025.7.25__cp313-cp313-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +13 -9
- monarch/_rust_bindings.so +0 -0
- monarch/{_monarch/selection → _src/actor}/__init__.py +3 -7
- monarch/_src/actor/actor_mesh.py +874 -0
- monarch/{allocator.py → _src/actor/allocator.py} +26 -17
- monarch/_src/actor/bootstrap_main.py +73 -0
- monarch/{code_sync.py → _src/actor/code_sync/__init__.py} +3 -1
- monarch/_src/actor/code_sync/auto_reload.py +223 -0
- monarch/_src/actor/debugger.py +565 -0
- monarch/_src/actor/endpoint.py +270 -0
- monarch/_src/actor/event_loop.py +97 -0
- monarch/_src/actor/future.py +100 -0
- monarch/{pdb_wrapper.py → _src/actor/pdb_wrapper.py} +47 -46
- monarch/{common/pickle_flatten.py → _src/actor/pickle.py} +26 -2
- monarch/_src/actor/proc_mesh.py +500 -0
- monarch/_src/actor/sync_state.py +18 -0
- monarch/{telemetry.py → _src/actor/telemetry/__init__.py} +1 -1
- monarch/_src/actor/telemetry/rust_span_tracing.py +159 -0
- monarch/_src/actor/tensor_engine_shim.py +56 -0
- monarch/_src/tensor_engine/rdma.py +180 -0
- monarch/_testing.py +3 -2
- monarch/actor/__init__.py +51 -0
- monarch/actor_mesh.py +6 -752
- monarch/bootstrap_main.py +8 -47
- monarch/common/client.py +1 -1
- monarch/common/controller_api.py +2 -1
- monarch/common/device_mesh.py +12 -2
- monarch/common/messages.py +12 -1
- monarch/common/recording.py +4 -3
- monarch/common/remote.py +135 -52
- monarch/common/tensor.py +2 -1
- monarch/controller/backend.py +2 -2
- monarch/controller/controller.py +2 -1
- monarch/controller/rust_backend/controller.py +2 -1
- monarch/fetch.py +3 -5
- monarch/mesh_controller.py +201 -139
- monarch/monarch_controller +0 -0
- monarch/opaque_module.py +4 -6
- monarch/opaque_object.py +3 -3
- monarch/proc_mesh.py +6 -309
- monarch/python_local_mesh.py +1 -1
- monarch/rust_backend_mesh.py +2 -1
- monarch/rust_local_mesh.py +4 -2
- monarch/sim_mesh.py +10 -19
- monarch/simulator/command_history.py +1 -1
- monarch/simulator/interface.py +2 -1
- monarch/simulator/mock_controller.py +1 -1
- monarch/simulator/simulator.py +1 -1
- monarch/tensor_engine/__init__.py +23 -0
- monarch/tensor_worker_main.py +3 -1
- monarch/tools/cli.py +3 -1
- monarch/tools/commands.py +95 -35
- monarch/tools/mesh_spec.py +55 -0
- monarch/tools/utils.py +38 -0
- monarch/worker/worker.py +1 -1
- monarch/world_mesh.py +2 -1
- monarch_supervisor/python_executable.py +6 -3
- tests/error_test_binary.py +75 -9
- tests/test_actor_error.py +370 -21
- tests/test_alloc.py +1 -1
- tests/test_allocator.py +373 -17
- tests/test_controller.py +2 -0
- tests/test_debugger.py +416 -0
- tests/test_env_before_cuda.py +162 -0
- tests/test_python_actors.py +184 -332
- tests/test_rdma.py +198 -0
- tests/test_remote_functions.py +40 -12
- tests/test_rust_backend.py +7 -5
- tests/test_sim_backend.py +1 -4
- tests/test_tensor_engine.py +55 -1
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/METADATA +6 -1
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/RECORD +80 -68
- torchmonarch_nightly-2025.7.25.dist-info/entry_points.txt +3 -0
- monarch/_monarch/hyperactor/__init__.py +0 -58
- monarch/_monarch/worker/debugger.py +0 -117
- monarch/_monarch/worker/logging.py +0 -107
- monarch/debugger.py +0 -379
- monarch/future.py +0 -76
- monarch/rdma.py +0 -162
- torchmonarch_nightly-2025.6.30.dist-info/entry_points.txt +0 -3
- /monarch/{_monarch/worker → _src}/__init__.py +0 -0
- /monarch/{common/_device_utils.py → _src/actor/device_utils.py} +0 -0
- /monarch/{common → _src/actor}/shape.py +0 -0
- /monarch/{_monarch → _src/tensor_engine}/__init__.py +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.6.30.dist-info → torchmonarch_nightly-2025.7.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,565 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import asyncio
|
9
|
+
import functools
|
10
|
+
import inspect
|
11
|
+
import logging
|
12
|
+
import os
|
13
|
+
import sys
|
14
|
+
from dataclasses import dataclass
|
15
|
+
from typing import cast, Dict, Generator, List, Tuple, Union
|
16
|
+
|
17
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
18
|
+
from monarch._src.actor.actor_mesh import (
|
19
|
+
_ActorMeshRefImpl,
|
20
|
+
Actor,
|
21
|
+
ActorMeshRef,
|
22
|
+
DebugContext,
|
23
|
+
MonarchContext,
|
24
|
+
)
|
25
|
+
from monarch._src.actor.endpoint import endpoint
|
26
|
+
from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper
|
27
|
+
from monarch._src.actor.sync_state import fake_sync_state
|
28
|
+
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
_DEBUG_MANAGER_ACTOR_NAME = "debug_manager"
|
33
|
+
|
34
|
+
|
35
|
+
async def _debugger_input(prompt=""):
|
36
|
+
return await asyncio.to_thread(input, prompt)
|
37
|
+
|
38
|
+
|
39
|
+
def _debugger_output(msg):
|
40
|
+
sys.stdout.write(msg)
|
41
|
+
sys.stdout.flush()
|
42
|
+
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class DebugSessionInfo:
|
46
|
+
rank: int
|
47
|
+
coords: Dict[str, int]
|
48
|
+
hostname: str
|
49
|
+
actor_id: ActorId
|
50
|
+
function: str | None
|
51
|
+
lineno: int | None
|
52
|
+
|
53
|
+
|
54
|
+
class DebugSession:
|
55
|
+
"""Represents a single session with a remote debugger."""
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
|
59
|
+
):
|
60
|
+
self.rank = rank
|
61
|
+
self.coords = coords
|
62
|
+
self.hostname = hostname
|
63
|
+
self.actor_id = actor_id
|
64
|
+
self._active = False
|
65
|
+
self._message_queue = asyncio.Queue()
|
66
|
+
self._task = None
|
67
|
+
self._pending_send_to_actor = asyncio.Queue()
|
68
|
+
self._outputs_since_last_input = []
|
69
|
+
self._function_lineno = None
|
70
|
+
self._need_read = False
|
71
|
+
|
72
|
+
async def _event_loop(self, line=None, suppress_output=False):
|
73
|
+
if not suppress_output:
|
74
|
+
# If the user had previously attached to this debug session,
|
75
|
+
# then it would have printed various messages from the
|
76
|
+
# message queue. When the user re-attaches, we want to
|
77
|
+
# print out all of the output that was printed since the
|
78
|
+
# last command sent to this session.
|
79
|
+
for output in self._outputs_since_last_input:
|
80
|
+
_debugger_output(output.payload.decode())
|
81
|
+
|
82
|
+
while True:
|
83
|
+
# When the user inputs "detach", it uses up a "read" message
|
84
|
+
# without actually responding to the actor being debugged. We
|
85
|
+
# can't manually reinsert the "read" message into the message queue,
|
86
|
+
# so instead the self._need_read flag indicates there's an additional
|
87
|
+
# "read" that we need to respond to.
|
88
|
+
if self._need_read:
|
89
|
+
self._need_read = False
|
90
|
+
message = "read"
|
91
|
+
else:
|
92
|
+
message = await self._message_queue.get()
|
93
|
+
if message == "detach":
|
94
|
+
# Return to the main outer debug loop.
|
95
|
+
break
|
96
|
+
elif message == "read":
|
97
|
+
break_after = False
|
98
|
+
if line is not None:
|
99
|
+
break_after = True
|
100
|
+
else:
|
101
|
+
line = await _debugger_input()
|
102
|
+
if line.strip("\n") == "detach":
|
103
|
+
self._need_read = True
|
104
|
+
break
|
105
|
+
else:
|
106
|
+
self._outputs_since_last_input = []
|
107
|
+
await self._pending_send_to_actor.put((line + "\n").encode())
|
108
|
+
line = None
|
109
|
+
if break_after:
|
110
|
+
break
|
111
|
+
elif message[0] == "write":
|
112
|
+
output = message[1]
|
113
|
+
# If the user sees this output but then detaches from the session,
|
114
|
+
# its useful to store all outputs since the last input so that
|
115
|
+
# they can be printed again when the user re-attaches.
|
116
|
+
self._outputs_since_last_input.append(output)
|
117
|
+
if not suppress_output:
|
118
|
+
_debugger_output(output.payload.decode())
|
119
|
+
|
120
|
+
if not suppress_output:
|
121
|
+
print(
|
122
|
+
f"Detaching from debug session for rank {self.rank} ({self.hostname})"
|
123
|
+
)
|
124
|
+
|
125
|
+
def get_info(self):
|
126
|
+
function = lineno = None
|
127
|
+
if self._function_lineno is not None:
|
128
|
+
function, lineno = self._function_lineno
|
129
|
+
return DebugSessionInfo(
|
130
|
+
self.rank, self.coords, self.hostname, self.actor_id, function, lineno
|
131
|
+
)
|
132
|
+
|
133
|
+
async def attach(self, line=None, suppress_output=False):
|
134
|
+
self._active = True
|
135
|
+
if not suppress_output:
|
136
|
+
print(f"Attached to debug session for rank {self.rank} ({self.hostname})")
|
137
|
+
self._task = asyncio.create_task(self._event_loop(line, suppress_output))
|
138
|
+
await self._task
|
139
|
+
if not suppress_output:
|
140
|
+
print(f"Detached from debug session for rank {self.rank} ({self.hostname})")
|
141
|
+
self._active = False
|
142
|
+
|
143
|
+
async def detach(self):
|
144
|
+
if self._active:
|
145
|
+
await self._message_queue.put("detach")
|
146
|
+
|
147
|
+
async def debugger_read(self, size: int) -> DebuggerWrite:
|
148
|
+
await self._message_queue.put("read")
|
149
|
+
input_data = await self._pending_send_to_actor.get()
|
150
|
+
if len(input_data) > size:
|
151
|
+
input_data = input_data[:size]
|
152
|
+
return DebuggerWrite(input_data, None, None)
|
153
|
+
|
154
|
+
async def debugger_write(self, write: DebuggerWrite) -> None:
|
155
|
+
if write.function is not None and write.lineno is not None:
|
156
|
+
self._function_lineno = (write.function, write.lineno)
|
157
|
+
await self._message_queue.put(("write", write))
|
158
|
+
|
159
|
+
|
160
|
+
RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]]
|
161
|
+
|
162
|
+
|
163
|
+
_debug_input_parser = None
|
164
|
+
|
165
|
+
|
166
|
+
# Wrap the parser in a function so that jobs don't have to import lark
|
167
|
+
# unless they want to use the debugger.
|
168
|
+
def _get_debug_input_parser():
|
169
|
+
global _debug_input_parser
|
170
|
+
if _debug_input_parser is None:
|
171
|
+
from lark import Lark
|
172
|
+
|
173
|
+
_debug_input_parser = Lark(
|
174
|
+
"""
|
175
|
+
rank_list: INT "," INT ("," INT)*
|
176
|
+
start: INT?
|
177
|
+
stop: INT?
|
178
|
+
step: INT?
|
179
|
+
rank_range: start ":" stop (":" step)?
|
180
|
+
dim: CNAME "=" (rank_range | "(" rank_list ")" | INT)
|
181
|
+
dims: dim ("," dim)*
|
182
|
+
ranks: "ranks(" (dims | rank_range | rank_list | INT) ")"
|
183
|
+
pdb_command: /\\w+.*/
|
184
|
+
cast: "cast" ranks pdb_command
|
185
|
+
help: "h" | "help"
|
186
|
+
attach: ("a" | "attach") INT
|
187
|
+
cont: "c" | "continue"
|
188
|
+
quit: "q" | "quit"
|
189
|
+
list: "l" | "list"
|
190
|
+
command: attach | list | cast | help | cont | quit
|
191
|
+
|
192
|
+
%import common.INT
|
193
|
+
%import common.CNAME
|
194
|
+
%import common.WS
|
195
|
+
%ignore WS
|
196
|
+
""",
|
197
|
+
start="command",
|
198
|
+
)
|
199
|
+
return _debug_input_parser
|
200
|
+
|
201
|
+
|
202
|
+
_debug_input_transformer = None
|
203
|
+
|
204
|
+
|
205
|
+
# Wrap the transformer in a function so that jobs don't have to import lark
|
206
|
+
# unless they want to use the debugger.
|
207
|
+
def _get_debug_input_transformer():
|
208
|
+
global _debug_input_transformer
|
209
|
+
if _debug_input_transformer is None:
|
210
|
+
from lark import Transformer
|
211
|
+
from lark.lexer import Token
|
212
|
+
|
213
|
+
class _IntoDebugCommandTransformer(Transformer):
|
214
|
+
def rank_list(self, items: List[Token]) -> List[int]:
|
215
|
+
return [int(item.value) for item in items]
|
216
|
+
|
217
|
+
def start(self, items: List[Token]) -> int:
|
218
|
+
if len(items) == 0:
|
219
|
+
return 0
|
220
|
+
return int(items[0].value)
|
221
|
+
|
222
|
+
def stop(self, items: List[Token]) -> int:
|
223
|
+
if len(items) == 0:
|
224
|
+
return sys.maxsize
|
225
|
+
return int(items[0].value)
|
226
|
+
|
227
|
+
def step(self, items: List[Token]) -> int:
|
228
|
+
if len(items) == 0:
|
229
|
+
return 1
|
230
|
+
return int(items[0].value)
|
231
|
+
|
232
|
+
def rank_range(self, items: List[int]) -> range:
|
233
|
+
return range(*items)
|
234
|
+
|
235
|
+
def dim(
|
236
|
+
self, items: Tuple[Token, Union[range, List[int], Token]]
|
237
|
+
) -> Tuple[str, Union[range, List[int], int]]:
|
238
|
+
if isinstance(items[1], range):
|
239
|
+
return (items[0].value, cast(range, items[1]))
|
240
|
+
elif isinstance(items[1], list):
|
241
|
+
return (items[0].value, cast(List[int], items[1]))
|
242
|
+
else:
|
243
|
+
return (items[0].value, int(cast(Token, items[1]).value))
|
244
|
+
|
245
|
+
def dims(
|
246
|
+
self, items: List[Tuple[str, Union[range, List[int], int]]]
|
247
|
+
) -> Dict[str, Union[range, List[int], int]]:
|
248
|
+
return {dim[0]: dim[1] for dim in items}
|
249
|
+
|
250
|
+
def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType:
|
251
|
+
if isinstance(items[0], Token):
|
252
|
+
return int(cast(Token, items[0]).value)
|
253
|
+
return cast(RanksType, items[0])
|
254
|
+
|
255
|
+
def pdb_command(self, items: List[Token]) -> str:
|
256
|
+
return items[0].value
|
257
|
+
|
258
|
+
def help(self, _items: List[Token]) -> "Help":
|
259
|
+
return Help()
|
260
|
+
|
261
|
+
def attach(self, items: List[Token]) -> "Attach":
|
262
|
+
return Attach(int(items[0].value))
|
263
|
+
|
264
|
+
def cont(self, _items: List[Token]) -> "Continue":
|
265
|
+
return Continue()
|
266
|
+
|
267
|
+
def quit(self, _items: List[Token]) -> "Quit":
|
268
|
+
return Quit()
|
269
|
+
|
270
|
+
def cast(self, items: Tuple[RanksType, str]) -> "Cast":
|
271
|
+
return Cast(items[0], items[1])
|
272
|
+
|
273
|
+
def list(self, items: List[Token]) -> "ListCommand":
|
274
|
+
return ListCommand()
|
275
|
+
|
276
|
+
def command(self, items: List["DebugCommand"]) -> "DebugCommand":
|
277
|
+
return items[0]
|
278
|
+
|
279
|
+
_debug_input_transformer = _IntoDebugCommandTransformer()
|
280
|
+
return _debug_input_transformer
|
281
|
+
|
282
|
+
|
283
|
+
class DebugCommand:
|
284
|
+
@staticmethod
|
285
|
+
def parse(line: str) -> Union["DebugCommand", None]:
|
286
|
+
try:
|
287
|
+
tree = _get_debug_input_parser().parse(line)
|
288
|
+
return _get_debug_input_transformer().transform(tree)
|
289
|
+
except Exception as e:
|
290
|
+
print(f"Error parsing input: {e}")
|
291
|
+
return None
|
292
|
+
|
293
|
+
|
294
|
+
@dataclass
|
295
|
+
class Attach(DebugCommand):
|
296
|
+
rank: int
|
297
|
+
|
298
|
+
|
299
|
+
@dataclass
|
300
|
+
class ListCommand(DebugCommand):
|
301
|
+
pass
|
302
|
+
|
303
|
+
|
304
|
+
@dataclass
|
305
|
+
class Quit(DebugCommand):
|
306
|
+
pass
|
307
|
+
|
308
|
+
|
309
|
+
@dataclass
|
310
|
+
class Help(DebugCommand):
|
311
|
+
pass
|
312
|
+
|
313
|
+
|
314
|
+
@dataclass
|
315
|
+
class Continue(DebugCommand):
|
316
|
+
pass
|
317
|
+
|
318
|
+
|
319
|
+
@dataclass
|
320
|
+
class Cast(DebugCommand):
|
321
|
+
ranks: RanksType
|
322
|
+
command: str
|
323
|
+
|
324
|
+
|
325
|
+
class DebugClient(Actor):
|
326
|
+
"""
|
327
|
+
Single actor for both remote debuggers and users to talk to.
|
328
|
+
|
329
|
+
Handles multiple sessions simultanesouly
|
330
|
+
"""
|
331
|
+
|
332
|
+
def __init__(self) -> None:
|
333
|
+
self.sessions = {} # rank -> DebugSession
|
334
|
+
|
335
|
+
@endpoint
|
336
|
+
async def wait_pending_session(self):
|
337
|
+
while len(self.sessions) == 0:
|
338
|
+
await asyncio.sleep(1)
|
339
|
+
|
340
|
+
@endpoint
|
341
|
+
async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]:
|
342
|
+
session_info = []
|
343
|
+
for _, session in self.sessions.items():
|
344
|
+
info = session.get_info()
|
345
|
+
session_info.append(
|
346
|
+
(
|
347
|
+
info.rank,
|
348
|
+
info.coords,
|
349
|
+
info.hostname,
|
350
|
+
info.actor_id,
|
351
|
+
info.function,
|
352
|
+
info.lineno,
|
353
|
+
)
|
354
|
+
)
|
355
|
+
table_info = sorted(session_info, key=lambda r: r[0])
|
356
|
+
|
357
|
+
from tabulate import tabulate
|
358
|
+
|
359
|
+
print(
|
360
|
+
tabulate(
|
361
|
+
table_info,
|
362
|
+
headers=[
|
363
|
+
"Rank",
|
364
|
+
"Coords",
|
365
|
+
"Hostname",
|
366
|
+
"Actor ID",
|
367
|
+
"Function",
|
368
|
+
"Line No.",
|
369
|
+
],
|
370
|
+
tablefmt="grid",
|
371
|
+
)
|
372
|
+
)
|
373
|
+
return table_info
|
374
|
+
|
375
|
+
@endpoint
|
376
|
+
async def enter(self) -> None:
|
377
|
+
await asyncio.sleep(0.5)
|
378
|
+
logger.info("Remote breakpoint hit. Entering monarch debugger...")
|
379
|
+
print("\n\n************************ MONARCH DEBUGGER ************************")
|
380
|
+
print("Enter 'help' for a list of commands.")
|
381
|
+
print("Enter 'list' to show all active breakpoints.\n")
|
382
|
+
|
383
|
+
while True:
|
384
|
+
try:
|
385
|
+
user_input = await _debugger_input("monarch_dbg> ")
|
386
|
+
command = DebugCommand.parse(user_input)
|
387
|
+
if isinstance(command, Help):
|
388
|
+
print("monarch_dbg commands:")
|
389
|
+
print("\tattach <rank> - attach to a debug session")
|
390
|
+
print("\tlist - list all debug sessions")
|
391
|
+
print("\tquit - exit the debugger, leaving all sessions in place")
|
392
|
+
print(
|
393
|
+
"\tcast ranks(...) <command> - send a command to a set of ranks.\n"
|
394
|
+
"\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n"
|
395
|
+
"\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n"
|
396
|
+
"\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))."
|
397
|
+
)
|
398
|
+
print(
|
399
|
+
"\tcontinue - tell all ranks to continue execution, then exit the debugger"
|
400
|
+
)
|
401
|
+
print("\thelp - print this help message")
|
402
|
+
elif isinstance(command, Attach):
|
403
|
+
if command.rank not in self.sessions:
|
404
|
+
print(f"No debug session for rank {command.rank}")
|
405
|
+
else:
|
406
|
+
await self.sessions[command.rank].attach()
|
407
|
+
elif isinstance(command, ListCommand):
|
408
|
+
# pyre-ignore
|
409
|
+
await self.list._method(self)
|
410
|
+
elif isinstance(command, Continue):
|
411
|
+
# Clear all breakpoints and make sure all ranks have
|
412
|
+
# exited their debug sessions. If we sent "quit", it
|
413
|
+
# would raise BdbQuit, crashing the process, which
|
414
|
+
# probably isn't what we want.
|
415
|
+
await self._cast_input_and_wait("clear")
|
416
|
+
while len(self.sessions) > 0:
|
417
|
+
await self._cast_input_and_wait("c")
|
418
|
+
return
|
419
|
+
elif isinstance(command, Quit):
|
420
|
+
return
|
421
|
+
elif isinstance(command, Cast):
|
422
|
+
await self._cast_input_and_wait(command.command, command.ranks)
|
423
|
+
except Exception as e:
|
424
|
+
print(f"Error processing command: {e}")
|
425
|
+
|
426
|
+
async def _cast_input_and_wait(
|
427
|
+
self,
|
428
|
+
command: str,
|
429
|
+
ranks: RanksType | None = None,
|
430
|
+
) -> None:
|
431
|
+
if ranks is None:
|
432
|
+
ranks = self.sessions.keys()
|
433
|
+
elif isinstance(ranks, dict):
|
434
|
+
ranks = self._iter_ranks_dict(ranks)
|
435
|
+
elif isinstance(ranks, range):
|
436
|
+
ranks = self._iter_ranks_range(ranks)
|
437
|
+
elif isinstance(ranks, int):
|
438
|
+
ranks = [ranks]
|
439
|
+
tasks = []
|
440
|
+
for rank in ranks:
|
441
|
+
if rank in self.sessions:
|
442
|
+
tasks.append(
|
443
|
+
self.sessions[rank].attach(
|
444
|
+
command,
|
445
|
+
suppress_output=True,
|
446
|
+
)
|
447
|
+
)
|
448
|
+
else:
|
449
|
+
print(f"No debug session for rank {rank}")
|
450
|
+
await asyncio.gather(*tasks)
|
451
|
+
|
452
|
+
def _iter_ranks_dict(
|
453
|
+
self, dims: Dict[str, Union[range, List[int], int]]
|
454
|
+
) -> Generator[int, None, None]:
|
455
|
+
for rank, session in self.sessions.items():
|
456
|
+
include_rank = True
|
457
|
+
for dim, ranks in dims.items():
|
458
|
+
if dim not in session.coords:
|
459
|
+
include_rank = False
|
460
|
+
break
|
461
|
+
elif (
|
462
|
+
isinstance(ranks, range) or isinstance(ranks, list)
|
463
|
+
) and session.coords[dim] not in ranks:
|
464
|
+
include_rank = False
|
465
|
+
break
|
466
|
+
elif isinstance(ranks, int) and session.coords[dim] != ranks:
|
467
|
+
include_rank = False
|
468
|
+
break
|
469
|
+
if include_rank:
|
470
|
+
yield rank
|
471
|
+
|
472
|
+
def _iter_ranks_range(self, rng: range) -> Generator[int, None, None]:
|
473
|
+
for rank in self.sessions.keys():
|
474
|
+
if rank in rng:
|
475
|
+
yield rank
|
476
|
+
|
477
|
+
##########################################################################
|
478
|
+
# Debugger APIs
|
479
|
+
#
|
480
|
+
# These endpoints are called by the remote debuggers to establish sessions
|
481
|
+
# and communicate with them.
|
482
|
+
@endpoint
|
483
|
+
async def debugger_session_start(
|
484
|
+
self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
|
485
|
+
) -> None:
|
486
|
+
# Create a session if it doesn't exist
|
487
|
+
if rank not in self.sessions:
|
488
|
+
self.sessions[rank] = DebugSession(rank, coords, hostname, actor_id)
|
489
|
+
|
490
|
+
@endpoint
|
491
|
+
async def debugger_session_end(self, rank: int) -> None:
|
492
|
+
"""Detach from the current debug session."""
|
493
|
+
session = self.sessions.pop(rank)
|
494
|
+
await session.detach()
|
495
|
+
|
496
|
+
@endpoint
|
497
|
+
async def debugger_read(self, rank: int, size: int) -> DebuggerWrite | str:
|
498
|
+
"""Read from the debug session for the given rank."""
|
499
|
+
session = self.sessions[rank]
|
500
|
+
|
501
|
+
return await session.debugger_read(size)
|
502
|
+
|
503
|
+
@endpoint
|
504
|
+
async def debugger_write(self, rank: int, write: DebuggerWrite) -> None:
|
505
|
+
"""Write to the debug session for the given rank."""
|
506
|
+
session = self.sessions[rank]
|
507
|
+
await session.debugger_write(write)
|
508
|
+
|
509
|
+
|
510
|
+
class DebugManager(Actor):
|
511
|
+
@staticmethod
|
512
|
+
@functools.cache
|
513
|
+
def ref() -> "DebugManager":
|
514
|
+
ctx = MonarchContext.get()
|
515
|
+
return cast(
|
516
|
+
DebugManager,
|
517
|
+
ActorMeshRef(
|
518
|
+
DebugManager,
|
519
|
+
_ActorMeshRefImpl.from_actor_id(
|
520
|
+
ctx.mailbox,
|
521
|
+
ActorId.from_string(
|
522
|
+
f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]"
|
523
|
+
),
|
524
|
+
),
|
525
|
+
ctx.mailbox,
|
526
|
+
),
|
527
|
+
)
|
528
|
+
|
529
|
+
def __init__(self, debug_client: DebugClient) -> None:
|
530
|
+
self._debug_client = debug_client
|
531
|
+
|
532
|
+
# pyre-ignore
|
533
|
+
@endpoint
|
534
|
+
def get_debug_client(self) -> DebugClient:
|
535
|
+
return self._debug_client
|
536
|
+
|
537
|
+
|
538
|
+
def remote_breakpointhook():
|
539
|
+
frame = inspect.currentframe()
|
540
|
+
assert frame is not None
|
541
|
+
frame = frame.f_back
|
542
|
+
assert frame is not None
|
543
|
+
file = frame.f_code.co_filename
|
544
|
+
line = frame.f_lineno
|
545
|
+
module = frame.f_globals.get("__name__", "__main__")
|
546
|
+
if module == "__main__" and not os.path.exists(file):
|
547
|
+
raise NotImplementedError(
|
548
|
+
f"Remote debugging not supported for breakpoint at {file}:{line} because "
|
549
|
+
f"it is defined inside __main__, and the file does not exist on the host. "
|
550
|
+
"In this case, cloudpickle serialization does not interact nicely with pdb. "
|
551
|
+
"To debug your code, move it out of __main__ and into a module that "
|
552
|
+
"exists on both your client and worker processes."
|
553
|
+
)
|
554
|
+
|
555
|
+
with fake_sync_state():
|
556
|
+
manager = DebugManager.ref().get_debug_client.call_one().get()
|
557
|
+
ctx = MonarchContext.get()
|
558
|
+
pdb_wrapper = PdbWrapper(
|
559
|
+
ctx.point.rank,
|
560
|
+
ctx.point.shape.coordinates(ctx.point.rank),
|
561
|
+
ctx.mailbox.actor_id,
|
562
|
+
manager,
|
563
|
+
)
|
564
|
+
DebugContext.set(DebugContext(pdb_wrapper))
|
565
|
+
pdb_wrapper.set_trace(frame)
|