torchmonarch-nightly 2025.6.17__cp310-cp310-manylinux2014_x86_64.whl → 2025.6.19__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
monarch/_rust_bindings.so CHANGED
Binary file
monarch/actor_mesh.py CHANGED
@@ -6,7 +6,6 @@
6
6
 
7
7
  # pyre-unsafe
8
8
 
9
- import asyncio
10
9
  import collections
11
10
  import contextvars
12
11
  import functools
@@ -27,9 +26,7 @@ from typing import (
27
26
  Callable,
28
27
  cast,
29
28
  Concatenate,
30
- Coroutine,
31
29
  Dict,
32
- Generator,
33
30
  Generic,
34
31
  Iterable,
35
32
  List,
@@ -51,8 +48,9 @@ from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh
51
48
  from monarch._rust_bindings.monarch_hyperactor.mailbox import (
52
49
  Mailbox,
53
50
  OncePortReceiver,
54
- PortId,
51
+ OncePortRef,
55
52
  PortReceiver as HyPortReceiver,
53
+ PortRef,
56
54
  )
57
55
  from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
58
56
  from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape
@@ -99,39 +97,6 @@ _context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
99
97
  )
100
98
 
101
99
 
102
- # this was implemented in python 3.12 as an argument to task
103
- # but I have to backport to 3.10/3.11.
104
- def create_eager_task(coro: Awaitable[None]) -> asyncio.Future:
105
- iter = coro.__await__()
106
- try:
107
- first_yield = next(iter)
108
- return asyncio.create_task(RestOfCoroutine(first_yield, iter).run())
109
- except StopIteration as e:
110
- t = asyncio.Future()
111
- t.set_result(e.value)
112
- return t
113
-
114
-
115
- class RestOfCoroutine(Generic[T1, T2]):
116
- def __init__(self, first_yield: T1, iter: Generator[T2, None, T2]) -> None:
117
- self.first_yield: T1 | None = first_yield
118
- self.iter: Generator[T2, None, T2] = iter
119
-
120
- def __await__(self) -> Generator[T1, None, T1] | Generator[T2, None, T2]:
121
- first_yield = self.first_yield
122
- assert first_yield is not None
123
- yield first_yield
124
- self.first_yield = None
125
- while True:
126
- try:
127
- yield next(self.iter)
128
- except StopIteration as e:
129
- return e.value
130
-
131
- async def run(self) -> T1 | T2:
132
- return await self
133
-
134
-
135
100
  T = TypeVar("T")
136
101
  P = ParamSpec("P")
137
102
  R = TypeVar("R")
@@ -263,6 +228,8 @@ class Endpoint(Generic[P, R]):
263
228
 
264
229
  Load balanced RPC-style entrypoint for request/response messaging.
265
230
  """
231
+ p: Port[R]
232
+ r: PortReceiver[R]
266
233
  p, r = port(self, once=True)
267
234
  # pyre-ignore
268
235
  send(self, args, kwargs, port=p, selection="choose")
@@ -285,7 +252,18 @@ class Endpoint(Generic[P, R]):
285
252
  async def process() -> ValueMesh[R]:
286
253
  results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
287
254
  for _ in range(len(self._actor_mesh)):
288
- rank, value = await r.recv() # pyre-fixme[23]
255
+ rank, value = await r.recv()
256
+ results[rank] = value
257
+ call_shape = Shape(
258
+ self._actor_mesh._shape.labels,
259
+ NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
260
+ )
261
+ return ValueMesh(call_shape, results)
262
+
263
+ def process_blocking() -> ValueMesh[R]:
264
+ results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
265
+ for _ in range(len(self._actor_mesh)):
266
+ rank, value = r.recv().get()
289
267
  results[rank] = value
290
268
  call_shape = Shape(
291
269
  self._actor_mesh._shape.labels,
@@ -293,7 +271,7 @@ class Endpoint(Generic[P, R]):
293
271
  )
294
272
  return ValueMesh(call_shape, results)
295
273
 
296
- return Future(process)
274
+ return Future(process, process_blocking)
297
275
 
298
276
  async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
299
277
  """
@@ -362,6 +340,9 @@ class ValueMesh(MeshTrait, Generic[R]):
362
340
  def __len__(self) -> int:
363
341
  return len(self._shape)
364
342
 
343
+ def __repr__(self) -> str:
344
+ return f"ValueMesh({self._shape})"
345
+
365
346
  @property
366
347
  def _ndslice(self) -> NDSlice:
367
348
  return self._shape.ndslice
