torchmonarch-nightly 2025.6.6__cp310-cp310-manylinux2014_x86_64.whl → 2025.6.8__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monarch/_rust_bindings.so CHANGED
Binary file
monarch/actor_mesh.py CHANGED
@@ -39,7 +39,7 @@ from typing import (
39
39
  import monarch
40
40
  from monarch import ActorFuture as Future
41
41
 
42
- from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
42
+ from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage
43
43
  from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
44
44
  from monarch._rust_bindings.monarch_hyperactor.mailbox import (
45
45
  Mailbox,
@@ -264,9 +264,9 @@ class Endpoint(Generic[P, R]):
264
264
  return self.choose(*args, **kwargs)
265
265
 
266
266
  def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
267
- p, r = port(self, kind=RankedPort)
267
+ p, r = port(self)
268
268
  # pyre-ignore
269
- send(self, args, kwargs, port=p)
269
+ send(self, args, kwargs, port=p, rank_in_response=True)
270
270
 
271
271
  async def process():
272
272
  results = [None] * len(self._actor_mesh)
@@ -361,8 +361,9 @@ def send(
361
361
  endpoint: Endpoint[P, R],
362
362
  args: Tuple[Any, ...],
363
363
  kwargs: Dict[str, Any],
364
- port: "Optional[Port]" = None,
364
+ port: "Optional[PortId]" = None,
365
365
  selection: Selection = "all",
366
+ rank_in_response: bool = False,
366
367
  ) -> None:
367
368
  """
368
369
  Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.
@@ -370,7 +371,9 @@ def send(
370
371
  This sends the message to all actors but does not wait for any result.
371
372
  """
372
373
  endpoint._signature.bind(None, *args, **kwargs)
373
- message = PythonMessage(endpoint._name, _pickle((args, kwargs, port)))
374
+ message = PythonMessage(
375
+ endpoint._name, _pickle((args, kwargs)), port, rank_in_response
376
+ )
374
377
  endpoint._actor_mesh.cast(message, selection)
375
378
 
376
379
 
@@ -392,28 +395,29 @@ def endpoint(
392
395
 
393
396
 
394
397
  class Port:
395
- def __init__(self, port: PortId, mailbox: Mailbox) -> None:
398
+ def __init__(self, port: PortId, mailbox: Mailbox, rank_in_response: bool) -> None:
396
399
  self._port = port
397
400
  self._mailbox = mailbox
401
+ self._rank_in_response = rank_in_response
398
402
 
399
403
  def send(self, method: str, obj: object) -> None:
404
+ if self._rank_in_response:
405
+ obj = (MonarchContext.get().point.rank, obj)
400
406
  self._mailbox.post(
401
407
  self._port,
402
- PythonMessage(method, _pickle(obj)),
408
+ PythonMessage(method, _pickle(obj), None),
403
409
  )
404
410
 
405
411
 
406
412
  # advance lower-level API for sending messages. This is intentially
407
413
  # not part of the Endpoint API because they way it accepts arguments
408
414
  # and handles concerns is different.
409
- def port(
410
- endpoint: Endpoint[P, R], once=False, kind=Port
411
- ) -> Tuple["Port", "PortReceiver[R]"]:
415
+ def port(endpoint: Endpoint[P, R], once=False) -> Tuple["PortId", "PortReceiver[R]"]:
412
416
  handle, receiver = (
413
417
  endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
414
418
  )
415
419
  port_id: PortId = handle.bind()
416
- return kind(port_id, endpoint._mailbox), PortReceiver(endpoint._mailbox, receiver)
420
+ return port_id, PortReceiver(endpoint._mailbox, receiver)
417
421
 
418
422
 
419
423
  class PortReceiver(Generic[R]):
@@ -439,8 +443,8 @@ class PortReceiver(Generic[R]):
439
443
  else:
440
444
  assert msg.method == "exception"
441
445
  if isinstance(payload, tuple):
442
- # If we're receiving on a RankedPort, raise the exception and ignore the rank.
443
- # pyre-ignore do something more structured here
446
+ # If the payload is a tuple, it's because we requested the rank
447
+ # to be included in the response; just ignore it.
444
448
  raise payload[1]
445
449
  else:
446
450
  # pyre-ignore
@@ -453,21 +457,16 @@ class PortReceiver(Generic[R]):
453
457
  singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
454
458
 
455
459
 
456
- class RankedPort(Port):
457
- def send(self, method: str, obj: object) -> None:
458
- super().send(method, (MonarchContext.get().point.rank, obj))
459
-
460
-
461
460
  class _Actor:
462
461
  def __init__(self) -> None:
463
462
  self.instance: object | None = None
464
463
  self.active_requests: asyncio.Queue[asyncio.Future[object]] = asyncio.Queue()
465
- self.complete_task: object | None = None
464
+ self.complete_task: asyncio.Task | None = None
466
465
 
467
466
  def handle(
468
- self, mailbox: Mailbox, message: PythonMessage
467
+ self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
469
468
  ) -> Optional[Coroutine[Any, Any, Any]]:
470
- return self.handle_cast(mailbox, 0, singleton_shape, message)
469
+ return self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag)
471
470
 
472
471
  def handle_cast(
473
472
  self,
@@ -475,14 +474,18 @@ class _Actor:
475
474
  rank: int,
476
475
  shape: Shape,
477
476
  message: PythonMessage,
477
+ panic_flag: PanicFlag,
478
478
  ) -> Optional[Coroutine[Any, Any, Any]]:
479
- port = None
479
+ port = (
480
+ Port(message.response_port, mailbox, message.rank_in_response)
481
+ if message.response_port
482
+ else None
483
+ )
480
484
  try:
481
- args, kwargs, port = _unpickle(message.message, mailbox)
482
-
483
485
  ctx = MonarchContext(mailbox, mailbox.actor_id.proc_id, Point(rank, shape))
484
486
  _context.set(ctx)
485
487
 
488
+ args, kwargs = _unpickle(message.message, mailbox)
486
489
  if message.method == "__init__":
487
490
  Class, *args = args
488
491
  self.instance = Class(*args, **kwargs)
@@ -495,10 +498,10 @@ class _Actor:
495
498
  port.send("result", result)
496
499
  return None
497
500
 
498
- return self.run_async(ctx, self.run_task(port, result))
501
+ return self.run_async(ctx, self.run_task(port, result, panic_flag))
499
502
  except Exception as e:
500
503
  traceback.print_exc()
501
- s = ActorMeshRefCallFailedException(e)
504
+ s = ActorError(e)
502
505
 
503
506
  # The exception is delivered to exactly one of:
504
507
  # (1) our caller, (2) our supervisor
@@ -510,17 +513,17 @@ class _Actor:
510
513
  async def run_async(self, ctx, coroutine):
511
514
  _context.set(ctx)
512
515
  if self.complete_task is None:
513
- asyncio.create_task(self._complete())
516
+ self.complete_task = asyncio.create_task(self._complete())
514
517
  await self.active_requests.put(create_eager_task(coroutine))
515
518
 
516
- async def run_task(self, port, coroutine):
519
+ async def run_task(self, port, coroutine, panic_flag):
517
520
  try:
518
521
  result = await coroutine
519
522
  if port is not None:
520
523
  port.send("result", result)
521
524
  except Exception as e:
522
525
  traceback.print_exc()
523
- s = ActorMeshRefCallFailedException(e)
526
+ s = ActorError(e)
524
527
 
525
528
  # The exception is delivered to exactly one of:
526
529
  # (1) our caller, (2) our supervisor
@@ -528,6 +531,16 @@ class _Actor:
528
531
  port.send("exception", s)
529
532
  else:
530
533
  raise s from None
534
+ except BaseException as e:
535
+ # A BaseException can be thrown in the case of a Rust panic.
536
+ # In this case, we need a way to signal the panic to the Rust side.
537
+ # See [Panics in async endpoints]
538
+ try:
539
+ panic_flag.signal_panic(e)
540
+ except Exception:
541
+ # The channel might be closed if the Rust side has already detected the error
542
+ pass
543
+ raise
531
544
 
532
545
  async def _complete(self) -> None:
533
546
  while True:
@@ -627,7 +640,6 @@ class ActorMeshRef(MeshTrait):
627
640
  null_func,
628
641
  self._mailbox,
629
642
  )
630
- # pyre-ignore
631
643
  send(ep, (self._class, *args), kwargs)
632
644
 
633
645
  def __reduce_ex__(
@@ -655,7 +667,7 @@ class ActorMeshRef(MeshTrait):
655
667
  )
656
668
 
657
669
 
658
- class ActorMeshRefCallFailedException(Exception):
670
+ class ActorError(Exception):
659
671
  """
660
672
  Deterministic problem with the user's code.
661
673
  For example, an OOM resulting in trying to allocate too much GPU memory, or violating
monarch/bootstrap_main.py CHANGED
@@ -53,6 +53,9 @@ def invoke_main():
53
53
  record.levelno,
54
54
  )
55
55
 
56
+ if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
57
+ raise RuntimeError("Error during bootstrap for testing")
58
+
56
59
  # forward logs to rust tracing. Defaults to on.
57
60
  if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1":
58
61
  logging.root.addHandler(TracingForwarder())
@@ -97,7 +97,7 @@ class RemoteProcessGroup(Referenceable):
97
97
  self._drop_ref()
98
98
 
99
99
  def size(self):
100
- return self.device_mesh.numdevices(self.dims)
100
+ return self.device_mesh.size(self.dims)
101
101
 
102
102
  def _drop_ref(self):
103
103
  if self.ref is None:
@@ -172,21 +172,6 @@ class DeviceMesh(Referenceable, MeshTrait):
172
172
  self.ref = None
173
173
  self._active_mesh_context = None
174
174
 
175
- def numdevices(self, dims: Optional[Dims] = None) -> int:
176
- """
177
- Returns the number of devices (total) of the subset of mesh asked for.
178
- If dims is None, returns the total number of devices in the mesh.
179
- """
180
- if dims is None:
181
- dims = self.names
182
- missing_dims = set(dims) - set(self.names)
183
- if missing_dims:
184
- raise ValueError(f"Dimensions not found: {', '.join(missing_dims)}")
185
- product = 1
186
- for dim in dims:
187
- product *= self.size(dim)
188
- return product
189
-
190
175
  def define_remotely(self):
191
176
  if self.ref is None:
192
177
  self.ref = self.client.new_ref()
@@ -275,25 +260,10 @@ class DeviceMesh(Referenceable, MeshTrait):
275
260
  combined_rank += self.rank(dim)
276
261
  return combined_rank
277
262
 
278
- def size(self, dim: Union[str, Sequence[str]]) -> int:
279
- if isinstance(dim, str):
280
- if dim not in self.names:
281
- raise KeyError(f"{self} does not have dimension {repr(dim)}")
282
- return self.processes.sizes[self.names.index(dim)]
283
- else:
284
- p = 1
285
- for d in dim:
286
- p *= self.size(d)
287
- return p
288
-
289
263
  @property
290
264
  def ranks(self) -> dict[str, int]:
291
265
  return {dim: self.rank(dim) for dim in self.names}
292
266
 
293
- @property
294
- def sizes(self) -> dict[str, int]:
295
- return dict(zip(self.names, self.processes.sizes))
296
-
297
267
  def process_idx(self):
298
268
  self.define_remotely()
299
269
  return _remote(
monarch/common/shape.py CHANGED
@@ -8,7 +8,7 @@ import itertools
8
8
  import operator
9
9
  from abc import ABC, abstractmethod
10
10
 
11
- from typing import Dict, Generator, Sequence, Tuple
11
+ from typing import Dict, Generator, Sequence, Tuple, Union
12
12
 
13
13
  from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
14
14
 
@@ -200,5 +200,27 @@ class MeshTrait(ABC):
200
200
  """
201
201
  return self.split(**{k: (v,) for k, v in kwargs.items()})
202
202
 
203
+ def size(self, dim: Union[None, str, Sequence[str]] = None) -> int:
204
+ """
205
+ Returns the number of elements (total) of the subset of mesh asked for.
206
+ If dims is None, returns the total number of devices in the mesh.
207
+ """
208
+
209
+ if dim is None:
210
+ dim = self._labels
211
+ if isinstance(dim, str):
212
+ if dim not in self._labels:
213
+ raise KeyError(f"{self} does not have dimension {repr(dim)}")
214
+ return self._ndslice.sizes[self._labels.index(dim)]
215
+ else:
216
+ p = 1
217
+ for d in dim:
218
+ p *= self.size(d)
219
+ return p
220
+
221
+ @property
222
+ def sizes(self) -> dict[str, int]:
223
+ return dict(zip(self._labels, self._ndslice.sizes))
224
+
203
225
 
204
226
  __all__ = ["NDSlice", "Shape", "MeshTrait"]
Binary file
monarch/proc_mesh.py CHANGED
@@ -18,9 +18,11 @@ from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//mon
18
18
  )
