torchmonarch-nightly 2025.8.2__cp310-cp310-manylinux2014_x86_64.whl → 2025.9.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. monarch/_rust_bindings.so +0 -0
  2. monarch/_src/actor/actor_mesh.py +504 -218
  3. monarch/_src/actor/allocator.py +75 -6
  4. monarch/_src/actor/bootstrap_main.py +7 -4
  5. monarch/_src/actor/code_sync/__init__.py +2 -0
  6. monarch/_src/actor/debugger/__init__.py +7 -0
  7. monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
  8. monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
  9. monarch/_src/actor/endpoint.py +27 -45
  10. monarch/_src/actor/future.py +86 -24
  11. monarch/_src/actor/host_mesh.py +125 -0
  12. monarch/_src/actor/logging.py +94 -0
  13. monarch/_src/actor/pickle.py +25 -0
  14. monarch/_src/actor/proc_mesh.py +423 -156
  15. monarch/_src/actor/python_extension_methods.py +90 -0
  16. monarch/_src/actor/shape.py +8 -1
  17. monarch/_src/actor/source_loader.py +45 -0
  18. monarch/_src/actor/telemetry/__init__.py +172 -0
  19. monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
  20. monarch/_src/debug_cli/__init__.py +7 -0
  21. monarch/_src/debug_cli/debug_cli.py +43 -0
  22. monarch/_src/tensor_engine/rdma.py +64 -9
  23. monarch/_testing.py +1 -3
  24. monarch/actor/__init__.py +24 -4
  25. monarch/common/_C.so +0 -0
  26. monarch/common/device_mesh.py +14 -0
  27. monarch/common/future.py +10 -0
  28. monarch/common/remote.py +14 -25
  29. monarch/common/tensor.py +12 -0
  30. monarch/debug_cli/__init__.py +7 -0
  31. monarch/debug_cli/__main__.py +12 -0
  32. monarch/fetch.py +2 -2
  33. monarch/gradient/_gradient_generator.so +0 -0
  34. monarch/gradient_generator.py +4 -2
  35. monarch/mesh_controller.py +34 -14
  36. monarch/monarch_controller +0 -0
  37. monarch/tools/colors.py +25 -0
  38. monarch/tools/commands.py +42 -7
  39. monarch/tools/components/hyperactor.py +6 -4
  40. monarch/tools/config/__init__.py +35 -12
  41. monarch/tools/config/defaults.py +15 -5
  42. monarch/tools/config/environment.py +45 -0
  43. monarch/tools/config/workspace.py +165 -0
  44. monarch/tools/mesh_spec.py +3 -3
  45. monarch/utils/__init__.py +9 -0
  46. monarch/utils/utils.py +78 -0
  47. tests/error_test_binary.py +5 -3
  48. tests/python_actor_test_binary.py +52 -0
  49. tests/test_actor_error.py +142 -14
  50. tests/test_alloc.py +1 -1
  51. tests/test_allocator.py +59 -72
  52. tests/test_debugger.py +639 -45
  53. tests/test_env_before_cuda.py +4 -4
  54. tests/test_mesh_trait.py +38 -0
  55. tests/test_python_actors.py +965 -75
  56. tests/test_rdma.py +7 -6
  57. tests/test_tensor_engine.py +6 -6
  58. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/METADATA +82 -4
  59. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/RECORD +63 -47
  60. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/WHEEL +0 -0
  61. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/entry_points.txt +0 -0
  62. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/licenses/LICENSE +0 -0
  63. {torchmonarch_nightly-2025.8.2.dist-info → torchmonarch_nightly-2025.9.4.dist-info}/top_level.txt +0 -0
@@ -11,35 +11,103 @@ import inspect
11
11
  import logging
12
12
  import os
13
13
  import sys
14
+ from abc import abstractmethod
14
15
  from dataclasses import dataclass
15
16
  from typing import cast, Dict, Generator, List, Optional, Tuple, Union
16
17
 
17
- from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
18
- from monarch._src.actor.actor_mesh import (
19
- _ActorMeshRefImpl,
20
- Actor,
21
- ActorMeshRef,
22
- DebugContext,
23
- MonarchContext,
24
- )
18
+ from monarch._src.actor.actor_mesh import Actor, context, DebugContext
19
+ from monarch._src.actor.debugger.pdb_wrapper import DebuggerWrite, PdbWrapper
25
20
  from monarch._src.actor.endpoint import endpoint
26
- from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper
21
+ from monarch._src.actor.proc_mesh import get_or_spawn_controller
27
22
  from monarch._src.actor.sync_state import fake_sync_state
23
+ from pyre_extensions import none_throws
28
24
  from tabulate import tabulate
29
25
 
30
26
 
31
27
  logger = logging.getLogger(__name__)
32
28
 
33
- _DEBUG_MANAGER_ACTOR_NAME = "debug_manager"
29
+ _MONARCH_DEBUG_SERVER_HOST_ENV_VAR = "MONARCH_DEBUG_SERVER_HOST"
30
+ _MONARCH_DEBUG_SERVER_HOST_DEFAULT = "localhost"
31
+ _MONARCH_DEBUG_SERVER_PORT_ENV_VAR = "MONARCH_DEBUG_SERVER_PORT"
32
+ _MONARCH_DEBUG_SERVER_PORT_DEFAULT = "27000"
33
+ _MONARCH_DEBUG_SERVER_PROTOCOL_ENV_VAR = "MONARCH_DEBUG_SERVER_PROTOCOL"
34
+ _MONARCH_DEBUG_SERVER_PROTOCOL_DEFAULT = "tcp"
34
35
 
35
36
 
36
- async def _debugger_input(prompt=""):
37
- return await asyncio.to_thread(input, prompt)
37
+ def _get_debug_server_host():
38
+ return os.environ.get(
39
+ _MONARCH_DEBUG_SERVER_HOST_ENV_VAR, _MONARCH_DEBUG_SERVER_HOST_DEFAULT
40
+ )
41
+
42
+
43
+ def _get_debug_server_port():
44
+ return os.environ.get(
45
+ _MONARCH_DEBUG_SERVER_PORT_ENV_VAR, _MONARCH_DEBUG_SERVER_PORT_DEFAULT
46
+ )
47
+
48
+
49
+ def _get_debug_server_protocol():
50
+ return os.environ.get(
51
+ _MONARCH_DEBUG_SERVER_PROTOCOL_ENV_VAR, _MONARCH_DEBUG_SERVER_PROTOCOL_DEFAULT
52
+ )
53
+
38
54
 
55
+ class DebugIO:
56
+ @abstractmethod
57
+ async def input(self, prompt: str = "") -> str: ...
39
58
 
40
- def _debugger_output(msg):
41
- sys.stdout.write(msg)
42
- sys.stdout.flush()
59
+ @abstractmethod
60
+ async def output(self, msg: str) -> None: ...
61
+
62
+ @abstractmethod
63
+ async def quit(self) -> None: ...
64
+
65
+
66
+ class DebugStdIO(DebugIO):
67
+ async def input(self, prompt: str = "") -> str:
68
+ return await asyncio.to_thread(input, prompt)
69
+
70
+ async def output(self, msg: str) -> None:
71
+ sys.stdout.write(msg)
72
+ sys.stdout.flush()
73
+
74
+ async def quit(self) -> None:
75
+ pass
76
+
77
+
78
+ class DebugIOError(RuntimeError):
79
+ def __init__(self):
80
+ super().__init__("Error encountered during debugger I/O operation.")
81
+
82
+
83
+ class DebugCliIO(DebugIO):
84
+ def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
85
+ self._reader = reader
86
+ self._writer = writer
87
+
88
+ async def input(self, prompt: str = "") -> str:
89
+ try:
90
+ await self.output(prompt)
91
+ msg = (await self._reader.readline()).decode()
92
+ # Incomplete read due to EOF
93
+ if not msg.endswith("\n"):
94
+ raise RuntimeError("Unexpected end of input.")
95
+ # Strip the newline to be consistent with the behavior of input()
96
+ return msg.strip("\n")
97
+ except Exception as e:
98
+ raise DebugIOError() from e
99
+
100
+ async def output(self, msg: str) -> None:
101
+ try:
102
+ self._writer.write(msg.encode())
103
+ await self._writer.drain()
104
+ except Exception as e:
105
+ raise DebugIOError() from e
106
+
107
+ async def quit(self) -> None:
108
+ await self.output("Quitting debug session...\n")
109
+ self._writer.close()
110
+ await self._writer.wait_closed()
43
111
 
44
112
 
45
113
  @dataclass
@@ -78,15 +146,19 @@ class DebugSession:
78
146
  self._function_lineno = None
79
147
  self._need_read = False
80
148
 
81
- async def _event_loop(self, line=None, suppress_output=False):
149
+ async def _event_loop(self, debug_io: DebugIO, line=None, suppress_output=False):
82
150
  if not suppress_output:
83
151
  # If the user had previously attached to this debug session,
84
152
  # then it would have printed various messages from the
85
153
  # message queue. When the user re-attaches, we want to
86
154
  # print out all of the output that was printed since the
87
155
  # last command sent to this session.
156
+ if len(self._outputs_since_last_input) > 0:
157
+ await debug_io.output(
158
+ f"<last pdb output for {self.actor_name} {self.rank} follows>\n"
159
+ )
88
160
  for output in self._outputs_since_last_input:
89
- _debugger_output(output.payload.decode())
161
+ await debug_io.output(output.payload.decode())
90
162
 
91
163
  while True:
92
164
  # When the user inputs "detach", it uses up a "read" message
@@ -103,20 +175,29 @@ class DebugSession:
103
175
  # Return to the main outer debug loop.
104
176
  break
105
177
  elif message == "read":
106
- break_after = False
107
- if line is not None:
108
- break_after = True
109
- else:
110
- line = await _debugger_input()
111
- if line.strip("\n") == "detach":
112
- self._need_read = True
113
- break
114
- else:
115
- self._outputs_since_last_input = []
116
- await self._pending_send_to_actor.put((line + "\n").encode())
117
- line = None
118
- if break_after:
178
+ try:
179
+ break_after = False
180
+ if line is not None:
181
+ break_after = True
182
+ else:
183
+ line = await debug_io.input()
184
+ if line == "detach":
185
+ self._need_read = True
119
186
  break
187
+ else:
188
+ await self._pending_send_to_actor.put((line + "\n").encode())
189
+ # Cancel safety: don't clear the previous outputs until we know
190
+ # the actor will receive the input.
191
+ self._outputs_since_last_input = []
192
+ line = None
193
+ if break_after:
194
+ break
195
+ except (DebugIOError, asyncio.CancelledError):
196
+ # See earlier comment about this flag. If either of the awaits inside
197
+ # the try block is cancelled, we need to redo the read without actually
198
+ # reinserting "read" into the message queue.
199
+ self._need_read = True
200
+ raise
120
201
  elif message[0] == "write":
121
202
  output = message[1]
122
203
  # If the user sees this output but then detaches from the session,
@@ -124,11 +205,11 @@ class DebugSession:
124
205
  # they can be printed again when the user re-attaches.
125
206
  self._outputs_since_last_input.append(output)
126
207
  if not suppress_output:
127
- _debugger_output(output.payload.decode())
208
+ await debug_io.output(output.payload.decode())
128
209
 
129
210
  if not suppress_output:
130
- print(
131
- f"Detaching from debug session for rank {self.rank} ({self.hostname})"
211
+ await debug_io.output(
212
+ f"Detaching from debug session for {self.actor_name} {self.rank} ({self.hostname})\n"
132
213
  )
133
214
 
134
215
  def get_info(self):
@@ -139,14 +220,20 @@ class DebugSession:
139
220
  self.actor_name, self.rank, self.coords, self.hostname, function, lineno
140
221
  )
141
222
 
142
- async def attach(self, line=None, suppress_output=False):
223
+ async def attach(self, debug_io: DebugIO, line=None, suppress_output=False):
143
224
  self._active = True
144
225
  if not suppress_output:
145
- print(f"Attached to debug session for rank {self.rank} ({self.hostname})")
146
- self._task = asyncio.create_task(self._event_loop(line, suppress_output))
226
+ await debug_io.output(
227
+ f"Attached to debug session for {self.actor_name} {self.rank} ({self.hostname})\n"
228
+ )
229
+ self._task = asyncio.create_task(
230
+ self._event_loop(debug_io, line, suppress_output)
231
+ )
147
232
  await self._task
148
233
  if not suppress_output:
149
- print(f"Detached from debug session for rank {self.rank} ({self.hostname})")
234
+ await debug_io.output(
235
+ f"Detached from debug session for {self.actor_name} {self.rank} ({self.hostname})\n"
236
+ )
150
237
  self._active = False
