torchmonarch-nightly 2025.6.11__cp310-cp310-manylinux2014_x86_64.whl → 2025.6.13__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,6 @@
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
7
  # pyre-strict
8
- import abc
9
8
 
10
9
  from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
11
10
 
@@ -29,21 +28,6 @@ from monarch._rust_bindings.monarch_hyperactor.shape import ( # @manual=//monar
29
28
  Shape,
30
29
  )
31
30
 
32
-
33
- class Actor(abc.ABC):
34
- @abc.abstractmethod
35
- async def handle(self, mailbox: Mailbox, message: PythonMessage) -> None: ...
36
-
37
- async def handle_cast(
38
- self,
39
- mailbox: Mailbox,
40
- rank: int,
41
- coordinates: list[tuple[str, int]],
42
- message: PythonMessage,
43
- ) -> None:
44
- await self.handle(mailbox, message)
45
-
46
-
47
31
  __all__ = [
48
32
  "init_proc",
49
33
  "Actor",
monarch/_rust_bindings.so CHANGED
Binary file
monarch/_testing.py CHANGED
@@ -10,7 +10,7 @@ import logging
10
10
  import tempfile
11
11
  import time
12
12
  from contextlib import contextmanager, ExitStack
13
- from typing import Callable, Generator, Optional
13
+ from typing import Any, Callable, Dict, Generator, Literal, Optional
14
14
 
15
15
  import monarch_supervisor
16
16
  from monarch.common.client import Client
@@ -18,6 +18,8 @@ from monarch.common.device_mesh import DeviceMesh
18
18
  from monarch.common.invocation import DeviceException, RemoteException
19
19
  from monarch.common.shape import NDSlice
20
20
  from monarch.controller.backend import ProcessBackend
21
+ from monarch.mesh_controller import spawn_tensor_engine
22
+ from monarch.proc_mesh import proc_mesh, ProcMesh
21
23
  from monarch.python_local_mesh import PythonLocalContext
22
24
  from monarch.rust_local_mesh import (
23
25
  local_mesh,
@@ -50,6 +52,7 @@ class TestingContext:
50
52
  self.cleanup = ExitStack()
51
53
  self._py_process_cache = {}
52
54
  self._rust_process_cache = None
55
+ self._proc_mesh_cache: Dict[Any, ProcMesh] = {}
53
56
 
54
57
  @contextmanager
55
58
  def _get_context(self, num_hosts, gpu_per_host):
@@ -75,16 +78,14 @@ class TestingContext:
75
78
 
76
79
  @contextmanager
77
80
  def local_py_device_mesh(
78
- self, num_hosts, gpu_per_host, activate=True
81
+ self,
82
+ num_hosts,
83
+ gpu_per_host,
79
84
  ) -> Generator[DeviceMesh, None, None]:
80
85
  ctx, hosts, processes = self._processes(num_hosts, gpu_per_host)
81
86
  dm = world_mesh(ctx, hosts, gpu_per_host, _processes=processes)
82
87
  try:
83
- if activate:
84
- with dm.activate():
85
- yield dm
86
- else:
87
- yield dm
88
+ yield dm
88
89
  dm.client.shutdown(destroy_pg=False)
89
90
  except Exception:
90
91
  # abnormal exit, so we just make sure we do not try to communicate in destructors,
@@ -97,7 +98,6 @@ class TestingContext:
97
98
  self,
98
99
  num_hosts,
99
100
  gpu_per_host,
100
- activate: bool = True,
101
101
  controller_params=None,
102
102
  ) -> Generator[DeviceMesh, None, None]:
103
103
  # Create a new system and mesh for test.
@@ -115,11 +115,7 @@ class TestingContext:
115
115
  controller_params=controller_params,
116
116
  ) as dm:
117
117
  try:
118
- if activate:
119
- with dm.activate():
120
- yield dm
121
- else:
122
- yield dm
118
+ yield dm
123
119
  dm.exit()
124
120
  except Exception:
125
121
  dm.client._shutdown = True
@@ -129,21 +125,57 @@ class TestingContext:
129
125
  # pyre-ignore: Undefined attribute
130
126
  dm.client.inner._actor.stop()
131
127
 