19
19
  from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox
20
20
  from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh as HyProcMesh
21
+ from monarch._rust_bindings.monarch_hyperactor.shape import Shape
21
22
  from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef
22
23
 
23
24
  from monarch.common._device_utils import _local_device_count
25
+ from monarch.common.shape import MeshTrait
24
26
  from monarch.rdma import RDMAManager
25
27
 
26
28
  T = TypeVar("T")
@@ -40,12 +42,23 @@ def _allocate_blocking(alloc: Alloc) -> "ProcMesh":
40
42
  return ProcMesh(HyProcMesh.allocate_blocking(alloc))
41
43
 
42
44
 
43
- class ProcMesh:
45
+ class ProcMesh(MeshTrait):
44
46
  def __init__(self, hy_proc_mesh: HyProcMesh) -> None:
45
47
  self._proc_mesh = hy_proc_mesh
46
48
  self._mailbox: Mailbox = self._proc_mesh.client
47
49
  self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)
48
50
 
51
+ @property
52
+ def _ndslice(self):
53
+ return self._proc_mesh.shape.ndslice
54
+
55
+ @property
56
+ def _labels(self):
57
+ return self._proc_mesh.shape.labels
58
+
59
+ def _new_with_shape(self, shape: Shape) -> "ProcMesh":
60
+ raise NotImplementedError("ProcMesh slicing is not implemeted yet.")
61
+
49
62
  def spawn(self, name: str, Class: Type[T], *args: Any, **kwargs: Any) -> Future[T]:
