torchmonarch-nightly 2025.6.11__cp310-cp310-manylinux2014_x86_64.whl → 2025.6.13__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monarch/debugger.py ADDED
@@ -0,0 +1,377 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import logging
9
+ import sys
10
+ from dataclasses import dataclass
11
+ from typing import Dict, List, Tuple, Union
12
+
13
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
14
+ from monarch.actor_mesh import Actor, endpoint
15
+
16
+ from monarch.pdb_wrapper import DebuggerWrite
17
+
18
+ from monarch.proc_mesh import local_proc_mesh
19
+ from tabulate import tabulate
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ CANCEL_TOKEN = object()
26
+
27
+
28
+ async def _debugger_input(prompt=""):
29
+ return await asyncio.to_thread(input, prompt)
30
+
31
+
32
+ def _debugger_output(msg):
33
+ sys.stdout.write(msg)
34
+ sys.stdout.flush()
35
+
36
+
37
+ @dataclass
38
+ class DebugSessionInfo:
39
+ rank: int
40
+ coords: Dict[str, int]
41
+ hostname: str
42
+ actor_id: ActorId
43
+ function: str | None
44
+ lineno: int | None
45
+
46
+
47
+ class DebugSession:
48
+ """Represents a single session with a remote debugger."""
49
+
50
+ def __init__(
51
+ self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
52
+ ):
53
+ self.rank = rank
54
+ self.coords = coords
55
+ self.hostname = hostname
56
+ self.actor_id = actor_id
57
+ self._active = False
58
+ self._message_queue = asyncio.Queue()
59
+ self._task = None
60
+ self._pending_send_to_actor = asyncio.Queue()
61
+ self._outputs_since_last_input = []
62
+ self._function_lineno = None
63
+ self._need_read = False
64
+
65
+ async def _event_loop(self, line=None, suppress_output=False):
66
+ if not suppress_output:
67
+ # If the user had previously attached to this debug session,
68
+ # then it would have printed various messages from the
69
+ # message queue. When the user re-attaches, we want to
70
+ # print out all of the output that was printed since the
71
+ # last command sent to this session.
72
+ for output in self._outputs_since_last_input:
73
+ _debugger_output(output.payload.decode())
74
+
75
+ while True:
76
+ # When the user inputs "detach", it uses up a "read" message
77
+ # without actually responding to the actor being debugged. We
78
+ # can't manually reinsert the "read" message into the message queue,
79
+ # so instead the self._need_read flag indicates there's an additional
80
+ # "read" that we need to respond to.
81
+ if self._need_read:
82
+ self._need_read = False
83
+ message = "read"
84
+ else:
85
+ message = await self._message_queue.get()
86
+ if message == "detach":
87
+ # Return to the main outer debug loop.
88
+ break
89
+ elif message == "read":
90
+ break_after = False
91
+ if line is not None:
92
+ break_after = True
93
+ else:
94
+ line = await _debugger_input()
95
+ if line.strip("\n") == "detach":
96
+ self._need_read = True
97
+ break
98
+ else:
99
+ self._outputs_since_last_input = []
100
+ await self._pending_send_to_actor.put((line + "\n").encode())
101
+ line = None
102
+ if break_after:
103
+ break
104
+ elif message[0] == "write":
105
+ output = message[1]
106
+ # If the user sees this output but then detaches from the session,
107
+ # its useful to store all outputs since the last input so that
108
+ # they can be printed again when the user re-attaches.
109
+ self._outputs_since_last_input.append(output)
110
+ if not suppress_output:
111
+ _debugger_output(output.payload.decode())
112
+
113
+ if not suppress_output:
114
+ print(
115
+ f"Detaching from debug session for rank {self.rank} ({self.hostname})"
116
+ )
117
+
118
+ def get_info(self):
119
+ function = lineno = None
120
+ if self._function_lineno is not None:
121
+ function, lineno = self._function_lineno
122
+ return DebugSessionInfo(
123
+ self.rank, self.coords, self.hostname, self.actor_id, function, lineno
124
+ )
125
+
126
+ async def attach(self, line=None, suppress_output=False):
127
+ self._active = True
128
+ if not suppress_output:
129
+ print(f"Attached to debug session for rank {self.rank} ({self.hostname})")
130
+ self._task = asyncio.create_task(self._event_loop(line, suppress_output))
131
+ await self._task
132
+ if not suppress_output:
133
+ print(f"Detached from debug session for rank {self.rank} ({self.hostname})")
134
+ self._active = False
135
+
136
+ async def detach(self):
137
+ if self._active:
138
+ await self._message_queue.put("detach")
139
+
140
+ async def debugger_read(self, size: int) -> DebuggerWrite:
141
+ await self._message_queue.put("read")
142
+ input_data = await self._pending_send_to_actor.get()
143
+ if len(input_data) > size:
144
+ input_data = input_data[:size]
145
+ return DebuggerWrite(input_data, None, None)
146
+
147
+ async def debugger_write(self, write: DebuggerWrite) -> None:
148
+ if write.function is not None and write.lineno is not None:
149
+ self._function_lineno = (write.function, write.lineno)
150
+ await self._message_queue.put(("write", write))
151
+
152
+
153
+ class DebugCommand:
154
+ @staticmethod
155
+ def parse(line: str) -> Union["DebugCommand", None]:
156
+ parts = line.strip("\n").split(" ")
157
+ if len(parts) == 0:
158
+ return None
159
+ command = parts[0]
160
+ match command:
161
+ case "attach":
162
+ return Attach._parse(parts)
163
+ case "list":
164
+ return ListCommand()
165
+ case "quit":
166
+ return Quit()
167
+ case "cast":
168
+ return Cast._parse(parts)
169
+ case "help":
170
+ return Help()
171
+ case "continue":
172
+ return Continue()
173
+ case _:
174
+ print(
175
+ f"Unknown command {command}. Expected: attach | list | quit | cast | continue | help"
176
+ )
177
+ return None
178
+
179
+
180
+ @dataclass
181
+ class Attach(DebugCommand):
182
+ rank: int
183
+
184
+ @classmethod
185
+ def _parse(cls, parts: List[str]) -> "Attach":
186
+ if len(parts) != 2:
187
+ raise ValueError("Invalid attach command. Expected: attach <rank>")
188
+ try:
189
+ rank = int(parts[1])
190
+ except ValueError:
191
+ raise ValueError(f"Invalid rank {parts[1]}. Expected: int")
192
+ return cls(rank)
193
+
194
+
195
+ class ListCommand(DebugCommand):
196
+ pass
197
+
198
+
199
+ class Quit(DebugCommand):
200
+ pass
201
+
202
+
203
+ class Help(DebugCommand):
204
+ pass
205
+
206
+
207
+ class Continue(DebugCommand):
208
+ pass
209
+
210
+
211
+ @dataclass
212
+ class Cast(DebugCommand):
213
+ ranks: List[int] | None
214
+ command: str
215
+
216
+ @classmethod
217
+ def _parse(cls, parts: List[str]) -> "Cast":
218
+ if len(parts) < 3:
219
+ raise ValueError(
220
+ "Invalid cast command. Expected: cast {<r0,r1,...> | *} <command>"
221
+ )
222
+ str_ranks = parts[1]
223
+ command = " ".join(parts[2:])
224
+ if str_ranks == "*":
225
+ return cls(None, command)
226
+ else:
227
+ str_ranks = str_ranks.split(",")
228
+ if len(str_ranks) == 0:
229
+ raise ValueError(
230
+ "Invalid rank list for cast. Expected at least one rank."
231
+ )
232
+ ranks = []
233
+ for rank in str_ranks:
234
+ try:
235
+ ranks.append(int(rank))
236
+ except ValueError:
237
+ raise ValueError(f"Invalid rank {rank}. Expected: int")
238
+ return cls(ranks, command)
239
+
240
+
241
+ class DebugClient(Actor):
242
+ """
243
+ Single actor for both remote debuggers and users to talk to.
244
+
245
+ Handles multiple sessions simultanesouly
246
+ """
247
+
248
+ def __init__(self) -> None:
249
+ self.sessions = {} # rank -> DebugSession
250
+
251
+ @endpoint
252
+ async def wait_pending_session(self):
253
+ while len(self.sessions) == 0:
254
+ await asyncio.sleep(1)
255
+
256
+ @endpoint
257
+ async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]:
258
+ table_data = []
259
+ for _, session in self.sessions.items():
260
+ info = session.get_info()
261
+ table_data.append(
262
+ (
263
+ info.rank,
264
+ info.coords,
265
+ info.hostname,
266
+ info.actor_id,
267
+ info.function,
268
+ info.lineno,
269
+ )
270
+ )
271
+ table_data = sorted(table_data, key=lambda r: r[0])
272
+
273
+ headers = ["Rank", "Coords", "Hostname", "Actor ID", "Function", "Line No."]
274
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
275
+
276
+ return table_data
277
+
278
+ @endpoint
279
+ async def enter(self) -> None:
280
+ # pyre-ignore
281
+ await getattr(self, "list")._method(self) # noqa
282
+
283
+ while True:
284
+ try:
285
+ user_input = await _debugger_input("monarch_dbg> ")
286
+ command = DebugCommand.parse(user_input)
287
+ if isinstance(command, Help):
288
+ print("monarch_dbg commands:")
289
+ print("\tattach <rank> - attach to a debug session")
290
+ print("\tlist - list all debug sessions")
291
+ print("\tquit - exit the debugger, leaving all sessions in place")
292
+ print(
293
+ "\tcast {<r0,r1,...> | *} <command> - send a command to a comma-separated list of ranks, or all ranks"
294
+ )
295
+ print(
296
+ "\tcontinue - tell all ranks to continue execution, then exit the debugger"
297
+ )
298
+ print("\thelp - print this help message")
299
+ elif isinstance(command, Attach):
300
+ if command.rank not in self.sessions:
301
+ print(f"No debug session for rank {command.rank}")
302
+ else:
303
+ await self.sessions[command.rank].attach()
304
+ elif isinstance(command, ListCommand):
305
+ await getattr(self, "list")._method(self) # noqa
306
+ elif isinstance(command, Continue):
307
+ # Make sure all ranks have exited their debug sessions.
308
+ # If we sent "quit", it would raise BdbQuit, crashing
309
+ # the process, which probably isn't what we want.
310
+ while len(self.sessions) > 0:
311
+ tasks = []
312
+ for rank in self.sessions:
313
+ tasks.append(
314
+ self.sessions[rank].attach("c", suppress_output=True)
315
+ )
316
+ await asyncio.gather(*tasks)
317
+ return
318
+ elif isinstance(command, Quit):
319
+ return
320
+ elif isinstance(command, Cast):
321
+ if command.ranks is None:
322
+ ranks = self.sessions.keys()
323
+ else:
324
+ ranks = command.ranks
325
+ tasks = []
326
+ for rank in ranks:
327
+ if rank in self.sessions:
328
+ tasks.append(
329
+ self.sessions[rank].attach(
330
+ command.command,
331
+ suppress_output=True,
332
+ )
333
+ )
334
+ else:
335
+ print(f"No debug session for rank {rank}")
336
+ await asyncio.gather(*tasks)
337
+ except Exception as e:
338
+ print(f"Error processing command: {e}")
339
+
340
+ ##########################################################################
341
+ # Debugger APIs
342
+ #
343
+ # These endpoints are called by the remote debuggers to establish sessions
344
+ # and communicate with them.
345
+ @endpoint
346
+ async def debugger_session_start(
347
+ self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId
348
+ ) -> None:
349
+ # Create a session if it doesn't exist
350
+ if rank not in self.sessions:
351
+ self.sessions[rank] = DebugSession(rank, coords, hostname, actor_id)
352
+
353
+ @endpoint
354
+ async def debugger_session_end(self, rank: int) -> None:
355
+ """Detach from the current debug session."""
356
+ session = self.sessions.pop(rank)
357
+ await session.detach()
358
+
359
+ @endpoint
360
+ async def debugger_read(self, rank: int, size: int) -> DebuggerWrite | str:
361
+ """Read from the debug session for the given rank."""
362
+ session = self.sessions[rank]
363
+
364
+ return await session.debugger_read(size)
365
+
366
+ @endpoint
367
+ async def debugger_write(self, rank: int, write: DebuggerWrite) -> None:
368
+ """Write to the debug session for the given rank."""
369
+ session = self.sessions[rank]
370
+ await session.debugger_write(write)
371
+
372
+
373
+ async def init_debugging(actor_mesh: Actor) -> DebugClient:
374
+ debugger_proc_mesh = await local_proc_mesh(gpus=1, hosts=1)
375
+ debug_client_mesh = await debugger_proc_mesh.spawn("debug_client", DebugClient)
376
+ await actor_mesh._set_debug_client.call(debug_client_mesh)
377
+ return debug_client_mesh
@@ -4,7 +4,10 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ import atexit
7
8
  import logging