128
+ @contextmanager
129
+ def local_engine_on_proc_mesh(
130
+ self,
131
+ num_hosts,
132
+ gpu_per_host,
133
+ ) -> Generator[DeviceMesh, None, None]:
134
+ key = (num_hosts, gpu_per_host)
135
+ if key not in self._proc_mesh_cache:
136
+ self._proc_mesh_cache[key] = proc_mesh(
137
+ hosts=num_hosts, gpus=gpu_per_host
138
+ ).get()
139
+
140
+ dm = spawn_tensor_engine(self._proc_mesh_cache[key])
141
+ dm = dm.rename(hosts="host", gpus="gpu")
142
+ try:
143
+ yield dm
144
+ dm.exit()
145
+ except Exception as e:
146
+ # abnormal exit, so we just make sure we do not try to communicate in destructors,
147
+ # but we do notn wait for workers to exit since we do not know what state they are in.
148
+ dm.client._shutdown = True
149
+ raise
150
+
132
151
  @contextmanager
133
152
  def local_device_mesh(
134
- self, num_hosts, gpu_per_host, activate=True, rust=False, controller_params=None
153
+ self,
154
+ num_hosts,
155
+ gpu_per_host,
156
+ activate=True,
157
+ backend: Literal["py", "rs", "mesh"] = "py",
158
+ controller_params=None,
135
159
  ) -> Generator[DeviceMesh, None, None]:
136
160
  start = time.time()
137
- if rust:
161
+ if backend == "rs":
138
162
  generator = self.local_rust_device_mesh(
139
- num_hosts, gpu_per_host, activate, controller_params=controller_params
163
+ num_hosts, gpu_per_host, controller_params=controller_params
140
164
  )
165
+ elif backend == "py":
166
+ generator = self.local_py_device_mesh(num_hosts, gpu_per_host)
167
+ elif backend == "mesh":
168
+ generator = self.local_engine_on_proc_mesh(num_hosts, gpu_per_host)
141
169
  else:
142
- generator = self.local_py_device_mesh(num_hosts, gpu_per_host, activate)
170
+ raise ValueError(f"invalid backend: {backend}")
143
171
  with generator as dm:
144
172
  end = time.time()
145
173
  logging.info("initialized mesh in {:.2f}s".format(end - start))
146
- yield dm
174
+ if activate:
175
+ with dm.activate():
176
+ yield dm
177
+ else:
178
+ yield dm
147
179
  start = time.time()
148
180
  end = time.time()
149
181
  logging.info("shutdown mesh in {:.2f}s".format(end - start))
monarch/actor_mesh.py CHANGED
@@ -4,6 +4,8 @@
4
4
  # This source code is licensed under the BSD-style license found in the
5
5
  # LICENSE file in the root directory of this source tree.
6
6
 
7
+ # pyre-unsafe
8
+
7
9
  import asyncio
8
10
  import collections
9
11
  import contextvars
@@ -13,6 +15,7 @@ import inspect
13
15
  import itertools
14
16
  import logging
15
17
  import random
18
+ import sys
16
19
  import traceback
17
20
 
18
21
  from dataclasses import dataclass
@@ -20,6 +23,7 @@ from traceback import extract_tb, StackSummary
20
23
  from typing import (
21
24
  Any,
22
25
  AsyncGenerator,
26
+ Awaitable,
23
27
  Callable,
24
28
  cast,
25
29
  Concatenate,
@@ -34,6 +38,7 @@ from typing import (
34
38
  ParamSpec,
35
39
  Tuple,
36
40
  Type,
41
+ TYPE_CHECKING,
37
42
  TypeVar,
38
43
  )
39
44
 
@@ -54,8 +59,12 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Sh
54
59
 
55
60
  from monarch.common.pickle_flatten import flatten, unflatten
56
61
  from monarch.common.shape import MeshTrait, NDSlice
62
+ from monarch.pdb_wrapper import remote_breakpointhook
63
+
64
+ if TYPE_CHECKING:
65
+ from monarch.debugger import DebugClient
57
66
 
58
- logger = logging.getLogger(__name__)
67
+ logger: logging.Logger = logging.getLogger(__name__)
59
68
 
60
69
  Allocator = monarch.ProcessAllocator | monarch.LocalAllocator
61
70
 
@@ -92,7 +101,7 @@ _context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
92
101
 
93
102
  # this was implemented in python 3.12 as an argument to task
94
103
  # but I have to backport to 3.10/3.11.
95
- def create_eager_task(coro: Coroutine[Any, None, Any]) -> asyncio.Future:
104
+ def create_eager_task(coro: Awaitable[None]) -> asyncio.Future:
96
105
  iter = coro.__await__()
97
106
  try:
98
107
  first_yield = next(iter)
@@ -235,7 +244,7 @@ class Endpoint(Generic[P, R]):
235
244
  self,
236
245
  actor_mesh_ref: _ActorMeshRefImpl,
237
246
  name: str,
238
- impl: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]],
247
+ impl: Callable[Concatenate[Any, P], Awaitable[R]],
239
248
  mailbox: Mailbox,
