torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import importlib
|
9
|
+
import logging
|
10
|
+
import sys
|
11
|
+
import warnings
|
12
|
+
from logging import Logger
|
13
|
+
|
14
|
+
# pyre-ignore
|
15
|
+
from pickle import _getattribute, PickleError, whichmodule
|
16
|
+
from types import BuiltinFunctionType, FunctionType
|
17
|
+
from typing import (
|
18
|
+
Any,
|
19
|
+
Callable,
|
20
|
+
Dict,
|
21
|
+
NamedTuple,
|
22
|
+
Optional,
|
23
|
+
Protocol,
|
24
|
+
runtime_checkable,
|
25
|
+
)
|
26
|
+
|
27
|
+
import cloudpickle
|
28
|
+
|
29
|
+
logger: Logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
@runtime_checkable
|
33
|
+
class ResolvableFunction(Protocol):
|
34
|
+
def resolve(self) -> Callable: ...
|
35
|
+
|
36
|
+
|
37
|
+
ConvertsToResolvable = Any
|
38
|
+
|
39
|
+
|
40
|
+
def _string_resolver(arg: Any) -> Optional[ResolvableFunction]:
|
41
|
+
if isinstance(arg, str) and "." in arg:
|
42
|
+
return ResolvableFunctionFromPath(arg)
|
43
|
+
|
44
|
+
|
45
|
+
def _torch_resolver(arg: Any) -> Optional[ResolvableFunction]:
|
46
|
+
import torch
|
47
|
+
|
48
|
+
if isinstance(arg, torch._ops.OpOverload):
|
49
|
+
return ResolvableFunctionFromPath("torch.ops." + str(arg))
|
50
|
+
|
51
|
+
|
52
|
+
def function_to_import_path(arg: BuiltinFunctionType | FunctionType) -> Optional[str]:
|
53
|
+
# code replicated from pickler to check if we
|
54
|
+
# would successfully be able to pickle this function.
|
55
|
+
name = getattr(arg, "__qualname__", None)
|
56
|
+
if name is None:
|
57
|
+
name = arg.__name__
|
58
|
+
try:
|
59
|
+
# pyre-ignore
|
60
|
+
module_name = whichmodule(arg, name)
|
61
|
+
__import__(module_name, level=0)
|
62
|
+
module = sys.modules[module_name]
|
63
|
+
if module_name == "__main__":
|
64
|
+
return None # the workers will not have the same main
|
65
|
+
|
66
|
+
# pytest installs its own custom loaders that do not
|
67
|
+
# survive process creation
|
68
|
+
try:
|
69
|
+
if "pytest" in module.__loader__.__class__.__module__:
|
70
|
+
return None
|
71
|
+
except AttributeError:
|
72
|
+
pass
|
73
|
+
|
74
|
+
# pyre-ignore
|
75
|
+
obj2, parent = _getattribute(module, name)
|
76
|
+
# support annotations that cover up the global impl
|
77
|
+
if obj2 is arg or getattr(obj2, "_remote_impl", None) is arg:
|
78
|
+
return f"{module_name}.{name}"
|
79
|
+
except (PickleError, ImportError, KeyError, AttributeError):
|
80
|
+
pass
|
81
|
+
return None
|
82
|
+
|
83
|
+
|
84
|
+
def _function_resolver(arg: Any):
|
85
|
+
if isinstance(arg, (FunctionType, BuiltinFunctionType)):
|
86
|
+
if path := function_to_import_path(arg):
|
87
|
+
return ResolvableFunctionFromPath(path)
|
88
|
+
|
89
|
+
|
90
|
+
def _cloudpickle_resolver(arg: Any):
|
91
|
+
# @lint-ignore PYTHONPICKLEISBAD
|
92
|
+
return ResolvableFromCloudpickle(cloudpickle.dumps(arg))
|
93
|
+
|
94
|
+
|
95
|
+
resolvers = [
|
96
|
+
_torch_resolver,
|
97
|
+
_string_resolver,
|
98
|
+
_function_resolver,
|
99
|
+
_cloudpickle_resolver,
|
100
|
+
]
|
101
|
+
|
102
|
+
|
103
|
+
_cached_resolvers = {}
|
104
|
+
|
105
|
+
|
106
|
+
def maybe_resolvable_function(arg: Any) -> Optional[ResolvableFunction]:
|
107
|
+
if arg == "__test_panic":
|
108
|
+
return ResolvableFunctionFromPath("__test_panic")
|
109
|
+
r = _cached_resolvers.get(arg)
|
110
|
+
if r is not None:
|
111
|
+
return r
|
112
|
+
for resolver in resolvers:
|
113
|
+
r = resolver(arg)
|
114
|
+
if r is not None:
|
115
|
+
_cached_resolvers[arg] = r
|
116
|
+
return r
|
117
|
+
return None
|
118
|
+
|
119
|
+
|
120
|
+
def resolvable_function(arg: ConvertsToResolvable) -> ResolvableFunction:
|
121
|
+
if isinstance(arg, ResolvableFunction):
|
122
|
+
return arg
|
123
|
+
r = maybe_resolvable_function(arg)
|
124
|
+
if r is None:
|
125
|
+
raise ValueError(f"Unsupported target for a remote call: {arg!r}")
|
126
|
+
return r
|
127
|
+
|
128
|
+
|
129
|
+
class ResolvableFunctionFromPath(NamedTuple):
|
130
|
+
path: str
|
131
|
+
|
132
|
+
def resolve(self):
|
133
|
+
first, *parts = self.path.split(".")
|
134
|
+
if first == "torch":
|
135
|
+
function = importlib.import_module("torch")
|
136
|
+
for p in parts:
|
137
|
+
function = getattr(function, p)
|
138
|
+
assert isinstance(function, Callable)
|
139
|
+
else:
|
140
|
+
modulename, funcname = self.path.rsplit(".", 1)
|
141
|
+
module = importlib.import_module(modulename)
|
142
|
+
function = getattr(module, funcname)
|
143
|
+
# support annotations that cover up the global impl
|
144
|
+
actual = getattr(function, "_remote_impl", None)
|
145
|
+
return function if actual is None else actual
|
146
|
+
return function
|
147
|
+
|
148
|
+
def __str__(self):
|
149
|
+
return self.path
|
150
|
+
|
151
|
+
|
152
|
+
class ResolvableFromCloudpickle(NamedTuple):
|
153
|
+
data: bytes
|
154
|
+
|
155
|
+
def resolve(self):
|
156
|
+
# @lint-ignore PYTHONPICKLEISBAD
|
157
|
+
return cloudpickle.loads(self.data)
|
158
|
+
|
159
|
+
|
160
|
+
Propagator = Any
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import itertools
|
8
|
+
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Sequence, Tuple
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from torch import autograd
|
12
|
+
from torch.utils._pytree import tree_flatten, TreeSpec
|
13
|
+
|
14
|
+
|
15
|
+
class AliasOf(NamedTuple):
|
16
|
+
group: int # 0 -this group, -1 - the parent, -2 - parent's parent, etc.
|
17
|
+
offset: int
|
18
|
+
|
19
|
+
|
20
|
+
class Storage(NamedTuple):
|
21
|
+
numel: int
|
22
|
+
|
23
|
+
|
24
|
+
# Hashable pattern for recreating tensors
|
25
|
+
# Each tensor either creates its own Storage
|
26
|
+
# or is an AliasOf another tensor either earlier in this list,
|
27
|
+
# or in one of the parent lists.
|
28
|
+
# parent lists are used to represent other collections of tensors
|
29
|
+
# for instance if this pattern is for outputs of a function
|
30
|
+
# parents might contains lists of inputs to the function and captured
|
31
|
+
# globals as two separate lists.
|
32
|
+
class TensorGroupPattern(NamedTuple):
|
33
|
+
entries: Tuple["PatternEntry", ...]
|
34
|
+
|
35
|
+
def empty(self, parents: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
36
|
+
tensors = []
|
37
|
+
for entry in self.entries:
|
38
|
+
match entry.storage:
|
39
|
+
case AliasOf(group=group, offset=offset):
|
40
|
+
base = tensors[offset] if group == 0 else parents[group][offset]
|
41
|
+
case Storage(numel=numel):
|
42
|
+
base = torch.empty(
|
43
|
+
(numel,),
|
44
|
+
dtype=entry.dtype,
|
45
|
+
layout=entry.layout,
|
46
|
+
device=entry.device,
|
47
|
+
)
|
48
|
+
case _:
|
49
|
+
raise ValueError("unexpected storage")
|
50
|
+
t = torch.as_strided(base, entry.size, entry.stride, entry.storage_offset)
|
51
|
+
tensors.append(t)
|
52
|
+
return tensors
|
53
|
+
|
54
|
+
|
55
|
+
class PatternEntry(NamedTuple):
|
56
|
+
size: Tuple[int, ...]
|
57
|
+
stride: Tuple[int, ...]
|
58
|
+
storage_offset: int
|
59
|
+
dtype: torch.dtype
|
60
|
+
layout: torch.layout
|
61
|
+
device: torch.device
|
62
|
+
storage: AliasOf | Storage
|
63
|
+
|
64
|
+
|
65
|
+
# Takes a list of tensors and computes the pattern of aliasing that
|
66
|
+
# would reconstruct the group. If `parent` is specified aliases
|
67
|
+
# are also computed with respect to that group and its parents.
|
68
|
+
# new storage is only specified is a tensor's storage was not
|
69
|
+
# seen in any parent or previously in a group.
|
70
|
+
class TensorGroup:
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
tensors: Sequence[torch.Tensor],
|
74
|
+
parent: Optional["TensorGroup"] = None,
|
75
|
+
):
|
76
|
+
self.parent = parent
|
77
|
+
self.tensors = tensors
|
78
|
+
self.storage_dict: Dict[torch.UntypedStorage, int] = {}
|
79
|
+
|
80
|
+
def create_entry(i: int, t: torch.Tensor):
|
81
|
+
storage = t.untyped_storage()
|
82
|
+
numel = t.untyped_storage().size() // t.element_size()
|
83
|
+
alias = self._find_alias(storage)
|
84
|
+
if alias is None:
|
85
|
+
self.storage_dict[storage] = i
|
86
|
+
alias = Storage(numel)
|
87
|
+
|
88
|
+
return PatternEntry(
|
89
|
+
tuple(t.size()),
|
90
|
+
tuple(t.stride()),
|
91
|
+
int(t.storage_offset()),
|
92
|
+
t.dtype,
|
93
|
+
t.layout,
|
94
|
+
t.device,
|
95
|
+
alias,
|
96
|
+
)
|
97
|
+
|
98
|
+
self.pattern = TensorGroupPattern(
|
99
|
+
tuple(create_entry(i, t) for i, t in enumerate(tensors))
|
100
|
+
)
|
101
|
+
|
102
|
+
def _find_alias(self, storage: torch.UntypedStorage) -> Optional[AliasOf]:
|
103
|
+
grp = self
|
104
|
+
for i in itertools.count():
|
105
|
+
if storage in grp.storage_dict:
|
106
|
+
return AliasOf(-i, grp.storage_dict[storage])
|
107
|
+
if grp.parent is None:
|
108
|
+
return None
|
109
|
+
grp = grp.parent
|
110
|
+
|
111
|
+
|
112
|
+
class TensorPlaceholder:
|
113
|
+
pass
|
114
|
+
|
115
|
+
|
116
|
+
# singleton to represent where tensors go in a pytree
|
117
|
+
tensor_placeholder = TensorPlaceholder()
|
118
|
+
|
119
|
+
|
120
|
+
def _to_placeholder(x):
|
121
|
+
if isinstance(x, torch.Tensor):
|
122
|
+
return tensor_placeholder
|
123
|
+
return x
|
124
|
+
|
125
|
+
|
126
|
+
def _remove_ctx(x):
|
127
|
+
if isinstance(x, autograd.function.FunctionCtx):
|
128
|
+
return None
|
129
|
+
return x
|
130
|
+
|
131
|
+
|
132
|
+
# customizable set of filters to handle data types that appear
|
133
|
+
# in functions that one wants to support in cached functions
|
134
|
+
key_filters = [_to_placeholder, _remove_ctx]
|
135
|
+
|
136
|
+
|
137
|
+
def _filter_key(v: Any):
|
138
|
+
for filter in key_filters:
|
139
|
+
v = filter(v)
|
140
|
+
return v
|
141
|
+
|
142
|
+
|
143
|
+
class HashableTreeSpec(NamedTuple):
|
144
|
+
type: Any
|
145
|
+
context: Any
|
146
|
+
children_specs: Tuple["HashableTreeSpec", ...]
|
147
|
+
|
148
|
+
@staticmethod
|
149
|
+
def from_treespec(t: "TreeSpec"):
|
150
|
+
return HashableTreeSpec(
|
151
|
+
t.type,
|
152
|
+
tuple(t.context) if isinstance(t.context, list) else t.context,
|
153
|
+
tuple(HashableTreeSpec.from_treespec(child) for child in t.children_specs),
|
154
|
+
)
|
155
|
+
|
156
|
+
|
157
|
+
def hashable_tensor_flatten(args, kwargs) -> Tuple[List[torch.Tensor], Hashable]:
|
158
|
+
values, spec = tree_flatten((args, kwargs))
|
159
|
+
tensors = [t for t in values if isinstance(t, torch.Tensor)]
|
160
|
+
key: Hashable = (
|
161
|
+
tuple(_filter_key(v) for v in values),
|
162
|
+
HashableTreeSpec.from_treespec(spec),
|
163
|
+
)
|
164
|
+
return tensors, key
|
monarch/common/future.py
ADDED
@@ -0,0 +1,168 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
import math
|
10
|
+
import os
|
11
|
+
import subprocess
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
Callable,
|
15
|
+
cast,
|
16
|
+
Generic,
|
17
|
+
Optional,
|
18
|
+
Sequence,
|
19
|
+
TYPE_CHECKING,
|
20
|
+
TypeVar,
|
21
|
+
)
|
22
|
+
|
23
|
+
from monarch_supervisor import TTL
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from monarch.common.client import Client
|
27
|
+
|
28
|
+
from .invocation import RemoteException
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
try:
|
33
|
+
PYSPY_REPORT_INTERVAL: Optional[float] = float(
|
34
|
+
os.environ["CONTROLLER_PYSPY_REPORT_INTERVAL"]
|
35
|
+
)
|
36
|
+
except KeyError:
|
37
|
+
PYSPY_REPORT_INTERVAL = None
|
38
|
+
|
39
|
+
|
40
|
+
def _split(elems, cond):
|
41
|
+
trues = []
|
42
|
+
falses = []
|
43
|
+
for elem in elems:
|
44
|
+
if cond(elem):
|
45
|
+
trues.append(elem)
|
46
|
+
else:
|
47
|
+
falses.append(elem)
|
48
|
+
return trues, falses
|
49
|
+
|
50
|
+
|
51
|
+
def _periodic_TTL(interval: Optional[float]) -> Callable[[], float]:
|
52
|
+
if interval is None:
|
53
|
+
return lambda: math.inf
|
54
|
+
|
55
|
+
ttl = TTL(interval)
|
56
|
+
|
57
|
+
def _remaining():
|
58
|
+
nonlocal ttl
|
59
|
+
rem = ttl()
|
60
|
+
if rem == 0:
|
61
|
+
ttl = TTL(interval)
|
62
|
+
return rem
|
63
|
+
|
64
|
+
return _remaining
|
65
|
+
|
66
|
+
|
67
|
+
T = TypeVar("T")
|
68
|
+
|
69
|
+
|
70
|
+
class Future(Generic[T]):
|
71
|
+
def __init__(self, client: "Client"):
|
72
|
+
self._client = client
|
73
|
+
self._status = "incomplete"
|
74
|
+
self._callbacks = None
|
75
|
+
self._result: T | Exception | None = None
|
76
|
+
|
77
|
+
def _set_result(self, r):
|
78
|
+
assert self._status == "incomplete"
|
79
|
+
self._result = r
|
80
|
+
self._status = "exception" if isinstance(r, RemoteException) else "complete"
|
81
|
+
if self._callbacks:
|
82
|
+
for cb in self._callbacks:
|
83
|
+
try:
|
84
|
+
cb(self)
|
85
|
+
except Exception:
|
86
|
+
logger.exception("exception in controller's Future callback")
|
87
|
+
self._callbacks = None
|
88
|
+
self._client = None
|
89
|
+
|
90
|
+
def _wait(self, timeout: Optional[float]):
|
91
|
+
if self._status != "incomplete":
|
92
|
+
return True
|
93
|
+
|
94
|
+
assert self._client is not None
|
95
|
+
|
96
|
+
# see if the future is done already
|
97
|
+
# and we just haven't processed the messages
|
98
|
+
while self._client.handle_next_message(0):
|
99
|
+
if self._status != "incomplete":
|
100
|
+
return True
|
101
|
+
|
102
|
+
ttl = TTL(timeout)
|
103
|
+
ttl_pyspy = _periodic_TTL(PYSPY_REPORT_INTERVAL)
|
104
|
+
while self._status == "incomplete" and _wait(self._client, ttl, ttl_pyspy):
|
105
|
+
...
|
106
|
+
|
107
|
+
return self._status != "incomplete"
|
108
|
+
|
109
|
+
def result(self, timeout: Optional[float] = None) -> T:
|
110
|
+
if not self._wait(timeout):
|
111
|
+
raise TimeoutError()
|
112
|
+
if self._status == "exception":
|
113
|
+
raise cast(Exception, self._result)
|
114
|
+
return cast(T, self._result)
|
115
|
+
|
116
|
+
def done(self) -> bool:
|
117
|
+
return self._wait(0)
|
118
|
+
|
119
|
+
def exception(self, timeout: Optional[float] = None):
|
120
|
+
if not self._wait(timeout):
|
121
|
+
raise TimeoutError()
|
122
|
+
return self._result if self._status == "exception" else None
|
123
|
+
|
124
|
+
def add_callback(self, callback):
|
125
|
+
if not self._callbacks:
|
126
|
+
self._callbacks = [callback]
|
127
|
+
else:
|
128
|
+
self._callbacks.append(callback)
|
129
|
+
|
130
|
+
|
131
|
+
def _wait(client: "Client", ttl: Callable[[], float], ttl_pyspy: Callable[[], float]):
|
132
|
+
remaining = ttl()
|
133
|
+
pyspy_remaining = ttl_pyspy()
|
134
|
+
if pyspy_remaining == 0:
|
135
|
+
try:
|
136
|
+
logging.warning(
|
137
|
+
f"future has not finished in {PYSPY_REPORT_INTERVAL} seconds (remaining time to live is {remaining}), py-spying process to debug."
|
138
|
+
)
|
139
|
+
subprocess.run(["py-spy", "dump", "-s", "-p", str(os.getpid())])
|
140
|
+
except FileNotFoundError:
|
141
|
+
logging.warning("py-spy is not installed.")
|
142
|
+
timeout = min(remaining, pyspy_remaining)
|
143
|
+
client.handle_next_message(timeout=None if timeout == math.inf else timeout)
|
144
|
+
return remaining > 0
|
145
|
+
|
146
|
+
|
147
|
+
def stream(futures: Sequence[Future], timeout: Optional[float] = None):
|
148
|
+
"""Stream the provided futures as they complete.
|
149
|
+
|
150
|
+
If a timeout is provided, it applies to the completion of the entire set of futures.
|
151
|
+
"""
|
152
|
+
assert len(futures) > 0
|
153
|
+
|
154
|
+
ttl = TTL(timeout)
|
155
|
+
pyspy_ttl = _periodic_TTL(PYSPY_REPORT_INTERVAL)
|
156
|
+
|
157
|
+
assert (
|
158
|
+
len({f._client for f in futures if f._client is not None}) <= 1
|
159
|
+
), "all futures must be from the same controller"
|
160
|
+
|
161
|
+
todo = futures
|
162
|
+
while True:
|
163
|
+
done, todo = _split(todo, lambda f: f._status != "incomplete")
|
164
|
+
for f in done:
|
165
|
+
yield f
|
166
|
+
|
167
|
+
if len(todo) == 0 or not _wait(todo[0]._client, ttl, pyspy_ttl):
|
168
|
+
break
|
@@ -0,0 +1,125 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import traceback
|
9
|
+
from typing import Any, List, Optional, Tuple
|
10
|
+
|
11
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
12
|
+
ActorId,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
Seq = int
|
17
|
+
|
18
|
+
|
19
|
+
class DeviceException(Exception):
|
20
|
+
"""
|
21
|
+
Non-deterministic failure in the underlying worker, controller or its infrastructure.
|
22
|
+
For example, a worker may enter a crash loop, or its GPU may be lost
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
exception: Exception,
|
28
|
+
frames: List[traceback.FrameSummary],
|
29
|
+
source_actor_id: ActorId,
|
30
|
+
message: str,
|
31
|
+
):
|
32
|
+
self.exception = exception
|
33
|
+
self.frames = frames
|
34
|
+
self.source_actor_id = source_actor_id
|
35
|
+
self.message = message
|
36
|
+
|
37
|
+
def __str__(self):
|
38
|
+
try:
|
39
|
+
exe = str(self.exception)
|
40
|
+
worker_tb = "".join(traceback.format_list(self.frames))
|
41
|
+
return (
|
42
|
+
f"{self.message}\n"
|
43
|
+
f"Traceback of the failure on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}"
|
44
|
+
)
|
45
|
+
except Exception as e:
|
46
|
+
print(e)
|
47
|
+
return "oops"
|
48
|
+
|
49
|
+
|
50
|
+
class RemoteException(Exception):
|
51
|
+
"""
|
52
|
+
Deterministic problem with the user's code.
|
53
|
+
For example, an OOM resulting in trying to allocate too much GPU memory, or violating
|
54
|
+
some invariant enforced by the various APIs.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
seq: Seq,
|
60
|
+
exception: Exception,
|
61
|
+
controller_frame_index: Optional[int],
|
62
|
+
controller_frames: Optional[List[traceback.FrameSummary]],
|
63
|
+
worker_frames: List[traceback.FrameSummary],
|
64
|
+
source_actor_id: ActorId,
|
65
|
+
message="A remote function has failed asynchronously.",
|
66
|
+
):
|
67
|
+
self.exception = exception
|
68
|
+
self.worker_frames = worker_frames
|
69
|
+
self.message = message
|
70
|
+
self.seq = seq
|
71
|
+
self.controller_frame_index = controller_frame_index
|
72
|
+
self.source_actor_id = source_actor_id
|
73
|
+
self.controller_frames = controller_frames
|
74
|
+
|
75
|
+
def __str__(self):
|
76
|
+
try:
|
77
|
+
exe = str(self.exception)
|
78
|
+
worker_tb = "".join(traceback.format_list(self.worker_frames))
|
79
|
+
controller_tb = (
|
80
|
+
"".join(traceback.format_list(self.controller_frames))
|
81
|
+
if self.controller_frames is not None
|
82
|
+
else " <not related to a specific invocation>\n"
|
83
|
+
)
|
84
|
+
return (
|
85
|
+
f"{self.message}\n"
|
86
|
+
f"Traceback of where the remote function was issued on controller (most recent call last):\n{controller_tb}"
|
87
|
+
f"Traceback of where the remote function failed on worker (most recent call last):\n{worker_tb}{type(self.exception).__name__}: {exe}"
|
88
|
+
)
|
89
|
+
except Exception as e:
|
90
|
+
print(e)
|
91
|
+
return "oops"
|
92
|
+
|
93
|
+
|
94
|
+
class Invocation:
|
95
|
+
def __init__(self, seq: Seq):
|
96
|
+
self.seq = seq
|
97
|
+
self.users: Optional[set["Invocation"]] = set()
|
98
|
+
self.failure: Optional[RemoteException] = None
|
99
|
+
self.fut_value: Any = None
|
100
|
+
|
101
|
+
def __repr__(self):
|
102
|
+
return f"<Invocation {self.seq}>"
|
103
|
+
|
104
|
+
def fail(self, remote_exception: RemoteException):
|
105
|
+
if self.failure is None or self.failure.seq > remote_exception.seq:
|
106
|
+
self.failure = remote_exception
|
107
|
+
return True
|
108
|
+
return False
|
109
|
+
|
110
|
+
def add_user(self, r: "Invocation"):
|
111
|
+
if self.users is not None:
|
112
|
+
self.users.add(r)
|
113
|
+
if self.failure is not None:
|
114
|
+
r.fail(self.failure)
|
115
|
+
|
116
|
+
def complete(self) -> Tuple[Any, Optional[RemoteException]]:
|
117
|
+
"""
|
118
|
+
Complete the current invocation.
|
119
|
+
Return the result and exception tuple.
|
120
|
+
"""
|
121
|
+
# after completion we no longer need to inform users of failures
|
122
|
+
# since they will just immediately get the value during add_user
|
123
|
+
self.users = None
|
124
|
+
|
125
|
+
return (self.fut_value if self.failure is None else None, self.failure)
|