@@ -387,7 +368,7 @@ def send(
387
368
  message = PythonMessage(
388
369
  endpoint._name,
389
370
  _pickle((args, kwargs)),
390
- None if port is None else port._port,
371
+ None if port is None else port._port_ref,
391
372
  None,
392
373
  )
393
374
  endpoint._actor_mesh.cast(message, selection)
@@ -411,14 +392,16 @@ def endpoint(
411
392
 
412
393
 
413
394
  class Port(Generic[R]):
414
- def __init__(self, port: PortId, mailbox: Mailbox, rank: Optional[int]) -> None:
415
- self._port = port
395
+ def __init__(
396
+ self, port_ref: PortRef | OncePortRef, mailbox: Mailbox, rank: Optional[int]
397
+ ) -> None:
398
+ self._port_ref = port_ref
416
399
  self._mailbox = mailbox
417
400
  self._rank = rank
418
401
 
419
402
  def send(self, method: str, obj: R) -> None:
420
- self._mailbox.post(
421
- self._port,
403
+ self._port_ref.send(
404
+ self._mailbox,
422
405
  PythonMessage(method, _pickle(obj), None, self._rank),
423
406
  )
424
407
 
@@ -432,8 +415,8 @@ def port(
432
415
  handle, receiver = (
433
416
  endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
434
417
  )
435
- port_id: PortId = handle.bind()
436
- return Port(port_id, endpoint._mailbox, rank=None), PortReceiver(
418
+ port_ref: PortRef | OncePortRef = handle.bind()
419
+ return Port(port_ref, endpoint._mailbox, rank=None), PortReceiver(
437
420
  endpoint._mailbox, receiver
438
421
  )
439
422
 
@@ -485,24 +468,36 @@ singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[]))
485
468
 
486
469
 
487
470
  class _Actor:
471
+ """
472
+ This is the message handling implementation of a Python actor.
473
+
474
+ The layering goes:
475
+ Rust `PythonActor` -> `_Actor` -> user-provided `Actor` instance
476
+
477
+ Messages are received from the Rust backend, and forwarded to the `handle`
478
+ methods on this class.
479
+
480
+ This class wraps the actual `Actor` instance provided by the user, and
481
+ routes messages to it, managing argument serialization/deserialization and
482
+ error handling.
483
+ """
484
+
488
485
  def __init__(self) -> None:
489
486
  self.instance: object | None = None
490
- self.active_requests: asyncio.Queue[asyncio.Future[object]] = asyncio.Queue()
491
- self.complete_task: asyncio.Task | None = None
492
487
 
493
- def handle(
488
+ async def handle(
494
489
  self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag
495
- ) -> Optional[Coroutine[Any, Any, Any]]:
496
- return self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag)
490
+ ) -> None:
491
+ return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag)
497
492
 
498
- def handle_cast(
493
+ async def handle_cast(
499
494
  self,
500
495
  mailbox: Mailbox,
501
496
  rank: int,
502
497
  shape: Shape,
503
498
  message: PythonMessage,
504
499
  panic_flag: PanicFlag,
505
- ) -> Optional[Coroutine[Any, Any, Any]]:
500
+ ) -> None:
506
501
  port = (
507
502
  Port(message.response_port, mailbox, rank)
508
503
  if message.response_port
@@ -515,26 +510,21 @@ class _Actor:
515
510
  _context.set(ctx)
516
511
 
517
512
  args, kwargs = _unpickle(message.message, mailbox)
513
+
518
514
  if message.method == "__init__":
519
515
  Class, *args = args
520
516
  self.instance = Class(*args, **kwargs)
521
517
  return None
522
- else:
523
- the_method = getattr(self.instance, message.method)._method
524
518
 
525
- if not inspect.iscoroutinefunction(the_method):
526
- enter_span(
527
- the_method.__module__, message.method, str(ctx.mailbox.actor_id)
528
- )
529
- result = the_method(self.instance, *args, **kwargs)
530
- exit_span()
531
- if port is not None:
532
- port.send("result", result)
533
- return None
519
+ the_method = getattr(self.instance, message.method)._method
520
+
521
+ if inspect.iscoroutinefunction(the_method):
534
522
 
535
523
  async def instrumented():
536
524
  enter_span(
537
- the_method.__module__, message.method, str(ctx.mailbox.actor_id)
525
+ the_method.__module__,
526
+ message.method,
527
+ str(ctx.mailbox.actor_id),
538
528
  )
539
529
  try:
540
530
  result = await the_method(self.instance, *args, **kwargs)
@@ -547,39 +537,14 @@ class _Actor:
547
537
  exit_span()
548
538
  return result
549
539
 
550
- return self.run_async(
551
- ctx,
552
- self.run_task(port, instrumented(), panic_flag),
553
- )
554
- except Exception as e:
555
- traceback.print_exc()
556
- s = ActorError(e)
557
-
558
- # The exception is delivered to exactly one of:
559
- # (1) our caller, (2) our supervisor
560
- if port is not None:
561
- port.send("exception", s)
540
+ result = await instrumented()
562
541
  else:
563
- raise s from None
564
-
565
- async def run_async(
566
- self,
567
- ctx: MonarchContext,
568
- coroutine: Awaitable[None],
569
- ) -> None:
570
- _context.set(ctx)
571
- if self.complete_task is None:
572
- self.complete_task = asyncio.create_task(self._complete())
573
- await self.active_requests.put(create_eager_task(coroutine))
542
+ enter_span(
543
+ the_method.__module__, message.method, str(ctx.mailbox.actor_id)
544
+ )
545
+ result = the_method(self.instance, *args, **kwargs)
546
+ exit_span()
574
547
 
575
- async def run_task(
576
- self,
577
- port: Port | None,
578
- coroutine: Awaitable[Any],
579
- panic_flag: PanicFlag,
580
- ) -> None:
581
- try:
582
- result = await coroutine
583
548
  if port is not None:
584
549
  port.send("result", result)
585
550
  except Exception as e:
@@ -603,11 +568,6 @@ class _Actor:
603
568
  pass
604
569
  raise
605
570
 
606
- async def _complete(self) -> None:
607
- while True:
608
- task = await self.active_requests.get()
609
- await task
610
-
611
571
 
612
572
  def _is_mailbox(x: object) -> bool:
613
573
  return isinstance(x, Mailbox)
@@ -648,8 +608,8 @@ class Actor(MeshTrait):
648
608
  "actor implementations are not meshes, but we can't convince the typechecker of it..."
649
609
  )
650
610
 
651
- @endpoint
652
- async def _set_debug_client(self, client: "DebugClient") -> None:
611
+ @endpoint # pyre-ignore
612
+ def _set_debug_client(self, client: "DebugClient") -> None:
653
613
  point = MonarchContext.get().point
654
614
  # For some reason, using a lambda instead of functools.partial
655
615
  # confuses the pdb wrapper implementation.
@@ -750,6 +710,9 @@ class ActorMeshRef(MeshTrait):
750
710
  self._mailbox,
751
711
  )
752
712
 
713
+ def __repr__(self) -> str:
714
+ return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})"
715
+
753
716
 
754
717
  class ActorError(Exception):