240
249
  ) -> None:
241
250
  self._actor_mesh = actor_mesh_ref
@@ -267,14 +276,16 @@ class Endpoint(Generic[P, R]):
267
276
  return self.choose(*args, **kwargs)
268
277
 
269
278
  def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
279
+ p: PortId
280
+ r: PortReceiver[R]
270
281
  p, r = port(self)
271
282
  # pyre-ignore
272
283
  send(self, args, kwargs, port=p, rank_in_response=True)
273
284
 
274
- async def process():
275
- results = [None] * len(self._actor_mesh)
285
+ async def process() -> ValueMesh[R]:
286
+ results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
276
287
  for _ in range(len(self._actor_mesh)):
277
- rank, value = await r.recv()
288
+ rank, value = await r.recv() # pyre-fixme[23]
278
289
  results[rank] = value
279
290
  call_shape = Shape(
280
291
  self._actor_mesh._shape.labels,
@@ -312,15 +323,15 @@ class Endpoint(Generic[P, R]):
312
323
  class Accumulator(Generic[P, R, A]):
313
324
  def __init__(
314
325
  self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
315
- ):
316
- self._endpoint = endpoint
317
- self._identity = identity
318
- self._combine = combine
326
+ ) -> None:
327
+ self._endpoint: Endpoint[P, R] = endpoint
328
+ self._identity: A = identity
329
+ self._combine: Callable[[A, R], A] = combine
319
330
 
320
331
  def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
321
- gen = self._endpoint.stream(*args, **kwargs)
332
+ gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs)
322
333
 
323
- async def impl():
334
+ async def impl() -> A:
324
335
  value = self._identity
325
336
  async for x in gen:
326
337
  value = self._combine(value, x)
@@ -337,7 +348,7 @@ class ValueMesh(MeshTrait, Generic[R]):
337
348
  def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
338
349
  return ValueMesh(shape, self._values)
339
350
 
340
- def item(self, **kwargs):
351
+ def item(self, **kwargs) -> R:
341
352
  coordinates = [kwargs.pop(label) for label in self._labels]
342
353
  if kwargs:
343
354
  raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
@@ -348,7 +359,7 @@ class ValueMesh(MeshTrait, Generic[R]):
348
359
  for rank in self._shape.ranks():
349
360
  yield Point(rank, self._shape), self._values[rank]
350
361
 
351
- def __len__(self):
362
+ def __len__(self) -> int:
352
363
  return len(self._shape)
353
364
 
354
365
  @property
@@ -381,7 +392,7 @@ def send(
381
392
 
382
393
 
383
394
  class EndpointProperty(Generic[P, R]):
384
- def __init__(self, method: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]]):
395
+ def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None:
385
396
  self._method = method
386
397
 
387
398
  def __get__(self, instance, owner) -> Endpoint[P, R]:
@@ -392,7 +403,7 @@ class EndpointProperty(Generic[P, R]):
392
403
 
393
404
 
394
405
  def endpoint(
395
- method: Callable[Concatenate[Any, P], Coroutine[Any, Any, R]],
406
+ method: Callable[Concatenate[Any, P], Awaitable[R]],
396
407
  ) -> EndpointProperty[P, R]:
397
408
  return EndpointProperty(method)
398
409
 
@@ -415,7 +426,9 @@ class Port:
415
426
  # advance lower-level API for sending messages. This is intentially
416
427
  # not part of the Endpoint API because they way it accepts arguments
417
428
  # and handles concerns is different.
418
- def port(endpoint: Endpoint[P, R], once=False) -> Tuple["PortId", "PortReceiver[R]"]:
429
+ def port(
430
+ endpoint: Endpoint[P, R], once: bool = False
431
+ ) -> Tuple["PortId", "PortReceiver[R]"]:
419
432
  handle, receiver = (
420
433
  endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
421
434
  )
@@ -428,9 +441,9 @@ class PortReceiver(Generic[R]):
428
441
  self,
429
442
  mailbox: Mailbox,
430
443
  receiver: HyPortReceiver | OncePortReceiver,
431
- ):
432
- self._mailbox = mailbox
433
- self._receiver = receiver
444
+ ) -> None:
445
+ self._mailbox: Mailbox = mailbox
446
+ self._receiver: HyPortReceiver | OncePortReceiver = receiver
434
447
 
435
448
  async def _recv(self) -> R:
