torchmonarch-nightly 2025.9.9__cp312-cp312-manylinux2014_x86_64.whl → 2025.9.11__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +7 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_src/actor/actor_mesh.py +1 -1
- monarch/_src/actor/bootstrap_main.py +7 -2
- monarch/_src/actor/debugger/breakpoint.py +30 -0
- monarch/_src/actor/debugger/debug_command.py +183 -0
- monarch/_src/actor/debugger/debug_controller.py +246 -0
- monarch/_src/actor/debugger/debug_io.py +68 -0
- monarch/_src/actor/debugger/debug_session.py +249 -0
- monarch/_src/actor/debugger/pdb_wrapper.py +1 -1
- monarch/_src/actor/host_mesh.py +10 -2
- monarch/_src/actor/pickle.py +4 -10
- monarch/_src/actor/proc_mesh.py +80 -19
- monarch/_src/tensor_engine/rdma.py +2 -0
- monarch/actor/__init__.py +1 -1
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/monarch_controller +0 -0
- monarch/tools/cli.py +26 -0
- monarch/tools/commands.py +15 -0
- monarch/tools/debug_env.py +34 -0
- monarch/tools/mesh_spec.py +2 -0
- tests/test_allocator.py +18 -9
- tests/test_debugger.py +29 -25
- tests/test_mock_cuda.py +11 -3
- torchmonarch_nightly-2025.9.11.data/scripts/process_allocator +0 -0
- {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/METADATA +1 -1
- {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/RECORD +31 -29
- monarch/_src/actor/debugger/debugger.py +0 -737
- monarch/_src/debug_cli/__init__.py +0 -7
- monarch/_src/debug_cli/debug_cli.py +0 -43
- monarch/debug_cli/__init__.py +0 -7
- monarch/debug_cli/__main__.py +0 -12
- {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/entry_points.txt +0 -0
- {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/top_level.txt +0 -0
monarch/__init__.py
CHANGED
@@ -9,6 +9,13 @@
|
|
9
9
|
from importlib import import_module as _import_module
|
10
10
|
from typing import TYPE_CHECKING
|
11
11
|
|
12
|
+
# Import before monarch to pre-load torch DSOs as, in exploded wheel flows,
|
13
|
+
# our RPATHs won't correctly find them.
|
14
|
+
try:
|
15
|
+
import torch # noqa: F401
|
16
|
+
except ImportError:
|
17
|
+
pass
|
18
|
+
|
12
19
|
# submodules of monarch should not be imported in this
|
13
20
|
# top-level file because it will cause them to get
|
14
21
|
# loaded even if they are not actually being used.
|
monarch/_rust_bindings.so
CHANGED
Binary file
|
monarch/_src/actor/actor_mesh.py
CHANGED
@@ -953,7 +953,7 @@ class _Actor:
|
|
953
953
|
DebugContext.set(DebugContext())
|
954
954
|
|
955
955
|
def _post_mortem_debug(self, exc_tb) -> None:
|
956
|
-
from monarch._src.actor.debugger.
|
956
|
+
from monarch._src.actor.debugger.debug_controller import debug_controller
|
957
957
|
|
958
958
|
if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
|
959
959
|
with fake_sync_state():
|
@@ -17,6 +17,12 @@ import multiprocessing
|
|
17
17
|
import os
|
18
18
|
import sys
|
19
19
|
|
20
|
+
# Import torch to avoid import-time races if a spawned actor tries to import torch.
|
21
|
+
try:
|
22
|
+
import torch # @manual # noqa: F401
|
23
|
+
except ImportError:
|
24
|
+
pass
|
25
|
+
|
20
26
|
|
21
27
|
async def main():
|
22
28
|
from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main
|
@@ -32,7 +38,6 @@ def invoke_main():
|
|
32
38
|
global bootstrap_main
|
33
39
|
|
34
40
|
# TODO: figure out what from worker_main.py we should reproduce here.
|
35
|
-
|
36
41
|
from monarch._src.actor.telemetry import TracingForwarder # noqa
|
37
42
|
|
38
43
|
if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
|
@@ -56,7 +61,7 @@ def invoke_main():
|
|
56
61
|
except Exception as e:
|
57
62
|
logging.warning(f"Failed to set up py-spy: {e}")
|
58
63
|
|
59
|
-
from monarch._src.actor.debugger.
|
64
|
+
from monarch._src.actor.debugger.breakpoint import remote_breakpointhook
|
60
65
|
|
61
66
|
sys.breakpointhook = remote_breakpointhook
|
62
67
|
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import inspect
|
9
|
+
|
10
|
+
from monarch._src.actor.actor_mesh import context, DebugContext
|
11
|
+
from monarch._src.actor.debugger.debug_controller import debug_controller
|
12
|
+
from monarch._src.actor.debugger.pdb_wrapper import PdbWrapper
|
13
|
+
|
14
|
+
|
15
|
+
def remote_breakpointhook() -> None:
|
16
|
+
frame = inspect.currentframe()
|
17
|
+
assert frame is not None
|
18
|
+
frame = frame.f_back
|
19
|
+
assert frame is not None
|
20
|
+
|
21
|
+
ctx = context()
|
22
|
+
rank = ctx.message_rank
|
23
|
+
pdb_wrapper = PdbWrapper(
|
24
|
+
rank.rank,
|
25
|
+
{k: rank[k] for k in rank},
|
26
|
+
ctx.actor_instance.actor_id,
|
27
|
+
debug_controller(),
|
28
|
+
)
|
29
|
+
DebugContext.set(DebugContext(pdb_wrapper))
|
30
|
+
pdb_wrapper.set_trace(frame)
|
@@ -0,0 +1,183 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import sys
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from typing import cast, Dict, List, Tuple, Union
|
11
|
+
|
12
|
+
from monarch._src.actor.debugger.debug_io import DebugIO
|
13
|
+
|
14
|
+
RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]]
|
15
|
+
|
16
|
+
_debug_input_parser = None
|
17
|
+
|
18
|
+
|
19
|
+
# Wrap the parser in a function so that jobs don't have to import lark
|
20
|
+
# unless they want to use the debugger.
|
21
|
+
def _get_debug_input_parser():
|
22
|
+
global _debug_input_parser
|
23
|
+
if _debug_input_parser is None:
|
24
|
+
from lark import Lark
|
25
|
+
|
26
|
+
_debug_input_parser = Lark(
|
27
|
+
"""
|
28
|
+
rank_list: INT "," INT ("," INT)*
|
29
|
+
start: INT?
|
30
|
+
stop: INT?
|
31
|
+
step: INT?
|
32
|
+
rank_range: start ":" stop (":" step)?
|
33
|
+
dim: CNAME "=" (rank_range | "(" rank_list ")" | INT)
|
34
|
+
dims: dim ("," dim)*
|
35
|
+
ranks: "ranks(" (dims | rank_range | rank_list | INT) ")"
|
36
|
+
pdb_command: /\\w+.*/
|
37
|
+
actor_name: /[-_a-zA-Z0-9]+/
|
38
|
+
cast: "cast" _WS actor_name ranks pdb_command
|
39
|
+
help: "h" | "help"
|
40
|
+
attach: ("a" | "attach") _WS actor_name INT
|
41
|
+
cont: "c" | "continue"
|
42
|
+
quit: "q" | "quit"
|
43
|
+
list: "l" | "list"
|
44
|
+
command: attach | list | cast | help | cont | quit
|
45
|
+
|
46
|
+
_WS: WS+
|
47
|
+
|
48
|
+
%import common.INT
|
49
|
+
%import common.CNAME
|
50
|
+
%import common.WS
|
51
|
+
%ignore WS
|
52
|
+
""",
|
53
|
+
start="command",
|
54
|
+
)
|
55
|
+
return _debug_input_parser
|
56
|
+
|
57
|
+
|
58
|
+
_debug_input_transformer = None
|
59
|
+
|
60
|
+
|
61
|
+
# Wrap the transformer in a function so that jobs don't have to import lark
|
62
|
+
# unless they want to use the debugger.
|
63
|
+
def _get_debug_input_transformer():
|
64
|
+
global _debug_input_transformer
|
65
|
+
if _debug_input_transformer is None:
|
66
|
+
from lark import Transformer
|
67
|
+
from lark.lexer import Token
|
68
|
+
|
69
|
+
class _IntoDebugCommandTransformer(Transformer):
|
70
|
+
def rank_list(self, items: List[Token]) -> List[int]:
|
71
|
+
return [int(item.value) for item in items]
|
72
|
+
|
73
|
+
def start(self, items: List[Token]) -> int:
|
74
|
+
if len(items) == 0:
|
75
|
+
return 0
|
76
|
+
return int(items[0].value)
|
77
|
+
|
78
|
+
def stop(self, items: List[Token]) -> int:
|
79
|
+
if len(items) == 0:
|
80
|
+
return sys.maxsize
|
81
|
+
return int(items[0].value)
|
82
|
+
|
83
|
+
def step(self, items: List[Token]) -> int:
|
84
|
+
if len(items) == 0:
|
85
|
+
return 1
|
86
|
+
return int(items[0].value)
|
87
|
+
|
88
|
+
def rank_range(self, items: List[int]) -> range:
|
89
|
+
return range(*items)
|
90
|
+
|
91
|
+
def dim(
|
92
|
+
self, items: Tuple[Token, Union[range, List[int], Token]]
|
93
|
+
) -> Tuple[str, Union[range, List[int], int]]:
|
94
|
+
if isinstance(items[1], range):
|
95
|
+
return (items[0].value, cast(range, items[1]))
|
96
|
+
elif isinstance(items[1], list):
|
97
|
+
return (items[0].value, cast(List[int], items[1]))
|
98
|
+
else:
|
99
|
+
return (items[0].value, int(cast(Token, items[1]).value))
|
100
|
+
|
101
|
+
def dims(
|
102
|
+
self, items: List[Tuple[str, Union[range, List[int], int]]]
|
103
|
+
) -> Dict[str, Union[range, List[int], int]]:
|
104
|
+
return {dim[0]: dim[1] for dim in items}
|
105
|
+
|
106
|
+
def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType:
|
107
|
+
if isinstance(items[0], Token):
|
108
|
+
return int(cast(Token, items[0]).value)
|
109
|
+
return cast(RanksType, items[0])
|
110
|
+
|
111
|
+
def pdb_command(self, items: List[Token]) -> str:
|
112
|
+
return items[0].value
|
113
|
+
|
114
|
+
def actor_name(self, items: List[Token]) -> str:
|
115
|
+
return items[0].value
|
116
|
+
|
117
|
+
def help(self, _items: List[Token]) -> "Help":
|
118
|
+
return Help()
|
119
|
+
|
120
|
+
def attach(self, items: Tuple[str, Token]) -> "Attach":
|
121
|
+
return Attach(items[0], int(items[1].value))
|
122
|
+
|
123
|
+
def cont(self, _items: List[Token]) -> "Continue":
|
124
|
+
return Continue()
|
125
|
+
|
126
|
+
def quit(self, _items: List[Token]) -> "Quit":
|
127
|
+
return Quit()
|
128
|
+
|
129
|
+
def cast(self, items: Tuple[str, RanksType, str]) -> "Cast":
|
130
|
+
return Cast(*items)
|
131
|
+
|
132
|
+
def list(self, items: List[Token]) -> "ListCommand":
|
133
|
+
return ListCommand()
|
134
|
+
|
135
|
+
def command(self, items: List["DebugCommand"]) -> "DebugCommand":
|
136
|
+
return items[0]
|
137
|
+
|
138
|
+
_debug_input_transformer = _IntoDebugCommandTransformer()
|
139
|
+
return _debug_input_transformer
|
140
|
+
|
141
|
+
|
142
|
+
class DebugCommand:
|
143
|
+
@staticmethod
|
144
|
+
async def parse(debug_io: DebugIO, line: str) -> Union["DebugCommand", None]:
|
145
|
+
try:
|
146
|
+
tree = _get_debug_input_parser().parse(line)
|
147
|
+
return _get_debug_input_transformer().transform(tree)
|
148
|
+
except Exception as e:
|
149
|
+
await debug_io.output(f"Error parsing input: {e}\n")
|
150
|
+
return None
|
151
|
+
|
152
|
+
|
153
|
+
@dataclass
|
154
|
+
class Attach(DebugCommand):
|
155
|
+
actor_name: str
|
156
|
+
rank: int
|
157
|
+
|
158
|
+
|
159
|
+
@dataclass
|
160
|
+
class ListCommand(DebugCommand):
|
161
|
+
pass
|
162
|
+
|
163
|
+
|
164
|
+
@dataclass
|
165
|
+
class Quit(DebugCommand):
|
166
|
+
pass
|
167
|
+
|
168
|
+
|
169
|
+
@dataclass
|
170
|
+
class Help(DebugCommand):
|
171
|
+
pass
|
172
|
+
|
173
|
+
|
174
|
+
@dataclass
|
175
|
+
class Continue(DebugCommand):
|
176
|
+
pass
|
177
|
+
|
178
|
+
|
179
|
+
@dataclass
|
180
|
+
class Cast(DebugCommand):
|
181
|
+
actor_name: str
|
182
|
+
ranks: RanksType
|
183
|
+
command: str
|
@@ -0,0 +1,246 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import asyncio
|
9
|
+
import functools
|
10
|
+
from typing import Dict, List, Optional, Tuple
|
11
|
+
|
12
|
+
from monarch._src.actor.actor_mesh import Actor
|
13
|
+
from monarch._src.actor.debugger.debug_command import (
|
14
|
+
Attach,
|
15
|
+
Cast,
|
16
|
+
Continue,
|
17
|
+
DebugCommand,
|
18
|
+
Help,
|
19
|
+
ListCommand,
|
20
|
+
Quit,
|
21
|
+
RanksType,
|
22
|
+
)
|
23
|
+
from monarch._src.actor.debugger.debug_io import (
|
24
|
+
DebugCliIO,
|
25
|
+
DebugIO,
|
26
|
+
DebugIOError,
|
27
|
+
DebugStdIO,
|
28
|
+
)
|
29
|
+
from monarch._src.actor.debugger.debug_session import (
|
30
|
+
DebugSession,
|
31
|
+
DebugSessionInfo,
|
32
|
+
DebugSessions,
|
33
|
+
)
|
34
|
+
from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite
|
35
|
+
from monarch._src.actor.endpoint import endpoint
|
36
|
+
from monarch._src.actor.proc_mesh import get_or_spawn_controller
|
37
|
+
from monarch._src.actor.sync_state import fake_sync_state
|
38
|
+
from monarch.tools.debug_env import (
|
39
|
+
_get_debug_server_host,
|
40
|
+
_get_debug_server_port,
|
41
|
+
_get_debug_server_protocol,
|
42
|
+
)
|
43
|
+
from pyre_extensions import none_throws
|
44
|
+
from tabulate import tabulate
|
45
|
+
|
46
|
+
|
47
|
+
class DebugController(Actor):
|
48
|
+
"""
|
49
|
+
Single actor for both remote debuggers and users to talk to.
|
50
|
+
|
51
|
+
Handles multiple sessions simultanesouly
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(self) -> None:
|
55
|
+
self.sessions = DebugSessions()
|
56
|
+
self._task_lock = asyncio.Lock()
|
57
|
+
self._task: asyncio.Task | None = None
|
58
|
+
self._debug_io: DebugIO = DebugStdIO()
|
59
|
+
self._server = asyncio.Future()
|
60
|
+
self._server_task = asyncio.create_task(self._serve())
|
61
|
+
|
62
|
+
async def _serve(self) -> None:
|
63
|
+
try:
|
64
|
+
if (proto := _get_debug_server_protocol()) != "tcp":
|
65
|
+
raise NotImplementedError(
|
66
|
+
f"Network protocol {proto} not yet supported."
|
67
|
+
)
|
68
|
+
server = await asyncio.start_server(
|
69
|
+
self._handle_client,
|
70
|
+
_get_debug_server_host(),
|
71
|
+
_get_debug_server_port(),
|
72
|
+
)
|
73
|
+
async with server:
|
74
|
+
self._server.set_result(server)
|
75
|
+
await server.serve_forever()
|
76
|
+
except Exception as e:
|
77
|
+
if self._server.done():
|
78
|
+
self._server = asyncio.Future()
|
79
|
+
self._server.set_exception(e)
|
80
|
+
raise
|
81
|
+
|
82
|
+
async def _handle_client(
|
83
|
+
self,
|
84
|
+
reader: asyncio.StreamReader,
|
85
|
+
writer: asyncio.StreamWriter,
|
86
|
+
) -> None:
|
87
|
+
# Make sure only one external debug process can
|
88
|
+
# be attached at a time. If a new request is
|
89
|
+
# received, the current task is cancelled.
|
90
|
+
async with self._task_lock:
|
91
|
+
if self._task is not None:
|
92
|
+
self._task.cancel()
|
93
|
+
try:
|
94
|
+
await none_throws(self._task)
|
95
|
+
except (DebugIOError, asyncio.CancelledError):
|
96
|
+
pass
|
97
|
+
self._debug_io = DebugCliIO(reader, writer)
|
98
|
+
self._task = asyncio.create_task(self._enter())
|
99
|
+
|
100
|
+
@endpoint
|
101
|
+
async def wait_pending_session(self):
|
102
|
+
while len(self.sessions) == 0:
|
103
|
+
await asyncio.sleep(1)
|
104
|
+
|
105
|
+
@endpoint
|
106
|
+
async def list(self, print_output=True) -> List[DebugSessionInfo]:
|
107
|
+
session_info = sorted(self.sessions.info())
|
108
|
+
if print_output:
|
109
|
+
await self._debug_io.output(
|
110
|
+
tabulate(
|
111
|
+
(
|
112
|
+
(
|
113
|
+
info.actor_name,
|
114
|
+
info.rank,
|
115
|
+
info.coords,
|
116
|
+
info.hostname,
|
117
|
+
info.function,
|
118
|
+
info.lineno,
|
119
|
+
)
|
120
|
+
for info in session_info
|
121
|
+
),
|
122
|
+
headers=[
|
123
|
+
"Actor Name",
|
124
|
+
"Rank",
|
125
|
+
"Coords",
|
126
|
+
"Hostname",
|
127
|
+
"Function",
|
128
|
+
"Line No.",
|
129
|
+
],
|
130
|
+
tablefmt="grid",
|
131
|
+
)
|
132
|
+
+ "\n"
|
133
|
+
)
|
134
|
+
return session_info
|
135
|
+
|
136
|
+
async def _enter(self) -> None:
|
137
|
+
await asyncio.sleep(0.5)
|
138
|
+
await self._debug_io.output(
|
139
|
+
"\n\n************************ MONARCH DEBUGGER ************************\n"
|
140
|
+
)
|
141
|
+
await self._debug_io.output("Enter 'help' for a list of commands.\n")
|
142
|
+
await self._debug_io.output("Enter 'list' to show all active breakpoints.\n\n")
|
143
|
+
|
144
|
+
while True:
|
145
|
+
try:
|
146
|
+
user_input = await self._debug_io.input("monarch_dbg> ")
|
147
|
+
if not user_input.strip():
|
148
|
+
continue
|
149
|
+
command = await DebugCommand.parse(self._debug_io, user_input)
|
150
|
+
if isinstance(command, Help):
|
151
|
+
await self._debug_io.output("monarch_dbg commands:\n")
|
152
|
+
await self._debug_io.output(
|
153
|
+
"\tattach <actor_name> <rank> - attach to a debug session\n"
|
154
|
+
)
|
155
|
+
await self._debug_io.output("\tlist - list all debug sessions\n")
|
156
|
+
await self._debug_io.output(
|
157
|
+
"\tquit - exit the debugger, leaving all sessions in place\n"
|
158
|
+
)
|
159
|
+
await self._debug_io.output(
|
160
|
+
"\tcast <actor_name> ranks(...) <command> - send a command to a set of ranks on the specified actor mesh.\n"
|
161
|
+
"\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n"
|
162
|
+
"\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n"
|
163
|
+
"\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6))).\n"
|
164
|
+
)
|
165
|
+
await self._debug_io.output(
|
166
|
+
"\tcontinue - clear all breakpoints and tell all ranks to continue\n"
|
167
|
+
)
|
168
|
+
await self._debug_io.output("\thelp - print this help message\n")
|
169
|
+
elif isinstance(command, Attach):
|
170
|
+
await self.sessions.get(command.actor_name, command.rank).attach(
|
171
|
+
self._debug_io
|
172
|
+
)
|
173
|
+
elif isinstance(command, ListCommand):
|
174
|
+
# pyre-ignore
|
175
|
+
await self.list._method(self)
|
176
|
+
elif isinstance(command, Continue):
|
177
|
+
await self._cast_input_and_wait("clear")
|
178
|
+
await self._cast_input_and_wait("c")
|
179
|
+
elif isinstance(command, Quit):
|
180
|
+
await self._debug_io.quit()
|
181
|
+
return
|
182
|
+
elif isinstance(command, Cast):
|
183
|
+
await self._cast_input_and_wait(
|
184
|
+
command.command, (command.actor_name, command.ranks)
|
185
|
+
)
|
186
|
+
except (DebugIOError, asyncio.CancelledError):
|
187
|
+
raise
|
188
|
+
except Exception as e:
|
189
|
+
await self._debug_io.output(f"Error processing command: {e}\n")
|
190
|
+
|
191
|
+
async def _cast_input_and_wait(
|
192
|
+
self,
|
193
|
+
command: str,
|
194
|
+
selection: Optional[Tuple[str, Optional[RanksType]]] = None,
|
195
|
+
) -> None:
|
196
|
+
tasks = []
|
197
|
+
for session in self.sessions.iter(selection):
|
198
|
+
tasks.append(session.attach(self._debug_io, command, suppress_output=True))
|
199
|
+
await asyncio.gather(*tasks)
|
200
|
+
|
201
|
+
##########################################################################
|
202
|
+
# Debugger APIs
|
203
|
+
#
|
204
|
+
# These endpoints are called by the remote debuggers to establish sessions
|
205
|
+
# and communicate with them.
|
206
|
+
@endpoint
|
207
|
+
async def debugger_session_start(
|
208
|
+
self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str
|
209
|
+
) -> None:
|
210
|
+
# Good enough for now to ensure that if the server for processing
|
211
|
+
# user interactions never starts, then the rank being debugged will
|
212
|
+
# fail instead of hanging indefinitely with no way to send it commands.
|
213
|
+
# Of course this isn't sufficient to handle the case where the server
|
214
|
+
# fails after the rank's debug session has successfully started.
|
215
|
+
# TODO: implement a heartbeat to prevent pdb sessions from hanging.
|
216
|
+
await self._server
|
217
|
+
# Create a session if it doesn't exist
|
218
|
+
if (actor_name, rank) not in self.sessions:
|
219
|
+
self.sessions.insert(DebugSession(rank, coords, hostname, actor_name))
|
220
|
+
|
221
|
+
@endpoint
|
222
|
+
async def debugger_session_end(self, actor_name: str, rank: int) -> None:
|
223
|
+
"""Detach from the current debug session."""
|
224
|
+
await self.sessions.remove(actor_name, rank).detach()
|
225
|
+
|
226
|
+
@endpoint
|
227
|
+
async def debugger_read(
|
228
|
+
self, actor_name: str, rank: int, size: int
|
229
|
+
) -> DebuggerWrite | str:
|
230
|
+
"""Read from the debug session for the given rank."""
|
231
|
+
return await self.sessions.get(actor_name, rank).debugger_read(size)
|
232
|
+
|
233
|
+
@endpoint
|
234
|
+
async def debugger_write(
|
235
|
+
self, actor_name: str, rank: int, write: DebuggerWrite
|
236
|
+
) -> None:
|
237
|
+
"""Write to the debug session for the given rank."""
|
238
|
+
await self.sessions.get(actor_name, rank).debugger_write(write)
|
239
|
+
|
240
|
+
|
241
|
+
# Cached so that we don't have to call out to the root client every time,
|
242
|
+
# which may be on a different host.
|
243
|
+
@functools.cache
|
244
|
+
def debug_controller() -> DebugController:
|
245
|
+
with fake_sync_state():
|
246
|
+
return get_or_spawn_controller("debug_controller", DebugController).get()
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import asyncio
|
9
|
+
import sys
|
10
|
+
from abc import abstractmethod
|
11
|
+
|
12
|
+
|
13
|
+
class DebugIO:
|
14
|
+
@abstractmethod
|
15
|
+
async def input(self, prompt: str = "") -> str: ...
|
16
|
+
|
17
|
+
@abstractmethod
|
18
|
+
async def output(self, msg: str) -> None: ...
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
async def quit(self) -> None: ...
|
22
|
+
|
23
|
+
|
24
|
+
class DebugStdIO(DebugIO):
|
25
|
+
async def input(self, prompt: str = "") -> str:
|
26
|
+
return await asyncio.to_thread(input, prompt)
|
27
|
+
|
28
|
+
async def output(self, msg: str) -> None:
|
29
|
+
sys.stdout.write(msg)
|
30
|
+
sys.stdout.flush()
|
31
|
+
|
32
|
+
async def quit(self) -> None:
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
class DebugIOError(RuntimeError):
|
37
|
+
def __init__(self):
|
38
|
+
super().__init__("Error encountered during debugger I/O operation.")
|
39
|
+
|
40
|
+
|
41
|
+
class DebugCliIO(DebugIO):
|
42
|
+
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
43
|
+
self._reader = reader
|
44
|
+
self._writer = writer
|
45
|
+
|
46
|
+
async def input(self, prompt: str = "") -> str:
|
47
|
+
try:
|
48
|
+
await self.output(prompt)
|
49
|
+
msg = (await self._reader.readline()).decode()
|
50
|
+
# Incomplete read due to EOF
|
51
|
+
if not msg.endswith("\n"):
|
52
|
+
raise RuntimeError("Unexpected end of input.")
|
53
|
+
# Strip the newline to be consistent with the behavior of input()
|
54
|
+
return msg.strip("\n")
|
55
|
+
except Exception as e:
|
56
|
+
raise DebugIOError() from e
|
57
|
+
|
58
|
+
async def output(self, msg: str) -> None:
|
59
|
+
try:
|
60
|
+
self._writer.write(msg.encode())
|
61
|
+
await self._writer.drain()
|
62
|
+
except Exception as e:
|
63
|
+
raise DebugIOError() from e
|
64
|
+
|
65
|
+
async def quit(self) -> None:
|
66
|
+
await self.output("Quitting debug session...\n")
|
67
|
+
self._writer.close()
|
68
|
+
await self._writer.wait_closed()
|