torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/future.py
ADDED
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
from functools import partial
|
9
|
+
from typing import Generator, Generic, Optional, TypeVar
|
10
|
+
|
11
|
+
R = TypeVar("R")
|
12
|
+
|
13
|
+
|
14
|
+
def _incomplete(impl, self):
|
15
|
+
try:
|
16
|
+
return self._set_result(impl())
|
17
|
+
except Exception as e:
|
18
|
+
self._set_exception(e)
|
19
|
+
raise
|
20
|
+
|
21
|
+
|
22
|
+
async def _aincomplete(impl, self):
|
23
|
+
try:
|
24
|
+
return self._set_result(await impl())
|
25
|
+
except Exception as e:
|
26
|
+
self._set_exception(e)
|
27
|
+
raise
|
28
|
+
|
29
|
+
|
30
|
+
# TODO: consolidate with monarch.common.future
|
31
|
+
class ActorFuture(Generic[R]):
|
32
|
+
def __init__(self, impl, blocking_impl=None):
|
33
|
+
if blocking_impl is None:
|
34
|
+
blocking_impl = partial(asyncio.run, impl())
|
35
|
+
self._get = partial(_incomplete, blocking_impl)
|
36
|
+
self._aget = partial(_aincomplete, impl)
|
37
|
+
|
38
|
+
def get(self, timeout: Optional[float] = None) -> R:
|
39
|
+
if timeout is not None:
|
40
|
+
return asyncio.run(asyncio.wait_for(self._aget(self), timeout))
|
41
|
+
return self._get(self)
|
42
|
+
|
43
|
+
def __await__(self) -> Generator[R, None, R]:
|
44
|
+
return self._aget(self).__await__()
|
45
|
+
|
46
|
+
def _set_result(self, result):
|
47
|
+
def f(self):
|
48
|
+
return result
|
49
|
+
|
50
|
+
async def af(self):
|
51
|
+
return result
|
52
|
+
|
53
|
+
self._get, self._aget = f, af
|
54
|
+
return result
|
55
|
+
|
56
|
+
def _set_exception(self, e):
|
57
|
+
def f(self):
|
58
|
+
raise e
|
59
|
+
|
60
|
+
async def af(self):
|
61
|
+
raise e
|
62
|
+
|
63
|
+
self._get, self._aget = f, af
|
64
|
+
|
65
|
+
# compatibility with old tensor engine Future objects
|
66
|
+
# hopefully we do not need done(), add_callback because
|
67
|
+
# they are harder to implement right.
|
68
|
+
def result(self, timeout: Optional[float] = None) -> R:
|
69
|
+
return self.get(timeout)
|
70
|
+
|
71
|
+
def exception(self, timeout: Optional[float] = None):
|
72
|
+
try:
|
73
|
+
self.get(timeout)
|
74
|
+
return None
|
75
|
+
except Exception as e:
|
76
|
+
return e
|
@@ -0,0 +1,11 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
from ._gradient_generator import GradientGenerator
|
10
|
+
|
11
|
+
__all__ = ["GradientGenerator"]
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from typing import Any, Optional
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
class GradientGenerator:
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
roots_list: Any,
|
16
|
+
with_respect_to: Any,
|
17
|
+
grad_roots: Any,
|
18
|
+
context_restorer: Any,
|
19
|
+
): ...
|
20
|
+
# pyre-ignore[11]: Annotation `torch.Tensor` is not defined as a type.
|
21
|
+
def __next__(self) -> Optional[torch.Tensor]: ...
|
22
|
+
def __iter__(self) -> "GradientGenerator": ...
|
Binary file
|
@@ -0,0 +1,185 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import math
|
9
|
+
from contextlib import nullcontext
|
10
|
+
from functools import partial
|
11
|
+
from types import CellType, FunctionType
|
12
|
+
from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
13
|
+
|
14
|
+
import torch
|
15
|
+
import torch.autograd.graph
|
16
|
+
|
17
|
+
from monarch.common import device_mesh, stream
|
18
|
+
from monarch.common.tensor import Tensor
|
19
|
+
from monarch.common.tree import flatten
|
20
|
+
from monarch.gradient import GradientGenerator as _GradientGenerator
|
21
|
+
from torch._C._autograd import _get_sequence_nr # @manual
|
22
|
+
from torch.autograd.graph import get_gradient_edge, GradientEdge
|
23
|
+
|
24
|
+
TensorOrEdge = Union[torch.Tensor, GradientEdge]
|
25
|
+
|
26
|
+
|
27
|
+
class Context(NamedTuple):
|
28
|
+
device_mesh: "Optional[device_mesh.DeviceMesh]"
|
29
|
+
stream: "stream.Stream"
|
30
|
+
|
31
|
+
def enable(self):
|
32
|
+
if device_mesh is None:
|
33
|
+
activate_mesh = device_mesh.no_mesh.activate()
|
34
|
+
elif self.device_mesh is not device_mesh._active:
|
35
|
+
# XXX: something about activating device meshes from this object
|
36
|
+
# doesn't work correctly and somehow inactivates the device mesh
|
37
|
+
# if it is already enabled. This is a temporary workaround for
|
38
|
+
# the demo.
|
39
|
+
activate_mesh = self.device_mesh.activate()
|
40
|
+
else:
|
41
|
+
activate_mesh = nullcontext()
|
42
|
+
with activate_mesh, self.stream.activate(), torch.no_grad():
|
43
|
+
yield
|
44
|
+
|
45
|
+
|
46
|
+
_sequence_nr_to_context: Dict[int, Context] = {}
|
47
|
+
_sequence_nr_end = 0
|
48
|
+
|
49
|
+
|
50
|
+
def restore_context(t: Optional[Tensor], sn: Optional[int], last: bool):
|
51
|
+
if sn is not None:
|
52
|
+
_update_context_map(Context(device_mesh._active, stream._active))
|
53
|
+
ctx = _sequence_nr_to_context.pop(sn) if last else _sequence_nr_to_context[sn]
|
54
|
+
return ctx.enable()
|
55
|
+
if t is not None:
|
56
|
+
return Context(t.mesh, t.stream).enable()
|
57
|
+
return Context(device_mesh._active, stream._active).enable()
|
58
|
+
|
59
|
+
|
60
|
+
def _update_context_map(ctx: Context):
|
61
|
+
global _sequence_nr_end
|
62
|
+
next_sequence_nr = _get_sequence_nr()
|
63
|
+
for i in range(_sequence_nr_end, next_sequence_nr):
|
64
|
+
_sequence_nr_to_context[i] = ctx
|
65
|
+
_sequence_nr_end = _get_sequence_nr()
|
66
|
+
|
67
|
+
|
68
|
+
device_mesh._on_change.append(
|
69
|
+
lambda old, mesh: _update_context_map(Context(old, stream._active))
|
70
|
+
)
|
71
|
+
stream._on_change.append(
|
72
|
+
lambda old, stream: _update_context_map(Context(device_mesh._active, old))
|
73
|
+
)
|
74
|
+
|
75
|
+
|
76
|
+
def grad_generator(
|
77
|
+
roots: Union[torch.Tensor, Sequence[TensorOrEdge]] = (),
|
78
|
+
with_respect_to: Sequence[TensorOrEdge] = (),
|
79
|
+
grad_roots: Sequence[Optional[torch.Tensor]] = (),
|
80
|
+
):
|
81
|
+
if isinstance(roots, torch.Tensor):
|
82
|
+
roots = [roots]
|
83
|
+
return _GradientGenerator(
|
84
|
+
list(roots), list(with_respect_to), list(grad_roots), restore_context
|
85
|
+
)
|
86
|
+
|
87
|
+
|
88
|
+
def _gradient_edge(a: TensorOrEdge) -> GradientEdge:
|
89
|
+
if isinstance(a, GradientEdge):
|
90
|
+
return a
|
91
|
+
return get_gradient_edge(a)
|
92
|
+
|
93
|
+
|
94
|
+
class GradGenerator:
|
95
|
+
def __init__(self):
|
96
|
+
self.roots: List[torch.autograd.graph.GradientEdge] = []
|
97
|
+
self.with_respect_to: List[torch.autograd.graph.GradientEdge] = []
|
98
|
+
self.grad_roots: List[Optional[torch.Tensor]] = []
|
99
|
+
self.unflattens: List[Tuple[int, Any]] = []
|
100
|
+
|
101
|
+
def grad(self, tree: Any):
|
102
|
+
tensors, unflatten = flatten(tree, lambda x: isinstance(x, torch.Tensor))
|
103
|
+
self.unflattens.append((len(tensors), unflatten))
|
104
|
+
self.with_respect_to.extend(_gradient_edge(t) for t in tensors)
|
105
|
+
|
106
|
+
def root(self, r: TensorOrEdge, grad: Optional[torch.Tensor] = None):
|
107
|
+
self.roots.append(_gradient_edge(r))
|
108
|
+
self.grad_roots.append(grad)
|
109
|
+
|
110
|
+
def __iter__(self):
|
111
|
+
gi = _GradientGenerator(
|
112
|
+
self.roots,
|
113
|
+
list(reversed(self.with_respect_to)),
|
114
|
+
self.grad_roots,
|
115
|
+
restore_context,
|
116
|
+
)
|
117
|
+
for n, unflatten in reversed(self.unflattens):
|
118
|
+
yield unflatten(reversed([next(gi) for _ in range(n)]))
|
119
|
+
|
120
|
+
|
121
|
+
class GradFunction(torch.autograd.Function):
|
122
|
+
@staticmethod
|
123
|
+
def forward(ctx, fn, *args, **kwargs):
|
124
|
+
result, backward_continuation = fn(*args, **kwargs)
|
125
|
+
ctx.backward_continuation = backward_continuation
|
126
|
+
values = []
|
127
|
+
if backward_continuation.__closure__ is not None:
|
128
|
+
for cell in backward_continuation.__closure__:
|
129
|
+
values.append(cell.cell_contents)
|
130
|
+
cell.cell_contents = None
|
131
|
+
tensors, ctx.unflatten = flatten(values, lambda x: isinstance(x, torch.Tensor))
|
132
|
+
ctx.save_for_backward(*tensors)
|
133
|
+
return result
|
134
|
+
|
135
|
+
@staticmethod
|
136
|
+
def backward(ctx, *args, **kwargs):
|
137
|
+
closure = tuple(CellType(v) for v in ctx.unflatten(ctx.saved_tensors))
|
138
|
+
orig = ctx.backward_continuation
|
139
|
+
fn = FunctionType(
|
140
|
+
orig.__code__, orig.__globals__, orig.__name__, orig.__defaults__, closure
|
141
|
+
)
|
142
|
+
output = fn(*args, **kwargs)
|
143
|
+
if isinstance(output, tuple):
|
144
|
+
return None, *output
|
145
|
+
else:
|
146
|
+
return None, output
|
147
|
+
|
148
|
+
|
149
|
+
def grad_function(fn):
|
150
|
+
return partial(GradFunction.apply, fn)
|
151
|
+
|
152
|
+
|
153
|
+
def gradient_execution_order(
|
154
|
+
roots: Sequence[TensorOrEdge], with_respect_to: Sequence[TensorOrEdge]
|
155
|
+
) -> List[int]:
|
156
|
+
"""
|
157
|
+
Returns the order in which the gradients for `with_respect_to` would become available
|
158
|
+
if autograd were run on `roots`. This is the reverse order of each tensors
|
159
|
+
first use in the gradient computation.
|
160
|
+
"""
|
161
|
+
with_respect_to = [_gradient_edge(g) for g in with_respect_to]
|
162
|
+
min_sequence_nr: Dict[Any, float] = {e: math.inf for e in with_respect_to}
|
163
|
+
|
164
|
+
to_scan = [_gradient_edge(r).node for r in roots]
|
165
|
+
scanned = set()
|
166
|
+
while to_scan:
|
167
|
+
node = to_scan.pop()
|
168
|
+
if node in scanned:
|
169
|
+
continue
|
170
|
+
scanned.add(node)
|
171
|
+
for key in node.next_functions:
|
172
|
+
nnode = key[0]
|
173
|
+
if nnode is None:
|
174
|
+
continue
|
175
|
+
to_scan.append(nnode)
|
176
|
+
value = min_sequence_nr.get(key)
|
177
|
+
if value is not None:
|
178
|
+
# pyre-ignore
|
179
|
+
min_sequence_nr[key] = min(node._sequence_nr(), value)
|
180
|
+
|
181
|
+
return sorted(
|
182
|
+
range(len(with_respect_to)),
|
183
|
+
key=lambda i: min_sequence_nr[with_respect_to[i]],
|
184
|
+
reverse=True,
|
185
|
+
)
|
monarch/memory.py
ADDED
@@ -0,0 +1,43 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import itertools
|
9
|
+
import os
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
import torch
|
13
|
+
from monarch.common.remote import remote
|
14
|
+
|
15
|
+
|
16
|
+
PATH_KEY = "dir_snapshots"
|
17
|
+
_counter = itertools.count()
|
18
|
+
|
19
|
+
|
20
|
+
@remote(propagate="inspect")
|
21
|
+
def record_memory_history() -> None:
|
22
|
+
torch.cuda.memory._record_memory_history()
|
23
|
+
|
24
|
+
|
25
|
+
def dump_memory_snapshot(*args, **kwargs) -> None:
|
26
|
+
"""
|
27
|
+
This function wraps torch.cuda.memory._dump_snapshot() to dump memory snapshot remotely.
|
28
|
+
"""
|
29
|
+
assert isinstance(
|
30
|
+
kwargs.get(PATH_KEY, None), str
|
31
|
+
), f"{PATH_KEY} must be passed and must be a string to represent the path to save the memory snapshots."
|
32
|
+
id = next(_counter)
|
33
|
+
_memory_controller_dump(id, *args, **kwargs)
|
34
|
+
|
35
|
+
|
36
|
+
@remote(propagate="inspect")
|
37
|
+
def _memory_controller_dump(ident, *args, **kwargs) -> None:
|
38
|
+
dir_path = Path(kwargs[PATH_KEY]).absolute()
|
39
|
+
os.makedirs(dir_path, exist_ok=True)
|
40
|
+
# This is not a synchronized call, so it is okay to call without device mesh.
|
41
|
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
42
|
+
snapshot_path = f"{dir_path}/snapshot_{rank}.pickle"
|
43
|
+
torch.cuda.memory._dump_snapshot(filename=snapshot_path)
|
@@ -0,0 +1,271 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import atexit
|
8
|
+
import logging
|
9
|
+
import os
|
10
|
+
import time
|
11
|
+
import traceback
|
12
|
+
from collections import deque
|
13
|
+
from logging import Logger
|
14
|
+
from typing import List, NamedTuple, Optional, TYPE_CHECKING, Union
|
15
|
+
|
16
|
+
import torch.utils._python_dispatch
|
17
|
+
|
18
|
+
from monarch import NDSlice
|
19
|
+
from monarch._rust_bindings.monarch_extension import client, debugger
|
20
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
21
|
+
WorldState,
|
22
|
+
)
|
23
|
+
from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller
|
24
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
25
|
+
ActorId,
|
26
|
+
)
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from monarch._rust_bindings.monarch_hyperactor.proc_mesh import (
|
30
|
+
ProcMesh as HyProcMesh,
|
31
|
+
)
|
32
|
+
from monarch.proc_mesh import ProcMesh
|
33
|
+
|
34
|
+
from monarch._rust_bindings.monarch_hyperactor.shape import Point
|
35
|
+
|
36
|
+
from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
|
37
|
+
from monarch.common.client import Client
|
38
|
+
from monarch.common.controller_api import LogMessage, MessageResult
|
39
|
+
from monarch.common.device_mesh import DeviceMesh, no_mesh
|
40
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
41
|
+
from monarch.controller.debugger import read as debugger_read, write as debugger_write
|
42
|
+
from monarch.rust_local_mesh import _get_worker_exec_info
|
43
|
+
from pyre_extensions import none_throws
|
44
|
+
|
45
|
+
logger: Logger = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
|
48
|
+
class Controller(_Controller):
|
49
|
+
def __init__(self, workers: "HyProcMesh") -> None:
|
50
|
+
super().__init__()
|
51
|
+
# Buffer for messages unrelated to debugging that are received while a
|
52
|
+
# debugger session is active.
|
53
|
+
self._non_debugger_pending_messages: deque[
|
54
|
+
Optional[client.LogMessage | client.WorkerResponse]
|
55
|
+
] = deque()
|
56
|
+
self._pending_debugger_sessions: deque[ActorId] = deque()
|
57
|
+
|
58
|
+
def next_message(
|
59
|
+
self, timeout: Optional[float]
|
60
|
+
) -> Optional[LogMessage | MessageResult]:
|
61
|
+
if self._non_debugger_pending_messages:
|
62
|
+
msg = self._non_debugger_pending_messages.popleft()
|
63
|
+
else:
|
64
|
+
msg = self._get_next_message(timeout_msec=int((timeout or 0.0) * 1000.0))
|
65
|
+
if msg is None:
|
66
|
+
return None
|
67
|
+
|
68
|
+
if isinstance(msg, client.WorkerResponse):
|
69
|
+
return _worker_response_to_result(msg)
|
70
|
+
elif isinstance(msg, client.LogMessage):
|
71
|
+
return LogMessage(msg.level, msg.message)
|
72
|
+
elif isinstance(msg, client.DebuggerMessage):
|
73
|
+
self._run_debugger_loop(msg)
|
74
|
+
|
75
|
+
def send(
|
76
|
+
self,
|
77
|
+
ranks: Union[NDSlice, List[NDSlice]],
|
78
|
+
msg: NamedTuple,
|
79
|
+
) -> None:
|
80
|
+
with torch.utils._python_dispatch._disable_current_modes():
|
81
|
+
return super().send(ranks, msg)
|
82
|
+
|
83
|
+
def drain_and_stop(
|
84
|
+
self,
|
85
|
+
) -> List[LogMessage | MessageResult | client.DebuggerMessage]:
|
86
|
+
self._drain_and_stop()
|
87
|
+
return []
|
88
|
+
|
89
|
+
def _run_debugger_loop(self, message: client.DebuggerMessage) -> None:
|
90
|
+
if not isinstance(message.action, DebuggerAction.Paused):
|
91
|
+
raise RuntimeError(
|
92
|
+
f"Unexpected debugger message {message} when no debugger session is running"
|
93
|
+
)
|
94
|
+
|
95
|
+
self._pending_debugger_sessions.append(message.debugger_actor_id)
|
96
|
+
while self._pending_debugger_sessions:
|
97
|
+
debugger_actor_id = self._pending_debugger_sessions.popleft()
|
98
|
+
rank = debugger_actor_id.rank
|
99
|
+
proc_id = debugger_actor_id.proc_id
|
100
|
+
debugger_write(
|
101
|
+
f"pdb attached to proc {proc_id} with rank {rank}, debugger actor {debugger_actor_id} \n"
|
102
|
+
)
|
103
|
+
|
104
|
+
self._debugger_attach(debugger_actor_id)
|
105
|
+
while True:
|
106
|
+
# TODO: Add appropriate timeout.
|
107
|
+
msg = self._get_next_message(timeout_msec=None)
|
108
|
+
|
109
|
+
if not isinstance(msg, client.DebuggerMessage):
|
110
|
+
self._non_debugger_pending_messages.append(msg)
|
111
|
+
continue
|
112
|
+
|
113
|
+
if msg.debugger_actor_id != debugger_actor_id:
|
114
|
+
if isinstance(msg.action, DebuggerAction.Paused):
|
115
|
+
self._pending_debugger_sessions.append(msg.debugger_actor_id)
|
116
|
+
continue
|
117
|
+
else:
|
118
|
+
raise RuntimeError(
|
119
|
+
f"unexpected debugger message {msg} from rank {msg.debugger_actor_id.rank} "
|
120
|
+
f"when debugging rank {debugger_actor_id.rank}"
|
121
|
+
)
|
122
|
+
|
123
|
+
action = msg.action
|
124
|
+
if isinstance(action, DebuggerAction.Detach):
|
125
|
+
break
|
126
|
+
elif isinstance(action, DebuggerAction.Read):
|
127
|
+
self._debugger_write(
|
128
|
+
debugger_actor_id, debugger_read(action.requested_size)
|
129
|
+
)
|
130
|
+
elif isinstance(action, DebuggerAction.Write):
|
131
|
+
debugger_write(
|
132
|
+
debugger.get_bytes_from_write_action(action).decode()
|
133
|
+
)
|
134
|
+
else:
|
135
|
+
raise RuntimeError(
|
136
|
+
f"unexpected debugger message {msg} when debugging rank {debugger_actor_id.rank}"
|
137
|
+
)
|
138
|
+
|
139
|
+
def worker_world_state(self) -> WorldState:
|
140
|
+
raise NotImplementedError("worker world state")
|
141
|
+
|
142
|
+
def stop_mesh(self):
|
143
|
+
# I think this is a noop?
|
144
|
+
|
145
|
+
pass
|
146
|
+
|
147
|
+
|
148
|
+
# TODO: Handling conversion of the response can move to a separate module over time
|
149
|
+
# especially as we have structured error messages.
|
150
|
+
def _worker_response_to_result(result: client.WorkerResponse) -> MessageResult:
|
151
|
+
if not result.is_exception():
|
152
|
+
# The result of the message needs to be unwrapped on a real device.
|
153
|
+
# Staying as a fake tensor will fail the tensor deserialization.
|
154
|
+
with no_mesh.activate():
|
155
|
+
return MessageResult(result.seq, result.result(), None)
|
156
|
+
exc = none_throws(result.exception())
|
157
|
+
if isinstance(exc, client.Error):
|
158
|
+
worker_frames = [
|
159
|
+
traceback.FrameSummary("<unknown>", None, frame)
|
160
|
+
for frame in exc.backtrace.split("\\n")
|
161
|
+
]
|
162
|
+
return MessageResult(
|
163
|
+
seq=result.seq,
|
164
|
+
result=None,
|
165
|
+
error=RemoteException(
|
166
|
+
seq=exc.caused_by_seq,
|
167
|
+
exception=RuntimeError(exc.backtrace),
|
168
|
+
controller_frame_index=0, # TODO: T225205291 fix this once we have recording support in rust
|
169
|
+
controller_frames=None,
|
170
|
+
worker_frames=worker_frames,
|
171
|
+
source_actor_id=exc.actor_id,
|
172
|
+
message=f"Remote function in {exc.actor_id} errored.",
|
173
|
+
),
|
174
|
+
)
|
175
|
+
elif isinstance(exc, client.Failure):
|
176
|
+
frames = [
|
177
|
+
traceback.FrameSummary("<unknown>", None, frame)
|
178
|
+
for frame in exc.backtrace.split("\n")
|
179
|
+
]
|
180
|
+
reason = f"Actor {exc.actor_id} crashed on {exc.address}, check the host log for details"
|
181
|
+
logger.error(reason)
|
182
|
+
return MessageResult(
|
183
|
+
seq=0, # seq is not consumed for DeviceException; it will be directly thrown by the client
|
184
|
+
result=None,
|
185
|
+
error=DeviceException(
|
186
|
+
exception=RuntimeError(reason),
|
187
|
+
frames=frames,
|
188
|
+
source_actor_id=exc.actor_id,
|
189
|
+
message=reason,
|
190
|
+
),
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
raise RuntimeError(f"Unknown exception type: {type(exc)}")
|
194
|
+
|
195
|
+
|
196
|
+
def _initialize_env(worker_point: Point, proc_id: str) -> None:
|
197
|
+
worker_rank = worker_point.rank
|
198
|
+
try:
|
199
|
+
_, worker_env = _get_worker_exec_info()
|
200
|
+
local_rank = worker_point["gpus"]
|
201
|
+
gpus_per_host = worker_point.size("gpus")
|
202
|
+
num_worker_procs = len(worker_point.shape)
|
203
|
+
process_env = {
|
204
|
+
**worker_env,
|
205
|
+
"HYPERACTOR_MANAGED_SUBPROCESS": "1",
|
206
|
+
"CUDA_VISIBLE_DEVICES": str(local_rank),
|
207
|
+
"NCCL_HOSTID": f"{proc_id}_host_{worker_rank // gpus_per_host}",
|
208
|
+
# This is needed to avoid a hard failure in ncclx when we do not
|
209
|
+
# have backend topology info (eg. on RE).
|
210
|
+
"NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
|
211
|
+
"LOCAL_RANK": str(local_rank),
|
212
|
+
"RANK": str(worker_rank),
|
213
|
+
"WORLD_SIZE": str(num_worker_procs),
|
214
|
+
"LOCAL_WORLD_SIZE": str(gpus_per_host),
|
215
|
+
}
|
216
|
+
os.environ.update(process_env)
|
217
|
+
except Exception:
|
218
|
+
traceback.print_exc()
|
219
|
+
raise
|
220
|
+
|
221
|
+
|
222
|
+
class MeshClient(Client):
|
223
|
+
def shutdown(
|
224
|
+
self,
|
225
|
+
destroy_pg: bool = True,
|
226
|
+
error_reason: Optional[RemoteException | DeviceException | Exception] = None,
|
227
|
+
):
|
228
|
+
# return
|
229
|
+
if self.has_shutdown:
|
230
|
+
return
|
231
|
+
logger.info("shutting down the client gracefully")
|
232
|
+
|
233
|
+
atexit.unregister(self._atexit)
|
234
|
+
self._shutdown = True
|
235
|
+
|
236
|
+
# ensure all pending work is finished.
|
237
|
+
# all errors must be messaged back at this point
|
238
|
+
self.new_node_nocoalesce([], [], None, [])
|
239
|
+
self._request_status()
|
240
|
+
|
241
|
+
ttl = 60
|
242
|
+
start_time = time.time()
|
243
|
+
end_time = start_time + ttl
|
244
|
+
while ttl > 0 and self.last_assigned_seq > self.last_processed_seq:
|
245
|
+
ttl = end_time - time.time()
|
246
|
+
self.handle_next_message(ttl)
|
247
|
+
if self._pending_shutdown_error:
|
248
|
+
raise self._pending_shutdown_error
|
249
|
+
|
250
|
+
if ttl <= 0:
|
251
|
+
raise RuntimeError("shutdown timed out")
|
252
|
+
|
253
|
+
# we are not expecting anything more now, because we already
|
254
|
+
# waited for the responses
|
255
|
+
self.inner.drain_and_stop()
|
256
|
+
|
257
|
+
|
258
|
+
def spawn_tensor_engine(proc_mesh: "ProcMesh") -> DeviceMesh:
|
259
|
+
# This argument to Controller
|
260
|
+
# is currently only used for debug printing. It should be fixed to
|
261
|
+
# report the proc ID instead of the rank it currently does.
|
262
|
+
gpus = proc_mesh.sizes.get("gpus", 1)
|
263
|
+
backend_ctrl = Controller(proc_mesh._proc_mesh)
|
264
|
+
client = MeshClient(backend_ctrl, proc_mesh.size(), gpus)
|
265
|
+
dm = DeviceMesh(
|
266
|
+
client,
|
267
|
+
NDSlice.new_row_major(list(proc_mesh.sizes.values())),
|
268
|
+
tuple(proc_mesh.sizes.keys()),
|
269
|
+
)
|
270
|
+
dm.exit = lambda: client.shutdown()
|
271
|
+
return dm
|
Binary file
|