436
449
  return self._process(await self._receiver.recv())
@@ -438,7 +451,7 @@ class PortReceiver(Generic[R]):
438
451
  def _blocking_recv(self) -> R:
439
452
  return self._process(self._receiver.blocking_recv())
440
453
 
441
- def _process(self, msg: PythonMessage):
454
+ def _process(self, msg: PythonMessage) -> R:
442
455
  # TODO: Try to do something more structured than a cast here
443
456
  payload = cast(R, _unpickle(msg.message, self._mailbox))
444
457
  if msg.method == "result":
@@ -485,7 +498,9 @@ class _Actor:
485
498
  else None
486
499
  )
487
500
  try:
488
- ctx = MonarchContext(mailbox, mailbox.actor_id.proc_id, Point(rank, shape))
501
+ ctx: MonarchContext = MonarchContext(
502
+ mailbox, mailbox.actor_id.proc_id, Point(rank, shape)
503
+ )
489
504
  _context.set(ctx)
490
505
 
491
506
  args, kwargs = _unpickle(message.message, mailbox)
@@ -510,7 +525,14 @@ class _Actor:
510
525
  enter_span(
511
526
  the_method.__module__, message.method, str(ctx.mailbox.actor_id)
512
527
  )
513
- result = await the_method(self.instance, *args, **kwargs)
528
+ try:
529
+ result = await the_method(self.instance, *args, **kwargs)
530
+ except Exception as e:
531
+ logging.critical(
532
+ "Unahndled exception in actor endpoint",
533
+ exc_info=e,
534
+ )
535
+ raise e
514
536
  exit_span()
515
537
  return result
516
538
 
@@ -532,14 +554,19 @@ class _Actor:
532
554
  async def run_async(
533
555
  self,
534
556
  ctx: MonarchContext,
535
- coroutine: Coroutine[Any, None, Any],
557
+ coroutine: Awaitable[None],
536
558
  ) -> None:
537
559
  _context.set(ctx)
538
560
  if self.complete_task is None:
539
561
  self.complete_task = asyncio.create_task(self._complete())
540
562
  await self.active_requests.put(create_eager_task(coroutine))
541
563
 
542
- async def run_task(self, port, coroutine, panic_flag):
564
+ async def run_task(
565
+ self,
566
+ port: Port | None,
567
+ coroutine: Awaitable[Any],
568
+ panic_flag: PanicFlag,
569
+ ) -> None:
543
570
  try:
544
571
  result = await coroutine
545
572
  if port is not None:
@@ -610,15 +637,28 @@ class Actor(MeshTrait):
610
637
  "actor implementations are not meshes, but we can't convince the typechecker of it..."
611
638
  )
612
639
 
640
+ @endpoint
641
+ async def _set_debug_client(self, client: "DebugClient") -> None:
642
+ point = MonarchContext.get().point
643
+ # For some reason, using a lambda instead of functools.partial
644
+ # confuses the pdb wrapper implementation.
645
+ sys.breakpointhook = functools.partial( # pyre-ignore
646
+ remote_breakpointhook,
647
+ point.rank,
648
+ point.shape.coordinates(point.rank),
649
+ MonarchContext.get().mailbox.actor_id,
650
+ client,
651
+ )
652
+
613
653
 
614
654
  class ActorMeshRef(MeshTrait):
615
655
  def __init__(
616
656
  self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
617
657
  ) -> None:
618
- self.__name__ = Class.__name__
619
- self._class = Class
620
- self._actor_mesh_ref = actor_mesh_ref
621
- self._mailbox = mailbox
658
+ self.__name__: str = Class.__name__
659
+ self._class: Type[T] = Class
660
+ self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref
661
+ self._mailbox: Mailbox = mailbox
622
662
  for attr_name in dir(self._class):
623
663
  attr_value = getattr(self._class, attr_name, None)
624
664
  if isinstance(attr_value, EndpointProperty):
@@ -659,7 +699,11 @@ class ActorMeshRef(MeshTrait):
659
699
  f"'{self.__class__.__name__}' object has no attribute '{name}'"
660
700
  )
661
701
 
662
- def _create(self, args: Iterable[Any], kwargs: Dict[str, Any]) -> None:
702
+ def _create(
703
+ self,
704
+ args: Iterable[Any],
705
+ kwargs: Dict[str, Any],
706
+ ) -> None:
663
707
  async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
664
708
  return None
665
709
 
monarch/bootstrap_main.py CHANGED
@@ -30,28 +30,9 @@ def invoke_main():
30
30
  # behavior of std out as if it were a terminal.