50
63
  return Future(
51
64
  lambda: self._spawn_nonblocking(name, Class, *args, **kwargs),
monarch/rdma.py CHANGED
@@ -6,10 +6,7 @@
6
6
 
7
7
  import ctypes
8
8
 
9
- import traceback
10
-
11
9
  from dataclasses import dataclass
12
- from traceback import extract_tb, StackSummary
13
10
  from typing import cast, Dict, Optional, Tuple
14
11
 
15
12
  import torch
@@ -163,28 +160,3 @@ class RDMABuffer:
163
160
  src.numel(),
164
161
  )
165
162
  await RDMAManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes)
166
-
167
-
168
- class ActorMeshRefCallFailedException(Exception):
169
- """
170
- Deterministic problem with the user's code.
171
- For example, an OOM resulting in trying to allocate too much GPU memory, or violating
172
- some invariant enforced by the various APIs.
173
- """
174
-
175
- def __init__(
176
- self,
177
- exception: Exception,
178
- message: str = "A remote service call has failed asynchronously.",
179
- ) -> None:
180
- self.exception = exception
181
- self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
182
- self.message = message
183
-
184
- def __str__(self) -> str:
185
- exe = str(self.exception)
186
- actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
187
- return (
188
- f"{self.message}\n"
189
- f"Traceback of where the service call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
190
- )
@@ -7,6 +7,8 @@
7
7
  import ctypes
8
8
  import sys
9
9
 
10
+ import click
11
+
10
12
  from monarch._rust_bindings.monarch_extension.panic import panicking_function
11
13
 
12
14
  from monarch.actor_mesh import Actor, endpoint
@@ -115,24 +117,33 @@ def _run_error_test(num_procs, sync_endpoint, endpoint_name):
115
117
  asyncio.run(run_test())