755
718
  """
monarch/allocator.py CHANGED
@@ -74,7 +74,7 @@ class RemoteAllocInitializer(abc.ABC):
74
74
  """
75
75
 
76
76
  @abc.abstractmethod
77
- async def initialize_alloc(self) -> list[str]:
77
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
78
78
  """
79
79
  Return the addresses of the servers that should be used to allocate processes
80
80
  for the proc mesh. The addresses should be running hyperactor's RemoteProcessAllocator.
@@ -88,6 +88,10 @@ class RemoteAllocInitializer(abc.ABC):
88
88
  in the future this method can be called multiple times and should return the current set of
89
89
  addresses that are eligible to handle allocation requests.
90
90
 
91
+ Arguments:
92
+ - `match_labels`: The match labels specified in `AllocSpec.AllocConstraints`. Initializer implementations
93
+ can read specific labels for matching a set of hosts that will service `allocate()` requests.
94
+
91
95
  """
92
96
  ...
93
97
 
@@ -102,7 +106,8 @@ class StaticRemoteAllocInitializer(RemoteAllocInitializer):
102
106
  super().__init__()
103
107
  self.addrs: list[str] = list(addrs)
104
108
 
105
- async def initialize_alloc(self) -> list[str]:
109
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
110
+ _ = match_labels # Suppress unused variable warning
106
111
  return list(self.addrs)
107
112
 
108
113
 
@@ -25,7 +25,6 @@ from monarch._rust_bindings.monarch_extension import tensor_worker
25
25
  from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction
26
26
  from monarch.common.invocation import DeviceException, RemoteException
27
27
  from monarch.common.reference import Referenceable
28
- from monarch.common.stream import StreamRef
29
28
  from monarch.common.tree import flattener
30
29
  from pyre_extensions import none_throws
31
30
 
@@ -33,6 +32,8 @@ from .shape import NDSlice
33
32
  from .tensor_factory import TensorFactory
34
33
 
35
34
  if TYPE_CHECKING:
35
+ from monarch.common.stream import StreamRef
36
+
36
37
  from .device_mesh import DeviceMesh, RemoteProcessGroup
37
38
  from .pipe import Pipe
38
39
  from .recording import Recording
@@ -98,7 +99,7 @@ class CreateDeviceMesh(NamedTuple):
98
99
 
99
100
 
100
101
  class CreateStream(NamedTuple):
101
- result: StreamRef
102
+ result: "StreamRef"
102
103
  default: bool
103
104
 
104
105
  def to_rust_message(self) -> tensor_worker.WorkerMessage:
@@ -132,7 +133,7 @@ class CallFunction(NamedTuple):
132
133
  function: ResolvableFunction
133
134
  args: Tuple[object, ...]
134
135
  kwargs: Dict[str, object]
135
- stream: StreamRef
136
+ stream: "StreamRef"
136
137
  device_mesh: DeviceMesh
137
138
  remote_process_groups: List[RemoteProcessGroup]
138
139
 
@@ -199,7 +200,7 @@ class RecordingFormal(NamedTuple):
199
200
  class RecordingResult(NamedTuple):
200
201
  input: Tensor | tensor_worker.Ref
201
202
  output_index: int
202
- stream: StreamRef
203
+ stream: "StreamRef"
203
204
 
204
205
  def to_rust_message(self) -> tensor_worker.WorkerMessage:
205
206
  return tensor_worker.RecordingResult(
Binary file
monarch/tools/cli.py CHANGED
@@ -112,7 +112,7 @@ class InfoCmd:
112
112
  file=sys.stderr,
113
113
  )
114
114
  else:
115
- json.dump(server_spec.to_json(), fp=sys.stdout)
115
+ json.dump(server_spec.to_json(), indent=2, fp=sys.stdout)
116
116
 
117
117
 
118
118
  class KillCmd:
monarch/tools/commands.py CHANGED
@@ -9,7 +9,10 @@
9
9
  import argparse
10
10
  import functools
11
11
  import inspect
12
+ import logging
12
13
  import os
14
+ import time
15
+ from datetime import timedelta
13
16
  from typing import Any, Callable, Mapping, Optional, Union
14
17
 
15
18
  from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/config/meta:defaults
@@ -18,12 +21,13 @@ from monarch.tools.config import ( # @manual=//monarch/python/monarch/tools/con
18
21
  )
19
22
 
20
23
  from monarch.tools.mesh_spec import mesh_spec_from_metadata, ServerSpec
21
-
22
24
  from torchx.runner import Runner
23
- from torchx.specs import AppDef, AppDryRunInfo, CfgVal
25
+ from torchx.specs import AppDef, AppDryRunInfo, AppState, CfgVal
24
26
  from torchx.specs.builders import parse_args
25
27
  from torchx.util.types import decode, decode_optional
26
28
 
29
+ logger: logging.Logger = logging.getLogger(__name__)
30
+
27
31
 
28
32
  def torchx_runner() -> Runner:
29
33
  # namespace is currently unused so make it empty str
@@ -165,15 +169,73 @@ def info(server_handle: str) -> Optional[ServerSpec]:
165
169
  if appdef is None:
166
170
  return None
167
171
 
172
+ # host status grouped by mesh (role) names
173
+ replica_status = {r.role: r.replicas for r in status.roles}
174
+
168
175
  mesh_specs = []
169
176
  for role in appdef.roles:
170
177
  spec = mesh_spec_from_metadata(appdef, role.name)
171
178
  assert spec is not None, "cannot be 'None' since we iterate over appdef's roles"
179
+
180
+ # null-guard since some schedulers do not fill replica_status
181
+ if host_status := replica_status.get(role.name):
182
+ spec.hostnames = [h.hostname for h in host_status]
183
+
172
184
  mesh_specs.append(spec)
173
185
 
174
186
  return ServerSpec(name=appdef.name, state=status.state, meshes=mesh_specs)
175
187
 
176
188
 
189
+ _5_SECONDS = timedelta(seconds=5)
190
+
191
+
192
+ async def server_ready(
193
+ server_handle: str, check_interval: timedelta = _5_SECONDS
194
+ ) -> Optional[ServerSpec]:
195
+ """Waits until the server's job is in RUNNING state to returns the server spec.
196
+ Returns `None` if the server does not exist.
197
+
198
+ NOTE: Certain fields such as `hostnames` is only filled (and valid) when the server is RUNNING.
199
+
200
+ Usage:
201
+
202
+ .. code-block:: python
203
+
204
+ server_info = await server_ready("slurm:///123")
205
+ if not server_info:
206
+ print(f"Job does not exist")
207
+ else:
208
+ if server_info.is_running:
209
+ for mesh in server_info.meshes:
210
+ connect_to(mesh.hostnames)
211
+ else:
212
+ print(f"Job in {server_info.state} state. Hostnames are not available")
213
+
214
+ """
215
+
216
+ while True:
217
+ server_spec = info(server_handle)
218
+
219
+ if not server_spec: # server not found
220
+ return None
221
+
222
+ if server_spec.state <= AppState.PENDING: # UNSUBMITTED or SUBMITTED or PENDING
223
+ # NOTE: TorchX currently does not have async APIs so need to loop-on-interval
224
+ # TODO maybe inverse exponential backoff instead of constant interval?
225
+ check_interval_seconds = check_interval.total_seconds()
226
+ logger.info(
227
+ "waiting for %s to be %s (current: %s), will check again in %g seconds...",
228
+ server_handle,
229
+ AppState.RUNNING,
230
+ server_spec.state,
231
+ check_interval_seconds,
232
+ )
233
+ time.sleep(check_interval_seconds)
234
+ continue
235
+ else:
236
+ return server_spec
237
+
238
+
177
239
  def kill(server_handle: str) -> None:
178
240
  with torchx_runner() as runner:
179
241
  runner.cancel(server_handle)
@@ -6,7 +6,7 @@
6
6
 
7
7
  # pyre-strict
8
8
  import string
9
- from dataclasses import dataclass
9
+ from dataclasses import dataclass, field
10
10
  from typing import Any, Optional
11
11
 
12
12
  from torchx import specs
@@ -29,6 +29,7 @@ class MeshSpec:
29
29
  host_type: str
30
30
  gpus: int
31
31
  port: int = DEFAULT_REMOTE_ALLOCATOR_PORT
32
+ hostnames: list[str] = field(default_factory=list)
32
33
 
33
34
 
34
35
  def _tag(mesh_name: str, tag_template: str) -> str:
@@ -84,6 +85,10 @@ class ServerSpec:
84
85
  state: specs.AppState
85
86
  meshes: list[MeshSpec]
86
87
 
88
+ @property
89
+ def is_running(self) -> bool:
90
+ return self.state == specs.AppState.RUNNING
91
+
87
92
  def get_mesh_spec(self, mesh_name: str) -> MeshSpec:
88
93
  for mesh_spec in self.meshes:
89
94
  if mesh_spec.name == mesh_name:
@@ -115,6 +120,7 @@ class ServerSpec:
115
120
  "host_type": mesh.host_type,
116
121
  "hosts": mesh.num_hosts,
117
122
  "gpus": mesh.gpus,
123
+ "hostnames": mesh.hostnames,
118
124
  }
119
125
  for mesh in self.meshes
120
126
  },
@@ -0,0 +1,62 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import logging
9
+ import socket
10
+ from typing import Optional
11
+
12
+ logger: logging.Logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_ip_addr(hostname: str) -> str:
16
+ """Resolves and returns the ip address of the given hostname.
17
+
18
+ This function will return an ipv6 address if one that can bind
19
+ `SOCK_STREAM` (TCP) socket is found. Otherwise it will fall-back
20
+ to resolving an ipv4 `SOCK_STREAM` address.
21
+
22
+ Raises a `RuntimeError` if neither ipv6 or ipv4 ip can be resolved from hostname.
23
+ """
24
+
25
+ def get_sockaddr(family: socket.AddressFamily) -> Optional[str]:
26
+ try:
27
+ # patternlint-disable-next-line python-dns-deps (only used for oss)
28
+ addrs = socket.getaddrinfo(
29
+ hostname, port=None, family=family, type=socket.SOCK_STREAM
30
+ ) # tcp
31
+ if addrs:
32
+ # socket.getaddrinfo return a list of addr 5-tuple addr infos
33
+ _, _, _, _, sockaddr = addrs[0] # use the first address
34
+
35
+ # sockaddr is a tuple (ipv4) or a 4-tuple (ipv6) where the first element is the ip addr
36
+ ipaddr = str(sockaddr[0])
37
+
38
+ logger.info(
39
+ "Resolved %s address: `%s` for host: `%s`",
40
+ family.name,
41
+ ipaddr,
42
+ hostname,
43
+ )
44
+ return str(ipaddr)
45
+ else:
46
+ return None
47
+ except socket.gaierror as e:
48
+ logger.info(
49
+ "No %s address that can bind TCP sockets for host: %s. %s",
50
+ family.name,
51
+ hostname,
52
+ e,
53
+ )
54
+ return None
55
+
56
+ ipaddr = get_sockaddr(socket.AF_INET6) or get_sockaddr(socket.AF_INET)
57
+ if not ipaddr:
58
+ raise RuntimeError(
59
+ f"Unable to resolve `{hostname}` to ipv6 or ipv4 address that can bind TCP socket."
60
+ " Check the network configuration on the host."
61
+ )
62
+ return ipaddr
@@ -4,6 +4,7 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ import asyncio
7
8
  import ctypes
8
9
  import sys
9
10
 
@@ -11,7 +12,7 @@ import click
11
12
 
12
13
  from monarch._rust_bindings.monarch_extension.panic import panicking_function
