torchmonarch-nightly 2025.6.11__cp310-cp310-manylinux2014_x86_64.whl → 2025.6.13__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/_monarch/hyperactor/__init__.py +0 -16
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +50 -18
- monarch/actor_mesh.py +74 -30
- monarch/bootstrap_main.py +1 -20
- monarch/builtins/random.py +4 -5
- monarch/common/client.py +17 -5
- monarch/common/stream.py +3 -0
- monarch/debugger.py +377 -0
- monarch/mesh_controller.py +72 -15
- monarch/monarch_controller +0 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +9 -5
- monarch/telemetry.py +19 -0
- tests/test_allocator.py +3 -3
- tests/test_coalescing.py +1 -1
- tests/test_controller.py +12 -2
- tests/test_python_actors.py +150 -0
- tests/test_remote_functions.py +1 -1
- {torchmonarch_nightly-2025.6.11.dist-info → torchmonarch_nightly-2025.6.13.dist-info}/METADATA +1 -1
- {torchmonarch_nightly-2025.6.11.dist-info → torchmonarch_nightly-2025.6.13.dist-info}/RECORD +25 -22
- {torchmonarch_nightly-2025.6.11.dist-info → torchmonarch_nightly-2025.6.13.dist-info}/WHEEL +0 -0
- {torchmonarch_nightly-2025.6.11.dist-info → torchmonarch_nightly-2025.6.13.dist-info}/entry_points.txt +0 -0
- {torchmonarch_nightly-2025.6.11.dist-info → torchmonarch_nightly-2025.6.13.dist-info}/licenses/LICENSE +0 -0
- {torchmonarch_nightly-2025.6.11.dist-info → torchmonarch_nightly-2025.6.13.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,6 @@
|
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
6
6
|
|
7
7
|
# pyre-strict
|
8
|
-
import abc
|
9
8
|
|
10
9
|
from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage
|
11
10
|
|
@@ -29,21 +28,6 @@ from monarch._rust_bindings.monarch_hyperactor.shape import ( # @manual=//monar
|
|
29
28
|
Shape,
|
30
29
|
)
|
31
30
|
|
32
|
-
|
33
|
-
class Actor(abc.ABC):
|
34
|
-
@abc.abstractmethod
|
35
|
-
async def handle(self, mailbox: Mailbox, message: PythonMessage) -> None: ...
|
36
|
-
|
37
|
-
async def handle_cast(
|
38
|
-
self,
|
39
|
-
mailbox: Mailbox,
|
40
|
-
rank: int,
|
41
|
-
coordinates: list[tuple[str, int]],
|
42
|
-
message: PythonMessage,
|
43
|
-
) -> None:
|
44
|
-
await self.handle(mailbox, message)
|
45
|
-
|
46
|
-
|
47
31
|
__all__ = [
|
48
32
|
"init_proc",
|
49
33
|
"Actor",
|
monarch/_rust_bindings.so
CHANGED
Binary file
|
monarch/_testing.py
CHANGED
@@ -10,7 +10,7 @@ import logging
|
|
10
10
|
import tempfile
|
11
11
|
import time
|
12
12
|
from contextlib import contextmanager, ExitStack
|
13
|
-
from typing import Callable, Generator, Optional
|
13
|
+
from typing import Any, Callable, Dict, Generator, Literal, Optional
|
14
14
|
|
15
15
|
import monarch_supervisor
|
16
16
|
from monarch.common.client import Client
|
@@ -18,6 +18,8 @@ from monarch.common.device_mesh import DeviceMesh
|
|
18
18
|
from monarch.common.invocation import DeviceException, RemoteException
|
19
19
|
from monarch.common.shape import NDSlice
|
20
20
|
from monarch.controller.backend import ProcessBackend
|
21
|
+
from monarch.mesh_controller import spawn_tensor_engine
|
22
|
+
from monarch.proc_mesh import proc_mesh, ProcMesh
|
21
23
|
from monarch.python_local_mesh import PythonLocalContext
|
22
24
|
from monarch.rust_local_mesh import (
|
23
25
|
local_mesh,
|
@@ -50,6 +52,7 @@ class TestingContext:
|
|
50
52
|
self.cleanup = ExitStack()
|
51
53
|
self._py_process_cache = {}
|
52
54
|
self._rust_process_cache = None
|
55
|
+
self._proc_mesh_cache: Dict[Any, ProcMesh] = {}
|
53
56
|
|
54
57
|
@contextmanager
|
55
58
|
def _get_context(self, num_hosts, gpu_per_host):
|
@@ -75,16 +78,14 @@ class TestingContext:
|
|
75
78
|
|
76
79
|
@contextmanager
|
77
80
|
def local_py_device_mesh(
|
78
|
-
self,
|
81
|
+
self,
|
82
|
+
num_hosts,
|
83
|
+
gpu_per_host,
|
79
84
|
) -> Generator[DeviceMesh, None, None]:
|
80
85
|
ctx, hosts, processes = self._processes(num_hosts, gpu_per_host)
|
81
86
|
dm = world_mesh(ctx, hosts, gpu_per_host, _processes=processes)
|
82
87
|
try:
|
83
|
-
|
84
|
-
with dm.activate():
|
85
|
-
yield dm
|
86
|
-
else:
|
87
|
-
yield dm
|
88
|
+
yield dm
|
88
89
|
dm.client.shutdown(destroy_pg=False)
|
89
90
|
except Exception:
|
90
91
|
# abnormal exit, so we just make sure we do not try to communicate in destructors,
|
@@ -97,7 +98,6 @@ class TestingContext:
|
|
97
98
|
self,
|
98
99
|
num_hosts,
|
99
100
|
gpu_per_host,
|
100
|
-
activate: bool = True,
|
101
101
|
controller_params=None,
|
102
102
|
) -> Generator[DeviceMesh, None, None]:
|
103
103
|
# Create a new system and mesh for test.
|
@@ -115,11 +115,7 @@ class TestingContext:
|
|
115
115
|
controller_params=controller_params,
|
116
116
|
) as dm:
|
117
117
|
try:
|
118
|
-
|
119
|
-
with dm.activate():
|
120
|
-
yield dm
|
121
|
-
else:
|
122
|
-
yield dm
|
118
|
+
yield dm
|
123
119
|
dm.exit()
|
124
120
|
except Exception:
|
125
121
|
dm.client._shutdown = True
|
@@ -129,21 +125,57 @@ class TestingContext:
|
|
129
125
|
# pyre-ignore: Undefined attribute
|
130
126
|
dm.client.inner._actor.stop()
|
131
127
|
|
128
|
+
@contextmanager
|
129
|
+
def local_engine_on_proc_mesh(
|
130
|
+
self,
|
131
|
+
num_hosts,
|
132
|
+
gpu_per_host,
|
133
|
+
) -> Generator[DeviceMesh, None, None]:
|
134
|
+
key = (num_hosts, gpu_per_host)
|
135
|
+
if key not in self._proc_mesh_cache:
|
136
|
+
self._proc_mesh_cache[key] = proc_mesh(
|
137
|
+
hosts=num_hosts, gpus=gpu_per_host
|
138
|
+
).get()
|
139
|
+
|
140
|
+
dm = spawn_tensor_engine(self._proc_mesh_cache[key])
|
141
|
+
dm = dm.rename(hosts="host", gpus="gpu")
|
142
|
+
try:
|
143
|
+
yield dm
|
144
|
+
dm.exit()
|
145
|
+
except Exception as e:
|
146
|
+
# abnormal exit, so we just make sure we do not try to communicate in destructors,
|
147
|
+
# but we do notn wait for workers to exit since we do not know what state they are in.
|
148
|
+
dm.client._shutdown = True
|
149
|
+
raise
|
150
|
+
|
132
151
|
@contextmanager
|
133
152
|
def local_device_mesh(
|
134
|
-
self,
|
153
|
+
self,
|
154
|
+
num_hosts,
|
155
|
+
gpu_per_host,
|
156
|
+
activate=True,
|
157
|
+
backend: Literal["py", "rs", "mesh"] = "py",
|
158
|
+
controller_params=None,
|
135
159
|
) -> Generator[DeviceMesh, None, None]:
|
136
160
|
start = time.time()
|
137
|
-
if
|
161
|
+
if backend == "rs":
|
138
162
|
generator = self.local_rust_device_mesh(
|
139
|
-
num_hosts, gpu_per_host,
|
163
|
+
num_hosts, gpu_per_host, controller_params=controller_params
|
140
164
|
)
|
165
|
+
elif backend == "py":
|
166
|
+
generator = self.local_py_device_mesh(num_hosts, gpu_per_host)
|
167
|
+
elif backend == "mesh":
|
168
|
+
generator = self.local_engine_on_proc_mesh(num_hosts, gpu_per_host)
|
141
169
|
else:
|
142
|
-
|
170
|
+
raise ValueError(f"invalid backend: {backend}")
|
143
171
|
with generator as dm:
|
144
172
|
end = time.time()
|
145
173
|
logging.info("initialized mesh in {:.2f}s".format(end - start))
|
146
|
-
|
174
|
+
if activate:
|
175
|
+
with dm.activate():
|
176
|
+
yield dm
|
177
|
+
else:
|
178
|
+
yield dm
|
147
179
|
start = time.time()
|
148
180
|
end = time.time()
|
149
181
|
logging.info("shutdown mesh in {:.2f}s".format(end - start))
|
monarch/actor_mesh.py
CHANGED
@@ -4,6 +4,8 @@
|
|
4
4
|
# This source code is licensed under the BSD-style license found in the
|
5
5
|
# LICENSE file in the root directory of this source tree.
|
6
6
|
|
7
|
+
# pyre-unsafe
|
8
|
+
|
7
9
|
import asyncio
|
8
10
|
import collections
|
9
11
|
import contextvars
|
@@ -13,6 +15,7 @@ import inspect
|
|
13
15
|
import itertools
|
14
16
|
import logging
|
15
17
|
import random
|
18
|
+
import sys
|
16
19
|
import traceback
|
17
20
|
|
18
21
|
from dataclasses import dataclass
|
@@ -20,6 +23,7 @@ from traceback import extract_tb, StackSummary
|
|
20
23
|
from typing import (
|
21
24
|
Any,
|
22
25
|
AsyncGenerator,
|
26
|
+
Awaitable,
|
23
27
|
Callable,
|
24
28
|
cast,
|
25
29
|
Concatenate,
|
@@ -34,6 +38,7 @@ from typing import (
|
|
34
38
|
ParamSpec,
|
35
39
|
Tuple,
|
36
40
|
Type,
|
41
|
+
TYPE_CHECKING,
|
37
42
|
TypeVar,
|
38
43
|
)
|
39
44
|
|
@@ -54,8 +59,12 @@ from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Sh
|
|
54
59
|
|
55
60
|
from monarch.common.pickle_flatten import flatten, unflatten
|
56
61
|
from monarch.common.shape import MeshTrait, NDSlice
|
62
|
+
from monarch.pdb_wrapper import remote_breakpointhook
|
63
|
+
|
64
|
+
if TYPE_CHECKING:
|
65
|
+
from monarch.debugger import DebugClient
|
57
66
|
|
58
|
-
logger = logging.getLogger(__name__)
|
67
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
59
68
|
|
60
69
|
Allocator = monarch.ProcessAllocator | monarch.LocalAllocator
|
61
70
|
|
@@ -92,7 +101,7 @@ _context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar(
|
|
92
101
|
|
93
102
|
# this was implemented in python 3.12 as an argument to task
|
94
103
|
# but I have to backport to 3.10/3.11.
|
95
|
-
def create_eager_task(coro:
|
104
|
+
def create_eager_task(coro: Awaitable[None]) -> asyncio.Future:
|
96
105
|
iter = coro.__await__()
|
97
106
|
try:
|
98
107
|
first_yield = next(iter)
|
@@ -235,7 +244,7 @@ class Endpoint(Generic[P, R]):
|
|
235
244
|
self,
|
236
245
|
actor_mesh_ref: _ActorMeshRefImpl,
|
237
246
|
name: str,
|
238
|
-
impl: Callable[Concatenate[Any, P],
|
247
|
+
impl: Callable[Concatenate[Any, P], Awaitable[R]],
|
239
248
|
mailbox: Mailbox,
|
240
249
|
) -> None:
|
241
250
|
self._actor_mesh = actor_mesh_ref
|
@@ -267,14 +276,16 @@ class Endpoint(Generic[P, R]):
|
|
267
276
|
return self.choose(*args, **kwargs)
|
268
277
|
|
269
278
|
def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
|
279
|
+
p: PortId
|
280
|
+
r: PortReceiver[R]
|
270
281
|
p, r = port(self)
|
271
282
|
# pyre-ignore
|
272
283
|
send(self, args, kwargs, port=p, rank_in_response=True)
|
273
284
|
|
274
|
-
async def process():
|
275
|
-
results = [None] * len(self._actor_mesh)
|
285
|
+
async def process() -> ValueMesh[R]:
|
286
|
+
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
|
276
287
|
for _ in range(len(self._actor_mesh)):
|
277
|
-
rank, value = await r.recv()
|
288
|
+
rank, value = await r.recv() # pyre-fixme[23]
|
278
289
|
results[rank] = value
|
279
290
|
call_shape = Shape(
|
280
291
|
self._actor_mesh._shape.labels,
|
@@ -312,15 +323,15 @@ class Endpoint(Generic[P, R]):
|
|
312
323
|
class Accumulator(Generic[P, R, A]):
|
313
324
|
def __init__(
|
314
325
|
self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
|
315
|
-
):
|
316
|
-
self._endpoint = endpoint
|
317
|
-
self._identity = identity
|
318
|
-
self._combine = combine
|
326
|
+
) -> None:
|
327
|
+
self._endpoint: Endpoint[P, R] = endpoint
|
328
|
+
self._identity: A = identity
|
329
|
+
self._combine: Callable[[A, R], A] = combine
|
319
330
|
|
320
331
|
def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]":
|
321
|
-
gen = self._endpoint.stream(*args, **kwargs)
|
332
|
+
gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs)
|
322
333
|
|
323
|
-
async def impl():
|
334
|
+
async def impl() -> A:
|
324
335
|
value = self._identity
|
325
336
|
async for x in gen:
|
326
337
|
value = self._combine(value, x)
|
@@ -337,7 +348,7 @@ class ValueMesh(MeshTrait, Generic[R]):
|
|
337
348
|
def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]":
|
338
349
|
return ValueMesh(shape, self._values)
|
339
350
|
|
340
|
-
def item(self, **kwargs):
|
351
|
+
def item(self, **kwargs) -> R:
|
341
352
|
coordinates = [kwargs.pop(label) for label in self._labels]
|
342
353
|
if kwargs:
|
343
354
|
raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}")
|
@@ -348,7 +359,7 @@ class ValueMesh(MeshTrait, Generic[R]):
|
|
348
359
|
for rank in self._shape.ranks():
|
349
360
|
yield Point(rank, self._shape), self._values[rank]
|
350
361
|
|
351
|
-
def __len__(self):
|
362
|
+
def __len__(self) -> int:
|
352
363
|
return len(self._shape)
|
353
364
|
|
354
365
|
@property
|
@@ -381,7 +392,7 @@ def send(
|
|
381
392
|
|
382
393
|
|
383
394
|
class EndpointProperty(Generic[P, R]):
|
384
|
-
def __init__(self, method: Callable[Concatenate[Any, P],
|
395
|
+
def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None:
|
385
396
|
self._method = method
|
386
397
|
|
387
398
|
def __get__(self, instance, owner) -> Endpoint[P, R]:
|
@@ -392,7 +403,7 @@ class EndpointProperty(Generic[P, R]):
|
|
392
403
|
|
393
404
|
|
394
405
|
def endpoint(
|
395
|
-
method: Callable[Concatenate[Any, P],
|
406
|
+
method: Callable[Concatenate[Any, P], Awaitable[R]],
|
396
407
|
) -> EndpointProperty[P, R]:
|
397
408
|
return EndpointProperty(method)
|
398
409
|
|
@@ -415,7 +426,9 @@ class Port:
|
|
415
426
|
# advance lower-level API for sending messages. This is intentially
|
416
427
|
# not part of the Endpoint API because they way it accepts arguments
|
417
428
|
# and handles concerns is different.
|
418
|
-
def port(
|
429
|
+
def port(
|
430
|
+
endpoint: Endpoint[P, R], once: bool = False
|
431
|
+
) -> Tuple["PortId", "PortReceiver[R]"]:
|
419
432
|
handle, receiver = (
|
420
433
|
endpoint._mailbox.open_once_port() if once else endpoint._mailbox.open_port()
|
421
434
|
)
|
@@ -428,9 +441,9 @@ class PortReceiver(Generic[R]):
|
|
428
441
|
self,
|
429
442
|
mailbox: Mailbox,
|
430
443
|
receiver: HyPortReceiver | OncePortReceiver,
|
431
|
-
):
|
432
|
-
self._mailbox = mailbox
|
433
|
-
self._receiver = receiver
|
444
|
+
) -> None:
|
445
|
+
self._mailbox: Mailbox = mailbox
|
446
|
+
self._receiver: HyPortReceiver | OncePortReceiver = receiver
|
434
447
|
|
435
448
|
async def _recv(self) -> R:
|
436
449
|
return self._process(await self._receiver.recv())
|
@@ -438,7 +451,7 @@ class PortReceiver(Generic[R]):
|
|
438
451
|
def _blocking_recv(self) -> R:
|
439
452
|
return self._process(self._receiver.blocking_recv())
|
440
453
|
|
441
|
-
def _process(self, msg: PythonMessage):
|
454
|
+
def _process(self, msg: PythonMessage) -> R:
|
442
455
|
# TODO: Try to do something more structured than a cast here
|
443
456
|
payload = cast(R, _unpickle(msg.message, self._mailbox))
|
444
457
|
if msg.method == "result":
|
@@ -485,7 +498,9 @@ class _Actor:
|
|
485
498
|
else None
|
486
499
|
)
|
487
500
|
try:
|
488
|
-
ctx = MonarchContext(
|
501
|
+
ctx: MonarchContext = MonarchContext(
|
502
|
+
mailbox, mailbox.actor_id.proc_id, Point(rank, shape)
|
503
|
+
)
|
489
504
|
_context.set(ctx)
|
490
505
|
|
491
506
|
args, kwargs = _unpickle(message.message, mailbox)
|
@@ -510,7 +525,14 @@ class _Actor:
|
|
510
525
|
enter_span(
|
511
526
|
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
|
512
527
|
)
|
513
|
-
|
528
|
+
try:
|
529
|
+
result = await the_method(self.instance, *args, **kwargs)
|
530
|
+
except Exception as e:
|
531
|
+
logging.critical(
|
532
|
+
"Unahndled exception in actor endpoint",
|
533
|
+
exc_info=e,
|
534
|
+
)
|
535
|
+
raise e
|
514
536
|
exit_span()
|
515
537
|
return result
|
516
538
|
|
@@ -532,14 +554,19 @@ class _Actor:
|
|
532
554
|
async def run_async(
|
533
555
|
self,
|
534
556
|
ctx: MonarchContext,
|
535
|
-
coroutine:
|
557
|
+
coroutine: Awaitable[None],
|
536
558
|
) -> None:
|
537
559
|
_context.set(ctx)
|
538
560
|
if self.complete_task is None:
|
539
561
|
self.complete_task = asyncio.create_task(self._complete())
|
540
562
|
await self.active_requests.put(create_eager_task(coroutine))
|
541
563
|
|
542
|
-
async def run_task(
|
564
|
+
async def run_task(
|
565
|
+
self,
|
566
|
+
port: Port | None,
|
567
|
+
coroutine: Awaitable[Any],
|
568
|
+
panic_flag: PanicFlag,
|
569
|
+
) -> None:
|
543
570
|
try:
|
544
571
|
result = await coroutine
|
545
572
|
if port is not None:
|
@@ -610,15 +637,28 @@ class Actor(MeshTrait):
|
|
610
637
|
"actor implementations are not meshes, but we can't convince the typechecker of it..."
|
611
638
|
)
|
612
639
|
|
640
|
+
@endpoint
|
641
|
+
async def _set_debug_client(self, client: "DebugClient") -> None:
|
642
|
+
point = MonarchContext.get().point
|
643
|
+
# For some reason, using a lambda instead of functools.partial
|
644
|
+
# confuses the pdb wrapper implementation.
|
645
|
+
sys.breakpointhook = functools.partial( # pyre-ignore
|
646
|
+
remote_breakpointhook,
|
647
|
+
point.rank,
|
648
|
+
point.shape.coordinates(point.rank),
|
649
|
+
MonarchContext.get().mailbox.actor_id,
|
650
|
+
client,
|
651
|
+
)
|
652
|
+
|
613
653
|
|
614
654
|
class ActorMeshRef(MeshTrait):
|
615
655
|
def __init__(
|
616
656
|
self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox
|
617
657
|
) -> None:
|
618
|
-
self.__name__ = Class.__name__
|
619
|
-
self._class = Class
|
620
|
-
self._actor_mesh_ref = actor_mesh_ref
|
621
|
-
self._mailbox = mailbox
|
658
|
+
self.__name__: str = Class.__name__
|
659
|
+
self._class: Type[T] = Class
|
660
|
+
self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref
|
661
|
+
self._mailbox: Mailbox = mailbox
|
622
662
|
for attr_name in dir(self._class):
|
623
663
|
attr_value = getattr(self._class, attr_name, None)
|
624
664
|
if isinstance(attr_value, EndpointProperty):
|
@@ -659,7 +699,11 @@ class ActorMeshRef(MeshTrait):
|
|
659
699
|
f"'{self.__class__.__name__}' object has no attribute '{name}'"
|
660
700
|
)
|
661
701
|
|
662
|
-
def _create(
|
702
|
+
def _create(
|
703
|
+
self,
|
704
|
+
args: Iterable[Any],
|
705
|
+
kwargs: Dict[str, Any],
|
706
|
+
) -> None:
|
663
707
|
async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
|
664
708
|
return None
|
665
709
|
|
monarch/bootstrap_main.py
CHANGED
@@ -30,28 +30,9 @@ def invoke_main():
|
|
30
30
|
# behavior of std out as if it were a terminal.
|
31
31
|
sys.stdout.reconfigure(line_buffering=True)
|
32
32
|
global bootstrap_main
|
33
|
-
from monarch._rust_bindings.hyperactor_extension.telemetry import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension
|
34
|
-
forward_to_tracing,
|
35
|
-
)
|
36
33
|
|
37
34
|
# TODO: figure out what from worker_main.py we should reproduce here.
|
38
|
-
|
39
|
-
class TracingForwarder(logging.Handler):
|
40
|
-
def emit(self, record: logging.LogRecord) -> None:
|
41
|
-
try:
|
42
|
-
forward_to_tracing(
|
43
|
-
record.getMessage(),
|
44
|
-
record.filename or "",
|
45
|
-
record.lineno or 0,
|
46
|
-
record.levelno,
|
47
|
-
)
|
48
|
-
except AttributeError:
|
49
|
-
forward_to_tracing(
|
50
|
-
record.__str__(),
|
51
|
-
record.filename or "",
|
52
|
-
record.lineno or 0,
|
53
|
-
record.levelno,
|
54
|
-
)
|
35
|
+
from monarch.telemetry import TracingForwarder
|
55
36
|
|
56
37
|
if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1":
|
57
38
|
raise RuntimeError("Error during bootstrap for testing")
|
monarch/builtins/random.py
CHANGED
@@ -16,11 +16,6 @@ def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
|
|
16
16
|
torch.manual_seed(seed ^ process_idx)
|
17
17
|
|
18
18
|
|
19
|
-
@remote(propagate=lambda: 0)
|
20
|
-
def initial_seed_remote() -> int:
|
21
|
-
return torch.initial_seed()
|
22
|
-
|
23
|
-
|
24
19
|
@remote(propagate=lambda: torch.zeros(1))
|
25
20
|
def get_rng_state_remote() -> torch.Tensor:
|
26
21
|
return torch.get_rng_state()
|
@@ -67,3 +62,7 @@ def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
|
|
67
62
|
@remote(propagate="inspect")
|
68
63
|
def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
|
69
64
|
torch.cuda.set_rng_state_all(states)
|
65
|
+
|
66
|
+
|
67
|
+
# initial_seed may sometimes return a uint64 which currenly can't be unwrapped by the framework
|
68
|
+
# def initial_seed_remote() -> int: ...
|
monarch/common/client.py
CHANGED
@@ -103,6 +103,13 @@ class Client:
|
|
103
103
|
# workers.
|
104
104
|
self.last_processed_seq = -1
|
105
105
|
|
106
|
+
# an error that we have received but know for certain has not
|
107
|
+
# been propagated to a future. This will be reported on shutdown
|
108
|
+
# to avoid hiding the error. This is best effort: we only keep
|
109
|
+
# the error until the point the a future is dependent on
|
110
|
+
# _any_ error, not particularly the tracked one.
|
111
|
+
self._pending_shutdown_error = None
|
112
|
+
|
106
113
|
self.recorder = Recorder()
|
107
114
|
|
108
115
|
self.pending_results: Dict[
|
@@ -174,6 +181,8 @@ class Client:
|
|
174
181
|
destroy_pg: bool = True,
|
175
182
|
error_reason: Optional[RemoteException | DeviceException | Exception] = None,
|
176
183
|
) -> None:
|
184
|
+
if self.has_shutdown:
|
185
|
+
return
|
177
186
|
logger.info("shutting down the client gracefully")
|
178
187
|
|
179
188
|
atexit.unregister(self._atexit)
|
@@ -302,7 +311,8 @@ class Client:
|
|
302
311
|
self.last_processed_seq = max(self.last_processed_seq, seq)
|
303
312
|
|
304
313
|
if error is not None:
|
305
|
-
logging.
|
314
|
+
logging.info("Received error for seq %s: %s", seq, error)
|
315
|
+
self._pending_shutdown_error = error
|
306
316
|
# We should not have set result if we have an error.
|
307
317
|
assert result is None
|
308
318
|
if not isinstance(error, RemoteException):
|
@@ -326,15 +336,17 @@ class Client:
|
|
326
336
|
|
327
337
|
fut, _ = self.pending_results[seq]
|
328
338
|
if fut is not None:
|
329
|
-
|
339
|
+
if error is None:
|
340
|
+
fut._set_result(result)
|
341
|
+
else:
|
342
|
+
fut._set_result(error)
|
343
|
+
self._pending_shutdown_error = None
|
330
344
|
elif result is not None:
|
331
345
|
logger.debug(f"{seq}: unused result {result}")
|
332
346
|
elif error is not None:
|
333
347
|
# errors get reported as results even if they
|
334
348
|
# do not have futures attached.
|
335
|
-
|
336
|
-
f"Error encountered for this instruction {seq}. Proceeding forward because error is unused and unhandled. Error details:\n{error}."
|
337
|
-
)
|
349
|
+
pass
|
338
350
|
|
339
351
|
# We can safely delete the seq as tracebacks have been saved to the remote failure itself.
|
340
352
|
del self.pending_results[seq]
|
monarch/common/stream.py
CHANGED
@@ -82,6 +82,9 @@ class StreamRef(Referenceable):
|
|
82
82
|
messages.CreateStream(self, self.default),
|
83
83
|
)
|
84
84
|
|
85
|
+
def __repr__(self):
|
86
|
+
return f"<StreamRef {repr(self.name)} {self.ref}>"
|
87
|
+
|
85
88
|
def delete_ref(self, ref):
|
86
89
|
client = self.client()
|
87
90
|
if client is not None and not client._shutdown:
|