116
118
 
117
119
 
120
+ @click.group()
118
121
  def main():
119
- import argparse
122
+ pass
120
123
 
121
- parser = argparse.ArgumentParser()
122
- parser.add_argument("--num-procs", type=int)
123
- parser.add_argument("--sync-test-impl", type=bool)
124
- parser.add_argument("--sync-endpoint", type=bool)
125
- parser.add_argument("--endpoint-name", type=str)
126
- args = parser.parse_args()
127
124
 
125
+ @main.command("error-endpoint")
126
+ @click.option("--num-procs", type=int, required=True)
127
+ @click.option("--sync-test-impl", type=bool, required=True)
128
+ @click.option("--sync-endpoint", type=bool, required=True)
129
+ @click.option("--endpoint-name", type=str, required=True)
130
+ def error_endpoint(num_procs, sync_test_impl, sync_endpoint, endpoint_name):
128
131
  print(
129
- f"Running segfault test: {args.num_procs=} {args.sync_test_impl=} {args.sync_endpoint=}, {args.endpoint_name=}"
132
+ f"Running segfault test: {num_procs=} {sync_test_impl=} {sync_endpoint=}, {endpoint_name=}"
130
133
  )
131
134
 
132
- if args.sync_test_impl:
133
- _run_error_test_sync(args.num_procs, args.sync_endpoint, args.endpoint_name)
135
+ if sync_test_impl:
136
+ _run_error_test_sync(num_procs, sync_endpoint, endpoint_name)
134
137
  else:
135
- _run_error_test(args.num_procs, args.sync_endpoint, args.endpoint_name)
138
+ _run_error_test(num_procs, sync_endpoint, endpoint_name)
139
+
140
+
141
+ @main.command("error-bootstrap")
142
+ def error_bootstrap():
143
+ print("I actually ran")
144
+ sys.stdout.flush()
145
+
146
+ proc_mesh(gpus=4, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}).get()
136
147
 
137
148
 
138
149
  if __name__ == "__main__":
tests/test_actor_error.py CHANGED
@@ -8,47 +8,76 @@ import importlib.resources
8
8
  import subprocess
9
9
 
10
10
  import pytest
11
- from monarch.actor_mesh import Actor, ActorMeshRefCallFailedException, endpoint
11
+ from monarch.actor_mesh import Actor, ActorError, endpoint
12
12
 
13
13
  from monarch.proc_mesh import proc_mesh
14
14
 
15
15
 
16
16
  class ExceptionActor(Actor):
17
- """An actor that has endpoints which raise exceptions."""
18
-
19
17
  @endpoint
20
18
  async def raise_exception(self) -> None:
21
- """Endpoint that raises an exception."""
22
19
  raise Exception("This is a test exception")
23
20
 
21
+ @endpoint
22
+ async def print_value(self, value) -> None:
23
+ """Endpoint that takes a value and prints it."""
24
+ print(f"Value received: {value}")
25
+ return value
24
26
 
25
- class ExceptionActorSync(Actor):
26
- """An actor that has endpoints which raise exceptions."""
27
27
 
28
+ class ExceptionActorSync(Actor):
28
29
  @endpoint # pyre-ignore
29
30
  def raise_exception(self) -> None:
30
- """Endpoint that raises an exception."""
31
31
  raise Exception("This is a test exception")
32
32
 
33
33
 
34
+ class BrokenPickleClass:
35
+ """A class that can be configured to raise exceptions during pickling/unpickling."""
36
+
37
+ def __init__(
38
+ self,
39
+ raise_on_getstate=False,
40
+ raise_on_setstate=False,
41
+ exception_message="Pickle error",
42
+ ):
43
+ self.raise_on_getstate = raise_on_getstate
44
+ self.raise_on_setstate = raise_on_setstate
45
+ self.exception_message = exception_message
46
+ self.value = "test_value"
47
+
48
+ def __getstate__(self):
49
+ """Called when pickling the object."""
50
+ if self.raise_on_getstate:
51
+ raise RuntimeError(f"__getstate__ error: {self.exception_message}")
52
+ return {
53
+ "raise_on_getstate": self.raise_on_getstate,
54
+ "raise_on_setstate": self.raise_on_setstate,
55
+ "exception_message": self.exception_message,
56
+ "value": self.value,
57
+ }
58
+
59
+ def __setstate__(self, state):
60
+ """Called when unpickling the object."""
61
+ if state.get("raise_on_setstate", False):
62
+ raise RuntimeError(
63
+ f"__setstate__ error: {state.get('exception_message', 'Unpickle error')}"
64
+ )
65
+ self.__dict__.update(state)
66
+
67
+
34
68
  @pytest.mark.parametrize(
35
- "actor_class,actor_name",
36
- [
37
- (ExceptionActor, "exception_actor_async_call"),
38
- (ExceptionActorSync, "exception_actor_sync_call"),
39
- ],
69
+ "actor_class",
70
+ [ExceptionActor, ExceptionActorSync],
40
71
  )
41
72
  @pytest.mark.parametrize("num_procs", [1, 2])
42
- async def test_actor_exception(actor_class, actor_name, num_procs):
73
+ async def test_actor_exception(actor_class, num_procs):
43
74
  """
44
75
  Test that exceptions raised in actor endpoints are propagated to the client.
45
76
  """
46
77
  proc = await proc_mesh(gpus=num_procs)
47
- exception_actor = await proc.spawn(actor_name, actor_class)
78
+ exception_actor = await proc.spawn("exception_actor", actor_class)
48
79
 
49
- with pytest.raises(
50
- ActorMeshRefCallFailedException, match="This is a test exception"
51
- ):
80
+ with pytest.raises(ActorError, match="This is a test exception"):
52
81
  if num_procs == 1:
53
82
  await exception_actor.raise_exception.call_one()
54
83
  else:
@@ -56,23 +85,18 @@ async def test_actor_exception(actor_class, actor_name, num_procs):
56
85
 
57
86
 
58
87
  @pytest.mark.parametrize(
59
- "actor_class,actor_name",
60
- [
61
- (ExceptionActor, "exception_actor_async_call"),
62
- (ExceptionActorSync, "exception_actor_sync_call"),
63
- ],
88
+ "actor_class",
89
+ [ExceptionActor, ExceptionActorSync],
64
90
  )
65
91
  @pytest.mark.parametrize("num_procs", [1, 2])
66
- def test_actor_exception_sync(actor_class, actor_name, num_procs):
92
+ def test_actor_exception_sync(actor_class, num_procs):
67
93
  """
68
94
  Test that exceptions raised in actor endpoints are propagated to the client.
69
95
  """
70
96
  proc = proc_mesh(gpus=num_procs).get()
71
- exception_actor = proc.spawn(actor_name, actor_class).get()
97
+ exception_actor = proc.spawn("exception_actor", actor_class).get()
72
98
 
73
- with pytest.raises(
74
- ActorMeshRefCallFailedException, match="This is a test exception"
75
- ):
99
+ with pytest.raises(ActorError, match="This is a test exception"):
76
100
  if num_procs == 1:
77
101
  exception_actor.raise_exception.call_one().get()
78
102
  else:
@@ -85,28 +109,102 @@ def test_actor_exception_sync(actor_class, actor_name, num_procs):
85
109
  @pytest.mark.parametrize("sync_endpoint", [False, True])
86
110
  @pytest.mark.parametrize("sync_test_impl", [False, True])
87
111
  @pytest.mark.parametrize("endpoint_name", ["cause_segfault", "cause_panic"])
88
- def test_actor_segfault(num_procs, sync_endpoint, sync_test_impl, endpoint_name):
112
+ def test_actor_supervision(num_procs, sync_endpoint, sync_test_impl, endpoint_name):
89
113
  """
90
- Test that segfaults in actor endpoints result in a non-zero exit code.
91
- This test spawns a subprocess that will segfault and checks its exit code.
114
+ Test that an endpoint causing spontaenous process exit is handled by the supervisor.
92
115
 
93
- Tests both ExceptionActor and ExceptionActorSync using async API.
116
+ Today, these events are delivered to the client and cause the client process
117
+ to exit with a non-zero code, so the only way we can test it is via a
118
+ subprocess harness.
94
119
  """
95
120
  # Run the segfault test in a subprocess
96
121
  test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
97
122
  cmd = [
98
123
  str(test_bin),
124
+ "error-endpoint",
99
125
  f"--num-procs={num_procs}",
100
126
  f"--sync-endpoint={sync_endpoint}",
101
127
  f"--sync-test-impl={sync_test_impl}",
102
128
  f"--endpoint-name={endpoint_name}",
103
129
  ]
104
- process = subprocess.run(cmd, capture_output=True, timeout=60)
105
- print(process.stdout.decode())
106
- print(process.stderr.decode())
130
+ try:
131
+ process = subprocess.run(cmd, capture_output=True, timeout=180)
132
+ except subprocess.TimeoutExpired as e:
133
+ print("timeout expired")
134
+ if e.stdout is not None:
135
+ print(e.stdout.decode())
136
+ if e.stderr is not None:
137
+ print(e.stderr.decode())
138
+ raise
139
+
140
+ # Assert that the subprocess exited with a non-zero code
141
+ assert "I actually ran" in process.stdout.decode()
142
+ assert (
143
+ process.returncode != 0
144
+ ), f"Expected non-zero exit code, got {process.returncode}"
145
+
146
+
147
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
148
+ @pytest.mark.oss_skip
149
+ def test_proc_mesh_bootstrap_error():
150
+ """
151
+ Test that attempts to spawn a ProcMesh with a failure during bootstrap.
152
+ """
153
+ # Run the segfault test in a subprocess
154
+ test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
155
+ cmd = [
156
+ str(test_bin),
157
+ "error-bootstrap",
158
+ ]
159
+ try:
160
+ process = subprocess.run(cmd, capture_output=True, timeout=180)
161
+ except subprocess.TimeoutExpired as e:
162
+ print("timeout expired")
163
+ if e.stdout is not None:
164
+ print(e.stdout.decode())
165
+ if e.stderr is not None:
166
+ print(e.stderr.decode())
167
+ raise
107
168
 
108
169
  # Assert that the subprocess exited with a non-zero code
109
170
  assert "I actually ran" in process.stdout.decode()
110
171
  assert (
111
172
  process.returncode != 0
112
173
  ), f"Expected non-zero exit code, got {process.returncode}"