9
+ import os
10
+ import time
8
11
  import traceback
9
12
  from collections import deque
10
13
  from logging import Logger
@@ -22,6 +25,8 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarc
22
25
  ActorId,
23
26
  )
24
27
  from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
28
+ from monarch._rust_bindings.monarch_hyperactor.shape import Point
29
+
25
30
  from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
26
31
  from monarch.common.client import Client
27
32
  from monarch.common.controller_api import LogMessage, MessageResult
@@ -29,6 +34,7 @@ from monarch.common.device_mesh import DeviceMesh, no_mesh
29
34
  from monarch.common.invocation import DeviceException, RemoteException
30
35
  from monarch.controller.debugger import read as debugger_read, write as debugger_write
31
36
  from monarch.proc_mesh import ProcMesh
37
+ from monarch.rust_local_mesh import _get_worker_exec_info
32
38
  from pyre_extensions import none_throws
33
39
 
34
40
  logger: Logger = logging.getLogger(__name__)
@@ -72,18 +78,8 @@ class Controller(_Controller):
72
78
  def drain_and_stop(
73
79
  self,
74
80
  ) -> List[LogMessage | MessageResult | client.DebuggerMessage]:
75
- logger.info("rust controller shutting down")
76
- results = []
77
- for msg in self._drain_and_stop():
78
- if isinstance(msg, client.WorkerResponse):
79
- results.append(_worker_response_to_result(msg))
80
- elif isinstance(msg, client.LogMessage):
81
- results.append(LogMessage(msg.level, msg.message))
82
- elif isinstance(msg, client.DebuggerMessage):
83
- results.append(msg)
84
- else:
85
- raise RuntimeError(f"Unexpected message type {type(msg)}")
86
- return results
81
+ self._drain_and_stop()
82
+ return []
87
83
 
