torchmonarch-nightly 2025.8.1__cp313-cp313-manylinux2014_x86_64.whl → 2025.9.3__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. monarch/_rust_bindings.so +0 -0
  2. monarch/_src/actor/actor_mesh.py +414 -216
  3. monarch/_src/actor/allocator.py +75 -6
  4. monarch/_src/actor/bootstrap_main.py +7 -4
  5. monarch/_src/actor/code_sync/__init__.py +2 -0
  6. monarch/_src/actor/debugger/__init__.py +7 -0
  7. monarch/_src/actor/{debugger.py → debugger/debugger.py} +246 -135
  8. monarch/_src/actor/{pdb_wrapper.py → debugger/pdb_wrapper.py} +62 -23
  9. monarch/_src/actor/endpoint.py +27 -45
  10. monarch/_src/actor/future.py +86 -24
  11. monarch/_src/actor/host_mesh.py +125 -0
  12. monarch/_src/actor/logging.py +94 -0
  13. monarch/_src/actor/pickle.py +25 -0
  14. monarch/_src/actor/proc_mesh.py +423 -156
  15. monarch/_src/actor/python_extension_methods.py +90 -0
  16. monarch/_src/actor/shape.py +8 -1
  17. monarch/_src/actor/source_loader.py +45 -0
  18. monarch/_src/actor/telemetry/__init__.py +172 -0
  19. monarch/_src/actor/telemetry/rust_span_tracing.py +6 -39
  20. monarch/_src/debug_cli/__init__.py +7 -0
  21. monarch/_src/debug_cli/debug_cli.py +43 -0
  22. monarch/_src/tensor_engine/rdma.py +64 -9
  23. monarch/_testing.py +1 -3
  24. monarch/actor/__init__.py +24 -4
  25. monarch/common/_C.so +0 -0
  26. monarch/common/device_mesh.py +14 -0
  27. monarch/common/future.py +10 -0
  28. monarch/common/remote.py +14 -25
  29. monarch/common/tensor.py +12 -0
  30. monarch/debug_cli/__init__.py +7 -0
  31. monarch/debug_cli/__main__.py +12 -0
  32. monarch/fetch.py +2 -2
  33. monarch/gradient/_gradient_generator.so +0 -0
  34. monarch/gradient_generator.py +4 -2
  35. monarch/mesh_controller.py +34 -14
  36. monarch/monarch_controller +0 -0
  37. monarch/tools/colors.py +25 -0
  38. monarch/tools/commands.py +42 -7
  39. monarch/tools/components/hyperactor.py +1 -1
  40. monarch/tools/config/__init__.py +31 -4
  41. monarch/tools/config/defaults.py +13 -3
  42. monarch/tools/config/environment.py +45 -0
  43. monarch/tools/config/workspace.py +165 -0
  44. monarch/tools/mesh_spec.py +2 -0
  45. monarch/utils/__init__.py +9 -0
  46. monarch/utils/utils.py +78 -0
  47. tests/error_test_binary.py +5 -3
  48. tests/python_actor_test_binary.py +52 -0
  49. tests/test_actor_error.py +142 -14
  50. tests/test_alloc.py +1 -1
  51. tests/test_allocator.py +59 -72
  52. tests/test_coalescing.py +1 -1
  53. tests/test_debugger.py +639 -45
  54. tests/test_env_before_cuda.py +4 -4
  55. tests/test_mesh_trait.py +38 -0
  56. tests/test_python_actors.py +979 -75
  57. tests/test_rdma.py +7 -6
  58. tests/test_tensor_engine.py +6 -6
  59. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/METADATA +82 -4
  60. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/RECORD +64 -48
  61. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/WHEEL +0 -0
  62. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/entry_points.txt +0 -0
  63. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/licenses/LICENSE +0 -0
  64. {torchmonarch_nightly-2025.8.1.dist-info → torchmonarch_nightly-2025.9.3.dist-info}/top_level.txt +0 -0
tests/test_debugger.py CHANGED
@@ -6,34 +6,47 @@
6
6
 
7
7
  # pyre-unsafe
8
8
  import asyncio
9
+ import functools
10
+ import importlib.resources
11
+ import os
9
12
  import re
13
+ import shutil
14
+ import signal
15
+ import subprocess
10
16
  import sys
11
- from typing import cast, List
17
+ from typing import cast, List, Optional, Tuple
12
18
  from unittest.mock import AsyncMock, patch
13
19
 
20
+ import cloudpickle
21
+
14
22
  import monarch
15
23
  import monarch.actor as actor
16
24
 
17
25
  import pytest
18
26
 
19
27
  import torch
20
-
21
- from monarch._src.actor.actor_mesh import Actor, ActorError, current_rank
22
- from monarch._src.actor.debugger import (
28
+ from monarch._src.actor.actor_mesh import Actor, ActorError, current_rank, IN_PAR
29
+ from monarch._src.actor.debugger.debugger import (
30
+ _MONARCH_DEBUG_SERVER_HOST_ENV_VAR,
31
+ _MONARCH_DEBUG_SERVER_PORT_ENV_VAR,
23
32
  Attach,
24
33
  Cast,
25
34
  Continue,
26
35
  DebugCommand,
36
+ DebugController,
27
37
  DebugSession,
28
38
  DebugSessionInfo,
29
39
  DebugSessions,
40
+ DebugStdIO,
30
41
  Help,
31
42
  ListCommand,
32
43
  Quit,
33
44
  )
34
45
  from monarch._src.actor.endpoint import endpoint
35
-
36
46
  from monarch._src.actor.proc_mesh import proc_mesh
47
+ from monarch._src.actor.source_loader import SourceLoaderController
48
+
49
+ from pyre_extensions import none_throws
37
50
 
38
51
  needs_cuda = pytest.mark.skipif(
39
52
  not torch.cuda.is_available(),
@@ -41,6 +54,70 @@ needs_cuda = pytest.mark.skipif(
41
54
  )
42
55
 
43
56
 
57
+ debug_env = {
58
+ _MONARCH_DEBUG_SERVER_HOST_ENV_VAR: "0.0.0.0",
59
+ _MONARCH_DEBUG_SERVER_PORT_ENV_VAR: "0",
60
+ }
61
+
62
+
63
+ def isolate_in_subprocess(test_fn=None, *, env=None):
64
+ if test_fn is None:
65
+ return functools.partial(isolate_in_subprocess, env=env)
66
+
67
+ if env is None:
68
+ env = {}
69
+
70
+ def sync_test_fn():
71
+ asyncio.run(test_fn())
72
+
73
+ sync_test_fn_name = f"sync_{test_fn.__name__}"
74
+ setattr(sys.modules[__name__], sync_test_fn_name, sync_test_fn)
75
+
76
+ env.update(os.environ.copy())
77
+
78
+ def wrapper():
79
+ if IN_PAR:
80
+ assert (
81
+ subprocess.call(
82
+ [
83
+ str(
84
+ importlib.resources.files("monarch.python.tests").joinpath(
85
+ "run_test_bin"
86
+ )
87
+ ),
88
+ sync_test_fn_name,
89
+ ],
90
+ env=env,
91
+ )
92
+ == 0
93
+ )
94
+ else:
95
+ assert (
96
+ subprocess.call(
97
+ [
98
+ sys.executable,
99
+ "-c",
100
+ f"import tests.test_debugger; tests.test_debugger.{sync_test_fn_name}()",
101
+ ],
102
+ env=env,
103
+ )
104
+ == 0
105
+ )
106
+
107
+ return wrapper
108
+
109
+
110
+ def run_test_from_name():
111
+ getattr(sys.modules[__name__], sys.argv[1])()
112
+
113
+
114
+ debug_cli_bin = (
115
+ str(importlib.resources.files("monarch.python.tests").joinpath("debug_cli_bin"))
116
+ if IN_PAR
117
+ else ""
118
+ )
119
+
120
+
44
121
  def _bad_rank():
45
122
  raise ValueError("bad rank")
46
123
 
@@ -75,22 +152,45 @@ class DebugeeActor(Actor):
75
152
  return _debugee_actor_internal(rank)
76
153
 
77
154
 
78
- async def _wait_for_breakpoints(debug_client, n_breakpoints) -> List[DebugSessionInfo]:
155
+ class DebugControllerForTesting(DebugController):
156
+ def __init__(self):
157
+ super().__init__()
158
+ self._debug_io = DebugStdIO()
159
+
160
+ @endpoint
161
+ async def blocking_enter(self):
162
+ async with self._task_lock:
163
+ assert self._task is None
164
+ await self._enter()
165
+
166
+ @endpoint
167
+ async def server_port(self):
168
+ server: asyncio.Server = await self._server
169
+ if len(server.sockets) > 0:
170
+ return server.sockets[0].getsockname()[1]
171
+
172
+
173
+ async def _wait_for_breakpoints(
174
+ debug_controller, n_breakpoints, timeout_sec=20
175
+ ) -> List[DebugSessionInfo]:
79
176
  breakpoints: List[DebugSessionInfo] = []
80
- for i in range(10):
81
- breakpoints = await debug_client.list.call_one()
82
- if len(breakpoints) == n_breakpoints:
83
- break
177
+ for _ in range(timeout_sec):
84
178
  await asyncio.sleep(1)
85
- if i == 9:
86
- raise RuntimeError("timed out waiting for breakpoints")
87
- return breakpoints
179
+ breakpoints = await debug_controller.list.call_one(print_output=False)
180
+ if len(breakpoints) == n_breakpoints:
181
+ return breakpoints
182
+ raise RuntimeError("timed out waiting for breakpoints")
88
183
 
89
184
 
185
+ # We have to run this test in a separate process because there is only one
186
+ # debug controller per process, and we don't want this to interfere with
187
+ # the other two tests that access the debug controller.
188
+ @isolate_in_subprocess(env=debug_env)
90
189
  @pytest.mark.skipif(
91
190
  torch.cuda.device_count() < 2,
92
191
  reason="Not enough GPUs, this test requires at least 2 GPUs",
93
192
  )
193
+ @pytest.mark.timeout(60)
94
194
  async def test_debug() -> None:
95
195
  input_mock = AsyncMock()
96
196
  input_mock.side_effect = [
@@ -122,6 +222,7 @@ async def test_debug() -> None:
122
222
  "c",
123
223
  "quit",
124
224
  "continue",
225
+ "quit",
125
226
  ]
126
227
 
127
228
  outputs = []
@@ -130,16 +231,21 @@ async def test_debug() -> None:
130
231
  nonlocal outputs
131
232
  outputs.append(msg)
132
233
 
234
+ output_mock = AsyncMock()
235
+ output_mock.side_effect = _patch_output
236
+
133
237
  with patch(
134
- "monarch._src.actor.debugger._debugger_input", side_effect=input_mock
135
- ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output):
136
- proc = await proc_mesh(hosts=2, gpus=2)
238
+ "monarch._src.actor.debugger.debugger.DebugStdIO.input", new=input_mock
239
+ ), patch("monarch._src.actor.debugger.debugger.DebugStdIO.output", new=output_mock):
240
+ proc = proc_mesh(hosts=2, gpus=2)
137
241
  debugee = await proc.spawn("debugee", DebugeeActor)
138
- debug_client = actor.debug_client()
242
+ debug_controller = actor.get_or_spawn_controller(
243
+ "debug_controller", DebugControllerForTesting
244
+ ).get()
139
245
 
140
246
  fut = debugee.to_debug.call()
141
- await debug_client.wait_pending_session.call_one()
142
- breakpoints = await _wait_for_breakpoints(debug_client, 4)
247
+ await debug_controller.wait_pending_session.call_one()
248
+ breakpoints = await _wait_for_breakpoints(debug_controller, 4)
143
249
 
144
250
  initial_linenos = {}
145
251
  for i in range(len(breakpoints)):
@@ -150,7 +256,7 @@ async def test_debug() -> None:
150
256
  assert info.function == "test_debugger._debugee_actor_internal"
151
257
  assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank
152
258
 
153
- await debug_client.enter.call_one()
259
+ await debug_controller.blocking_enter.call_one()
154
260
 
155
261
  # Check that when detaching and re-attaching to a session, the last portion of the output is repeated
156
262
  expected_last_output = [
@@ -161,13 +267,22 @@ async def test_debug() -> None:
161
267
  r"\(Pdb\) ",
162
268
  ]
163
269
  output_len = len(expected_last_output)
164
- assert outputs[-2 * output_len : -output_len] == outputs[-output_len:]
270
+ rev_outputs = outputs[::-1]
271
+ last_return = rev_outputs.index("--Return--")
272
+ second_to_last_return = rev_outputs.index("--Return--", last_return + 1)
273
+ last_return = len(rev_outputs) - last_return - 1
274
+ second_to_last_return = len(rev_outputs) - second_to_last_return - 1
275
+ assert (
276
+ outputs[second_to_last_return : second_to_last_return + output_len] # noqa
277
+ == outputs[last_return : last_return + output_len] # noqa
278
+ )
165
279
  for real_output, expected_output in zip(
166
- outputs[-output_len:], expected_last_output
280
+ outputs[last_return : last_return + output_len], # noqa
281
+ expected_last_output,
167
282
  ):
168
283
  assert re.match(expected_output, real_output) is not None
169
284
 
170
- breakpoints = await debug_client.list.call_one()
285
+ breakpoints = await debug_controller.list.call_one(print_output=False)
171
286
  for i in range(len(breakpoints)):
172
287
  if i == 1:
173
288
  assert breakpoints[i].function == "test_debugger.to_debug"
@@ -177,9 +292,9 @@ async def test_debug() -> None:
177
292
  )