174
+
175
+
176
+ @pytest.mark.parametrize("raise_on_getstate", [True, False])
177
+ @pytest.mark.parametrize("raise_on_setstate", [True, False])
178
+ @pytest.mark.parametrize("num_procs", [1, 2])
179
+ async def test_broken_pickle_class(raise_on_getstate, raise_on_setstate, num_procs):
180
+ """
181
+ Test that exceptions during pickling/unpickling are properly handled.
182
+
183
+ This test creates a BrokenPickleClass instance configured to raise exceptions
184
+ during __getstate__ and/or __setstate__, then passes it to an ExceptionActor's
185
+ print_value endpoint and verifies that an ActorError is raised.
186
+ """
187
+ if not raise_on_getstate and not raise_on_setstate:
188
+ # Pass this test trivially
189
+ return
190
+
191
+ proc = await proc_mesh(gpus=num_procs)
192
+ exception_actor = await proc.spawn("exception_actor", ExceptionActor)
193
+
194
+ # Create a BrokenPickleClass instance configured to raise exceptions
195
+ broken_obj = BrokenPickleClass(
196
+ raise_on_getstate=raise_on_getstate,
197
+ raise_on_setstate=raise_on_setstate,
198
+ exception_message="Test pickle error",
199
+ )
200
+
201
+ # On the getstate path, we expect a RuntimeError to be raised locally.
202
+ # On the setstate path, we expect an ActorError to be raised remotely.
203
+ error_type = RuntimeError if raise_on_getstate else ActorError
204
+ error_pattern = "__getstate__ error" if raise_on_getstate else "__setstate__ error"
205
+
206
+ with pytest.raises(error_type, match=error_pattern):
207
+ if num_procs == 1:
208
+ await exception_actor.print_value.call_one(broken_obj)
209
+ else:
210
+ await exception_actor.print_value.call(broken_obj)
@@ -320,6 +320,11 @@ def test_sync_actor_sync_client():
320
320
  assert r == 5
321
321
 
322
322
 
323
+ def test_proc_mesh_size() -> None:
324
+ proc = local_proc_mesh(gpus=2).get()
325
+ assert 2 == proc.size("gpus")
326
+
327
+
323
328
  def test_rank_size_sync() -> None:
324
329
  proc = local_proc_mesh(gpus=2).get()
325
330
  r = proc.spawn("runit", RunIt).get()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchmonarch-nightly
3
- Version: 2025.6.6
3
+ Version: 2025.6.8
4
4
  Summary: Monarch: Single controller library
5
5
  Author: Meta
6
6
  Author-email: oncall+monarch@xmail.facebook.com
@@ -1,23 +1,23 @@
1
1
  monarch/__init__.py,sha256=iUvWHc0-7Q2tovRoRxOIiA3TsefMXCbWl-jEfQ2djew,6897
2
- monarch/_rust_bindings.so,sha256=b3W_9ICPY2P_yDKwRY6liTdhGGhi2euITQwhj0-4hZk,39102560
2
+ monarch/_rust_bindings.so,sha256=HiisXwHtZrYKATL6RdJxw2u_y7Wjgjtwt52V1LIR6ss,39151608
3
3
  monarch/_testing.py,sha256=MN8DK1e-wzV0-R_nFW1b_7-O5oKfWvZ12BMGD4Z7PQk,6755
4
- monarch/actor_mesh.py,sha256=KfvHST97c35kyL7FslAfkRxo22UN33u3_gk2yrbZu4A,22320
4
+ monarch/actor_mesh.py,sha256=5DbU9OrmNk5I9yasmE-rkTgHyO07oiLlAG0jbJBOXgI,23000
5
5
  monarch/allocator.py,sha256=_2DKFP9pSD33zDgH7xZJC8Tq7BQrCeQEUmMB7_xCT0Y,1784
6
- monarch/bootstrap_main.py,sha256=_LgEvfI_kFHj2QWH8CLRBQI1tbxS0uWrnHqwzOVbjeI,2417
6
+ monarch/bootstrap_main.py,sha256=SYTOz-pTXiJNk78PPD5HAOJDSb8t2JfitRWdmWB3ogo,2559
7
7
  monarch/cached_remote_function.py,sha256=kYdB6r4OHx_T_uX4q3tCNcp1t2DJwF8tPTIahUiT2pU,8785
8
8
  monarch/fetch.py,sha256=61jxo7sx4QNUTkc0_rF5NaJROen4tKbAaiIjrXWLOvg,1705
9
9
  monarch/future.py,sha256=lcdFEe7m1shYPPuvZ1RkS6JUIChEKGBWe3v7x_nu4Hg,731
10
10
  monarch/gradient_generator.py,sha256=Rl3dmXGceTdCc1mYBg2JciR88ywGPnW7TVkL86KwqEA,6366
11
11
  monarch/memory.py,sha256=ol86dBhFAJqg78iF25-BuK0wuwj1onR8FIioZ_B0gjw,1377
12
- monarch/monarch_controller,sha256=LlUJ69r9spjCivWYdew83cFM4YszUx5Djqp9k_xxppo,20386688
12
+ monarch/monarch_controller,sha256=5TKjcz7U7K8OttrwYv-w7yYtPUm2aMOQV4gt0u_Vj5c,20385960
13
13
  monarch/notebook.py,sha256=zu9MKDFKf1-rCM2TqFSRJjMBeiWuKcJSyUFLvoZRQzs,25949
14
14
  monarch/opaque_module.py,sha256=oajOu_WD1hD4hxE8HDdO-tvWY7KDHWd7VaAhJEa5L2I,10446
15
15
  monarch/opaque_object.py,sha256=IVpll4pyuKZMo_EnPh4s0qnx8RlAcJrJ1yoLX6E75wQ,2782