13
14
 
14
- from monarch.actor_mesh import Actor, endpoint
15
+ from monarch.actor_mesh import Actor, endpoint, send
15
16
  from monarch.proc_mesh import proc_mesh
16
17
 
17
18
 
@@ -35,6 +36,12 @@ class ErrorActor(Actor):
35
36
  """Endpoint that calls a Rust function that panics."""
36
37
  panicking_function()
37
38
 
39
+ @endpoint
40
+ async def await_then_error(self) -> None:
41
+ await asyncio.sleep(0.1)
42
+ await asyncio.sleep(0.1)
43
+ raise RuntimeError("oh noez")
44
+
38
45
 
39
46
  class ErrorActorSync(Actor):
40
47
  """An actor that has endpoints cause segfaults."""
@@ -146,5 +153,28 @@ def error_bootstrap():
146
153
  proc_mesh(gpus=4, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}).get()
147
154
 
148
155
 
156
+ async def _error_unmonitored():
157
+ print("I actually ran")
158
+ sys.stdout.flush()
159
+
160
+ proc = await proc_mesh(gpus=1)
161
+ actor = await proc.spawn("error_actor", ErrorActor)
162
+
163
+ # fire and forget
164
+ send(actor.await_then_error, (), {}, None, "all")
165
+
166
+ # Wait. Eventually a supervision event will get propagated and the process
167
+ # will exit.
168
+ #
169
+ # If an event is not delivered, the test will time out before this sleep
170
+ # finishes.
171
+ await asyncio.sleep(300)
172
+
173
+
174
+ @main.command("error-unmonitored")
175
+ def error_unmonitored():
176
+ asyncio.run(_error_unmonitored())
177
+
178
+
149
179
  if __name__ == "__main__":
150
180
  main()
tests/test_actor_error.py CHANGED
@@ -4,11 +4,12 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ import asyncio
7
8
  import importlib.resources
8
9
  import subprocess
9
10
 
10
11
  import pytest
11
- from monarch.actor_mesh import Actor, ActorError, endpoint
12
+ from monarch.actor_mesh import Actor, ActorError, endpoint, send
12
13
 
13
14
  from monarch.proc_mesh import proc_mesh
14
15
 
@@ -128,6 +129,7 @@ def test_actor_supervision(num_procs, sync_endpoint, sync_test_impl, endpoint_na
128
129
  f"--endpoint-name={endpoint_name}",
129
130
  ]
130
131
  try:
132
+ print("running cmd", " ".join(cmd))
131
133
  process = subprocess.run(cmd, capture_output=True, timeout=180)
132
134
  except subprocess.TimeoutExpired as e:
133
135
  print("timeout expired")
@@ -157,6 +159,7 @@ def test_proc_mesh_bootstrap_error():
157
159
  "error-bootstrap",
158
160
  ]
159
161
  try:
162
+ print("running cmd", " ".join(cmd))
160
163
  process = subprocess.run(cmd, capture_output=True, timeout=180)
161
164
  except subprocess.TimeoutExpired as e:
162
165
  print("timeout expired")