178
293
  assert breakpoints[i].lineno == initial_linenos[i]
179
294
 
180
- await debug_client.enter.call_one()
295
+ await debug_controller.blocking_enter.call_one()
181
296
 
182
- breakpoints = await debug_client.list.call_one()
297
+ breakpoints = await debug_controller.list.call_one(print_output=False)
183
298
  for i in range(len(breakpoints)):
184
299
  if i == 1:
185
300
  assert breakpoints[i].function == "test_debugger.to_debug"
@@ -194,14 +309,14 @@ async def test_debug() -> None:
194
309
  )
195
310
  assert breakpoints[i].lineno == initial_linenos[i]
196
311
 
197
- await debug_client.enter.call_one()
312
+ await debug_controller.blocking_enter.call_one()
198
313
 
199
- breakpoints = await debug_client.list.call_one()
314
+ breakpoints = await debug_controller.list.call_one(print_output=False)
200
315
  assert len(breakpoints) == 4
201
316
  # Expect post-mortem debugging for rank 2
202
317
  assert breakpoints[2].function == "test_debugger._bad_rank"
203
318
 
204
- await debug_client.enter.call_one()
319
+ await debug_controller.blocking_enter.call_one()
205
320
 
206
321
  expected_last_output = [
207
322
  r"\s*(/.*/)+test_debugger.py\(\d+\)_debugee_actor_internal\(\)\n-> _bad_rank\(\)",
@@ -211,18 +326,24 @@ async def test_debug() -> None:
211
326
  r"\(Pdb\) ",
212
327
  ]
213
328
 
329
+ rev_outputs = outputs[::-1]
330
+ output_index = len(outputs) - (
331
+ rev_outputs.index("(Pdb) ") + len(expected_last_output)
332
+ )
333
+
214
334
  for output, expected_output in zip(
215
- outputs[-len(expected_last_output) :], expected_last_output
335
+ outputs[output_index : output_index + len(expected_last_output)], # noqa
336
+ expected_last_output,
216
337
  ):
217
338
  assert re.match(expected_output, output) is not None
218
339
 
219
- breakpoints = await debug_client.list.call_one()
340
+ breakpoints = await debug_controller.list.call_one(print_output=False)
220
341
  assert len(breakpoints) == 3
221
342
  for i, rank in enumerate((0, 1, 3)):
222
343
  assert breakpoints[i].rank == rank
223
344
 
224
- await debug_client.enter.call_one()
225
- breakpoints = await debug_client.list.call_one()
345
+ await debug_controller.blocking_enter.call_one()
346
+ breakpoints = await debug_controller.list.call_one(print_output=False)
226
347
  assert len(breakpoints) == 0
227
348
 
228
349
  with pytest.raises(
@@ -231,10 +352,13 @@ async def test_debug() -> None:
231
352
  await fut
232
353
 
233
354
 
355
+ # See earlier comment
356
+ @isolate_in_subprocess(env=debug_env)
234
357
  @pytest.mark.skipif(
235
358
  torch.cuda.device_count() < 2,
236
359
  reason="Not enough GPUs, this test requires at least 2 GPUs",
237
360
  )
361
+ @pytest.mark.timeout(60)
238
362
  async def test_debug_multi_actor() -> None:
239
363
  input_mock = AsyncMock()
240
364
  input_mock.side_effect = [
@@ -251,19 +375,24 @@ async def test_debug_multi_actor() -> None:
251
375
  "c",
252
376
  "quit",
253
377
  "continue",
378
+ "quit",
254
379
  ]
255
380
 
256
- with patch("monarch._src.actor.debugger._debugger_input", side_effect=input_mock):
381
+ with patch(
382
+ "monarch._src.actor.debugger.debugger.DebugStdIO.input", side_effect=input_mock
383
+ ):
257
384
  proc = await proc_mesh(hosts=2, gpus=2)
258
385
  debugee_1 = await proc.spawn("debugee_1", DebugeeActor)
259
386
  debugee_2 = await proc.spawn("debugee_2", DebugeeActor)
260
- debug_client = actor.debug_client()
387
+ debug_controller = actor.get_or_spawn_controller(
388
+ "debug_controller", DebugControllerForTesting
389
+ ).get()
261
390
 
262
391
  fut_1 = debugee_1.to_debug.call()
263
392
  fut_2 = debugee_2.to_debug.call()
264
- await debug_client.wait_pending_session.call_one()
393
+ await debug_controller.wait_pending_session.call_one()
265
394
 
266
- breakpoints = await _wait_for_breakpoints(debug_client, 8)
395
+ breakpoints = await _wait_for_breakpoints(debug_controller, 8)
267
396
 
268
397
  initial_linenos = {}
269
398
  for i in range(len(breakpoints)):
@@ -275,9 +404,9 @@ async def test_debug_multi_actor() -> None:
275
404
  assert info.function == "test_debugger._debugee_actor_internal"
276
405
  assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank
277
406
 
278
- await debug_client.enter.call_one()
407
+ await debug_controller.blocking_enter.call_one()
279
408
 
280
- breakpoints = await _wait_for_breakpoints(debug_client, 8)
409
+ breakpoints = await _wait_for_breakpoints(debug_controller, 8)
281
410
  for i in range(len(breakpoints)):
282
411
  if i == 1:
283
412
  assert breakpoints[i].actor_name == "debugee_1"
@@ -294,18 +423,18 @@ async def test_debug_multi_actor() -> None:
294
423
  assert breakpoints[i].rank == i % 4
295
424
  assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank]
296
425
 
297
- await debug_client.enter.call_one()
426
+ await debug_controller.blocking_enter.call_one()
298
427
 
299
- breakpoints = await _wait_for_breakpoints(debug_client, 1)
428
+ breakpoints = await _wait_for_breakpoints(debug_controller, 1)
300
429
  with pytest.raises(ActorError, match="ValueError: bad rank"):
301
430
  await fut_2
302
431
  assert breakpoints[0].actor_name == "debugee_1"
303
432
  assert breakpoints[0].rank == 2
304
433
  assert breakpoints[0].function == "test_debugger._bad_rank"
305
434
 
306
- await debug_client.enter.call_one()
435
+ await debug_controller.blocking_enter.call_one()
307
436
 
308
- breakpoints = await _wait_for_breakpoints(debug_client, 0)
437
+ breakpoints = await _wait_for_breakpoints(debug_controller, 0)
309
438
  with pytest.raises(ActorError, match="ValueError: bad rank"):
310
439
  await fut_1
311
440
 
@@ -512,7 +641,7 @@ async def test_debug_sessions_iter() -> None:
512
641
  ["user_input", "expected_output"],
513
642
  [
514
643
  ("attach debugee 1", Attach("debugee", 1)),
515
- ("a my_awesome_actor 100", Attach("my_awesome_actor", 100)),
644
+ ("a my_awesome_actor-123_DBG 100", Attach("my_awesome_actor-123_DBG", 100)),
516
645
  ("list", ListCommand()),
517
646
  ("l", ListCommand()),
518
647
  ("help", Help()),
@@ -600,7 +729,7 @@ async def test_debug_sessions_iter() -> None:
600
729
  ],
601
730
  )
602
731
  async def test_debug_command_parser_valid_inputs(user_input, expected_output):
603
- assert DebugCommand.parse(user_input) == expected_output
732
+ assert await DebugCommand.parse(DebugStdIO(), user_input) == expected_output
604
733
 
605
734
 
606
735
  @pytest.mark.parametrize(
@@ -641,4 +770,469 @@ async def test_debug_command_parser_valid_inputs(user_input, expected_output):
641
770
  ],
642
771
  )
643
772
  async def test_debug_command_parser_invalid_inputs(invalid_input):