16
- monarch/proc_mesh.py,sha256=sTMmwQLKqM0h-yY0mn8uSzOb9B_MX9DKWCI9EsyfD6s,6384
16
+ monarch/proc_mesh.py,sha256=pVN0BLnjGaty6-UGn1U81rNdmfiDvD4gO1c4bISHtqs,6807
17
17
  monarch/profiler.py,sha256=TQ9fnVM8H7smBWtYdB_6Irtzz8DBOmcp7U1T3wlUmco,4911
18
18
  monarch/python_local_mesh.py,sha256=YsureIzR9uGlNVrKd4vRghxOXBeYabkt9lICRErfRAI,3536
19
19
  monarch/random.py,sha256=f9QR7Esu4Vxqxs-KCf5QYyVqlWvXJ3-UtG90L_h4j40,1527
20
- monarch/rdma.py,sha256=eWwYKurW-Y6j68m0xH8jeyE3bfmSgB5ZwM2j-RmbCHc,6397
20
+ monarch/rdma.py,sha256=1pNh11S_FWeETRgkdUpauTMUlodrRohIq1UfQjKVnN8,5418
21
21
  monarch/remote_class.py,sha256=-OAowzU1aDP6i4ik_SjXntVUC9h4dqAzgqwohkQ6Grc,4167
22
22
  monarch/rust_backend_mesh.py,sha256=1htC62of4MgFtkezWGlsxSFtKJdc0CIeqeSuOx7yu3M,9944
23
23
  monarch/rust_local_mesh.py,sha256=7ASptybn3wy4J7eoBc7LhGW4j4AA6bigl5Kuhyflw8s,47405
@@ -46,7 +46,7 @@ monarch/common/client.py,sha256=wOAnoaLmabrcv7mK_z_HVnk_ivGe5igPy3iWZI4LVZc,2451
46
46
  monarch/common/constants.py,sha256=ohvsVYMpfeWopv3KXDAeHWDFLukwc-OY37VRxpKNBE8,300
47
47
  monarch/common/context_manager.py,sha256=GOeyaFbyCqvQmkJ0oI7q6IxRd8_0mVyYKZRccI8iaug,1067
48
48
  monarch/common/controller_api.py,sha256=djGkK5aSd-V6pBkr3uBCXbfJv3OKf2o2VbBXJgFF2WI,3202
49
- monarch/common/device_mesh.py,sha256=PyVONLa0EDOzVobU-PK-mGAQyj1Dyo9dr__lDmx2uKY,13144
49
+ monarch/common/device_mesh.py,sha256=fBZMYDpfAp5tAEXTe9l6eJxDI4-TMWVOMrAJXp5hzvI,12082
50
50
  monarch/common/fake.py,sha256=h57Cggz2qXNqImZ7yPuOZOSe9-l9i553ki1z-YHlgQA,1801
51
51
  monarch/common/function.py,sha256=V8kdgSRTvild2SpcewWa5IETX3QiWDZQ2BEIDFa5zz8,4374
52
52
  monarch/common/function_caching.py,sha256=HVdbWtv6Eea7ENMWi8iv36w1G1TaVuUJhkUX_JxGx5A,5060
@@ -63,7 +63,7 @@ monarch/common/recording.py,sha256=hoI9VY_FyW_xVx-jmfsKydqX5vW2GulwcDWsBdUVOm8,4
63
63
  monarch/common/reference.py,sha256=O26lkzEeVwj0S1xEy-OLqdHVnACmmlbQCUmXRrW4n1Q,938
64
64
  monarch/common/remote.py,sha256=qZWXkShX20l07TseQSpVECh2yXZaVKYUvQXkeEM-zvY,9220
65
65
  monarch/common/selection.py,sha256=lpWFbZs3ArYy29e-53eoAVAjQFksf1RvZz9NvM0CUW4,308
66
- monarch/common/shape.py,sha256=jEHneh190QI7zGOVAARpXtkxI9mXV1YbnycXlpYQGuc,7388
66
+ monarch/common/shape.py,sha256=k6-0S0U19PmrfP62SMb9Ihx6_I4QQFUGErloZn8GcZ0,8144
67
67
  monarch/common/stream.py,sha256=J9UCqhSXSbKYFGtbKaqAq1Vgmg6DJcLzsXXm-tsBQ-w,3499
68
68
  monarch/common/tensor.py,sha256=mSXiHoD0Up4m2RLdQcsbesaz2N4QCFS34UNNX3Dbldk,28842
69
69
  monarch/common/tensor_factory.py,sha256=qm8NZx-5ezMAFjNLiXQvb66okm5XgdboB_GRarGOdN0,801
@@ -127,9 +127,9 @@ monarch_supervisor/python_executable.py,sha256=WfCiK3wdAvm9Jxx5jgjGF991NgGc9-oHU
127
127
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
128
128
  tests/dispatch_bench.py,sha256=sU_m-8KAjQgYTsxI5khV664NdgLLutidni69Rtowk98,3933
129
129
  tests/dispatch_bench_helper.py,sha256=1ORgAMrRgjAjmmWeCHLLQd_bda9mJk0rS2ucEbRu28s,633
130
- tests/error_test_binary.py,sha256=r9-mm4eDqaJYnBo3gXcuqwhpYq1HeH6xem3a4p8rakI,4600
130
+ tests/error_test_binary.py,sha256=64H-ucdkQ2i7GD8sidStl227cOy7gyeqvO4kTm1y7Ic,4817
131
131
  tests/sleep_binary.py,sha256=XfLYaAfwm9xgzM-svs8fhAeFhwYIg6SyVEnx4e6wbUw,1009
132
- tests/test_actor_error.py,sha256=YBDS6BKwZqgKTFtydEJt4qwJGXRfWx3hgxup9ayVbhY,3827
132
+ tests/test_actor_error.py,sha256=z3Sf4lteUggTryPLOhRKJ55v0MwVK3a7QN7-U2U9iJg,7484
133
133
  tests/test_alloc.py,sha256=D6DdQbtOZEvvnnc7LV-WyWFMk0Xb77eblH6Oz90zJTA,745
134
134
  tests/test_coalescing.py,sha256=-KtAWzTaeXbyzltplfojavx0iFeeZnvej-tFTlu2p5k,15616
135
135
  tests/test_controller.py,sha256=yxuVp2DG3TDKJlwuE3cFm9dbWMlbrYtG1uHfvVWRYbw,30935
@@ -139,7 +139,7 @@ tests/test_future.py,sha256=cXzaNi2YDwVyjR541ScXmgktX1YFsKzbl8wep0DMVbk,3032
139
139
  tests/test_grad_generator.py,sha256=p4Pm4kMEeGldt2jUVAkGKCB0mLccKI28pltH6OTGbQA,3412
140
140
  tests/test_mock_cuda.py,sha256=5hisElxeLJ5MHw3KM9gwxBiXiMaG-Rm382u3AsQcDOI,3068
141
141
  tests/test_pdb_actor.py,sha256=5KJhuhcZDPWMdjC6eAtDdwnz1W7jNFXvIrMSFaCWaPw,3858
142
- tests/test_python_actors.py,sha256=dY109ofFtmmni9wJWNVb3W7YQH_tMZWSIGovnuAsrUw,10786
142
+ tests/test_python_actors.py,sha256=fDvHUIWNZeL3CWnTJMbdh98i1tnH1-LJEG1pIFkGYF8,10898
143
143
  tests/test_remote_functions.py,sha256=ExqYlRQWRabpGBuKvNIOa8Hwj-iXuP87Jfb9i5RhaGs,50066
144
144
  tests/test_rust_backend.py,sha256=nXSa0ZQ0NniZm4PzvKhrWvVLD-RKvIWYkPXm1BEBXq8,6235
145
145
  tests/test_signal_safe_block_on.py,sha256=bmal0XgzJowZXJV6T1Blow5a-vZluYWusCThLMGxyTE,3336
@@ -149,9 +149,9 @@ tests/simulator/test_profiling.py,sha256=TGYCfzTLdkpIwnOuO6KApprmrgPIRQe60KRX3wk
149
149
  tests/simulator/test_simulator.py,sha256=LO8lA0ssY-OGEBL5ipEu74f97Y765TEwfUOv-DtIptM,14568
150
150
  tests/simulator/test_task.py,sha256=ipqBDuDAysuo1xOB9S5psaFvwe6VATD43IovCTSs0t4,2327
151
151
  tests/simulator/test_worker.py,sha256=QrWWIJ3HDgDLkBPRc2mwYPlOQoXQcj1qRfc0WUfKkFY,3507
152
- torchmonarch_nightly-2025.6.6.dist-info/licenses/LICENSE,sha256=e0Eotbf_rHOYPuEUlppIbvwy4SN98CZnl_hqwvbDA4Q,1530
153
- torchmonarch_nightly-2025.6.6.dist-info/METADATA,sha256=q8PbST0aYM1zzxSGM9AjrDdDbTmJbg-7gaMyH6sqDPQ,2771
154
- torchmonarch_nightly-2025.6.6.dist-info/WHEEL,sha256=_wZSFk0d90K9wOBp8Q-UGxshyiJ987JoPiyUBNC6VLk,104
155
- torchmonarch_nightly-2025.6.6.dist-info/entry_points.txt,sha256=sqfQ16oZqjEvttUI-uj9BBXIIE6jt05bYFSmy-2hyXI,106
156
- torchmonarch_nightly-2025.6.6.dist-info/top_level.txt,sha256=E-ZssZzyM17glpVrh-S9--qJ-w9p2EjuYOuNw9tQ4Eg,33
157
- torchmonarch_nightly-2025.6.6.dist-info/RECORD,,
152
+ torchmonarch_nightly-2025.6.8.dist-info/licenses/LICENSE,sha256=e0Eotbf_rHOYPuEUlppIbvwy4SN98CZnl_hqwvbDA4Q,1530
153
+ torchmonarch_nightly-2025.6.8.dist-info/METADATA,sha256=AfGuuk6TyhejOLotJWjRt3Hsl80lkEWS4iOaZ61YHj4,2771
154
+ torchmonarch_nightly-2025.6.8.dist-info/WHEEL,sha256=_wZSFk0d90K9wOBp8Q-UGxshyiJ987JoPiyUBNC6VLk,104
155
+ torchmonarch_nightly-2025.6.8.dist-info/entry_points.txt,sha256=sqfQ16oZqjEvttUI-uj9BBXIIE6jt05bYFSmy-2hyXI,106
156
+ torchmonarch_nightly-2025.6.8.dist-info/top_level.txt,sha256=E-ZssZzyM17glpVrh-S9--qJ-w9p2EjuYOuNw9tQ4Eg,33
157
+ torchmonarch_nightly-2025.6.8.dist-info/RECORD,,