88
84
  def _run_debugger_loop(self, message: client.DebuggerMessage) -> None:
89
85
  if not isinstance(message.action, DebuggerAction.Paused):
@@ -158,7 +154,6 @@ def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
158
154
  traceback.FrameSummary("<unknown>", None, frame)
159
155
  for frame in exc.backtrace.split("\\n")
160
156
  ]
161
- logger.error(f"Worker {exc.actor_id} failed")
162
157
  return MessageResult(
163
158
  seq=result.seq,
164
159
  result=None,
@@ -169,7 +164,7 @@ def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
169
164
  controller_frames=None,
170
165
  worker_frames=worker_frames,
171
166
  source_actor_id=exc.actor_id,
172
- message=f"Worker {exc.actor_id} failed",
167
+ message=f"Remote function in {exc.actor_id} errored.",
173
168
  ),
174
169
  )
175
170
  elif isinstance(exc, client.Failure):
@@ -193,13 +188,75 @@ def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
193
188
  raise RuntimeError(f"Unknown exception type: {type(exc)}")
194
189
 
195
190
 
191
+ def _initialize_env(worker_point: Point, proc_id: str) -> None:
192
+ worker_rank = worker_point.rank
193
+ try:
194
+ _, worker_env = _get_worker_exec_info()
195
+ local_rank = worker_point["gpus"]
196
+ gpus_per_host = worker_point.size("gpus")
197
+ num_worker_procs = len(worker_point.shape)
198
+ process_env = {
199
+ **worker_env,
200
+ "HYPERACTOR_MANAGED_SUBPROCESS": "1",
201
+ "CUDA_VISIBLE_DEVICES": str(local_rank),
202
+ "NCCL_HOSTID": f"{proc_id}_host_{worker_rank // gpus_per_host}",
203
+ # This is needed to avoid a hard failure in ncclx when we do not
204
+ # have backend topology info (eg. on RE).
205
+ "NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
206
+ "LOCAL_RANK": str(local_rank),
207
+ "RANK": str(worker_rank),
208
+ "WORLD_SIZE": str(num_worker_procs),
209
+ "LOCAL_WORLD_SIZE": str(gpus_per_host),
210
+ }
211
+ os.environ.update(process_env)
212
+ except Exception:
213
+ traceback.print_exc()
214
+ raise
215
+
216
+
217
+ class MeshClient(Client):
218
+ def shutdown(
219
+ self,
220
+ destroy_pg: bool = True,
221
+ error_reason: Optional[RemoteException | DeviceException | Exception] = None,
222
+ ):
223
+ # return
224
+ if self.has_shutdown:
225
+ return
226
+ logger.info("shutting down the client gracefully")
227
+
228
+ atexit.unregister(self._atexit)
229
+ self._shutdown = True
230
+
231
+ # ensure all pending work is finished.
232
+ # all errors must be messaged back at this point
233
+ self.new_node_nocoalesce([], [], None, [])
234
+ self._request_status()
235
+
236
+ ttl = 60
237
+ start_time = time.time()
238
+ end_time = start_time + ttl
239
+ while ttl > 0 and self.last_assigned_seq > self.last_processed_seq:
240
+ ttl = end_time - time.time()
241
+ self.handle_next_message(ttl)
242
+ if self._pending_shutdown_error:
243
+ raise self._pending_shutdown_error
244
+
245
+ if ttl <= 0:
246
+ raise RuntimeError("shutdown timed out")
247
+
248
+ # we are not expecting anything more now, because we already
249
+ # waited for the responses
250
+ self.inner.drain_and_stop()
251
+
252
+
196
253
  def spawn_tensor_engine(proc_mesh: ProcMesh) -> DeviceMesh:
