torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from typing import NamedTuple, Tuple
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
|
13
|
+
class TensorFactory(NamedTuple):
|
14
|
+
size: Tuple[int, ...]
|
15
|
+
dtype: torch.dtype
|
16
|
+
layout: torch.layout
|
17
|
+
device: torch.device
|
18
|
+
|
19
|
+
@staticmethod
|
20
|
+
def from_tensor(t):
|
21
|
+
return TensorFactory(t.size(), t.dtype, t.layout, t.device)
|
22
|
+
|
23
|
+
def empty(self):
|
24
|
+
return torch.empty(
|
25
|
+
self.size, dtype=self.dtype, layout=self.layout, device=self.device
|
26
|
+
)
|
27
|
+
|
28
|
+
def zeros(self):
|
29
|
+
return torch.full(
|
30
|
+
self.size, 0, dtype=self.dtype, layout=self.layout, device=self.device
|
31
|
+
)
|
monarch/common/tree.py
ADDED
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from typing import Any, Callable, Protocol, Sequence, Tuple
|
9
|
+
|
10
|
+
import torch.utils._pytree as _pytree
|
11
|
+
from torch.utils._pytree import (
|
12
|
+
_get_node_type,
|
13
|
+
register_pytree_node,
|
14
|
+
SUPPORTED_NODES,
|
15
|
+
tree_flatten,
|
16
|
+
tree_map,
|
17
|
+
tree_unflatten,
|
18
|
+
)
|
19
|
+
|
20
|
+
|
21
|
+
def flatten(tree, cond):
|
22
|
+
r, spec = tree_flatten(tree)
|
23
|
+
|
24
|
+
# be careful to not capture values we return in
|
25
|
+
# 'trues'. We do not need them to reconstruct and do not want to
|
26
|
+
# extend their lifetime.
|
27
|
+
trues = []
|
28
|
+
falses = []
|
29
|
+
conds = []
|
30
|
+
for e in r:
|
31
|
+
c = cond(e)
|
32
|
+
(trues if c else falses).append(e)
|
33
|
+
conds.append(c)
|
34
|
+
|
35
|
+
def unflatten(n):
|
36
|
+
n_it = iter(n)
|
37
|
+
falses_it = iter(falses)
|
38
|
+
return tree_unflatten([next(n_it if c else falses_it) for c in conds], spec)
|
39
|
+
|
40
|
+
return trues, unflatten
|
41
|
+
|
42
|
+
|
43
|
+
def flattener(tree, cond=None):
|
44
|
+
"""
|
45
|
+
Produce a _traceable_ flattener routine from tree. That is, it produces code that can
|
46
|
+
flatten another object shaped the same as tree, but whose structure cannot
|
47
|
+
be introspected because it might be (e.g.) an fx proxy value.
|
48
|
+
"""
|
49
|
+
if isinstance(tree, (tuple, list)):
|
50
|
+
flattens = [flattener(t, cond) for t in tree]
|
51
|
+
return lambda obj: [
|
52
|
+
f for i, flatten in enumerate(flattens) for f in flatten(obj[i])
|
53
|
+
]
|
54
|
+
elif isinstance(tree, dict):
|
55
|
+
keys = tuple(tree.keys())
|
56
|
+
flattens = [flattener(t, cond) for t in tree.values()]
|
57
|
+
return lambda obj: [
|
58
|
+
f for k, flatten in zip(keys, flattens) for f in flatten(obj[k])
|
59
|
+
]
|
60
|
+
elif _get_node_type(tree) in SUPPORTED_NODES:
|
61
|
+
flatten_fn = SUPPORTED_NODES[_get_node_type(tree)].flatten_fn
|
62
|
+
trees, _ = flatten_fn(tree)
|
63
|
+
flattens = [flattener(t, cond) for t in trees]
|
64
|
+
|
65
|
+
def the_flattener(obj):
|
66
|
+
trees, _ = flatten_fn(obj)
|
67
|
+
return [f for i, flatten in enumerate(flattens) for f in flatten(trees[i])]
|
68
|
+
|
69
|
+
return the_flattener
|
70
|
+
elif cond is None or cond(tree):
|
71
|
+
return lambda obj: [obj]
|
72
|
+
else:
|
73
|
+
return lambda obj: []
|
@@ -0,0 +1,223 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
|
10
|
+
import os
|
11
|
+
import socket
|
12
|
+
|
13
|
+
from abc import ABC, abstractmethod
|
14
|
+
from typing import List, NamedTuple, Optional, Sequence, Tuple
|
15
|
+
|
16
|
+
from monarch.common import messages
|
17
|
+
|
18
|
+
from monarch.common.shape import iter_ranks, Slices as Ranks
|
19
|
+
from monarch_supervisor import (
|
20
|
+
Context,
|
21
|
+
FunctionCall,
|
22
|
+
Host,
|
23
|
+
Process,
|
24
|
+
ProcessExited as ProcessExitedMsg,
|
25
|
+
)
|
26
|
+
from torch.distributed import TCPStore
|
27
|
+
|
28
|
+
|
29
|
+
logger = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
class Backend(ABC):
|
33
|
+
@abstractmethod
|
34
|
+
def send(self, ranks: Ranks, msg) -> None:
|
35
|
+
raise NotImplementedError()
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def recvready(self, timeout: Optional[float]) -> Sequence[Tuple[int, NamedTuple]]:
|
39
|
+
raise NotImplementedError()
|
40
|
+
|
41
|
+
@property
|
42
|
+
@abstractmethod
|
43
|
+
def world_size(self):
|
44
|
+
raise NotImplementedError()
|
45
|
+
|
46
|
+
@property
|
47
|
+
@abstractmethod
|
48
|
+
def gpu_per_host(self):
|
49
|
+
raise NotImplementedError()
|
50
|
+
|
51
|
+
|
52
|
+
class ProcessBackend(Backend):
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
ctx: Context,
|
56
|
+
hosts: List[Host],
|
57
|
+
gpu_per_host: int,
|
58
|
+
_processes=None,
|
59
|
+
_store=None,
|
60
|
+
):
|
61
|
+
self.ctx = ctx
|
62
|
+
self.hosts = hosts
|
63
|
+
self.store = self._create_store() if _store is None else _store
|
64
|
+
self._gpu_per_host = gpu_per_host
|
65
|
+
self.worker_processes = (
|
66
|
+
self._create_pg(ctx, hosts, gpu_per_host, self.store)
|
67
|
+
if _processes is None
|
68
|
+
else _processes
|
69
|
+
)
|
70
|
+
self.exiting = False
|
71
|
+
self.process_to_rank = {p: p.rank for p in self.worker_processes}
|
72
|
+
self.live_processes_per_rank: List[List[Process]] = [
|
73
|
+
[p] for p in self.worker_processes
|
74
|
+
]
|
75
|
+
|
76
|
+
@property
|
77
|
+
def world_size(self):
|
78
|
+
return len(self.worker_processes)
|
79
|
+
|
80
|
+
@property
|
81
|
+
def gpu_per_host(self) -> int:
|
82
|
+
return self._gpu_per_host
|
83
|
+
|
84
|
+
def send(self, ranks: Ranks, msg) -> None:
|
85
|
+
handler = getattr(self, msg.__class__.__name__, None)
|
86
|
+
if handler is not None:
|
87
|
+
handler(ranks, msg)
|
88
|
+
self._send(ranks, msg)
|
89
|
+
|
90
|
+
def _send(self, ranks: Ranks, msg):
|
91
|
+
# the intent is for this to be optimized as tree broadcast
|
92
|
+
# base on if members of tree nodes overlap with a slice.
|
93
|
+
for rank in iter_ranks(ranks):
|
94
|
+
self.worker_processes[rank].send(msg)
|
95
|
+
|
96
|
+
def CommandGroup(self, ranks: Ranks, msg: messages.CommandGroup):
|
97
|
+
for command in msg.commands:
|
98
|
+
handler = getattr(self, command.__class__.__name__, None)
|
99
|
+
if handler is not None:
|
100
|
+
handler(ranks, command)
|
101
|
+
|
102
|
+
def CreatePipe(self, ranks: Ranks, msg: messages.CreatePipe):
|
103
|
+
pipe_ranks = list(enumerate(iter_ranks(ranks)))
|
104
|
+
for i, rank in pipe_ranks:
|
105
|
+
# In general, pipes on different workers may need to have different behavior.
|
106
|
+
# For example, two data loader pipes operating on the same dataset should
|
107
|
+
# load different shards of the dataset. In order to do this, each pipe process
|
108
|
+
# on the worker needs to know the number of instances of the pipe (e.g. len(pipe_ranks))
|
109
|
+
# and its unique rank among all instances of the pipe (e.g., i).
|
110
|
+
proc = self.worker_processes[rank].host.create_process(
|
111
|
+
FunctionCall(
|
112
|
+
"monarch.worker.worker.pipe_main",
|
113
|
+
f"{msg.key}-{rank}",
|
114
|
+
msg.max_messages,
|
115
|
+
),
|
116
|
+
env={"CUDA_VISIBLE_DEVICES": ""},
|
117
|
+
name=f"pipe-{rank}",
|
118
|
+
)
|
119
|
+
self.live_processes_per_rank[rank].append(proc)
|
120
|
+
self.process_to_rank[proc] = rank
|
121
|
+
|
122
|
+
def ProcessExited(
|
123
|
+
self, sender: Process, msg: ProcessExitedMsg
|
124
|
+
) -> List[Tuple[int, NamedTuple]]:
|
125
|
+
return self._process_exited(sender, msg.result)
|
126
|
+
|
127
|
+
def Restarted(
|
128
|
+
self, sender: Process, restarted: messages.Restarted
|
129
|
+
) -> List[Tuple[int, NamedTuple]]:
|
130
|
+
return self._process_exited(sender, restarted.result)
|
131
|
+
|
132
|
+
def _process_exited(
|
133
|
+
self, sender: Process, result: int | Exception
|
134
|
+
) -> List[Tuple[int, NamedTuple]]:
|
135
|
+
rank = self.process_to_rank[sender]
|
136
|
+
if result != 0:
|
137
|
+
if not self.exiting or self.worker_processes[rank] is sender:
|
138
|
+
kind = (
|
139
|
+
"worker"
|
140
|
+
if self.worker_processes[rank] is sender
|
141
|
+
else "pipe_process"
|
142
|
+
)
|
143
|
+
raise RuntimeError(f"Unexpected {kind} exit on rank {rank}")
|
144
|
+
|
145
|
+
live_procs = self.live_processes_per_rank[rank]
|
146
|
+
live_procs.remove(sender)
|
147
|
+
if len(live_procs) == 0:
|
148
|
+
return [(rank, ProcessExitedMsg(0))]
|
149
|
+
return []
|
150
|
+
|
151
|
+
def Exit(self, ranks: Ranks, msg: messages.Exit):
|
152
|
+
self.exiting = True
|
153
|
+
for rank in iter_ranks(ranks):
|
154
|
+
# ideally we are more kind to these processes.
|
155
|
+
# but first we need to develop the API for asking them
|
156
|
+
# to suspend, restore, fast forward, rewind, etc.
|
157
|
+
worker = self.worker_processes[rank]
|
158
|
+
for proc in self.live_processes_per_rank[rank]:
|
159
|
+
if worker is not proc:
|
160
|
+
proc.signal()
|
161
|
+
self.worker_processes[rank].send(msg)
|
162
|
+
|
163
|
+
def recvready(self, timeout: Optional[float]) -> Sequence[Tuple[int, NamedTuple]]:
|
164
|
+
result = []
|
165
|
+
for sender, msg in self.ctx.recvready(timeout):
|
166
|
+
handler = getattr(self, msg.__class__.__name__, None)
|
167
|
+
if handler is not None:
|
168
|
+
result.extend(handler(sender, msg))
|
169
|
+
continue
|
170
|
+
elif isinstance(sender, Process):
|
171
|
+
result.append((sender.rank, msg))
|
172
|
+
else:
|
173
|
+
logger.warning("TODO: ignoring non-worker message: %s %s", sender, msg)
|
174
|
+
return result
|
175
|
+
|
176
|
+
@staticmethod
|
177
|
+
def _create_store():
|
178
|
+
if os.environ.get("INSIDE_RE_WORKER"):
|
179
|
+
hostname = "localhost"
|
180
|
+
else:
|
181
|
+
hostname = socket.gethostname()
|
182
|
+
|
183
|
+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
|
184
|
+
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
185
|
+
sock.bind(("::", 0))
|
186
|
+
port = sock.getsockname()[1]
|
187
|
+
store = TCPStore(
|
188
|
+
hostname,
|
189
|
+
port,
|
190
|
+
is_master=True,
|
191
|
+
use_libuv=False,
|
192
|
+
master_listen_fd=sock.detach(),
|
193
|
+
)
|
194
|
+
return store
|
195
|
+
|
196
|
+
@staticmethod
|
197
|
+
def _create_pg(
|
198
|
+
ctx: Context, hosts: List[Host], gpu_per_host: int, store, _restartable=False
|
199
|
+
):
|
200
|
+
env = {
|
201
|
+
# cuda event cache disabled pending fix for:
|
202
|
+
# https://github.com/pytorch/pytorch/issues/143470
|
203
|
+
"TORCH_NCCL_CUDA_EVENT_CACHE": "0",
|
204
|
+
# disable nonblocking comm until D68727854 lands.
|
205
|
+
"TORCH_NCCL_USE_COMM_NONBLOCKING": "0",
|
206
|
+
# supervisor_pipe is a unique ID per Host object,
|
207
|
+
# so it lets us put multiple processes on the same GPU.
|
208
|
+
"NCCL_HOSTID": "$SUPERVISOR_PIPE",
|
209
|
+
"STORE_HOSTNAME": store.host,
|
210
|
+
"STORE_PORT": str(store.port),
|
211
|
+
}
|
212
|
+
for name, value in os.environ.items():
|
213
|
+
if name.startswith("NCCL_") and name not in env:
|
214
|
+
env[name] = value
|
215
|
+
return ctx.create_process_group(
|
216
|
+
hosts,
|
217
|
+
FunctionCall(
|
218
|
+
"monarch.worker.worker.worker_main", _restartable=_restartable
|
219
|
+
),
|
220
|
+
processes_per_host=gpu_per_host,
|
221
|
+
env=env,
|
222
|
+
name="worker",
|
223
|
+
)
|
@@ -0,0 +1,223 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import logging
|
9
|
+
import traceback
|
10
|
+
from collections import deque
|
11
|
+
from typing import Generator, List, NamedTuple, Optional, Sequence, Tuple, Union
|
12
|
+
|
13
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
14
|
+
DebuggerMessage,
|
15
|
+
WorldState,
|
16
|
+
)
|
17
|
+
|
18
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
19
|
+
ActorId,
|
20
|
+
)
|
21
|
+
|
22
|
+
from monarch.common import messages
|
23
|
+
from monarch.common.controller_api import LogMessage, MessageResult
|
24
|
+
from monarch.common.invocation import DeviceException, Seq
|
25
|
+
from monarch.common.reference import Ref
|
26
|
+
from monarch.common.shape import NDSlice
|
27
|
+
from monarch.common.tensor import Tensor
|
28
|
+
from monarch.controller import debugger
|
29
|
+
|
30
|
+
from .backend import Backend
|
31
|
+
from .history import History
|
32
|
+
|
33
|
+
logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
class Controller:
|
37
|
+
def __init__(self, backend: Backend):
|
38
|
+
self._backend = backend
|
39
|
+
self._history = History(backend.world_size)
|
40
|
+
self._messages = deque()
|
41
|
+
|
42
|
+
self.exited = {}
|
43
|
+
self.active_debugger: Optional[Tuple[int, int]] = None
|
44
|
+
self.pending_debugger_sessions: deque[Tuple[int, int]] = deque()
|
45
|
+
# for current active session
|
46
|
+
self.pending_debugger_messages: deque[messages.DebuggerMessage] = deque()
|
47
|
+
|
48
|
+
def send(
|
49
|
+
self,
|
50
|
+
ranks: Union[NDSlice, List[NDSlice]],
|
51
|
+
msg: NamedTuple,
|
52
|
+
) -> None:
|
53
|
+
self._backend.send(ranks, msg)
|
54
|
+
|
55
|
+
def next_message(
|
56
|
+
self, timeout: Optional[float]
|
57
|
+
) -> Optional[MessageResult | LogMessage]:
|
58
|
+
if len(self._messages) == 0:
|
59
|
+
self._messages.extend(self._read_messages(timeout))
|
60
|
+
return self._messages.popleft() if len(self._messages) > 0 else None
|
61
|
+
|
62
|
+
def drop_refs(self, refs: Sequence[Ref]) -> None:
|
63
|
+
"""
|
64
|
+
noop as this is used for the Rust controller to know when to gc invocations_for_ref for failed invocations
|
65
|
+
"""
|
66
|
+
pass
|
67
|
+
|
68
|
+
def _read_messages(
|
69
|
+
self, timeout: Optional[float]
|
70
|
+
) -> Generator[MessageResult, None, None]:
|
71
|
+
# XXX - how can we avoid always requesting status when waiting on futures?
|
72
|
+
# we need to figure out what submesh we need to hear from before a future
|
73
|
+
# is considered 'good'. This means not just waiting for the future value
|
74
|
+
# but also for signal that any failures that could invalidate the future have
|
75
|
+
# not happened. We could do better if tensors/collectives had an invalid bit
|
76
|
+
# that we propagate. In real uses fetches might lag behind anyway so we would not
|
77
|
+
# have to send out so many requests for current status.
|
78
|
+
for rank, value in self._backend.recvready(timeout):
|
79
|
+
yield from self._handle_message(rank, value)
|
80
|
+
|
81
|
+
def drain_and_stop(self) -> List[MessageResult | LogMessage | DebuggerMessage]:
|
82
|
+
messages = []
|
83
|
+
while self._messages:
|
84
|
+
messages.append(self._messages.popleft())
|
85
|
+
while len(self.exited) < self._backend.world_size:
|
86
|
+
messages.extend(self._read_messages(None))
|
87
|
+
return messages
|
88
|
+
|
89
|
+
def stop_mesh(self) -> None:
|
90
|
+
pass
|
91
|
+
|
92
|
+
def node(
|
93
|
+
self,
|
94
|
+
seq: Seq,
|
95
|
+
defs: Sequence["Tensor"],
|
96
|
+
uses: Sequence["Tensor"],
|
97
|
+
) -> None:
|
98
|
+
self._history.ident(seq, defs, uses)
|
99
|
+
|
100
|
+
def _handle_message(self, sender, value) -> Generator[MessageResult, None, None]:
|
101
|
+
yield from getattr(self, value.__class__.__name__)(sender, *value)
|
102
|
+
|
103
|
+
def worker_world_state(self) -> WorldState:
|
104
|
+
# Eventhough not implemented, return needed so return value complies with type checking
|
105
|
+
assert 1 == 2, "not implemented"
|
106
|
+
return WorldState()
|
107
|
+
|
108
|
+
def ProcessExited(self, proc, result) -> Generator[MessageResult, None, None]:
|
109
|
+
if result != 0:
|
110
|
+
# XXX - this should start the failure recovery process
|
111
|
+
raise RuntimeError("Unexpected worker process exit")
|
112
|
+
self.exited[proc] = result
|
113
|
+
yield from []
|
114
|
+
|
115
|
+
def ProcessStarted(self, proc, pid) -> Generator[MessageResult, None, None]:
|
116
|
+
yield from []
|
117
|
+
|
118
|
+
def FetchResult(self, proc, ident, value) -> Generator[MessageResult, None, None]:
|
119
|
+
self._history.future_completed(ident, value)
|
120
|
+
yield from []
|
121
|
+
|
122
|
+
def RemoteFunctionFailed(
|
123
|
+
self,
|
124
|
+
proc,
|
125
|
+
failing_ident,
|
126
|
+
traceback_index,
|
127
|
+
exception: Exception,
|
128
|
+
worker_frames: List[traceback.FrameSummary],
|
129
|
+
) -> Generator[MessageResult, None, None]:
|
130
|
+
self._history.propagate_failure(
|
131
|
+
failing_ident, traceback_index, exception, worker_frames
|
132
|
+
)
|
133
|
+
yield from self._history.rank_completed(proc, failing_ident)
|
134
|
+
|
135
|
+
def InternalException(
|
136
|
+
self,
|
137
|
+
proc,
|
138
|
+
exception: Exception,
|
139
|
+
worker_frames: List[traceback.FrameSummary],
|
140
|
+
) -> Generator[MessageResult, None, None]:
|
141
|
+
yield MessageResult(
|
142
|
+
seq=0, # will not be used
|
143
|
+
result=None,
|
144
|
+
error=DeviceException(
|
145
|
+
exception,
|
146
|
+
worker_frames,
|
147
|
+
ActorId.from_string("unknown[0].unknown[0]"),
|
148
|
+
message="A worker experienced an internal error.",
|
149
|
+
),
|
150
|
+
)
|
151
|
+
|
152
|
+
def RemoteGeneratorFailed(
|
153
|
+
self,
|
154
|
+
proc,
|
155
|
+
exception: Exception,
|
156
|
+
frames: List[traceback.FrameSummary],
|
157
|
+
) -> Generator[MessageResult, None, None]:
|
158
|
+
yield MessageResult(
|
159
|
+
seq=0, # will not be used
|
160
|
+
result=None,
|
161
|
+
error=DeviceException(
|
162
|
+
exception=exception,
|
163
|
+
frames=frames,
|
164
|
+
source_actor_id=ActorId.from_string("unknown[0].unknown[0]"),
|
165
|
+
message="A remote generator failed.",
|
166
|
+
),
|
167
|
+
)
|
168
|
+
|
169
|
+
def Status(
|
170
|
+
self, proc, first_uncompleted_ident
|
171
|
+
) -> Generator[MessageResult, None, None]:
|
172
|
+
yield from self._history.rank_completed(proc, first_uncompleted_ident)
|
173
|
+
|
174
|
+
def DebuggerMessage(
|
175
|
+
self, proc, stream_id: int, action
|
176
|
+
) -> Generator[MessageResult, None, None]:
|
177
|
+
if action == "paused":
|
178
|
+
self.pending_debugger_sessions.append((proc, stream_id))
|
179
|
+
else:
|
180
|
+
assert self.active_debugger == (proc, stream_id)
|
181
|
+
self.pending_debugger_messages.append(action)
|
182
|
+
|
183
|
+
if self.active_debugger is None:
|
184
|
+
yield from self._run_debugger_loop()
|
185
|
+
|
186
|
+
def _run_debugger_loop(self) -> Generator[MessageResult, None, None]:
|
187
|
+
# debug loop
|
188
|
+
while self.pending_debugger_sessions:
|
189
|
+
yield from self._run_debugger_session(
|
190
|
+
*self.pending_debugger_sessions.popleft()
|
191
|
+
)
|
192
|
+
|
193
|
+
def _run_debugger_session(
|
194
|
+
self, proc_id: int, stream_id: int
|
195
|
+
) -> Generator[MessageResult, None, None]:
|
196
|
+
debugger.write(f"pdb attached to rank {proc_id}, stream {stream_id}\n")
|
197
|
+
self.active_debugger = (proc_id, stream_id)
|
198
|
+
try:
|
199
|
+
rank = NDSlice(offset=proc_id, sizes=[], strides=[])
|
200
|
+
self.send(rank, messages.DebuggerMessage(stream_id, "attach"))
|
201
|
+
while True:
|
202
|
+
while not self.pending_debugger_messages:
|
203
|
+
# todo: eventually we should timeout
|
204
|
+
yield from self._read_messages(None)
|
205
|
+
message = self.pending_debugger_messages.popleft()
|
206
|
+
match message:
|
207
|
+
case "detach":
|
208
|
+
break
|
209
|
+
case messages.DebuggerRead(requested):
|
210
|
+
self.send(
|
211
|
+
rank,
|
212
|
+
messages.DebuggerMessage(
|
213
|
+
stream_id,
|
214
|
+
messages.DebuggerWrite(debugger.read(requested)),
|
215
|
+
),
|
216
|
+
)
|
217
|
+
case messages.DebuggerWrite(payload):
|
218
|
+
debugger.write(payload.decode())
|
219
|
+
case other:
|
220
|
+
raise RuntimeError(f"unexpected debugger message: {other}")
|
221
|
+
finally:
|
222
|
+
self.active_debugger = None
|
223
|
+
self.pending_debugger_messages.clear()
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
import sys
|
9
|
+
from typing import Optional
|
10
|
+
|
11
|
+
_is_ipython: Optional[bool] = None
|
12
|
+
|
13
|
+
|
14
|
+
def is_ipython() -> bool:
|
15
|
+
global _is_ipython
|
16
|
+
if _is_ipython is not None:
|
17
|
+
return _is_ipython
|
18
|
+
try:
|
19
|
+
from IPython import get_ipython
|
20
|
+
|
21
|
+
_is_ipython = get_ipython() is not None
|
22
|
+
except ImportError:
|
23
|
+
_is_ipython = False
|
24
|
+
return _is_ipython
|
25
|
+
|
26
|
+
|
27
|
+
def write(msg: str) -> None:
|
28
|
+
sys.stdout.write(msg)
|
29
|
+
sys.stdout.flush()
|
30
|
+
|
31
|
+
|
32
|
+
def read(requested_size: int) -> bytes:
|
33
|
+
if not is_ipython():
|
34
|
+
b = bytearray(requested_size)
|
35
|
+
bytes_read = sys.stdin.buffer.raw.readinto(b)
|
36
|
+
return bytes(b[:bytes_read])
|
37
|
+
|
38
|
+
# ipython doesn't have stdin directly connected
|
39
|
+
# so we need to use input() instead.
|
40
|
+
user_input = input() + "\n"
|
41
|
+
input_bytes = user_input.encode("utf-8")
|
42
|
+
num_bytes_to_write = len(input_bytes)
|
43
|
+
if requested_size < num_bytes_to_write:
|
44
|
+
raise RuntimeError(
|
45
|
+
f"Debugger input line too long, max length is {requested_size}"
|
46
|
+
)
|
47
|
+
return input_bytes[:num_bytes_to_write]
|
@@ -0,0 +1,90 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from collections import deque
|
9
|
+
from typing import Generator, Sequence, TYPE_CHECKING
|
10
|
+
|
11
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
12
|
+
ActorId,
|
13
|
+
)
|
14
|
+
|
15
|
+
from monarch.common.controller_api import MessageResult
|
16
|
+
|
17
|
+
from monarch.common.invocation import Invocation, RemoteException, Seq
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from monarch.common.tensor import Tensor
|
21
|
+
|
22
|
+
|
23
|
+
class History:
|
24
|
+
def __init__(self, N):
|
25
|
+
self.first_uncompleted_ident = [0 for _ in range(N)]
|
26
|
+
self.min_first_uncompleted_ident = 0
|
27
|
+
self.invocations = deque[Invocation]()
|
28
|
+
|
29
|
+
def _invocation(
|
30
|
+
self,
|
31
|
+
seq: Seq,
|
32
|
+
defs: Sequence["Tensor"],
|
33
|
+
uses: Sequence["Tensor"],
|
34
|
+
):
|
35
|
+
r = Invocation(seq)
|
36
|
+
for t in uses:
|
37
|
+
u = t._invocation
|
38
|
+
assert u is not None
|
39
|
+
u.add_user(r)
|
40
|
+
for t in defs:
|
41
|
+
t._invocation = r
|
42
|
+
return r
|
43
|
+
|
44
|
+
def ident(
|
45
|
+
self,
|
46
|
+
seq: Seq,
|
47
|
+
defs: Sequence["Tensor"],
|
48
|
+
uses: Sequence["Tensor"],
|
49
|
+
):
|
50
|
+
invocation = self._invocation(seq, defs, uses)
|
51
|
+
self.invocations.append(invocation)
|
52
|
+
|
53
|
+
def propagate_failure(self, seq, traceback_index, exception, worker_frames):
|
54
|
+
invocation = self.invocations[seq - self.min_first_uncompleted_ident]
|
55
|
+
remote_exception = RemoteException(
|
56
|
+
seq,
|
57
|
+
exception,
|
58
|
+
traceback_index,
|
59
|
+
None,
|
60
|
+
worker_frames,
|
61
|
+
ActorId.from_string("unknown[0].unknown[0]"),
|
62
|
+
)
|
63
|
+
worklist = deque((invocation,))
|
64
|
+
while worklist:
|
65
|
+
invocation = worklist.popleft()
|
66
|
+
if invocation.fail(remote_exception):
|
67
|
+
worklist.extend(invocation.users)
|
68
|
+
|
69
|
+
def rank_completed(
|
70
|
+
self, rank, first_uncompleted_ident
|
71
|
+
) -> Generator[MessageResult, None, None]:
|
72
|
+
# advance what our last completed action was, and
|
73
|
+
# trim the list of tracebacks if we no longer need them.
|
74
|
+
prev = self.first_uncompleted_ident[rank]
|
75
|
+
self.first_uncompleted_ident[rank] = first_uncompleted_ident
|
76
|
+
if prev == self.min_first_uncompleted_ident:
|
77
|
+
self.min_first_uncompleted_ident = min(self.first_uncompleted_ident)
|
78
|
+
for seq in range(prev, self.min_first_uncompleted_ident):
|
79
|
+
invocation = self.invocations.popleft()
|
80
|
+
assert seq == invocation.seq
|
81
|
+
result, error = invocation.complete()
|
82
|
+
yield MessageResult(
|
83
|
+
seq=seq,
|
84
|
+
result=result,
|
85
|
+
error=error,
|
86
|
+
)
|
87
|
+
|
88
|
+
def future_completed(self, ident, value):
|
89
|
+
invocation = self.invocations[ident - self.min_first_uncompleted_ident]
|
90
|
+
invocation.fut_value = value
|