644
- assert DebugCommand.parse(invalid_input) is None
773
+ assert await DebugCommand.parse(DebugStdIO(), invalid_input) is None
774
+
775
+
776
+ # See earlier comment
777
+ @isolate_in_subprocess(env={"MONARCH_DEBUG_CLI_BIN": debug_cli_bin, **debug_env})
778
+ @pytest.mark.skipif(
779
+ torch.cuda.device_count() < 2,
780
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
781
+ )
782
+ @pytest.mark.timeout(60)
783
+ async def test_debug_cli():
784
+ proc = proc_mesh(hosts=2, gpus=2)
785
+ debugee = await proc.spawn("debugee", DebugeeActor)
786
+ debug_controller = actor.get_or_spawn_controller(
787
+ "debug_controller", DebugControllerForTesting
788
+ ).get()
789
+
790
+ fut = debugee.to_debug.call()
791
+ # Stupidly high timeout because when CI tries to run many instances of this
792
+ # test in parallel, it can take a long time for breakpoints to actually show
793
+ # up.
794
+ breakpoints = await _wait_for_breakpoints(debug_controller, 4, timeout_sec=180)
795
+
796
+ initial_linenos = {}
797
+ for i in range(len(breakpoints)):
798
+ info = breakpoints[i]
799
+ initial_linenos[info.rank] = info.lineno
800
+ assert info.rank == i
801
+ assert info.coords == {"hosts": info.rank // 2, "gpus": info.rank % 2}
802
+ assert info.function == "test_debugger._debugee_actor_internal"
803
+ assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank
804
+
805
+ port = debug_controller.server_port.call_one().get()
806
+
807
+ async def create_debug_cli_proc() -> (
808
+ Tuple[
809
+ Optional[asyncio.subprocess.Process],
810
+ asyncio.StreamWriter,
811
+ asyncio.StreamReader,
812
+ ]
813
+ ):
814
+ cmd = None
815
+ if IN_PAR:
816
+ cmd = [
817
+ os.environ["MONARCH_DEBUG_CLI_BIN"],
818
+ "--host",
819
+ os.environ[_MONARCH_DEBUG_SERVER_HOST_ENV_VAR],
820
+ "--port",
821
+ str(port),
822
+ ]
823
+ elif any(shutil.which(nc_cmd) for nc_cmd in ["ncat", "nc", "netcat"]):
824
+ cmd = [
825
+ sys.executable,
826
+ "-m",
827
+ "monarch.debug_cli",
828
+ "--host",
829
+ os.environ[_MONARCH_DEBUG_SERVER_HOST_ENV_VAR],
830
+ "--port",
831
+ str(port),
832
+ ]
833
+ if cmd:
834
+ debug_cli_proc = await asyncio.create_subprocess_exec(
835
+ *cmd,
836
+ stdin=subprocess.PIPE,
837
+ stdout=subprocess.PIPE,
838
+ )
839
+ debug_cli_stdin = none_throws(debug_cli_proc.stdin)
840
+ debug_cli_stdout = none_throws(debug_cli_proc.stdout)
841
+ return debug_cli_proc, debug_cli_stdin, debug_cli_stdout
842
+ else:
843
+ # Netcat isn't available in our github CI environment, so we can't
844
+ # run the monarch.debug_cli module
845
+ reader, writer = await asyncio.open_connection(
846
+ os.environ[_MONARCH_DEBUG_SERVER_HOST_ENV_VAR], port
847
+ )
848
+ return None, writer, reader
849
+
850
+ (
851
+ debug_cli_proc,
852
+ debug_cli_stdin,
853
+ debug_cli_stdout,
854
+ ) = await create_debug_cli_proc()
855
+
856
+ debug_cli_stdin.writelines(
857
+ [
858
+ b"attach debugee 1\n",
859
+ b"n\n",
860
+ b"n\n",
861
+ b"n\n",
862
+ b"n\n",
863
+ b"detach\n",
864
+ b"attach debugee 1\n",
865
+ b"print('test separator')\n",
866
+ b"detach\n",
867
+ ]
868
+ )
869
+ await debug_cli_stdin.drain()
870
+
871
+ # Check that when detaching and re-attaching to a session, the last portion of the output is repeated
872
+ expected_last_output = (
873
+ r"--Return--\n"
874
+ r"> (?:/.*/)+test_debugger.py\(\d+\)to_debug\(\)->5\n"
875
+ r"-> return _debugee_actor_internal\(rank\)\n"
876
+ r"\(Pdb\) "
877
+ )
878
+
879
+ outputs = (await debug_cli_stdout.readuntil(b"test separator")).decode()
880
+ assert len(re.findall(expected_last_output, outputs)) == 2
881
+ assert outputs[0] == outputs[1]
882
+
883
+ breakpoints = await debug_controller.list.call_one(print_output=False)
884
+ for i in range(len(breakpoints)):
885
+ if i == 1:
886
+ assert breakpoints[i].function == "test_debugger.to_debug"
887
+ else:
888
+ assert breakpoints[i].function == "test_debugger._debugee_actor_internal"
889
+ assert breakpoints[i].lineno == initial_linenos[i]
890
+
891
+ debug_cli_stdin.write(b"quit\n")
892
+ await debug_cli_stdin.drain()
893
+ # Yield and wait so that the debug controller has a chance to process the
894
+ # input before we close stdin.
895
+ await asyncio.sleep(1)
896
+ debug_cli_stdin.close()
897
+ await debug_cli_stdin.wait_closed()
898
+ if debug_cli_proc:
899
+ assert await debug_cli_proc.wait() == 0
900
+
901
+ (
902
+ debug_cli_proc,
903
+ debug_cli_stdin,
904
+ debug_cli_stdout,
905
+ ) = await create_debug_cli_proc()
906
+
907
+ debug_cli_stdin.writelines(
908
+ [
909
+ b"cast debugee ranks(0,3) n\n",
910
+ b"cast debugee ranks(0,3) n\n",
911
+ # Attaching to 0 and 3 ensures that when we call "list"
912
+ # the next time, their function/lineno info will be
913
+ # up-to-date.
914
+ b"attach debugee 0\n",
915
+ b"detach\n",
916
+ b"attach debugee 3\n",
917
+ b"detach\n",
918
+ ]
919
+ )
920
+ await debug_cli_stdin.drain()
921
+
922
+ # Make sure we have run all the commands before killing the CLI, otherwise
923
+ # the commands may not actually be sent to the debug controller.
924
+ await debug_cli_stdout.readuntil(b"Detached from debug session for debugee 3")
925
+ if debug_cli_proc:
926
+ # Even if we kill the proc using a signal, we should be able to reconnect
927
+ # without issue.
928
+ debug_cli_proc.send_signal(signal.SIGINT)
929
+ assert await debug_cli_proc.wait() != 0
930
+ else:
931
+ debug_cli_stdin.close()
932
+ await debug_cli_stdin.wait_closed()
933
+
934
+ breakpoints = await debug_controller.list.call_one(print_output=False)
935
+ for i in range(len(breakpoints)):
936
+ if i == 1:
937
+ assert breakpoints[i].function == "test_debugger.to_debug"
938
+ elif i in (0, 3):
939
+ assert breakpoints[i].function == "test_debugger._debugee_actor_internal"
940
+ assert breakpoints[i].lineno == initial_linenos[i] + 2
941
+ else:
942
+ assert breakpoints[i].function == "test_debugger._debugee_actor_internal"
943
+ assert breakpoints[i].lineno == initial_linenos[i]
944
+
945
+ (
946
+ debug_cli_proc,
947
+ debug_cli_stdin,
948
+ debug_cli_stdout,
949
+ ) = await create_debug_cli_proc()
950
+
951
+ debug_cli_stdin.writelines([b"attach debugee 2\n", b"c\n"])
952
+ await debug_cli_stdin.drain()
953
+
954
+ # Make sure we have run all the commands before killing the CLI, otherwise
955
+ # the commands may not actually be sent to the debug controller.
956
+ await debug_cli_stdout.readuntil(b"raise ValueError")
957
+ if debug_cli_proc:
958
+ # Even if we kill the proc using a signal while the debugger is attached to
959
+ # a specific rank, we should be able to reconnect to that rank later without
960
+ # issue.
961
+ debug_cli_proc.send_signal(signal.SIGINT)
962
+ assert await debug_cli_proc.wait() != 0
963
+ else:
964
+ debug_cli_stdin.close()
965
+ await debug_cli_stdin.wait_closed()
966
+
967
+ breakpoints = await debug_controller.list.call_one(print_output=False)
968
+ assert len(breakpoints) == 4
969
+ # Expect post-mortem debugging for rank 2
970
+ assert breakpoints[2].function == "test_debugger._bad_rank"
971
+
972
+ (
973
+ debug_cli_proc,
974
+ debug_cli_stdin,
975
+ debug_cli_stdout,
976
+ ) = await create_debug_cli_proc()
977
+
978
+ debug_cli_stdin.writelines([b"attach debugee 2\n", b"bt\n", b"c\n"])
979
+ await debug_cli_stdin.drain()
980
+
981
+ expected_output = (
982
+ r"(?:/.*/)+test_debugger.py\(\d+\)_debugee_actor_internal\(\)\n-> _bad_rank\(\)\n"
983
+ r'> (?:/.*/)+test_debugger.py\(\d+\)_bad_rank\(\)\n-> raise ValueError\("bad rank"\)\n'
984
+ r"\(Pdb\)"
985
+ )
986
+
987
+ output = (
988
+ await debug_cli_stdout.readuntil(b"Detached from debug session for debugee 2")
989
+ ).decode()
990
+ assert len(re.findall(expected_output, output)) == 1
991
+
992
+ debug_cli_stdin.writelines([b"quit\n"])
993
+ await debug_cli_stdin.drain()
994
+ debug_cli_stdin.close()
995
+ # Yield and wait so that the debug controller has a chance to process the
996
+ # input before we close stdin.
997
+ await asyncio.sleep(1)
998
+ await debug_cli_stdin.wait_closed()
999
+ if debug_cli_proc:
1000
+ assert await debug_cli_proc.wait() == 0
1001
+
1002
+ breakpoints = await debug_controller.list.call_one(print_output=False)
1003
+ assert len(breakpoints) == 3
1004
+ for i, rank in enumerate((0, 1, 3)):
1005
+ assert breakpoints[i].rank == rank
1006
+
1007
+ debug_cli_proc, debug_cli_stdin, _ = await create_debug_cli_proc()
1008
+ debug_cli_stdin.writelines([b"continue\n", b"quit\n"])
1009
+ await debug_cli_stdin.drain()
1010
+ # Yield and wait so that the debug controller has a chance to process the
1011
+ # input before we close stdin.
1012
+ await asyncio.sleep(1)
1013
+ debug_cli_stdin.close()
1014
+ await debug_cli_stdin.wait_closed()
1015
+ if debug_cli_proc:
1016
+ assert await debug_cli_proc.wait() == 0
1017
+
1018
+ breakpoints = await _wait_for_breakpoints(debug_controller, 0)
1019
+ assert len(breakpoints) == 0
1020
+
1021
+ with pytest.raises(
1022
+ monarch._src.actor.actor_mesh.ActorError, match="ValueError: bad rank"
1023
+ ):
1024
+ await fut
1025
+
1026
+
1027
+ class_closure_source = """class ClassClosure:
1028
+ def __init__(self, arg):
1029
+ self.arg = arg
1030
+
1031
+ def closure(self):
1032
+ arg = self.arg
1033
+
1034
+ class Internal:
1035
+ def __init__(self):
1036
+ self.arg = arg
1037
+ # noqa
1038
+ def get_arg(self):
1039
+ breakpoint()
1040
+ return self.arg
1041
+
1042
+ return Internal
1043
+ """
1044
+
1045
+ function_closure_source = """def func_closure(arg, bp):
1046
+ def func(internal):
1047
+ if bp:
1048
+ breakpoint()
1049
+ return internal().get_arg() + arg
1050
+ return func
1051
+ """
1052
+
1053
+
1054
+ def load_class_closure():
1055
+ pickled = b'\x80\x05\x95\xc7\x03\x00\x00\x00\x00\x00\x00\x8c\x17cloudpickle.cloudpickle\x94\x8c\x14_make_skeleton_class\x94\x93\x94(\x8c\x08builtins\x94\x8c\x04type\x94\x93\x94\x8c\x08Internal\x94h\x03\x8c\x06object\x94\x93\x94\x85\x94}\x94\x8c\n__module__\x94\x8c\rclass_closure\x94s\x8c 0f63369d5845486db9033c9f3c3253d5\x94Nt\x94R\x94h\x00\x8c\x0f_class_setstate\x94\x93\x94h\x0f}\x94(\x8c\x07__doc__\x94N\x8c\x08__init__\x94h\x00\x8c\x0e_make_function\x94\x93\x94(h\x00\x8c\r_builtin_type\x94\x93\x94\x8c\x08CodeType\x94\x85\x94R\x94(K\x01K\x00K\x00K\x01K\x02K\x13C\n\x88\x00|\x00_\x00d\x00S\x00\x94N\x85\x94\x8c\x03arg\x94\x85\x94\x8c\x04self\x94\x85\x94\x8c"/tmp/monarch_test/class_closure.py\x94\x8c\x08__init__\x94K\tC\x02\n\x01\x94h\x1e\x85\x94)t\x94R\x94}\x94(\x8c\x0b__package__\x94\x8c\x00\x94\x8c\x08__name__\x94h\x0c\x8c\x08__file__\x94h"uNNh\x00\x8c\x10_make_empty_cell\x94\x93\x94)R\x94\x85\x94t\x94R\x94h\x00\x8c\x12_function_setstate\x94\x93\x94h2}\x94}\x94(h+\x8c\x08__init__\x94\x8c\x0c__qualname__\x94\x8c/ClassClosure.closure.<locals>.Internal.__init__\x94\x8c\x0f__annotations__\x94}\x94\x8c\x0e__kwdefaults__\x94N\x8c\x0c__defaults__\x94Nh\x0bh\x0c\x8c\x07__doc__\x94N\x8c\x0b__closure__\x94h\x00\x8c\n_make_cell\x94\x93\x94K\n\x85\x94R\x94\x85\x94\x8c\x17_cloudpickle_submodules\x94]\x94\x8c\x0b__globals__\x94}\x94u\x86\x94\x86R0\x8c\n__module__\x94h\x0c\x8c\x07get_arg\x94h\x16(h\x1b(K\x01K\x00K\x00K\x01K\x01KSC\x0ct\x00\x83\x00\x01\x00|\x00j\x01S\x00\x94h\x1d\x8c\nbreakpoint\x94h\x1e\x86\x94h \x85\x94h"\x8c\x07get_arg\x94K\x0cC\x04\x06\x01\x06\x01\x94))t\x94R\x94h(NNNt\x94R\x94h4hU}\x94}\x94(h+\x8c\x07get_arg\x94h8\x8c.ClassClosure.closure.<locals>.Internal.get_arg\x94h:}\x94h<Nh=Nh\x0bh\x0ch>Nh?NhE]\x94hG}\x94u\x86\x94\x86R0u}\x94\x86\x94\x86R0.'
1056
+ # Unpickle `ClassClosure(10).closure()``
1057
+ return cloudpickle.loads(pickled)
1058
+
1059
+
1060
+ def load_func_closure():
1061
+ pickled = b"\x80\x05\x95\xd9\x02\x00\x00\x00\x00\x00\x00\x8c\x17cloudpickle.cloudpickle\x94\x8c\x0e_make_function\x94\x93\x94(h\x00\x8c\r_builtin_type\x94\x93\x94\x8c\x08CodeType\x94\x85\x94R\x94(K\x01K\x00K\x00K\x01K\x02K\x13C\x18\x88\x01r\x05t\x00\x83\x00\x01\x00|\x00\x83\x00\xa0\x01\xa1\x00\x88\x00\x17\x00S\x00\x94N\x85\x94\x8c\nbreakpoint\x94\x8c\x07get_arg\x94\x86\x94\x8c\x08internal\x94\x85\x94\x8c%/tmp/monarch_test/function_closure.py\x94\x8c\x04func\x94K\x02C\x06\x04\x01\x06\x01\x0e\x01\x94\x8c\x03arg\x94\x8c\x02bp\x94\x86\x94)t\x94R\x94}\x94(\x8c\x0b__package__\x94\x8c\x00\x94\x8c\x08__name__\x94\x8c\x10function_closure\x94\x8c\x08__file__\x94h\x0fuNNh\x00\x8c\x10_make_empty_cell\x94\x93\x94)R\x94h\x1e)R\x94\x86\x94t\x94R\x94h\x00\x8c\x12_function_setstate\x94\x93\x94h#}\x94}\x94(h\x1a\x8c\x04func\x94\x8c\x0c__qualname__\x94\x8c\x1afunc_closure.<locals>.func\x94\x8c\x0f__annotations__\x94}\x94\x8c\x0e__kwdefaults__\x94N\x8c\x0c__defaults__\x94N\x8c\n__module__\x94h\x1b\x8c\x07__doc__\x94N\x8c\x0b__closure__\x94h\x00\x8c\n_make_cell\x94\x93\x94K\x05\x85\x94R\x94h3\x88\x85\x94R\x94\x86\x94\x8c\x17_cloudpickle_submodules\x94]\x94\x8c\x0b__globals__\x94}\x94u\x86\x94\x86R0h\x02(h\x16h\x17NNh\x1e)R\x94h\x1e)R\x94\x86\x94t\x94R\x94h%hB}\x94}\x94(h\x1a\x8c\x04func\x94h)\x8c\x1afunc_closure.<locals>.func\x94h+}\x94h-Nh.Nh/h\x1bh0Nh1h3K\x05\x85\x94R\x94h3\x89\x85\x94R\x94\x86\x94h9]\x94h;}\x94u\x86\x94\x86R0\x86\x94."
1062
+ # Unpickle `(func(5, True), func(5, False))`
1063
+ return cloudpickle.loads(pickled)
1064
+
1065
+
1066
+ class SourceLoaderControllerWithMockedSource(SourceLoaderController):
1067
+ @endpoint
1068
+ def get_source(self, filename: str) -> str:
1069
+ if filename == "/tmp/monarch_test/class_closure.py":
1070
+ return class_closure_source
1071
+ elif filename == "/tmp/monarch_test/function_closure.py":
1072
+ return function_closure_source
1073
+ else:
1074
+ raise ValueError(f"Test should not have requested source for {filename}")
1075
+
1076
+
1077
+ class ClosureDebugeeActor(Actor):
1078
+ @endpoint
1079
+ def debug_class_closure(self, class_closure) -> int:
1080
+ return class_closure().get_arg()
1081
+
1082
+ @endpoint
1083
+ def debug_func(self, func, class_closure) -> int:
1084
+ return func(class_closure)
1085
+
1086
+
1087
+ # We have to run this test in a subprocess because it requires a special
1088
+ # instantiation of the debug controller singleton.
1089
+ @isolate_in_subprocess(env=debug_env)
1090
+ @pytest.mark.timeout(60)
1091
+ async def test_debug_with_pickle_by_value():
1092
+ """
1093
+ This test tests debugger functionality when there are breakpoints in
1094
+ code that has been pickled by value (as opposed to pickling by reference,
1095
+ where the pickled representation is essentially just "from <module> import
1096
+ <code>"). Cloudpickle will pickle by value for a few reasons, the primary
1097
+ among them being:
1098
+ - The function, class, etc. was defined in the __main__ module
1099
+ - The function, class, etc. is a closure
1100
+ - The function is a lambda
1101
+ When code that was pickled by value hits a breakpoint, if the original file
1102
+ that the code came from doesn't exist on the host, we need to do some special
1103
+ handling inside `monarch._src.actor.debugger.pdb_wrapper` to make all the pdb
1104
+ commands work as expected.
1105
+
1106
+ For this test, I created two files: /tmp/monarch_test/class_closure.py and
1107
+ /tmp/monarch_test/function_closure.py. Their source code is contained in
1108
+ the variables `class_closure_source` and `function_closure_source`,
1109
+ respectively, above. The functions `load_class_closure` and `load_func_closure`
1110
+ above contain `cloudpickle.dumps(ClassClosure(10).closure())`, and
1111
+ `cloudpickle.dumps((func(5, True), func(5, False)))`, respectively.
1112
+
1113
+ The test unpickles these and sends them to an actor endpoint, in which
1114
+ breakpoints will be hit and we can test the special pdb handling logic.
1115
+ """
1116
+
1117
+ input_mock = AsyncMock()
1118
+ input_mock.side_effect = [
1119
+ "attach debugee 0",
1120
+ "c",
1121
+ "quit",
1122
+ "attach debugee 0",
1123
+ "bt",
1124
+ "c",
1125
+ "quit",
1126
+ "attach debugee 0",
1127
+ "b /tmp/monarch_test/class_closure:10",
1128
+ "c",
1129
+ "detach",
1130
+ "quit",
1131
+ "attach debugee 0",
1132
+ "c",
1133
+ "detach",
1134
+ "quit",
1135
+ "c",
1136
+ "quit",
1137
+ ]
1138
+
1139
+ outputs = []
1140
+
1141
+ def _patch_output(msg):
1142
+ nonlocal outputs
1143
+ outputs.append(msg)
1144
+
1145
+ output_mock = AsyncMock()
1146
+ output_mock.side_effect = _patch_output
1147
+
1148
+ with patch(
1149
+ "monarch._src.actor.debugger.debugger.DebugStdIO.input", new=input_mock
1150
+ ), patch("monarch._src.actor.debugger.debugger.DebugStdIO.output", new=output_mock):
1151
+ pm = proc_mesh(gpus=1, hosts=1)
1152
+
1153
+ debug_controller = actor.get_or_spawn_controller(
1154
+ "debug_controller", DebugControllerForTesting
1155
+ ).get()
1156
+
1157
+ # Spawn a special source loader that knows how to retrieve the source code
1158
+ # for /tmp/monarch_test/class_closure.py and
1159
+ # /tmp/monarch_test/function_closure.py
1160
+ actor.get_or_spawn_controller(
1161
+ "source_loader", SourceLoaderControllerWithMockedSource
1162
+ ).get()
1163
+
1164
+ debugee = pm.spawn("debugee", ClosureDebugeeActor)
1165
+
1166
+ class_closure = load_class_closure()
1167
+ func_bp_true, func_bp_false = load_func_closure()
1168
+
1169
+ fut = debugee.debug_class_closure.call_one(class_closure)
1170
+ breakpoints = await _wait_for_breakpoints(debug_controller, 1)
1171
+ assert breakpoints[0].function == "class_closure.get_arg"
1172
+ assert breakpoints[0].lineno == 14
1173
+
1174
+ debug_controller.blocking_enter.call_one().get()
1175
+
1176
+ assert (
1177
+ "> /tmp/monarch_test/class_closure.py(14)get_arg()\n-> return self.arg"
1178
+ in outputs
1179
+ )
1180
+
1181
+ await fut
1182
+
1183
+ fut = debugee.debug_func.call_one(func_bp_false, class_closure)
1184
+ breakpoints = await _wait_for_breakpoints(debug_controller, 1)
1185
+ assert breakpoints[0].function == "class_closure.get_arg"
1186
+ assert breakpoints[0].lineno == 14
1187
+
1188
+ debug_controller.blocking_enter.call_one().get()
1189
+
1190
+ expected_backtrace = [
1191
+ (
1192
+ " /tmp/monarch_test/function_closure.py(5)func()\n"
1193
+ "-> return internal().get_arg() + arg"
1194
+ ),
1195
+ "\n",
1196
+ "> /tmp/monarch_test/class_closure.py(14)get_arg()\n-> return self.arg",
1197
+ "\n",
1198
+ "(Pdb) ",
1199
+ ]
1200
+ start = outputs.index(expected_backtrace[0])
1201
+ assert expected_backtrace == outputs[start : start + len(expected_backtrace)] # noqa
1202
+
1203
+ await fut
1204
+
1205
+ fut = debugee.debug_func.call_one(func_bp_true, class_closure)
1206
+ breakpoints = await _wait_for_breakpoints(debug_controller, 1)
1207
+ assert breakpoints[0].function == "function_closure.func"
1208
+ assert breakpoints[0].lineno == 5
1209
+
1210
+ debug_controller.blocking_enter.call_one().get()
1211
+
1212
+ assert (
1213
+ "> /tmp/monarch_test/function_closure.py(5)func()\n-> return internal().get_arg() + arg"
1214
+ in outputs
1215
+ )
1216
+ assert "Breakpoint 1 at /tmp/monarch_test/class_closure.py:10" in outputs
1217
+ assert (
1218
+ "> /tmp/monarch_test/class_closure.py(10)__init__()\n-> self.arg = arg"
1219
+ in outputs
1220
+ )
1221
+
1222
+ breakpoints = await _wait_for_breakpoints(debug_controller, 1)
1223
+ assert breakpoints[0].function == "class_closure.__init__"
1224
+ assert breakpoints[0].lineno == 10
1225
+
1226
+ debug_controller.blocking_enter.call_one().get()
1227
+
1228
+ breakpoints = await _wait_for_breakpoints(debug_controller, 1)
1229
+ assert breakpoints[0].function == "class_closure.get_arg"
1230
+ assert breakpoints[0].lineno == 14
1231
+
1232
+ debug_controller.blocking_enter.call_one().get()
1233
+
1234
+ breakpoints = debug_controller.list.call_one().get()
1235
+ assert len(breakpoints) == 0
1236
+
1237
+ await fut
1238
+ await pm.stop()