197
254
  # This argument to Controller
198
255
  # is currently only used for debug printing. It should be fixed to
199
256
  # report the proc ID instead of the rank it currently does.
200
257
  gpus = proc_mesh.sizes.get("gpus", 1)
201
258
  backend_ctrl = Controller(proc_mesh._proc_mesh)
202
- client = Client(backend_ctrl, proc_mesh.size(), gpus)
259
+ client = MeshClient(backend_ctrl, proc_mesh.size(), gpus)
203
260
  dm = DeviceMesh(
204
261
  client,
205
262
  NDSlice.new_row_major(list(proc_mesh.sizes.values())),
Binary file
monarch/pdb_wrapper.py ADDED
@@ -0,0 +1,135 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import bdb
8
+ import inspect
9
+ import io
10
+ import pdb # noqa
11
+ import socket
12
+ import sys
13
+ from dataclasses import dataclass
14
+
15
+ from typing import Dict, TYPE_CHECKING
16
+
17
+ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
18
+
19
+ if TYPE_CHECKING:
20
+ from monarch.debugger import DebugClient
21
+
22
+
23
+ @dataclass
24
+ class DebuggerWrite:
25
+ payload: bytes
26
+ function: str | None
27
+ lineno: int | None
28
+
29
+
30
+ class PdbWrapper(pdb.Pdb):
31
+ def __init__(
32
+ self,
33
+ rank: int,
34
+ coords: Dict[str, int],
35
+ actor_id: ActorId,
36
+ client_ref: "DebugClient",
37
+ header: str | None = None,
38
+ ):
39
+ self.rank = rank
40
+ self.coords = coords
41
+ self.header = header
42
+ self.actor_id = actor_id
43
+ self.client_ref = client_ref
44
+ # pyre-ignore
45
+ super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self))
46
+ self._first = True
47
+
48
+ def setup(self, *args, **kwargs):
49
+ r = super().setup(*args, **kwargs)
50
+ if self._first:
51
+ self._first = False
52
+ # when we enter the debugger, we want to present the user's stack frame
53
+ # not the nested one inside session.run. This means that the local
54
+ # variables are what gets printed, etc. To do this
55
+ # we first execute up 2 to get to that frame.
56
+ self.do_up(2)
57
+ return r
58
+
59
+ def set_continue(self) -> None:
60
+ r = super().set_continue()
61
+ if not self.breaks:
62
+ # no more breakpoints so this debugger will not
63
+ # be used again, and we detach from the controller io.
64
+ self.client_ref.debugger_session_end.call_one(self.rank).get()
65
+ # break cycle with itself before we exit
66
+ self.stdin = sys.stdin
67
+ self.stdout = sys.stdout
68
+ return r
69
+
70
+ def set_trace(self):
71
+ self.client_ref.debugger_session_start.call_one(
72
+ self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id
73
+ ).get()
74
+ if self.header:
75
+ self.message(self.header)
76
+ super().set_trace()
77
+
78
+
79
+ class ReadWrapper(io.RawIOBase):
80
+ def __init__(self, session: "PdbWrapper"):
81
+ self.session = session
82
+
83
+ def readinto(self, b):
84
+ response = self.session.client_ref.debugger_read.call_one(
85
+ self.session.rank, len(b)
86
+ ).get()
87
+ if response == "detach":
88
+ # this gets injected by the worker event loop to
89
+ # get the worker thread to exit on an Exit command.
90
+ raise bdb.BdbQuit
91
+ assert isinstance(response, DebuggerWrite) and len(response.payload) <= len(b)
92
+ b[: len(response.payload)] = response.payload
93
+ return len(response.payload)
94
+
95
+ def readable(self) -> bool:
96
+ return True
97
+
98
+ @classmethod
99
+ def create(cls, session: "PdbWrapper"):
100
+ return io.TextIOWrapper(io.BufferedReader(cls(session)))
101
+
102
+
103
+ class WriteWrapper:
104
+ def __init__(self, session: "PdbWrapper"):
105
+ self.session = session
106
+
107
+ def writable(self) -> bool:
108
+ return True
109
+
110
+ def write(self, s: str):
111
+ function = None
112
+ lineno = None
113
+ if self.session.curframe is not None:
114
+ # pyre-ignore
115
+ function = f"{inspect.getmodulename(self.session.curframe.f_code.co_filename)}.{self.session.curframe.f_code.co_name}"
116
+ # pyre-ignore
117
+ lineno = self.session.curframe.f_lineno
118
+ self.session.client_ref.debugger_write.call_one(
119
+ self.session.rank,
120
+ DebuggerWrite(
121
+ s.encode(),
122
+ function,
123
+ lineno,
124
+ ),
125
+ ).get()
126
+
127
+ def flush(self):
128
+ pass
129
+
130
+
131
+ def remote_breakpointhook(
132
+ rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient"
133
+ ):
134
+ ds = PdbWrapper(rank, coords, actor_id, client_ref)
135
+ ds.set_trace()
monarch/proc_mesh.py CHANGED
@@ -4,9 +4,11 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-strict
8
+
7
9
  import sys
8
10
 
9
- from typing import Any, cast, Optional, Type, TypeVar
11
+ from typing import Any, cast, List, Optional, Type, TypeVar
10
12
 
11
13
  import monarch
12
14
  from monarch import ActorFuture as Future
@@ -18,7 +20,7 @@ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//mon
18
20
  )
19
21
  from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
20
22
  from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
21
- from monarch._rust_bindings.monarch_hyperactor.shape import Shape
23
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
22
24
  from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
23
25
 
24
26
  from monarch.common._device_utils import _local_device_count
@@ -46,14 +48,16 @@ class ProcMesh(MeshTrait):
46
48
  def __init__(self, hy_proc_mesh: HyProcMesh) -> None:
47
49
  self._proc_mesh = hy_proc_mesh
48
50
  self._mailbox: Mailbox = self._proc_mesh.client
49
- self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)
51
+ self._rdma_manager: RDMAManager = self._spawn_blocking(
52
+ "rdma_manager", RDMAManager
53
+ )
50
54
 
51
55
  @property
52
- def _ndslice(self):
56
+ def _ndslice(self) -> Slice:
53
57
  return self._proc_mesh.shape.ndslice
54
58
 
55
59
  @property
56
- def _labels(self):
60
+ def _labels(self) -> List[str]:
57
61
  return self._proc_mesh.shape.labels
58
62
 
59
63
  def _new_with_shape(self, shape: Shape) -> "ProcMesh":