@@ -208,3 +211,30 @@ async def test_broken_pickle_class(raise_on_getstate, raise_on_setstate, num_pro
208
211
  await exception_actor.print_value.call_one(broken_obj)
209
212
  else:
210
213
  await exception_actor.print_value.call(broken_obj)
214
+
215
+
216
+ # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
217
+ @pytest.mark.oss_skip
218
+ async def test_exception_after_wait_unmonitored():
219
+ # Run the test in a subprocess
220
+ test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
221
+ cmd = [
222
+ str(test_bin),
223
+ "error-unmonitored",
224
+ ]
225
+ try:
226
+ print("running cmd", " ".join(cmd))
227
+ process = subprocess.run(cmd, capture_output=True, timeout=180)
228
+ except subprocess.TimeoutExpired as e:
229
+ print("timeout expired")
230
+ if e.stdout is not None:
231
+ print(e.stdout.decode())
232
+ if e.stderr is not None:
233
+ print(e.stderr.decode())
234
+ raise
235
+
236
+ # Assert that the subprocess exited with a non-zero code
237
+ assert "I actually ran" in process.stdout.decode()
238
+ assert (
239
+ process.returncode != 0
240
+ ), f"Expected non-zero exit code, got {process.returncode}"
tests/test_allocator.py CHANGED
@@ -116,8 +116,8 @@ class TestRemoteAllocator(unittest.IsolatedAsyncioTestCase):
116
116
  used to test that the state of the initializer is preserved across calls to allocate()
117
117
  """
118
118
 
119
- async def initialize_alloc(self) -> list[str]:
120
- alloc = await super().initialize_alloc()
119
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
120
+ alloc = await super().initialize_alloc(match_labels)
121
121
  self.addrs.pop(-1)
122
122
  return alloc
123
123
 
@@ -142,7 +142,8 @@ class TestRemoteAllocator(unittest.IsolatedAsyncioTestCase):
142
142
  class EmptyAllocInitializer(StaticRemoteAllocInitializer):
143
143
  """test initializer that returns an empty list of addresses"""
144
144
 
145
- async def initialize_alloc(self) -> list[str]:
145
+ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]:
146
+ _ = match_labels # Suppress unused variable warning
146
147
  return []
147
148
 
148
149
  empty_initializer = EmptyAllocInitializer()
@@ -9,6 +9,7 @@ import operator
9
9
  import os
10
10
  import re
11
11
  import threading
12
+ import time
12
13
  from types import ModuleType
13
14
  from unittest.mock import AsyncMock, patch
14
15
 
@@ -391,6 +392,16 @@ def test_rust_binding_modules_correct() -> None:
391
392
  check(bindings, "monarch._rust_bindings")
392
393
 
393
394
 
395
+ def test_proc_mesh_liveness() -> None:
396
+ mesh = proc_mesh(gpus=2).get()
397
+ counter = mesh.spawn("counter", Counter, 1).get()
398
+ del mesh
399
+ # Give some time for the mesh to have been shut down.
400
+ # (It only would if there were a bug.)
401
+ time.sleep(0.5)
402
+ counter.value.call().get()
403
+
404
+
394
405
  two_gpu = pytest.mark.skipif(
395
406
  torch.cuda.device_count() < 2,
396
407
  reason="Not enough GPUs, this test requires at least 2 GPUs",
@@ -584,16 +595,40 @@ async def test_actor_tls() -> None:
584
595
  pm = await proc_mesh(gpus=1)
585
596
  am = await pm.spawn("tls", TLSActor)
586
597
  await am.increment.call_one()
587
- # TODO(suo): TLS is NOT preserved across async/sync endpoints, because currently
588
- # we run async endpoints on a different thread than sync ones.
589
- # Will fix this in a followup diff.
598
+ await am.increment_async.call_one()
599
+ await am.increment.call_one()
600
+ await am.increment_async.call_one()
601
+
602
+ assert 4 == await am.get.call_one()
603
+ assert 4 == await am.get_async.call_one()
604
+
605
+
606
+ class TLSActorFullSync(Actor):
607
+ """An actor that manages thread-local state."""
608
+
609
+ def __init__(self):
610
+ self.local = threading.local()
611
+ self.local.value = 0
612
+
613
+ @endpoint
614
+ def increment(self):
615
+ self.local.value += 1
616
+
617
+ @endpoint
618
+ def get(self):
619
+ return self.local.value
620
+
590
621
 
591
- # await am.increment_async.call_one()
622
+ async def test_actor_tls_full_sync() -> None:
623
+ """Test that thread-local state is respected."""
624
+ pm = await proc_mesh(gpus=1)
625
+ am = await pm.spawn("tls", TLSActorFullSync)
626
+ await am.increment.call_one()
627
+ await am.increment.call_one()
628
+ await am.increment.call_one()
592
629
  await am.increment.call_one()
593
- # await am.increment_async.call_one()
594
630
 
595
- assert 2 == await am.get.call_one()
596
- # assert 4 == await am.get_async.call_one()
631
+ assert 4 == await am.get.call_one()
597
632
 
598
633
 
599
634
  @two_gpu
@@ -611,3 +646,29 @@ def test_proc_mesh_tensor_engine() -> None:
611
646
  assert a == 0
612
647
  assert b == 10
613
648
  assert c == 100
649
+
650
+
651
+ class AsyncActor(Actor):
652
+ def __init__(self):
653
+ self.should_exit = False
654
+
655
+ @endpoint
656
+ async def sleep(self) -> None:
657
+ while True and not self.should_exit:
658
+ await asyncio.sleep(1)
659
+
660
+ @endpoint
661
+ async def no_more(self) -> None:
662
+ self.should_exit = True
663
+
664
+
665
+ @pytest.mark.timeout(15)
666
+ async def test_async_concurrency():
667
+ """Test that async endpoints will be processed concurrently."""
668
+ pm = await proc_mesh(gpus=1)
669
+ am = await pm.spawn("async", AsyncActor)
670
+ fut = am.sleep.call()
671
+ # This call should go through and exit the sleep loop, as long as we are
672
+ # actually concurrently processing messages.
673
+ await am.no_more.call()
674
+ await fut
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: torchmonarch-nightly
3
- Version: 2025.6.17
3
+ Version: 2025.6.19
4
4
  Summary: Monarch: Single controller library
5
5
  Author: Meta
6
6
  Author-email: oncall+monarch@xmail.facebook.com
@@ -1,8 +1,8 @@
1
1
  monarch/__init__.py,sha256=iUvWHc0-7Q2tovRoRxOIiA3TsefMXCbWl-jEfQ2djew,6897
2
- monarch/_rust_bindings.so,sha256=BIOc6AH_iVbNSGCnF3de-4l9bp82KlPwWxBWUCMKf40,40709968
2
+ monarch/_rust_bindings.so,sha256=EUkkinIuX45ihfDu4ot656fOd0CxaepnmaZdUv0cOMY,41044112
3
3
  monarch/_testing.py,sha256=jOIOG6jcZBzvEvG_DwSnwCkaMVXvSun6sJAG6nXemww,7859
4
- monarch/actor_mesh.py,sha256=nAW65WFEWMJWCv8zuH9GSOyTNXwFN8QNqZxMZTuSYxw,25537
5
- monarch/allocator.py,sha256=ylvYTf31o-PT385cYJPhi17uNbC4yl_RAraqD0fVe4g,4112
4
+ monarch/actor_mesh.py,sha256=m6QapbZHqYujXya28jW1II2wkBUV_nKGvxmWPSW9lsQ,24327
5
+ monarch/allocator.py,sha256=UEaVLntH4xQ8Lr84TbgcXusvuK8FhSMJmav-omztUbw,4473
6
6
  monarch/bootstrap_main.py,sha256=RCUQhJk07yMFiKp6HzQuqZFUpkgsT9kVEyimiwjn6_E,1827
7
7
  monarch/cached_remote_function.py,sha256=kYdB6r4OHx_T_uX4q3tCNcp1t2DJwF8tPTIahUiT2pU,8785
8
8
  monarch/debugger.py,sha256=AdlvOG3X-9Pw9c1DLQYEy4vjEfh0ZtwtsNJEFLFzN8o,13312
@@ -11,7 +11,7 @@ monarch/future.py,sha256=lcdFEe7m1shYPPuvZ1RkS6JUIChEKGBWe3v7x_nu4Hg,731
11
11
  monarch/gradient_generator.py,sha256=Rl3dmXGceTdCc1mYBg2JciR88ywGPnW7TVkL86KwqEA,6366
12
12
  monarch/memory.py,sha256=ol86dBhFAJqg78iF25-BuK0wuwj1onR8FIioZ_B0gjw,1377
13
13
  monarch/mesh_controller.py,sha256=am1QP7dvn0OH1z9ADSKm41APs1HY_dHcBAhOVP-QDmE,10427
14
- monarch/monarch_controller,sha256=yEs4PlEWgSMnRUSNWyFKvT5LmpkJ9p7GRi6WF-nsdM0,20347496
14
+ monarch/monarch_controller,sha256=sWOUMClz3JPUjZbppDWgdrPOAjbydygdRPDZ1kaAVC4,20328464
15
15
  monarch/notebook.py,sha256=zu9MKDFKf1-rCM2TqFSRJjMBeiWuKcJSyUFLvoZRQzs,25949
16
16
  monarch/opaque_module.py,sha256=oajOu_WD1hD4hxE8HDdO-tvWY7KDHWd7VaAhJEa5L2I,10446
17
17
  monarch/opaque_object.py,sha256=IVpll4pyuKZMo_EnPh4s0qnx8RlAcJrJ1yoLX6E75wQ,2782
@@ -57,7 +57,7 @@ monarch/common/function_caching.py,sha256=HVdbWtv6Eea7ENMWi8iv36w1G1TaVuUJhkUX_J
57
57
  monarch/common/future.py,sha256=D1UJ_8Rvb8-VG9vNE-z7xz2m2otMd2HgB0rnA02nlvA,4681
58
58
  monarch/common/invocation.py,sha256=L4mSmzqlHMxo1Tb71hBU_M8aBZCRCOcb6vvPhvvewec,4195
59
59
  monarch/common/mast.py,sha256=XTzYljGR0aZ7GjmNMPgU2HyuL4HWSAy4IwE3kEDqdOw,7735
60
- monarch/common/messages.py,sha256=El7BoGZ2jlP8HyyE-S8wkiG9W8Ciw3_5JERnNrgOYHU,18278
60
+ monarch/common/messages.py,sha256=OFMd_4yBoMIHjdXcKcJDG88iERfViLG3QxTqzwV4Gnw,18289
61
61
  monarch/common/mock_cuda.py,sha256=x6ho1Ton6BbKjBZ5ZxnFOUaQM032X70wnpoUNB7Ci2w,1039
62
62
  monarch/common/opaque_ref.py,sha256=tWNvOC6CsjNPKD1JDx-8PSaeXqZC3eermgBExUPKML4,2871
63
63
  monarch/common/pickle_flatten.py,sha256=2mc-dPiZy7kRqAstyfMLnPuoGJwsBftYYEHyF_HOZw4,1313
@@ -106,9 +106,10 @@ monarch/timer/example_spmd.py,sha256=p8i3_tO1AmpwSkZryiSjgkh7qaEZ6QXp2Fy1qtPpECA
106
106
  monarch/timer/execution_timer.py,sha256=1YsrLIZirdohKOeFAU2H4UcONhQXHuctJbYcoX8I6gY,6985
107
107
  monarch/timer/execution_timer_test.py,sha256=CSxTv44fFZQURJlCBmYvysQI1aS_zEGZs_uxl9SOHak,4486
108
108
  monarch/tools/__init__.py,sha256=J8qjUOysmcMAek2KFN13mViOXZxTYc5vCrF02t3VuFU,223
109
- monarch/tools/cli.py,sha256=66F7dr90bh27P3kOCmxwJkVmWv2v4wBrkifvwqwUwFE,4967
110
- monarch/tools/commands.py,sha256=BfmXndJmU_cZP4cMPlknkxGca1NjqYd8_ReDePWksXw,6908
111
- monarch/tools/mesh_spec.py,sha256=JLykhgy1dClXiNbH1Qsl2fX5MbqplQAhl8LGoragvbo,3702
109
+ monarch/tools/cli.py,sha256=EIdarsfuFX0WqRCe29_5GNKWJBhxx0lABalw3zPSagw,4977
110
+ monarch/tools/commands.py,sha256=OuFDVAcl5LvBdBZ-HyemErR0IiDtiMMNgmGPD4MWTHY,8996
111
+ monarch/tools/mesh_spec.py,sha256=3Qp7Lu3pAa9tfaG-METsCmj-QXECQ6OsrPWiLydWvKc,3914
112
+ monarch/tools/network.py,sha256=bRj-jOs5qDqnM3BcE9MSXCLS01hiMN4YSWfKZ_d7bc4,2182
112
113
  monarch/tools/components/__init__.py,sha256=J8qjUOysmcMAek2KFN13mViOXZxTYc5vCrF02t3VuFU,223
113
114
  monarch/tools/components/hyperactor.py,sha256=Ryi1X07VLcaQVlpc4af65JNBbZtOb9IAlKxSKMZ1AW4,2120
114
115
  monarch/tools/config/__init__.py,sha256=OPSflEmJB2zxAaRVzzWSWXV5M5vlknLgpulGdW1ze5U,510
@@ -131,11 +132,11 @@ monarch_supervisor/python_executable.py,sha256=WfCiK3wdAvm9Jxx5jgjGF991NgGc9-oHU
131
132
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
132
133
  tests/dispatch_bench.py,sha256=sU_m-8KAjQgYTsxI5khV664NdgLLutidni69Rtowk98,3933
133
134
  tests/dispatch_bench_helper.py,sha256=1ORgAMrRgjAjmmWeCHLLQd_bda9mJk0rS2ucEbRu28s,633
134
- tests/error_test_binary.py,sha256=64H-ucdkQ2i7GD8sidStl227cOy7gyeqvO4kTm1y7Ic,4817
135
+ tests/error_test_binary.py,sha256=BRj13wAROsUWx4jcxc07HYN2n-xyBNhnnRAhjqah-A0,5582
135
136
  tests/sleep_binary.py,sha256=XfLYaAfwm9xgzM-svs8fhAeFhwYIg6SyVEnx4e6wbUw,1009
136
- tests/test_actor_error.py,sha256=z3Sf4lteUggTryPLOhRKJ55v0MwVK3a7QN7-U2U9iJg,7484
137
+ tests/test_actor_error.py,sha256=-0UJCEpyzsBh-RdbGhDiG1-sRtu7bJPQWmtjUD0ad48,8526
137
138
  tests/test_alloc.py,sha256=D6DdQbtOZEvvnnc7LV-WyWFMk0Xb77eblH6Oz90zJTA,745
138
- tests/test_allocator.py,sha256=P11sQ95ADjzC_-CfPs3CEP80nP8sn7wW8vVPsmpSVoM,8164
139
+ tests/test_allocator.py,sha256=jaYWPVEFdcK0XmmEA1Y9uwkeBjhxb2iI1GUL6IZKh4s,8305
139
140
  tests/test_coalescing.py,sha256=JZ4YgQNlWWs7N-Z8KCCXQPANcuyyXEKjeHIXYbPnQhk,15606
140
141
  tests/test_controller.py,sha256=Rp_kW20zYT8ocsK5LX0Ha3LB9azS2LSKpp8n_dBlzVU,31384
141
142
  tests/test_device_mesh.py,sha256=DrbezYOM0thfP9MgLXb5-F0VoLOmSz5GR0GwjR_3bE4,5290
@@ -144,7 +145,7 @@ tests/test_future.py,sha256=cXzaNi2YDwVyjR541ScXmgktX1YFsKzbl8wep0DMVbk,3032
144
145
  tests/test_grad_generator.py,sha256=p4Pm4kMEeGldt2jUVAkGKCB0mLccKI28pltH6OTGbQA,3412
145
146
  tests/test_mock_cuda.py,sha256=5hisElxeLJ5MHw3KM9gwxBiXiMaG-Rm382u3AsQcDOI,3068
146
147
  tests/test_pdb_actor.py,sha256=5KJhuhcZDPWMdjC6eAtDdwnz1W7jNFXvIrMSFaCWaPw,3858
147
- tests/test_python_actors.py,sha256=YiDJaMFoQ3xPGq602QTuhRM8CsgZo5pttKMKAnLm6io,17773
148
+ tests/test_python_actors.py,sha256=3ru2JsPQmaO7ppVX3-ls7JcvIeOgEmWWUsYKZCuBXPg,19256
148
149
  tests/test_remote_functions.py,sha256=5nxYB8dfA9NT9f9Od9O3htgQtPbiRNiXZ1Kgtn75sOQ,50056
149
150
  tests/test_rust_backend.py,sha256=94S3R995ZkyIhEiBsM5flcjf5X7bscEAHBtInbTRFe8,7776
150
151
  tests/test_signal_safe_block_on.py,sha256=bmal0XgzJowZXJV6T1Blow5a-vZluYWusCThLMGxyTE,3336
@@ -154,9 +155,9 @@ tests/simulator/test_profiling.py,sha256=TGYCfzTLdkpIwnOuO6KApprmrgPIRQe60KRX3wk
154
155
  tests/simulator/test_simulator.py,sha256=LO8lA0ssY-OGEBL5ipEu74f97Y765TEwfUOv-DtIptM,14568
155
156
  tests/simulator/test_task.py,sha256=ipqBDuDAysuo1xOB9S5psaFvwe6VATD43IovCTSs0t4,2327
156
157
  tests/simulator/test_worker.py,sha256=QrWWIJ3HDgDLkBPRc2mwYPlOQoXQcj1qRfc0WUfKkFY,3507
157
- torchmonarch_nightly-2025.6.17.dist-info/licenses/LICENSE,sha256=e0Eotbf_rHOYPuEUlppIbvwy4SN98CZnl_hqwvbDA4Q,1530
158
- torchmonarch_nightly-2025.6.17.dist-info/METADATA,sha256=xnYwQ3UlDfJcHRWA86w2X71Fzl0Eddvs4u4UKveyIuo,2772
159
- torchmonarch_nightly-2025.6.17.dist-info/WHEEL,sha256=_wZSFk0d90K9wOBp8Q-UGxshyiJ987JoPiyUBNC6VLk,104
160
- torchmonarch_nightly-2025.6.17.dist-info/entry_points.txt,sha256=sqfQ16oZqjEvttUI-uj9BBXIIE6jt05bYFSmy-2hyXI,106
161
- torchmonarch_nightly-2025.6.17.dist-info/top_level.txt,sha256=E-ZssZzyM17glpVrh-S9--qJ-w9p2EjuYOuNw9tQ4Eg,33
162
- torchmonarch_nightly-2025.6.17.dist-info/RECORD,,
158
+ torchmonarch_nightly-2025.6.19.dist-info/licenses/LICENSE,sha256=e0Eotbf_rHOYPuEUlppIbvwy4SN98CZnl_hqwvbDA4Q,1530
159
+ torchmonarch_nightly-2025.6.19.dist-info/METADATA,sha256=2XYBEhTb9iSTFKhAGmq2Bg_AXwjQvcPj6CQmG4bBiLE,2772
160
+ torchmonarch_nightly-2025.6.19.dist-info/WHEEL,sha256=_wZSFk0d90K9wOBp8Q-UGxshyiJ987JoPiyUBNC6VLk,104
161
+ torchmonarch_nightly-2025.6.19.dist-info/entry_points.txt,sha256=sqfQ16oZqjEvttUI-uj9BBXIIE6jt05bYFSmy-2hyXI,106
162
+ torchmonarch_nightly-2025.6.19.dist-info/top_level.txt,sha256=E-ZssZzyM17glpVrh-S9--qJ-w9p2EjuYOuNw9tQ4Eg,33
163
+ torchmonarch_nightly-2025.6.19.dist-info/RECORD,,