torchmonarch-nightly 2025.6.12__cp310-cp310-manylinux2014_x86_64.whl → 2025.6.14__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monarch/_rust_bindings.so CHANGED
Binary file
monarch/_testing.py CHANGED
@@ -10,7 +10,7 @@ import logging
10
10
  import tempfile
11
11
  import time
12
12
  from contextlib import contextmanager, ExitStack
13
- from typing import Callable, Generator, Optional
13
+ from typing import Any, Callable, Dict, Generator, Literal, Optional
14
14
 
15
15
  import monarch_supervisor
16
16
  from monarch.common.client import Client
@@ -18,6 +18,8 @@ from monarch.common.device_mesh import DeviceMesh
18
18
  from monarch.common.invocation import DeviceException, RemoteException
19
19
  from monarch.common.shape import NDSlice
20
20
  from monarch.controller.backend import ProcessBackend
21
+ from monarch.mesh_controller import spawn_tensor_engine
22
+ from monarch.proc_mesh import proc_mesh, ProcMesh
21
23
  from monarch.python_local_mesh import PythonLocalContext
22
24
  from monarch.rust_local_mesh import (
23
25
  local_mesh,
@@ -50,6 +52,7 @@ class TestingContext:
50
52
  self.cleanup = ExitStack()
51
53
  self._py_process_cache = {}
52
54
  self._rust_process_cache = None
55
+ self._proc_mesh_cache: Dict[Any, ProcMesh] = {}
53
56
 
54
57
  @contextmanager
55
58
  def _get_context(self, num_hosts, gpu_per_host):
@@ -75,16 +78,14 @@ class TestingContext:
75
78
 
76
79
  @contextmanager
77
80
  def local_py_device_mesh(
78
- self, num_hosts, gpu_per_host, activate=True
81
+ self,
82
+ num_hosts,
83
+ gpu_per_host,
79
84
  ) -> Generator[DeviceMesh, None, None]:
80
85
  ctx, hosts, processes = self._processes(num_hosts, gpu_per_host)
81
86
  dm = world_mesh(ctx, hosts, gpu_per_host, _processes=processes)
82
87
  try:
83
- if activate:
84
- with dm.activate():
85
- yield dm
86
- else:
87
- yield dm
88
+ yield dm
88
89
  dm.client.shutdown(destroy_pg=False)
89
90
  except Exception:
90
91
  # abnormal exit, so we just make sure we do not try to communicate in destructors,
@@ -97,7 +98,6 @@ class TestingContext:
97
98
  self,
98
99
  num_hosts,
99
100
  gpu_per_host,
100
- activate: bool = True,
101
101
  controller_params=None,
102
102
  ) -> Generator[DeviceMesh, None, None]:
103
103
  # Create a new system and mesh for test.
@@ -115,11 +115,7 @@ class TestingContext:
115
115
  controller_params=controller_params,
116
116
  ) as dm:
117
117
  try:
118
- if activate:
119
- with dm.activate():
120
- yield dm
121
- else:
122
- yield dm
118
+ yield dm
123
119
  dm.exit()
124
120
  except Exception:
125
121
  dm.client._shutdown = True
@@ -129,21 +125,57 @@ class TestingContext:
129
125
  # pyre-ignore: Undefined attribute
130
126
  dm.client.inner._actor.stop()
131
127
 
128
+ @contextmanager
129
+ def local_engine_on_proc_mesh(
130
+ self,
131
+ num_hosts,
132
+ gpu_per_host,
133
+ ) -> Generator[DeviceMesh, None, None]:
134
+ key = (num_hosts, gpu_per_host)
135
+ if key not in self._proc_mesh_cache:
136
+ self._proc_mesh_cache[key] = proc_mesh(
137
+ hosts=num_hosts, gpus=gpu_per_host
138
+ ).get()
139
+
140
+ dm = spawn_tensor_engine(self._proc_mesh_cache[key])
141
+ dm = dm.rename(hosts="host", gpus="gpu")
142
+ try:
143
+ yield dm
144
+ dm.exit()
145
+ except Exception as e:
146
+ # abnormal exit, so we just make sure we do not try to communicate in destructors,
147
+ # but we do notn wait for workers to exit since we do not know what state they are in.
148
+ dm.client._shutdown = True
149
+ raise
150
+
132
151
  @contextmanager