31
31
  sys.stdout.reconfigure(line_buffering=True)
32
32
  global bootstrap_main
33
- from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
34
- forward_to_tracing,
35
- )
36
33
 
37
34
  # TODO: figure out what from worker_main.py we should reproduce here.
38
-
39
- class TracingForwarder(logging.Handler):
40
- def emit(self, record: logging.LogRecord) -> None:
41
- try:
42
- forward_to_tracing(
43
- record.getMessage(),
44
- record.filename or "",
45
- record.lineno or 0,
46
- record.levelno,
47
- )
48
- except AttributeError:
49
- forward_to_tracing(
50
- record.__str__(),
51
- record.filename or "",
52
- record.lineno or 0,
53
- record.levelno,
54
- )
35
+ from monarch.telemetry import TracingForwarder
55
36
 
56
37
  if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
57
38
  raise RuntimeError("Error during bootstrap for testing")
@@ -16,11 +16,6 @@ def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
16
16
  torch.manual_seed(seed ^ process_idx)
17
17
 
18
18
 
19
- @remote(propagate=lambda: 0)
20
- def initial_seed_remote() -> int:
21
- return torch.initial_seed()
22
-
23
-
24
19
  @remote(propagate=lambda: torch.zeros(1))
25
20
  def get_rng_state_remote() -> torch.Tensor:
26
21
  return torch.get_rng_state()
@@ -67,3 +62,7 @@ def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
67
62
  @remote(propagate="inspect")
68
63
  def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
69
64
  torch.cuda.set_rng_state_all(states)
65
+
66
+
67
+ # initial_seed may sometimes return a uint64 which currenly can't be unwrapped by the framework
68
+ # def initial_seed_remote() -> int: ...
monarch/common/client.py CHANGED
@@ -103,6 +103,13 @@ class Client:
103
103
  # workers.
104
104
  self.last_processed_seq = -1
105
105
 
106
+ # an error that we have received but know for certain has not
107
+ # been propagated to a future. This will be reported on shutdown
108
+ # to avoid hiding the error. This is best effort: we only keep
109
+ # the error until the point the a future is dependent on
110
+ # _any_ error, not particularly the tracked one.
111
+ self._pending_shutdown_error = None
112
+
106
113
  self.recorder = Recorder()
107
114
 
108
115
  self.pending_results: Dict[
@@ -174,6 +181,8 @@ class Client:
174
181
  destroy_pg: bool = True,
175
182
  error_reason: Optional[RemoteException | DeviceException | Exception] = None,
176
183
  ) -> None:
184
+ if self.has_shutdown:
185
+ return
177
186
  logger.info("shutting down the client gracefully")
178
187
 
179
188
  atexit.unregister(self._atexit)
@@ -302,7 +311,8 @@ class Client:
302
311
  self.last_processed_seq = max(self.last_processed_seq, seq)
303
312
 
304
313
  if error is not None:
305
- logging.error("Received error for seq %s: %s", seq, error)
314
+ logging.info("Received error for seq %s: %s", seq, error)
315
+ self._pending_shutdown_error = error
306
316
  # We should not have set result if we have an error.
307
317
  assert result is None
308
318
  if not isinstance(error, RemoteException):
@@ -326,15 +336,17 @@ class Client:
326
336
 
327
337
  fut, _ = self.pending_results[seq]
328
338
  if fut is not None:
329
- fut._set_result(result if error is None else error)
339
+ if error is None:
340
+ fut._set_result(result)
341
+ else:
342
+ fut._set_result(error)
343
+ self._pending_shutdown_error = None
330
344
  elif result is not None:
331
345
  logger.debug(f"{seq}: unused result {result}")
332
346
  elif error is not None:
333
347
  # errors get reported as results even if they
334
348
  # do not have futures attached.
335
- logger.warning(
336
- f"Error encountered for this instruction {seq}. Proceeding forward because error is unused and unhandled. Error details:\n{error}."
337
- )
349
+ pass
338
350
 
339
351
  # We can safely delete the seq as tracebacks have been saved to the remote failure itself.
340
352
  del self.pending_results[seq]
monarch/common/stream.py CHANGED
@@ -82,6 +82,9 @@ class StreamRef(Referenceable):
82
82
  messages.CreateStream(self, self.default),
83
83
  )
84
84
 
85
+ def __repr__(self):
86
+ return f"<StreamRef {repr(self.name)} {self.ref}>"
87
+
85
88
  def delete_ref(self, ref):
86
89
  client = self.client()
87
90
  if client is not None and not client._shutdown: