torchmonarch-nightly 2025.9.9__cp312-cp312-manylinux2014_x86_64.whl → 2025.9.11__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. monarch/__init__.py +7 -0
  2. monarch/_rust_bindings.so +0 -0
  3. monarch/_src/actor/actor_mesh.py +1 -1
  4. monarch/_src/actor/bootstrap_main.py +7 -2
  5. monarch/_src/actor/debugger/breakpoint.py +30 -0
  6. monarch/_src/actor/debugger/debug_command.py +183 -0
  7. monarch/_src/actor/debugger/debug_controller.py +246 -0
  8. monarch/_src/actor/debugger/debug_io.py +68 -0
  9. monarch/_src/actor/debugger/debug_session.py +249 -0
  10. monarch/_src/actor/debugger/pdb_wrapper.py +1 -1
  11. monarch/_src/actor/host_mesh.py +10 -2
  12. monarch/_src/actor/pickle.py +4 -10
  13. monarch/_src/actor/proc_mesh.py +80 -19
  14. monarch/_src/tensor_engine/rdma.py +2 -0
  15. monarch/actor/__init__.py +1 -1
  16. monarch/gradient/_gradient_generator.so +0 -0
  17. monarch/monarch_controller +0 -0
  18. monarch/tools/cli.py +26 -0
  19. monarch/tools/commands.py +15 -0
  20. monarch/tools/debug_env.py +34 -0
  21. monarch/tools/mesh_spec.py +2 -0
  22. tests/test_allocator.py +18 -9
  23. tests/test_debugger.py +29 -25
  24. tests/test_mock_cuda.py +11 -3
  25. torchmonarch_nightly-2025.9.11.data/scripts/process_allocator +0 -0
  26. {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/METADATA +1 -1
  27. {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/RECORD +31 -29
  28. monarch/_src/actor/debugger/debugger.py +0 -737
  29. monarch/_src/debug_cli/__init__.py +0 -7
  30. monarch/_src/debug_cli/debug_cli.py +0 -43
  31. monarch/debug_cli/__init__.py +0 -7
  32. monarch/debug_cli/__main__.py +0 -12
  33. {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/WHEEL +0 -0
  34. {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/entry_points.txt +0 -0
  35. {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/licenses/LICENSE +0 -0
  36. {torchmonarch_nightly-2025.9.9.dist-info → torchmonarch_nightly-2025.9.11.dist-info}/top_level.txt +0 -0
@@ -1,737 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- # pyre-unsafe
8
- import asyncio
9
- import functools
10
- import inspect
11
- import logging
12
- import os
13
- import sys
14
- from abc import abstractmethod
15
- from dataclasses import dataclass
16
- from typing import cast, Dict, Generator, List, Optional, Tuple, Union
17
-
18
- from monarch._src.actor.actor_mesh import Actor, context, DebugContext
19
- from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite, PdbWrapper
20
- from monarch._src.actor.endpoint import endpoint
21
- from monarch._src.actor.proc_mesh import get_or_spawn_controller
22
- from monarch._src.actor.sync_state import fake_sync_state
23
- from pyre_extensions import none_throws
24
- from tabulate import tabulate
25
-
26
-
27
- logger = logging.getLogger(__name__)
28
-
29
- _MONARCH_DEBUG_SERVER_HOST_ENV_VAR = "MONARCH_DEBUG_SERVER_HOST"
30
- _MONARCH_DEBUG_SERVER_HOST_DEFAULT = "localhost"
31
- _MONARCH_DEBUG_SERVER_PORT_ENV_VAR = "MONARCH_DEBUG_SERVER_PORT"
32
- _MONARCH_DEBUG_SERVER_PORT_DEFAULT = "27000"
33
- _MONARCH_DEBUG_SERVER_PROTOCOL_ENV_VAR = "MONARCH_DEBUG_SERVER_PROTOCOL"
34
- _MONARCH_DEBUG_SERVER_PROTOCOL_DEFAULT = "tcp"
35
-
36
-
37
- def _get_debug_server_host():
38
- return os.environ.get(
39
- _MONARCH_DEBUG_SERVER_HOST_ENV_VAR, _MONARCH_DEBUG_SERVER_HOST_DEFAULT
40
- )
41
-
42
-
43
- def _get_debug_server_port():
44
- return os.environ.get(
45
- _MONARCH_DEBUG_SERVER_PORT_ENV_VAR, _MONARCH_DEBUG_SERVER_PORT_DEFAULT
46
- )
47
-
48
-
49
- def _get_debug_server_protocol():
50
- return os.environ.get(
51
- _MONARCH_DEBUG_SERVER_PROTOCOL_ENV_VAR, _MONARCH_DEBUG_SERVER_PROTOCOL_DEFAULT
52
- )
53
-
54
-
55
- class DebugIO:
56
- @abstractmethod
57
- async def input(self, prompt: str = "") -> str: ...
58
-
59
- @abstractmethod
60
- async def output(self, msg: str) -> None: ...
61
-
62
- @abstractmethod
63
- async def quit(self) -> None: ...
64
-
65
-
66
- class DebugStdIO(DebugIO):
67
- async def input(self, prompt: str = "") -> str:
68
- return await asyncio.to_thread(input, prompt)
69
-
70
- async def output(self, msg: str) -> None:
71
- sys.stdout.write(msg)
72
- sys.stdout.flush()
73
-
74
- async def quit(self) -> None:
75
- pass
76
-
77
-
78
- class DebugIOError(RuntimeError):
79
- def __init__(self):
80
- super().__init__("Error encountered during debugger I/O operation.")
81
-
82
-
83
- class DebugCliIO(DebugIO):
84
- def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
85
- self._reader = reader
86
- self._writer = writer
87
-
88
- async def input(self, prompt: str = "") -> str:
89
- try:
90
- await self.output(prompt)
91
- msg = (await self._reader.readline()).decode()
92
- # Incomplete read due to EOF
93
- if not msg.endswith("\n"):
94
- raise RuntimeError("Unexpected end of input.")
95
- # Strip the newline to be consistent with the behavior of input()
96
- return msg.strip("\n")
97
- except Exception as e:
98
- raise DebugIOError() from e
99
-
100
- async def output(self, msg: str) -> None:
101
- try:
102
- self._writer.write(msg.encode())
103
- await self._writer.drain()
104
- except Exception as e:
105
- raise DebugIOError() from e
106
-
107
- async def quit(self) -> None:
108
- await self.output("Quitting debug session...\n")
109
- self._writer.close()
110
- await self._writer.wait_closed()
111
-
112
-
113
- @dataclass
114
- class DebugSessionInfo:
115
- actor_name: str
116
- rank: int
117
- coords: Dict[str, int]
118
- hostname: str
119
- function: str | None
120
- lineno: int | None
121
-
122
- def __lt__(self, other):
123
- if self.actor_name < other.actor_name:
124
- return True
125
- elif self.actor_name == other.actor_name:
126
- return self.rank < other.rank
127
- else:
128
- return False
129
-
130
-
131
- class DebugSession:
132
- """Represents a single session with a remote debugger."""
133
-
134
- def __init__(
135
- self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str
136
- ):
137
- self.rank = rank
138
- self.coords = coords
139
- self.hostname = hostname
140
- self.actor_name = actor_name
141
- self._active = False
142
- self._message_queue = asyncio.Queue()
143
- self._task = None
144
- self._pending_send_to_actor = asyncio.Queue()
145
- self._outputs_since_last_input = []
146
- self._function_lineno = None
147
- self._need_read = False
148
-
149
- async def _event_loop(self, debug_io: DebugIO, line=None, suppress_output=False):
150
- if not suppress_output:
151
- # If the user had previously attached to this debug session,
152
- # then it would have printed various messages from the
153
- # message queue. When the user re-attaches, we want to
154
- # print out all of the output that was printed since the
155
- # last command sent to this session.
156
- if len(self._outputs_since_last_input) > 0:
157
- await debug_io.output(
158
- f"<last pdb output for {self.actor_name} {self.rank} follows>\n"
159
- )
160
- for output in self._outputs_since_last_input:
161
- await debug_io.output(output.payload.decode())
162
-
163
- while True:
164
- # When the user inputs "detach", it uses up a "read" message
165
- # without actually responding to the actor being debugged. We
166
- # can't manually reinsert the "read" message into the message queue,
167
- # so instead the self._need_read flag indicates there's an additional
168
- # "read" that we need to respond to.
169
- if self._need_read:
170
- self._need_read = False
171
- message = "read"
172
- else:
173
- message = await self._message_queue.get()
174
- if message == "detach":
175
- # Return to the main outer debug loop.
176
- break
177
- elif message == "read":
178
- try:
179
- break_after = False
180
- if line is not None:
181
- break_after = True
182
- else:
183
- line = await debug_io.input()
184
- if line == "detach":
185
- self._need_read = True
186
- break
187
- else:
188
- await self._pending_send_to_actor.put((line + "\n").encode())
189
- # Cancel safety: don't clear the previous outputs until we know
190
- # the actor will receive the input.
191
- self._outputs_since_last_input = []
192
- line = None
193
- if break_after:
194
- break
195
- except (DebugIOError, asyncio.CancelledError):
196
- # See earlier comment about this flag. If either of the awaits inside
197
- # the try block is cancelled, we need to redo the read without actually
198
- # reinserting "read" into the message queue.
199
- self._need_read = True
200
- raise
201
- elif message[0] == "write":
202
- output = message[1]
203
- # If the user sees this output but then detaches from the session,
204
- # its useful to store all outputs since the last input so that
205
- # they can be printed again when the user re-attaches.
206
- self._outputs_since_last_input.append(output)
207
- if not suppress_output:
208
- await debug_io.output(output.payload.decode())
209
-
210
- if not suppress_output:
211
- await debug_io.output(
212
- f"Detaching from debug session for {self.actor_name} {self.rank} ({self.hostname})\n"
213
- )
214
-
215
- def get_info(self):
216
- function = lineno = None
217
- if self._function_lineno is not None:
218
- function, lineno = self._function_lineno
219
- return DebugSessionInfo(
220
- self.actor_name, self.rank, self.coords, self.hostname, function, lineno
221
- )
222
-
223
- async def attach(self, debug_io: DebugIO, line=None, suppress_output=False):
224
- self._active = True
225
- if not suppress_output:
226
- await debug_io.output(
227
- f"Attached to debug session for {self.actor_name} {self.rank} ({self.hostname})\n"
228
- )
229
- self._task = asyncio.create_task(
230
- self._event_loop(debug_io, line, suppress_output)
231
- )
232
- await self._task
233
- if not suppress_output:
234
- await debug_io.output(
235
- f"Detached from debug session for {self.actor_name} {self.rank} ({self.hostname})\n"
236
- )
237
- self._active = False
238
-
239
- async def detach(self):
240
- if self._active:
241
- await self._message_queue.put("detach")
242
-
243
- async def debugger_read(self, size: int) -> DebuggerWrite:
244
- await self._message_queue.put("read")
245
- input_data = await self._pending_send_to_actor.get()
246
- if len(input_data) > size:
247
- input_data = input_data[:size]
248
- return DebuggerWrite(input_data, None, None)
249
-
250
- async def debugger_write(self, write: DebuggerWrite) -> None:
251
- if write.function is not None and write.lineno is not None:
252
- self._function_lineno = (write.function, write.lineno)
253
- await self._message_queue.put(("write", write))
254
-
255
-
256
- RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]]
257
-
258
-
259
- class DebugSessions:
260
- def __init__(self):
261
- self._sessions: Dict[str, Dict[int, DebugSession]] = {}
262
-
263
- def insert(self, session: DebugSession) -> None:
264
- if session.actor_name not in self._sessions:
265
- self._sessions[session.actor_name] = {session.rank: session}
266
- elif session.rank not in self._sessions[session.actor_name]:
267
- self._sessions[session.actor_name][session.rank] = session
268
- else:
269
- raise ValueError(
270
- f"Debug session for rank {session.rank} already exists for actor {session.actor_name}"
271
- )
272
-
273
- def remove(self, actor_name: str, rank: int) -> DebugSession:
274
- if actor_name not in self._sessions:
275
- raise ValueError(f"No debug sessions for actor {actor_name}")
276
- elif rank not in self._sessions[actor_name]:
277
- raise ValueError(f"No debug session for rank {rank} for actor {actor_name}")
278
- session = self._sessions[actor_name].pop(rank)
279
- if len(self._sessions[actor_name]) == 0:
280
- del self._sessions[actor_name]
281
- return session
282
-
283
- def get(self, actor_name: str, rank: int) -> DebugSession:
284
- if actor_name not in self._sessions:
285
- raise ValueError(f"No debug sessions for actor {actor_name}")
286
- elif rank not in self._sessions[actor_name]:
287
- raise ValueError(f"No debug session for rank {rank} for actor {actor_name}")
288
- return self._sessions[actor_name][rank]
289
-
290
- def iter(
291
- self, selection: Optional[Tuple[str, Optional[RanksType]]]
292
- ) -> Generator[DebugSession, None, None]:
293
- if selection is None:
294
- for sessions in self._sessions.values():
295
- for session in sessions.values():
296
- yield session
297
- return
298
- actor_name, ranks = selection
299
- if actor_name not in self._sessions:
300
- return
301
- sessions = self._sessions[actor_name]
302
- if ranks is None:
303
- for session in sessions.values():
304
- yield session
305
- elif isinstance(ranks, int):
306
- if ranks in sessions:
307
- yield sessions[ranks]
308
- elif isinstance(ranks, list):
309
- for rank in ranks:
310
- if rank in sessions:
311
- yield sessions[rank]
312
- elif isinstance(ranks, dict):
313
- dims = ranks
314
- for session in sessions.values():
315
- include_rank = True
316
- for dim, ranks in dims.items():
317
- if dim not in session.coords:
318
- include_rank = False
319
- break
320
- elif (
321
- isinstance(ranks, range) or isinstance(ranks, list)
322
- ) and session.coords[dim] not in ranks:
323
- include_rank = False
324
- break
325
- elif isinstance(ranks, int) and session.coords[dim] != ranks:
326
- include_rank = False
327
- break
328
- if include_rank:
329
- yield session
330
- elif isinstance(ranks, range):
331
- for rank, session in sessions.items():
332
- if rank in ranks:
333
- yield session
334
-
335
- def info(self) -> List[DebugSessionInfo]:
336
- session_info = []
337
- for sessions in self._sessions.values():
338
- for session in sessions.values():
339
- session_info.append(session.get_info())
340
- return session_info
341
-
342
- def __len__(self) -> int:
343
- return sum(len(sessions) for sessions in self._sessions.values())
344
-
345
- def __contains__(self, item: Tuple[str, int]) -> bool:
346
- actor_name, rank = item
347
- return actor_name in self._sessions and rank in self._sessions[actor_name]
348
-
349
-
350
- _debug_input_parser = None
351
-
352
-
353
- # Wrap the parser in a function so that jobs don't have to import lark
354
- # unless they want to use the debugger.
355
- def _get_debug_input_parser():
356
- global _debug_input_parser
357
- if _debug_input_parser is None:
358
- from lark import Lark
359
-
360
- _debug_input_parser = Lark(
361
- """
362
- rank_list: INT "," INT ("," INT)*
363
- start: INT?
364
- stop: INT?
365
- step: INT?
366
- rank_range: start ":" stop (":" step)?
367
- dim: CNAME "=" (rank_range | "(" rank_list ")" | INT)
368
- dims: dim ("," dim)*
369
- ranks: "ranks(" (dims | rank_range | rank_list | INT) ")"
370
- pdb_command: /\\w+.*/
371
- actor_name: /[-_a-zA-Z0-9]+/
372
- cast: "cast" _WS actor_name ranks pdb_command
373
- help: "h" | "help"
374
- attach: ("a" | "attach") _WS actor_name INT
375
- cont: "c" | "continue"
376
- quit: "q" | "quit"
377
- list: "l" | "list"
378
- command: attach | list | cast | help | cont | quit
379
-
380
- _WS: WS+
381
-
382
- %import common.INT
383
- %import common.CNAME
384
- %import common.WS
385
- %ignore WS
386
- """,
387
- start="command",
388
- )
389
- return _debug_input_parser
390
-
391
-
392
- _debug_input_transformer = None
393
-
394
-
395
- # Wrap the transformer in a function so that jobs don't have to import lark
396
- # unless they want to use the debugger.
397
- def _get_debug_input_transformer():
398
- global _debug_input_transformer
399
- if _debug_input_transformer is None:
400
- from lark import Transformer
401
- from lark.lexer import Token
402
-
403
- class _IntoDebugCommandTransformer(Transformer):
404
- def rank_list(self, items: List[Token]) -> List[int]:
405
- return [int(item.value) for item in items]
406
-
407
- def start(self, items: List[Token]) -> int:
408
- if len(items) == 0:
409
- return 0
410
- return int(items[0].value)
411
-
412
- def stop(self, items: List[Token]) -> int:
413
- if len(items) == 0:
414
- return sys.maxsize
415
- return int(items[0].value)
416
-
417
- def step(self, items: List[Token]) -> int:
418
- if len(items) == 0:
419
- return 1
420
- return int(items[0].value)
421
-
422
- def rank_range(self, items: List[int]) -> range:
423
- return range(*items)
424
-
425
- def dim(
426
- self, items: Tuple[Token, Union[range, List[int], Token]]
427
- ) -> Tuple[str, Union[range, List[int], int]]:
428
- if isinstance(items[1], range):
429
- return (items[0].value, cast(range, items[1]))
430
- elif isinstance(items[1], list):
431
- return (items[0].value, cast(List[int], items[1]))
432
- else:
433
- return (items[0].value, int(cast(Token, items[1]).value))
434
-
435
- def dims(
436
- self, items: List[Tuple[str, Union[range, List[int], int]]]
437
- ) -> Dict[str, Union[range, List[int], int]]:
438
- return {dim[0]: dim[1] for dim in items}
439
-
440
- def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType:
441
- if isinstance(items[0], Token):
442
- return int(cast(Token, items[0]).value)
443
- return cast(RanksType, items[0])
444
-
445
- def pdb_command(self, items: List[Token]) -> str:
446
- return items[0].value
447
-
448
- def actor_name(self, items: List[Token]) -> str:
449
- return items[0].value
450
-
451
- def help(self, _items: List[Token]) -> "Help":
452
- return Help()
453
-
454
- def attach(self, items: Tuple[str, Token]) -> "Attach":
455
- return Attach(items[0], int(items[1].value))
456
-
457
- def cont(self, _items: List[Token]) -> "Continue":
458
- return Continue()
459
-
460
- def quit(self, _items: List[Token]) -> "Quit":
461
- return Quit()
462
-
463
- def cast(self, items: Tuple[str, RanksType, str]) -> "Cast":
464
- return Cast(*items)
465
-
466
- def list(self, items: List[Token]) -> "ListCommand":
467
- return ListCommand()
468
-
469
- def command(self, items: List["DebugCommand"]) -> "DebugCommand":
470
- return items[0]
471
-
472
- _debug_input_transformer = _IntoDebugCommandTransformer()
473
- return _debug_input_transformer
474
-
475
-
476
- class DebugCommand:
477
- @staticmethod
478
- async def parse(debug_io: DebugIO, line: str) -> Union["DebugCommand", None]:
479
- try:
480
- tree = _get_debug_input_parser().parse(line)
481
- return _get_debug_input_transformer().transform(tree)
482
- except Exception as e:
483
- await debug_io.output(f"Error parsing input: {e}\n")
484
- return None
485
-
486
-
487
- @dataclass
488
- class Attach(DebugCommand):
489
- actor_name: str
490
- rank: int
491
-
492
-
493
- @dataclass
494
- class ListCommand(DebugCommand):
495
- pass
496
-
497
-
498
- @dataclass
499
- class Quit(DebugCommand):
500
- pass
501
-
502
-
503
- @dataclass
504
- class Help(DebugCommand):
505
- pass
506
-
507
-
508
- @dataclass
509
- class Continue(DebugCommand):
510
- pass
511
-
512
-
513
- @dataclass
514
- class Cast(DebugCommand):
515
- actor_name: str
516
- ranks: RanksType
517
- command: str
518
-
519
-
520
- class DebugController(Actor):
521
- """
522
- Single actor for both remote debuggers and users to talk to.
523
-
524
- Handles multiple sessions simultanesouly
525
- """
526
-
527
- def __init__(self) -> None:
528
- self.sessions = DebugSessions()
529
- self._task_lock = asyncio.Lock()
530
- self._task: asyncio.Task | None = None
531
- self._debug_io: DebugIO = DebugStdIO()
532
- self._server = asyncio.Future()
533
- self._server_task = asyncio.create_task(self._serve())
534
-
535
- async def _serve(self) -> None:
536
- try:
537
- if (proto := _get_debug_server_protocol()) != "tcp":
538
- raise NotImplementedError(
539
- f"Network protocol {proto} not yet supported."
540
- )
541
- server = await asyncio.start_server(
542
- self._handle_client,
543
- _get_debug_server_host(),
544
- _get_debug_server_port(),
545
- )
546
- async with server:
547
- self._server.set_result(server)
548
- await server.serve_forever()
549
- except Exception as e:
550
- if self._server.done():
551
- self._server = asyncio.Future()
552
- self._server.set_exception(e)
553
- raise
554
-
555
- async def _handle_client(
556
- self,
557
- reader: asyncio.StreamReader,
558
- writer: asyncio.StreamWriter,
559
- ) -> None:
560
- # Make sure only one external debug process can
561
- # be attached at a time. If a new request is
562
- # received, the current task is cancelled.
563
- async with self._task_lock:
564
- if self._task is not None:
565
- self._task.cancel()
566
- try:
567
- await none_throws(self._task)
568
- except (DebugIOError, asyncio.CancelledError):
569
- pass
570
- self._debug_io = DebugCliIO(reader, writer)
571
- self._task = asyncio.create_task(self._enter())
572
-
573
- @endpoint
574
- async def wait_pending_session(self):
575
- while len(self.sessions) == 0:
576
- await asyncio.sleep(1)
577
-
578
- @endpoint
579
- async def list(self, print_output=True) -> List[DebugSessionInfo]:
580
- session_info = sorted(self.sessions.info())
581
- if print_output:
582
- await self._debug_io.output(
583
- tabulate(
584
- (
585
- (
586
- info.actor_name,
587
- info.rank,
588
- info.coords,
589
- info.hostname,
590
- info.function,
591
- info.lineno,
592
- )
593
- for info in session_info
594
- ),
595
- headers=[
596
- "Actor Name",
597
- "Rank",
598
- "Coords",
599
- "Hostname",
600
- "Function",
601
- "Line No.",
602
- ],
603
- tablefmt="grid",
604
- )
605
- + "\n"
606
- )
607
- return session_info
608
-
609
- async def _enter(self) -> None:
610
- await asyncio.sleep(0.5)
611
- await self._debug_io.output(
612
- "\n\n************************ MONARCH DEBUGGER ************************\n"
613
- )
614
- await self._debug_io.output("Enter 'help' for a list of commands.\n")
615
- await self._debug_io.output("Enter 'list' to show all active breakpoints.\n\n")
616
-
617
- while True:
618
- try:
619
- user_input = await self._debug_io.input("monarch_dbg> ")
620
- if not user_input.strip():
621
- continue
622
- command = await DebugCommand.parse(self._debug_io, user_input)
623
- if isinstance(command, Help):
624
- await self._debug_io.output("monarch_dbg commands:\n")
625
- await self._debug_io.output(
626
- "\tattach <actor_name> <rank> - attach to a debug session\n"
627
- )
628
- await self._debug_io.output("\tlist - list all debug sessions\n")
629
- await self._debug_io.output(
630
- "\tquit - exit the debugger, leaving all sessions in place\n"
631
- )
632
- await self._debug_io.output(
633
- "\tcast <actor_name> ranks(...) <command> - send a command to a set of ranks on the specified actor mesh.\n"
634
- "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n"
635
- "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n"
636
- "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6))).\n"
637
- )
638
- await self._debug_io.output(
639
- "\tcontinue - clear all breakpoints and tell all ranks to continue\n"
640
- )
641
- await self._debug_io.output("\thelp - print this help message\n")
642
- elif isinstance(command, Attach):
643
- await self.sessions.get(command.actor_name, command.rank).attach(
644
- self._debug_io
645
- )
646
- elif isinstance(command, ListCommand):
647
- # pyre-ignore
648
- await self.list._method(self)
649
- elif isinstance(command, Continue):
650
- await self._cast_input_and_wait("clear")
651
- await self._cast_input_and_wait("c")
652
- elif isinstance(command, Quit):
653
- await self._debug_io.quit()
654
- return
655
- elif isinstance(command, Cast):
656
- await self._cast_input_and_wait(
657
- command.command, (command.actor_name, command.ranks)
658
- )
659
- except (DebugIOError, asyncio.CancelledError):
660
- raise
661
- except Exception as e:
662
- await self._debug_io.output(f"Error processing command: {e}\n")
663
-
664
- async def _cast_input_and_wait(
665
- self,
666
- command: str,
667
- selection: Optional[Tuple[str, Optional[RanksType]]] = None,
668
- ) -> None:
669
- tasks = []
670
- for session in self.sessions.iter(selection):
671
- tasks.append(session.attach(self._debug_io, command, suppress_output=True))
672
- await asyncio.gather(*tasks)
673
-
674
- ##########################################################################
675
- # Debugger APIs
676
- #
677
- # These endpoints are called by the remote debuggers to establish sessions
678
- # and communicate with them.
679
- @endpoint
680
- async def debugger_session_start(
681
- self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str
682
- ) -> None:
683
- # Good enough for now to ensure that if the server for processing
684
- # user interactions never starts, then the rank being debugged will
685
- # fail instead of hanging indefinitely with no way to send it commands.
686
- # Of course this isn't sufficient to handle the case where the server
687
- # fails after the rank's debug session has successfully started.
688
- # TODO: implement a heartbeat to prevent pdb sessions from hanging.
689
- await self._server
690
- # Create a session if it doesn't exist
691
- if (actor_name, rank) not in self.sessions:
692
- self.sessions.insert(DebugSession(rank, coords, hostname, actor_name))
693
-
694
- @endpoint
695
- async def debugger_session_end(self, actor_name: str, rank: int) -> None:
696
- """Detach from the current debug session."""
697
- await self.sessions.remove(actor_name, rank).detach()
698
-
699
- @endpoint
700
- async def debugger_read(
701
- self, actor_name: str, rank: int, size: int
702
- ) -> DebuggerWrite | str:
703
- """Read from the debug session for the given rank."""
704
- return await self.sessions.get(actor_name, rank).debugger_read(size)
705
-
706
- @endpoint
707
- async def debugger_write(
708
- self, actor_name: str, rank: int, write: DebuggerWrite
709
- ) -> None:
710
- """Write to the debug session for the given rank."""
711
- await self.sessions.get(actor_name, rank).debugger_write(write)
712
-
713
-
714
- # Cached so that we don't have to call out to the root client every time,
715
- # which may be on a different host.
716
- @functools.cache
717
- def debug_controller() -> DebugController:
718
- with fake_sync_state():
719
- return get_or_spawn_controller("debug_controller", DebugController).get()
720
-
721
-
722
- def remote_breakpointhook() -> None:
723
- frame = inspect.currentframe()
724
- assert frame is not None
725
- frame = frame.f_back
726
- assert frame is not None
727
-
728
- ctx = context()
729
- rank = ctx.message_rank
730
- pdb_wrapper = PdbWrapper(
731
- rank.rank,
732
- {k: rank[k] for k in rank},
733
- ctx.actor_instance.actor_id,
734
- debug_controller(),
735
- )
736
- DebugContext.set(DebugContext(pdb_wrapper))
737
- pdb_wrapper.set_trace(frame)