133
152
  def local_device_mesh(
134
- self, num_hosts, gpu_per_host, activate=True, rust=False, controller_params=None
153
+ self,
154
+ num_hosts,
155
+ gpu_per_host,
156
+ activate=True,
157
+ backend: Literal["py", "rs", "mesh"] = "py",
158
+ controller_params=None,
135
159
  ) -> Generator[DeviceMesh, None, None]:
136
160
  start = time.time()
137
- if rust:
161
+ if backend == "rs":
138
162
  generator = self.local_rust_device_mesh(
139
- num_hosts, gpu_per_host, activate, controller_params=controller_params
163
+ num_hosts, gpu_per_host, controller_params=controller_params
140
164
  )
165
+ elif backend == "py":
166
+ generator = self.local_py_device_mesh(num_hosts, gpu_per_host)
167
+ elif backend == "mesh":
168
+ generator = self.local_engine_on_proc_mesh(num_hosts, gpu_per_host)
141
169
  else:
142
- generator = self.local_py_device_mesh(num_hosts, gpu_per_host, activate)
170
+ raise ValueError(f"invalid backend: {backend}")
143
171
  with generator as dm:
144
172
  end = time.time()
145
173
  logging.info("initialized mesh in {:.2f}s".format(end - start))
146
- yield dm
174
+ if activate:
175
+ with dm.activate():
176
+ yield dm
177
+ else:
178
+ yield dm
147
179
  start = time.time()
148
180
  end = time.time()
149
181
  logging.info("shutdown mesh in {:.2f}s".format(end - start))
monarch/actor_mesh.py CHANGED
@@ -15,6 +15,7 @@ import inspect
15
15
  import itertools
16
16
  import logging
17
17
  import random
18
+ import sys
18
19
  import traceback
19
20
 
20
21
  from dataclasses import dataclass
@@ -37,6 +38,7 @@ from typing import (
37
38
  ParamSpec,
38
39
  Tuple,
39
40
  Type,
41
+ TYPE_CHECKING,
40
42
  TypeVar,
41
43
  )
42
44
 
@@ -57,6 +59,10 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Sh
57
59
 
58
60
  from monarch.common.pickle_flatten import flatten, unflatten
59
61
  from monarch.common.shape import MeshTrait, NDSlice
62
+ from monarch.pdb_wrapper import remote_breakpointhook
63
+
64
+ if TYPE_CHECKING:
65
+ from monarch.debugger import DebugClient
60
66
 
61
67
  logger: logging.Logger = logging.getLogger(__name__)
62
68
 
@@ -270,11 +276,11 @@ class Endpoint(Generic[P, R]):
270
276
  return self.choose(*args, **kwargs)
271
277
 
272
278
  def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
273
- p: PortId
274
- r: PortReceiver[R]
275
- p, r = port(self)
279
+ p: Port[R]
280
+ r: RankedPortReceiver[R]
281
+ p, r = ranked_port(self)
276
282
  # pyre-ignore
277
- send(self, args, kwargs, port=p, rank_in_response=True)
283
+ send(self, args, kwargs, port=p)
278
284
 
279
285
  async def process() -> ValueMesh[R]:
280
286
  results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
@@ -369,9 +375,8 @@ def send(
369
375
  endpoint: Endpoint[P, R],
370
376
  args: Tuple[Any, ...],
371
377
  kwargs: Dict[str, Any],
372
- port: "Optional[PortId]" = None,
378
+ port: "Optional[Port]" = None,
373
379
  selection: Selection = "all",
374
- rank_in_response: bool = False,
375
380
  ) -> None:
376
381
  """
377
382
  Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
@@ -380,7 +385,10 @@ def send(
380
385
  """
381
386
  endpoint._signature.bind(None, *args, **kwargs)
382
387
  message = PythonMessage(
383
- endpoint._name, _pickle((args, kwargs)), port, rank_in_response
388
+ endpoint._name,
389
+ _pickle((args, kwargs)),
390
+ None if port is None else port._port,
391
+ None,
384
392
  )
385
393
  endpoint._actor_mesh.cast(message, selection)
386
394
 