151
238
 
152
239
  async def detach(self):
@@ -281,7 +368,7 @@ def _get_debug_input_parser():
281
368
  dims: dim ("," dim)*
282
369
  ranks: "ranks(" (dims | rank_range | rank_list | INT) ")"
283
370
  pdb_command: /\\w+.*/
284
- actor_name: /\\w+/
371
+ actor_name: /[-_a-zA-Z0-9]+/
285
372
  cast: "cast" _WS actor_name ranks pdb_command
286
373
  help: "h" | "help"
287
374
  attach: ("a" | "attach") _WS actor_name INT
@@ -388,12 +475,12 @@ def _get_debug_input_transformer():
388
475
 
389
476
  class DebugCommand:
390
477
  @staticmethod
391
- def parse(line: str) -> Union["DebugCommand", None]:
478
+ async def parse(debug_io: DebugIO, line: str) -> Union["DebugCommand", None]:
392
479
  try:
393
480
  tree = _get_debug_input_parser().parse(line)
394
481
  return _get_debug_input_transformer().transform(tree)
395
482
  except Exception as e:
396
- print(f"Error parsing input: {e}")
483
+ await debug_io.output(f"Error parsing input: {e}\n")
397
484
  return None
398
485
 
399
486
 
@@ -430,7 +517,7 @@ class Cast(DebugCommand):
430
517
  command: str
431
518
 
432
519
 
433
- class DebugClient(Actor):
520
+ class DebugController(Actor):
434
521
  """
435
522
  Single actor for both remote debuggers and users to talk to.
436
523
 
@@ -439,6 +526,49 @@ class DebugClient(Actor):
439
526
 
440
527
  def __init__(self) -> None:
441
528
  self.sessions = DebugSessions()
529
+ self._task_lock = asyncio.Lock()
530
+ self._task: asyncio.Task | None = None
531
+ self._debug_io: DebugIO = DebugStdIO()
532
+ self._server = asyncio.Future()
533
+ self._server_task = asyncio.create_task(self._serve())
534
+
535
+ async def _serve(self) -> None:
536
+ try:
537
+ if (proto := _get_debug_server_protocol()) != "tcp":
538
+ raise NotImplementedError(
539
+ f"Network protocol {proto} not yet supported."
540
+ )
541
+ server = await asyncio.start_server(
542
+ self._handle_client,
543
+ _get_debug_server_host(),
544
+ _get_debug_server_port(),
545
+ )
546
+ async with server:
547
+ self._server.set_result(server)
548
+ await server.serve_forever()
549
+ except Exception as e:
550
+ if self._server.done():
551
+ self._server = asyncio.Future()
552
+ self._server.set_exception(e)
553
+ raise
554
+
555
+ async def _handle_client(
556
+ self,
557
+ reader: asyncio.StreamReader,
558
+ writer: asyncio.StreamWriter,
559
+ ) -> None:
560
+ # Make sure only one external debug process can
561
+ # be attached at a time. If a new request is
562
+ # received, the current task is cancelled.
563
+ async with self._task_lock:
564
+ if self._task is not None:
565
+ self._task.cancel()
566
+ try:
567
+ await none_throws(self._task)
568
+ except (DebugIOError, asyncio.CancelledError):
569
+ pass
570
+ self._debug_io = DebugCliIO(reader, writer)
571
+ self._task = asyncio.create_task(self._enter())
442
572
 
443
573
  @endpoint
444
574
  async def wait_pending_session(self):
@@ -446,85 +576,90 @@ class DebugClient(Actor):
446
576
  await asyncio.sleep(1)
447
577
 
448
578
  @endpoint
449
- async def list(self) -> List[DebugSessionInfo]:
579
+ async def list(self, print_output=True) -> List[DebugSessionInfo]:
450
580
  session_info = sorted(self.sessions.info())
451
- print(
452
- tabulate(
453
- (
581
+ if print_output:
582
+ await self._debug_io.output(
583
+ tabulate(
454
584
  (
455
- info.actor_name,
456
- info.rank,
457
- info.coords,
458
- info.hostname,
459
- info.function,
460
- info.lineno,
461
- )
462
- for info in session_info
463
- ),
464
- headers=[
465
- "Actor Name",
466
- "Rank",
467
- "Coords",
468
- "Hostname",
469
- "Function",
470
- "Line No.",
471
- ],
472
- tablefmt="grid",
585
+ (
586
+ info.actor_name,
587
+ info.rank,
588
+ info.coords,
589
+ info.hostname,
590
+ info.function,
591
+ info.lineno,
592
+ )
593
+ for info in session_info
594
+ ),
595
+ headers=[
596
+ "Actor Name",
597
+ "Rank",
598
+ "Coords",
599
+ "Hostname",
600
+ "Function",
601
+ "Line No.",
602
+ ],
603
+ tablefmt="grid",
604
+ )
605
+ + "\n"
473
606
  )
474
- )
475
607
  return session_info
476
608
 
477
- @endpoint
478
- async def enter(self) -> None:
609
+ async def _enter(self) -> None:
479
610
  await asyncio.sleep(0.5)
480
- logger.info("Remote breakpoint hit. Entering monarch debugger...")
481
- print("\n\n************************ MONARCH DEBUGGER ************************")
482
- print("Enter 'help' for a list of commands.")
483
- print("Enter 'list' to show all active breakpoints.\n")
611
+ await self._debug_io.output(
612
+ "\n\n************************ MONARCH DEBUGGER ************************\n"
613
+ )
614
+ await self._debug_io.output("Enter 'help' for a list of commands.\n")
615
+ await self._debug_io.output("Enter 'list' to show all active breakpoints.\n\n")
484
616
 
485
617
  while True:
486
618
  try:
487
- user_input = await _debugger_input("monarch_dbg> ")
619
+ user_input = await self._debug_io.input("monarch_dbg> ")
488
620
  if not user_input.strip():
489
621
  continue
490
- command = DebugCommand.parse(user_input)
622
+ command = await DebugCommand.parse(self._debug_io, user_input)
491
623
  if isinstance(command, Help):
492
- print("monarch_dbg commands:")
493
- print("\tattach <actor_name> <rank> - attach to a debug session")
494
- print("\tlist - list all debug sessions")
495
- print("\tquit - exit the debugger, leaving all sessions in place")
496
- print(
624
+ await self._debug_io.output("monarch_dbg commands:\n")
625
+ await self._debug_io.output(
626
+ "\tattach <actor_name> <rank> - attach to a debug session\n"
627
+ )
628
+ await self._debug_io.output("\tlist - list all debug sessions\n")
629
+ await self._debug_io.output(
630
+ "\tquit - exit the debugger, leaving all sessions in place\n"
631
+ )
632
+ await self._debug_io.output(
497
633
  "\tcast <actor_name> ranks(...) <command> - send a command to a set of ranks on the specified actor mesh.\n"
498
634
  "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n"
499
635
  "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n"
500
- "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))."
636
+ "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6))).\n"
501
637
  )
502
- print(
503
- "\tcontinue - tell all ranks to continue execution, then exit the debugger"
638
+ await self._debug_io.output(
639
+ "\tcontinue - clear all breakpoints and tell all ranks to continue\n"
504
640
  )
505
- print("\thelp - print this help message")
641
+ await self._debug_io.output("\thelp - print this help message\n")
506
642
  elif isinstance(command, Attach):
507
- await self.sessions.get(command.actor_name, command.rank).attach()
643
+ await self.sessions.get(command.actor_name, command.rank).attach(
644
+ self._debug_io
645
+ )
508
646
  elif isinstance(command, ListCommand):
509
647
  # pyre-ignore
510
648
  await self.list._method(self)
511
649
  elif isinstance(command, Continue):
512
- # Clear all breakpoints and make sure all ranks have
513
- # exited their debug sessions. If we sent "quit", it
514
- # would raise BdbQuit, crashing the process, which
515
- # probably isn't what we want.
516
650
  await self._cast_input_and_wait("clear")
517
- while len(self.sessions) > 0:
518
- await self._cast_input_and_wait("c")
519
- return
651
+ await self._cast_input_and_wait("c")
520
652
  elif isinstance(command, Quit):
653
+ await self._debug_io.quit()
521
654
  return
522
655
  elif isinstance(command, Cast):
523
656
  await self._cast_input_and_wait(
524
657
  command.command, (command.actor_name, command.ranks)
525
658
  )
659
+ except (DebugIOError, asyncio.CancelledError):
660
+ raise
526
661
  except Exception as e:
527
- print(f"Error processing command: {e}")
662
+ await self._debug_io.output(f"Error processing command: {e}\n")
528
663
 
529
664
  async def _cast_input_and_wait(
530
665
  self,
@@ -533,7 +668,7 @@ class DebugClient(Actor):
533
668
  ) -> None:
534
669
  tasks = []
535
670
  for session in self.sessions.iter(selection):
536
- tasks.append(session.attach(command, suppress_output=True))
671
+ tasks.append(session.attach(self._debug_io, command, suppress_output=True))
537
672
  await asyncio.gather(*tasks)
538
673
 
539
674
  ##########################################################################
@@ -545,6 +680,13 @@ class DebugClient(Actor):
545
680
  async def debugger_session_start(
546
681
  self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str
547
682
  ) -> None:
683
+ # Good enough for now to ensure that if the server for processing
684
+ # user interactions never starts, then the rank being debugged will
685
+ # fail instead of hanging indefinitely with no way to send it commands.
686
+ # Of course this isn't sufficient to handle the case where the server
687
+ # fails after the rank's debug session has successfully started.
688
+ # TODO: implement a heartbeat to prevent pdb sessions from hanging.
689
+ await self._server
548
690
  # Create a session if it doesn't exist
549
691
  if (actor_name, rank) not in self.sessions:
550
692
  self.sessions.insert(DebugSession(rank, coords, hostname, actor_name))
@@ -569,58 +711,27 @@ class DebugClient(Actor):
569
711
  await self.sessions.get(actor_name, rank).debugger_write(write)
570
712
 
571
713
 
572
- class DebugManager(Actor):
573
- @staticmethod
574
- @functools.cache
575
- def ref() -> "DebugManager":
576
- ctx = MonarchContext.get()
577
- return cast(
578
- DebugManager,
579
- ActorMeshRef(
580
- DebugManager,
581
- _ActorMeshRefImpl.from_actor_id(
582
- ctx.mailbox,
583
- ActorId.from_string(
584
- f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]"
585
- ),
586
- ),
587
- ctx.mailbox,
588
- ),
589
- )
590
-
591
- def __init__(self, debug_client: DebugClient) -> None:
592
- self._debug_client = debug_client
593
-
594
- @endpoint
595
- def get_debug_client(self) -> DebugClient:
596
- return self._debug_client
714
+ # Cached so that we don't have to call out to the root client every time,
715
+ # which may be on a different host.
716
+ @functools.cache
717
+ def debug_controller() -> DebugController:
718
+ with fake_sync_state():
719
+ return get_or_spawn_controller("debug_controller", DebugController).get()
597
720
 
598
721
 
599
- def remote_breakpointhook():
722
+ def remote_breakpointhook() -> None:
600
723
  frame = inspect.currentframe()
601
724
  assert frame is not None
602
725
  frame = frame.f_back
603
726
  assert frame is not None
604
- file = frame.f_code.co_filename
605
- line = frame.f_lineno
606
- module = frame.f_globals.get("__name__", "__main__")
607
- if module == "__main__" and not os.path.exists(file):
608
- raise NotImplementedError(
609
- f"Remote debugging not supported for breakpoint at {file}:{line} because "
610
- f"it is defined inside __main__, and the file does not exist on the host. "
611
- "In this case, cloudpickle serialization does not interact nicely with pdb. "
612
- "To debug your code, move it out of __main__ and into a module that "
613
- "exists on both your client and worker processes."
614
- )
615
727
 
616
- with fake_sync_state():
617
- manager = DebugManager.ref().get_debug_client.call_one().get()
618
- ctx = MonarchContext.get()
728
+ ctx = context()
729
+ rank = ctx.message_rank
619
730
  pdb_wrapper = PdbWrapper(
620
- ctx.point.rank,
621
- ctx.point.shape.coordinates(ctx.point.rank),
622
- ctx.mailbox.actor_id,
623
- manager,
731
+ rank.rank,
732
+ {k: rank[k] for k in rank},
733
+ ctx.actor_instance.actor_id,
734
+ debug_controller(),
624
735
  )
625
736
  DebugContext.set(DebugContext(pdb_wrapper))
626
737
  pdb_wrapper.set_trace(frame)