torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,7 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
@@ -0,0 +1,245 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import logging
10
+ import traceback
11
+ from collections import deque
12
+ from logging import Logger
13
+ from typing import List, NamedTuple, Optional, Sequence, Union
14
+
15
+ from monarch._rust_bindings.monarch_extension import (
16
+ client,
17
+ controller,
18
+ debugger,
19
+ tensor_worker,
20
+ )
21
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
22
+ ClientActor,
23
+ SystemSnapshotFilter,
24
+ WorldState,
25
+ )
26
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
27
+ ActorId,
28
+ Proc,
29
+ )
30
+
31
+ from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
32
+ from monarch.common.controller_api import LogMessage, MessageResult
33
+ from monarch.common.device_mesh import no_mesh
34
+ from monarch.common.invocation import DeviceException, RemoteException
35
+ from monarch.common.messages import SupportsToRustMessage
36
+ from monarch.common.shape import NDSlice
37
+ from monarch.common.tensor import Tensor
38
+ from monarch.controller.debugger import read as debugger_read, write as debugger_write
39
+ from pyre_extensions import none_throws
40
+
41
+ logger: Logger = logging.getLogger(__name__)
42
+
43
+
44
+ class RustController:
45
+ def __init__(
46
+ self,
47
+ proc: Proc,
48
+ client_actor: ClientActor,
49
+ controller_id: ActorId,
50
+ worker_world_name: str,
51
+ ) -> None:
52
+ self._controller_actor = controller_id
53
+ self._proc = proc
54
+ self._actor = client_actor
55
+ # Attach the client to the controller
56
+ # Errors will be raised if someone else has attached it already.
57
+ self._actor.attach(self._controller_actor)
58
+ self._worker_world_name = worker_world_name
59
+
60
+ # Buffer for messages unrelated to debugging that are received while a
61
+ # debugger session is active.
62
+ self._non_debugger_pending_messages: deque[
63
+ Optional[client.LogMessage | client.WorkerResponse]
64
+ ] = deque()
65
+ self._pending_debugger_sessions: deque[ActorId] = deque()
66
+
67
+ def send(
68
+ self,
69
+ ranks: Union[NDSlice, List[NDSlice]],
70
+ msg: NamedTuple,
71
+ ) -> None:
72
+ self._actor.send_obj(self._controller_actor, ranks, msg)
73
+
74
+ def drop_refs(self, refs: Sequence[tensor_worker.Ref]) -> None:
75
+ self._actor.drop_refs(self._controller_actor, list(refs))
76
+
77
+ def node(
78
+ self,
79
+ seq: int,
80
+ defs: Sequence["Tensor"],
81
+ uses: Sequence["Tensor"],
82
+ ) -> None:
83
+ node = controller.Node(
84
+ seq=seq,
85
+ defs=[tensor_worker.Ref(id=t.ref) for t in defs if t.ref is not None],
86
+ uses=[tensor_worker.Ref(id=t.ref) for t in uses if t.ref is not None],
87
+ )
88
+
89
+ self._actor.send(self._controller_actor, node.serialize())
90
+
91
+ def next_message(
92
+ self, timeout: Optional[float]
93
+ ) -> Optional[LogMessage | MessageResult]:
94
+ if self._non_debugger_pending_messages:
95
+ msg = self._non_debugger_pending_messages.popleft()
96
+ else:
97
+ msg = self._actor.get_next_message(
98
+ timeout_msec=int((timeout or 0.0) * 1000.0)
99
+ )
100
+ if msg is None:
101
+ return None
102
+
103
+ if isinstance(msg, client.WorkerResponse):
104
+ return _worker_response_to_result(msg)
105
+ elif isinstance(msg, client.LogMessage):
106
+ return LogMessage(msg.level, msg.message)
107
+ elif isinstance(msg, client.DebuggerMessage):
108
+ self._run_debugger_loop(msg)
109
+
110
+ def stop_mesh(self) -> None:
111
+ logger.info("rust controller stopping the system")
112
+ self._actor.stop_worlds(
113
+ [self._controller_actor.world_name, self._worker_world_name]
114
+ )
115
+
116
+ def drain_and_stop(
117
+ self,
118
+ ) -> List[LogMessage | MessageResult | client.DebuggerMessage]:
119
+ logger.info("rust controller shutting down")
120
+ results = []
121
+ for msg in self._actor.drain_and_stop():
122
+ if isinstance(msg, client.WorkerResponse):
123
+ results.append(_worker_response_to_result(msg))
124
+ elif isinstance(msg, client.LogMessage):
125
+ results.append(LogMessage(msg.level, msg.message))
126
+ elif isinstance(msg, client.DebuggerMessage):
127
+ results.append(msg)
128
+ else:
129
+ raise RuntimeError(f"Unexpected message type {type(msg)}")
130
+ return results
131
+
132
+ def _run_debugger_loop(self, message: client.DebuggerMessage) -> None:
133
+ if not isinstance(message.action, DebuggerAction.Paused):
134
+ raise RuntimeError(
135
+ f"Unexpected debugger message {message} when no debugger session is running"
136
+ )
137
+
138
+ self._pending_debugger_sessions.append(message.debugger_actor_id)
139
+ while self._pending_debugger_sessions:
140
+ debugger_actor_id = self._pending_debugger_sessions.popleft()
141
+ rank = debugger_actor_id.rank
142
+ proc_id = debugger_actor_id.proc_id
143
+ debugger_write(
144
+ f"pdb attached to proc {proc_id} with rank {rank}, debugger actor {debugger_actor_id} \n"
145
+ )
146
+
147
+ self._actor.send(
148
+ debugger_actor_id,
149
+ debugger.DebuggerMessage(action=DebuggerAction.Attach()).serialize(),
150
+ )
151
+
152
+ while True:
153
+ # TODO: Add appropriate timeout.
154
+ msg = self._actor.get_next_message(timeout_msec=None)
155
+
156
+ if not isinstance(msg, client.DebuggerMessage):
157
+ self._non_debugger_pending_messages.append(msg)
158
+ continue
159
+
160
+ if msg.debugger_actor_id != debugger_actor_id:
161
+ if isinstance(msg.action, DebuggerAction.Paused):
162
+ self._pending_debugger_sessions.append(msg.debugger_actor_id)
163
+ continue
164
+ else:
165
+ raise RuntimeError(
166
+ f"unexpected debugger message {msg} from rank {msg.debugger_actor_id.rank} "
167
+ f"when debugging rank {debugger_actor_id.rank}"
168
+ )
169
+
170
+ action = msg.action
171
+ if isinstance(action, DebuggerAction.Detach):
172
+ break
173
+ elif isinstance(action, DebuggerAction.Read):
174
+ self._actor.send(
175
+ debugger_actor_id,
176
+ debugger.DebuggerMessage(
177
+ action=DebuggerAction.Write(
178
+ bytes=debugger_read(action.requested_size)
179
+ )
180
+ ).serialize(),
181
+ )
182
+ elif isinstance(action, DebuggerAction.Write):
183
+ debugger_write(
184
+ debugger.get_bytes_from_write_action(action).decode()
185
+ )
186
+ else:
187
+ raise RuntimeError(
188
+ f"unexpected debugger message {msg} when debugging rank {debugger_actor_id.rank}"
189
+ )
190
+
191
+ def worker_world_state(self) -> WorldState:
192
+ worlds_state = self._actor.world_state(
193
+ SystemSnapshotFilter(worlds=[self._worker_world_name])
194
+ )
195
+
196
+ return worlds_state[self._worker_world_name]
197
+
198
+
199
+ # TODO: Handling conversion of the response can move to a separate module over time
200
+ # especially as we have structured error messages.
201
+ def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
202
+ if not result.is_exception():
203
+ # The result of the message needs to be unwrapped on a real device.
204
+ # Staying as a fake tensor will fail the tensor deserialization.
205
+ with no_mesh.activate():
206
+ return MessageResult(result.seq, result.result(), None)
207
+ exc = none_throws(result.exception())
208
+ if isinstance(exc, client.Error):
209
+ worker_frames = [
210
+ traceback.FrameSummary("<unknown>", None, frame)
211
+ for frame in exc.backtrace.split("\\n")
212
+ ]
213
+ logger.error(f"Worker {exc.actor_id} failed")
214
+ return MessageResult(
215
+ seq=result.seq,
216
+ result=None,
217
+ error=RemoteException(
218
+ seq=exc.caused_by_seq,
219
+ exception=RuntimeError(exc.backtrace),
220
+ controller_frame_index=0, # TODO: fix this once we have recording support in rust
221
+ controller_frames=None,
222
+ worker_frames=worker_frames,
223
+ source_actor_id=exc.actor_id,
224
+ message=f"Worker {exc.actor_id} failed",
225
+ ),
226
+ )
227
+ elif isinstance(exc, client.Failure):
228
+ frames = [
229
+ traceback.FrameSummary("<unknown>", None, frame)
230
+ for frame in exc.backtrace.split("\n")
231
+ ]
232
+ reason = f"Actor {exc.actor_id} crashed on {exc.address}, check the host log for details"
233
+ logger.error(reason)
234
+ return MessageResult(
235
+ seq=0, # seq is not consumed for DeviceException; it will be directly thrown by the client
236
+ result=None,
237
+ error=DeviceException(
238
+ exception=RuntimeError(reason),
239
+ frames=frames,
240
+ source_actor_id=exc.actor_id,
241
+ message=reason,
242
+ ),
243
+ )
244
+ else:
245
+ raise RuntimeError(f"Unknown exception type: {type(exc)}")
monarch/debugger.py ADDED
@@ -0,0 +1,379 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import logging
9
+ import sys
10
+ from dataclasses import dataclass
11
+ from typing import Dict, List, Tuple, Union
12
+
13
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
14
+ from monarch.actor_mesh import Actor, ActorMeshRef, endpoint
15
+
16
+ from monarch.pdb_wrapper import DebuggerWrite
17
+
18
+ from monarch.proc_mesh import local_proc_mesh
19
+ from tabulate import tabulate
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ CANCEL_TOKEN = object()
26
+
27
+
28
+ async def _debugger_input(prompt=""):
29
+ return await asyncio.to_thread(input, prompt)
30
+
31
+
32
+ def _debugger_output(msg):
33
+ sys.stdout.write(msg)
34
+ sys.stdout.flush()
35
+
36
+
37
+ @dataclass
38
+ class DebugSessionInfo:
39
+ rank: int
40
+ coords: Dict[str, int]
41
+ hostname: str
42
+ actor_id: ActorId
43
+ function: str | None
44
+ lineno: int | None
45
+
46
+
47
+ class DebugSession:
48
+ """Represents a single session with a remote debugger."""
49
+
50
+ def __init__(
51
+ self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
52
+ ):
53
+ self.rank = rank
54
+ self.coords = coords
55
+ self.hostname = hostname
56
+ self.actor_id = actor_id
57
+ self._active = False
58
+ self._message_queue = asyncio.Queue()
59
+ self._task = None
60
+ self._pending_send_to_actor = asyncio.Queue()
61
+ self._outputs_since_last_input = []
62
+ self._function_lineno = None
63
+ self._need_read = False
64
+
65
+ async def _event_loop(self, line=None, suppress_output=False):
66
+ if not suppress_output:
67
+ # If the user had previously attached to this debug session,
68
+ # then it would have printed various messages from the
69
+ # message queue. When the user re-attaches, we want to
70
+ # print out all of the output that was printed since the
71
+ # last command sent to this session.
72
+ for output in self._outputs_since_last_input:
73
+ _debugger_output(output.payload.decode())
74
+
75
+ while True:
76
+ # When the user inputs "detach", it uses up a "read" message
77
+ # without actually responding to the actor being debugged. We
78
+ # can't manually reinsert the "read" message into the message queue,
79
+ # so instead the self._need_read flag indicates there's an additional
80
+ # "read" that we need to respond to.
81
+ if self._need_read:
82
+ self._need_read = False
83
+ message = "read"
84
+ else:
85
+ message = await self._message_queue.get()
86
+ if message == "detach":
87
+ # Return to the main outer debug loop.
88
+ break
89
+ elif message == "read":
90
+ break_after = False
91
+ if line is not None:
92
+ break_after = True
93
+ else:
94
+ line = await _debugger_input()
95
+ if line.strip("\n") == "detach":
96
+ self._need_read = True
97
+ break
98
+ else:
99
+ self._outputs_since_last_input = []
100
+ await self._pending_send_to_actor.put((line + "\n").encode())
101
+ line = None
102
+ if break_after:
103
+ break
104
+ elif message[0] == "write":
105
+ output = message[1]
106
+ # If the user sees this output but then detaches from the session,
107
+ # its useful to store all outputs since the last input so that
108
+ # they can be printed again when the user re-attaches.
109
+ self._outputs_since_last_input.append(output)
110
+ if not suppress_output:
111
+ _debugger_output(output.payload.decode())
112
+
113
+ if not suppress_output:
114
+ print(
115
+ f"Detaching from debug session for rank {self.rank} ({self.hostname})"
116
+ )
117
+
118
+ def get_info(self):
119
+ function = lineno = None
120
+ if self._function_lineno is not None:
121
+ function, lineno = self._function_lineno
122
+ return DebugSessionInfo(
123
+ self.rank, self.coords, self.hostname, self.actor_id, function, lineno
124
+ )
125
+
126
+ async def attach(self, line=None, suppress_output=False):
127
+ self._active = True
128
+ if not suppress_output:
129
+ print(f"Attached to debug session for rank {self.rank} ({self.hostname})")
130
+ self._task = asyncio.create_task(self._event_loop(line, suppress_output))
131
+ await self._task
132
+ if not suppress_output:
133
+ print(f"Detached from debug session for rank {self.rank} ({self.hostname})")
134
+ self._active = False
135
+
136
+ async def detach(self):
137
+ if self._active:
138
+ await self._message_queue.put("detach")
139
+
140
+ async def debugger_read(self, size: int) -> DebuggerWrite:
141
+ await self._message_queue.put("read")
142
+ input_data = await self._pending_send_to_actor.get()
143
+ if len(input_data) > size:
144
+ input_data = input_data[:size]
145
+ return DebuggerWrite(input_data, None, None)
146
+
147
+ async def debugger_write(self, write: DebuggerWrite) -> None:
148
+ if write.function is not None and write.lineno is not None:
149
+ self._function_lineno = (write.function, write.lineno)
150
+ await self._message_queue.put(("write", write))
151
+
152
+
153
+ class DebugCommand:
154
+ @staticmethod
155
+ def parse(line: str) -> Union["DebugCommand", None]:
156
+ parts = line.strip("\n").split(" ")
157
+ if len(parts) == 0:
158
+ return None
159
+ command = parts[0]
160
+ match command:
161
+ case "attach":
162
+ return Attach._parse(parts)
163
+ case "list":
164
+ return ListCommand()
165
+ case "quit":
166
+ return Quit()
167
+ case "cast":
168
+ return Cast._parse(parts)
169
+ case "help":
170
+ return Help()
171
+ case "continue":
172
+ return Continue()
173
+ case _:
174
+ print(
175
+ f"Unknown command {command}. Expected: attach | list | quit | cast | continue | help"
176
+ )
177
+ return None
178
+
179
+
180
+ @dataclass
181
+ class Attach(DebugCommand):
182
+ rank: int
183
+
184
+ @classmethod
185
+ def _parse(cls, parts: List[str]) -> "Attach":
186
+ if len(parts) != 2:
187
+ raise ValueError("Invalid attach command. Expected: attach <rank>")
188
+ try:
189
+ rank = int(parts[1])
190
+ except ValueError:
191
+ raise ValueError(f"Invalid rank {parts[1]}. Expected: int")
192
+ return cls(rank)
193
+
194
+
195
+ class ListCommand(DebugCommand):
196
+ pass
197
+
198
+
199
+ class Quit(DebugCommand):
200
+ pass
201
+
202
+
203
+ class Help(DebugCommand):
204
+ pass
205
+
206
+
207
+ class Continue(DebugCommand):
208
+ pass
209
+
210
+
211
+ @dataclass
212
+ class Cast(DebugCommand):
213
+ ranks: List[int] | None
214
+ command: str
215
+
216
+ @classmethod
217
+ def _parse(cls, parts: List[str]) -> "Cast":
218
+ if len(parts) < 3:
219
+ raise ValueError(
220
+ "Invalid cast command. Expected: cast {<r0,r1,...> | *} <command>"
221
+ )
222
+ str_ranks = parts[1]
223
+ command = " ".join(parts[2:])
224
+ if str_ranks == "*":
225
+ return cls(None, command)
226
+ else:
227
+ str_ranks = str_ranks.split(",")
228
+ if len(str_ranks) == 0:
229
+ raise ValueError(
230
+ "Invalid rank list for cast. Expected at least one rank."
231
+ )
232
+ ranks = []
233
+ for rank in str_ranks:
234
+ try:
235
+ ranks.append(int(rank))
236
+ except ValueError:
237
+ raise ValueError(f"Invalid rank {rank}. Expected: int")
238
+ return cls(ranks, command)
239
+
240
+
241
+ class DebugClient(Actor):
242
+ """
243
+ Single actor for both remote debuggers and users to talk to.
244
+
245
+ Handles multiple sessions simultanesouly
246
+ """
247
+
248
+ def __init__(self) -> None:
249
+ self.sessions = {} # rank -> DebugSession
250
+
251
+ @endpoint
252
+ async def wait_pending_session(self):
253
+ while len(self.sessions) == 0:
254
+ await asyncio.sleep(1)
255
+
256
+ @endpoint
257
+ async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]:
258
+ table_data = []
259
+ for _, session in self.sessions.items():
260
+ info = session.get_info()
261
+ table_data.append(
262
+ (
263
+ info.rank,
264
+ info.coords,
265
+ info.hostname,
266
+ info.actor_id,
267
+ info.function,
268
+ info.lineno,
269
+ )
270
+ )
271
+ table_data = sorted(table_data, key=lambda r: r[0])
272
+
273
+ headers = ["Rank", "Coords", "Hostname", "Actor ID", "Function", "Line No."]
274
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
275
+
276
+ return table_data
277
+
278
+ @endpoint
279
+ async def enter(self) -> None:
280
+ # pyre-ignore
281
+ await getattr(self, "list")._method(self) # noqa
282
+
283
+ while True:
284
+ try:
285
+ user_input = await _debugger_input("monarch_dbg> ")
286
+ command = DebugCommand.parse(user_input)
287
+ if isinstance(command, Help):
288
+ print("monarch_dbg commands:")
289
+ print("\tattach <rank> - attach to a debug session")
290
+ print("\tlist - list all debug sessions")
291
+ print("\tquit - exit the debugger, leaving all sessions in place")
292
+ print(
293
+ "\tcast {<r0,r1,...> | *} <command> - send a command to a comma-separated list of ranks, or all ranks"
294
+ )
295
+ print(
296
+ "\tcontinue - tell all ranks to continue execution, then exit the debugger"
297
+ )
298
+ print("\thelp - print this help message")
299
+ elif isinstance(command, Attach):
300
+ if command.rank not in self.sessions:
301
+ print(f"No debug session for rank {command.rank}")
302
+ else:
303
+ await self.sessions[command.rank].attach()
304
+ elif isinstance(command, ListCommand):
305
+ await getattr(self, "list")._method(self) # noqa
306
+ elif isinstance(command, Continue):
307
+ # Make sure all ranks have exited their debug sessions.
308
+ # If we sent "quit", it would raise BdbQuit, crashing
309
+ # the process, which probably isn't what we want.
310
+ while len(self.sessions) > 0:
311
+ tasks = []
312
+ for rank in self.sessions:
313
+ tasks.append(
314
+ self.sessions[rank].attach("c", suppress_output=True)
315
+ )
316
+ await asyncio.gather(*tasks)
317
+ return
318
+ elif isinstance(command, Quit):
319
+ return
320
+ elif isinstance(command, Cast):
321
+ if command.ranks is None:
322
+ ranks = self.sessions.keys()
323
+ else:
324
+ ranks = command.ranks
325
+ tasks = []
326
+ for rank in ranks:
327
+ if rank in self.sessions:
328
+ tasks.append(
329
+ self.sessions[rank].attach(
330
+ command.command,
331
+ suppress_output=True,
332
+ )
333
+ )
334
+ else:
335
+ print(f"No debug session for rank {rank}")
336
+ await asyncio.gather(*tasks)
337
+ except Exception as e:
338
+ print(f"Error processing command: {e}")
339
+
340
+ ##########################################################################
341
+ # Debugger APIs
342
+ #
343
+ # These endpoints are called by the remote debuggers to establish sessions
344
+ # and communicate with them.
345
+ @endpoint
346
+ async def debugger_session_start(
347
+ self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
348
+ ) -> None:
349
+ # Create a session if it doesn't exist
350
+ if rank not in self.sessions:
351
+ self.sessions[rank] = DebugSession(rank, coords, hostname, actor_id)
352
+
353
+ @endpoint
354
+ async def debugger_session_end(self, rank: int) -> None:
355
+ """Detach from the current debug session."""
356
+ session = self.sessions.pop(rank)
357
+ await session.detach()
358
+
359
+ @endpoint
360
+ async def debugger_read(self, rank: int, size: int) -> DebuggerWrite | str:
361
+ """Read from the debug session for the given rank."""
362
+ session = self.sessions[rank]
363
+
364
+ return await session.debugger_read(size)
365
+
366
+ @endpoint
367
+ async def debugger_write(self, rank: int, write: DebuggerWrite) -> None:
368
+ """Write to the debug session for the given rank."""
369
+ session = self.sessions[rank]
370
+ await session.debugger_write(write)
371
+
372
+
373
+ async def init_debugging(
374
+ actor_mesh: ActorMeshRef,
375
+ ) -> ActorMeshRef[DebugClient]:
376
+ debugger_proc_mesh = await local_proc_mesh(gpus=1, hosts=1)
377
+ debug_client_mesh = await debugger_proc_mesh.spawn("debug_client", DebugClient)
378
+ await actor_mesh._set_debug_client.call(debug_client_mesh)
379
+ return debug_client_mesh
monarch/fetch.py ADDED
@@ -0,0 +1,55 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ """
9
+ This is a utility file for fetching a shard of a tensor from remote.
10
+ """
11
+
12
+ from typing import TypeVar
13
+
14
+ from monarch.common.device_mesh import no_mesh
15
+
16
+ from monarch.common.future import Future
17
+
18
+ from monarch.common.remote import _call_on_shard_and_fetch
19
+
20
+ T = TypeVar("T")
21
+
22
+
23
+ def fetch_shard(
24
+ obj: T, shard: dict[str, int] | None = None, **kwargs: int
25
+ ) -> Future[T]:
26
+ """
27
+ Retrieve the shard at `coordinates` of the current device mesh of each
28
+ tensor in obj. All tensors in `obj` will be fetched to the CPU device.
29
+ obj - a pytree containing the tensors the fetch
30
+ shard - a dictionary from mesh dimension name to coordinate of the shard
31
+ If None, this will fetch from coordinate 0 for all dimensions (useful after all_reduce/all_gather)
32
+ preprocess - a
33
+ **kwargs - additional keyword arguments are added as entries to the shard dictionary
34
+ """
35
+ if kwargs:
36
+ if shard is None:
37
+ shard = {}
38
+ shard.update(kwargs)
39
+
40
+ return _call_on_shard_and_fetch(
41
+ None, lambda *args, **kwargs: None, obj, shard=shard
42
+ )
43
+
44
+
45
+ def show(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> object:
46
+ v = inspect(obj, shard=shard, **kwargs)
47
+ # pyre-ignore
48
+ from torchshow import show # @manual
49
+
50
+ with no_mesh.activate():
51
+ return show(v)
52
+
53
+
54
+ def inspect(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> T:
55
+ return fetch_shard(obj, shard=shard, **kwargs).result()