@@ -402,18 +410,16 @@ def endpoint(
402
410
  return EndpointProperty(method)
403
411
 
404
412
 
405
- class Port:
406
- def __init__(self, port: PortId, mailbox: Mailbox, rank_in_response: bool) -> None:
413
+ class Port(Generic[R]):
414
+ def __init__(self, port: PortId, mailbox: Mailbox, rank: Optional[int]) -> None:
407
415
  self._port = port
408
416
  self._mailbox = mailbox
409
- self._rank_in_response = rank_in_response
417
+ self._rank = rank
410
418
 
411
- def send(self, method: str, obj: object) -> None:
412
- if self._rank_in_response:
413
- obj = (MonarchContext.get().point.rank, obj)
419
+ def send(self, method: str, obj: R) -> None:
414
420
  self._mailbox.post(
415
421
  self._port,
416
- PythonMessage(method, _pickle(obj), None),
422
+ PythonMessage(method, _pickle(obj), None, self._rank),
417
423
  )
418
424
 
419
425
 
@@ -422,12 +428,21 @@ class Port:
422
428
  # and handles concerns is different.
423
429
  def port(
424
430
  endpoint: Endpoint[P, R], once: bool = False
425
- ) -> Tuple["PortId", "PortReceiver[R]"]:
431
+ ) -> Tuple["Port[R]", "PortReceiver[R]"]:
426
432
  handle, receiver = (
427
433
  endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
428
434
  )
429
435
  port_id: PortId = handle.bind()
430
- return port_id, PortReceiver(endpoint._mailbox, receiver)
436
+ return Port(port_id, endpoint._mailbox, rank=None), PortReceiver(
437
+ endpoint._mailbox, receiver
438
+ )
439
+
440
+
441
+ def ranked_port(
442
+ endpoint: Endpoint[P, R], once: bool = False
443
+ ) -> Tuple["Port[R]", "RankedPortReceiver[R]"]:
444
+ p, receiver = port(endpoint, once)
445
+ return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver)
431
446
 
432
447
 
433
448
  class PortReceiver(Generic[R]):
@@ -452,18 +467,20 @@ class PortReceiver(Generic[R]):
452
467
  return payload
453
468
  else:
454
469
  assert msg.method == "exception"
455
- if isinstance(payload, tuple):
456
- # If the payload is a tuple, it's because we requested the rank
457
- # to be included in the response; just ignore it.
458
- raise payload[1]
459
- else:
460
- # pyre-ignore
461
- raise payload
470
+ # pyre-ignore
471
+ raise payload
462
472
 
463
473
  def recv(self) -> "Future[R]":
464
474
  return Future(lambda: self._recv(), self._blocking_recv)
465
475
 
466
476
 
477
+ class RankedPortReceiver(PortReceiver[Tuple[int, R]]):
478
+ def _process(self, msg: PythonMessage) -> Tuple[int, R]:
479
+ if msg.rank is None:
480
+ raise ValueError("RankedPort receiver got a message without a rank")
481
+ return msg.rank, super()._process(msg)
482
+
483
+
467
484
  singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
468
485
 
469
486
 
@@ -487,7 +504,7 @@ class _Actor:
487
504
  panic_flag: PanicFlag,
488
505
  ) -> Optional[Coroutine[Any, Any, Any]]:
489
506
  port = (
490
- Port(message.response_port, mailbox, message.rank_in_response)
507
+ Port(message.response_port, mailbox, rank)
491
508
  if message.response_port
492
509
  else None
493
510
  )
@@ -519,7 +536,14 @@ class _Actor:
519
536
  enter_span(
520
537
  the_method.__module__, message.method, str(ctx.mailbox.actor_id)
521
538
  )
522
- result = await the_method(self.instance, *args, **kwargs)
539
+ try:
540
+ result = await the_method(self.instance, *args, **kwargs)
541
+ except Exception as e:
542
+ logging.critical(
543
+ "Unahndled exception in actor endpoint",
544
+ exc_info=e,
545
+ )
546
+ raise e
523
547
  exit_span()
524
548
  return result
525
549
 
@@ -624,6 +648,19 @@ class Actor(MeshTrait):
624
648
  "actor implementations are not meshes, but we can't convince the typechecker of it..."
625
649
  )
626
650
 
651
+ @endpoint
652
+ async def _set_debug_client(self, client: "DebugClient") -> None:
653
+ point = MonarchContext.get().point
654
+ # For some reason, using a lambda instead of functools.partial
655
+ # confuses the pdb wrapper implementation.
656
+ sys.breakpointhook = functools.partial( # pyre-ignore
657
+ remote_breakpointhook,
658
+ point.rank,
659
+ point.shape.coordinates(point.rank),
660
+ MonarchContext.get().mailbox.actor_id,
661
+ client,
662
+ )
663
+
627
664
 
628
665
  class ActorMeshRef(MeshTrait):
629
666
  def __init__(
monarch/bootstrap_main.py CHANGED
@@ -30,28 +30,9 @@ def invoke_main():
30
30
  # behavior of std out as if it were a terminal.
31
31
  sys.stdout.reconfigure(line_buffering=True)
32
32
  global bootstrap_main
33
- from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
34
- forward_to_tracing,
35
- )
36
33
 
37
34
  # TODO: figure out what from worker_main.py we should reproduce here.
38
-
39
- class TracingForwarder(logging.Handler):
40
- def emit(self, record: logging.LogRecord) -> None:
41
- try:
42
- forward_to_tracing(
43
- record.getMessage(),
44
- record.filename or "",
45
- record.lineno or 0,
46
- record.levelno,
47
- )
48
- except AttributeError:
49
- forward_to_tracing(
50
- record.__str__(),
51
- record.filename or "",
52
- record.lineno or 0,
53
- record.levelno,
54
- )
35
+ from monarch.telemetry import TracingForwarder
55
36
 
56
37
  if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
57
38
  raise RuntimeError("Error during bootstrap for testing")
@@ -16,11 +16,6 @@ def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
16
16
  torch.manual_seed(seed ^ process_idx)
17
17
 
18
18
 
19
- @remote(propagate=lambda: 0)
20
- def initial_seed_remote() -> int:
21
- return torch.initial_seed()
22
-
23
-
24
19
  @remote(propagate=lambda: torch.zeros(1))
25
20
  def get_rng_state_remote() -> torch.Tensor:
26
21
  return torch.get_rng_state()
@@ -67,3 +62,7 @@ def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
67
62
  @remote(propagate="inspect")
68
63
  def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
69
64
  torch.cuda.set_rng_state_all(states)
65
+
66
+
67
+ # initial_seed may sometimes return a uint64 which currenly can't be unwrapped by the framework
68
+ # def initial_seed_remote() -> int: ...
monarch/common/client.py CHANGED
@@ -103,6 +103,13 @@ class Client:
103
103
  # workers.
104
104
  self.last_processed_seq = -1
105
105
 
106
+ # an error that we have received but know for certain has not
107
+ # been propagated to a future. This will be reported on shutdown
108
+ # to avoid hiding the error. This is best effort: we only keep
109
+ # the error until the point the a future is dependent on
110
+ # _any_ error, not particularly the tracked one.
111
+ self._pending_shutdown_error = None
112
+
106
113
  self.recorder = Recorder()
107
114
 
108
115
  self.pending_results: Dict[
@@ -174,6 +181,8 @@ class Client:
174
181
  destroy_pg: bool = True,
175
182
  error_reason: Optional[RemoteException | DeviceException | Exception] = None,
176
183
  ) -> None:
184
+ if self.has_shutdown:
185
+ return
177
186
  logger.info("shutting down the client gracefully")
178
187
 
179
188
  atexit.unregister(self._atexit)
@@ -303,6 +312,7 @@ class Client:
303
312
 
304
313
  if error is not None:
305
314
  logging.info("Received error for seq %s: %s", seq, error)
315
+ self._pending_shutdown_error = error
306
316
  # We should not have set result if we have an error.
307
317
  assert result is None
308
318
  if not isinstance(error, RemoteException):
@@ -326,7 +336,11 @@ class Client:
326
336
 
327
337
  fut, _ = self.pending_results[seq]
328
338
  if fut is not None:
329
- fut._set_result(result if error is None else error)
339
+ if error is None:
340
+ fut._set_result(result)
341
+ else:
342
+ fut._set_result(error)
343
+ self._pending_shutdown_error = None
330
344
  elif result is not None:
331
345
  logger.debug(f"{seq}: unused result {result}")
332
346
  elif error is not None: