torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from typing import NamedTuple
|
9
|
+
|
10
|
+
from monarch_supervisor import get_message_queue
|
11
|
+
|
12
|
+
|
13
|
+
class Reply(NamedTuple):
|
14
|
+
a: int
|
15
|
+
b: int
|
16
|
+
x: int
|
17
|
+
|
18
|
+
|
19
|
+
def reply_hello(a, b, x):
|
20
|
+
q = get_message_queue()
|
21
|
+
q.send(Reply(a, b, x))
|
22
|
+
|
23
|
+
|
24
|
+
def echo():
|
25
|
+
q = get_message_queue()
|
26
|
+
i = 0
|
27
|
+
while True:
|
28
|
+
sender, m = q.recv()
|
29
|
+
if m == "exit":
|
30
|
+
break
|
31
|
+
assert m == i
|
32
|
+
q.send(m)
|
33
|
+
i += 1
|
34
|
+
|
35
|
+
|
36
|
+
class Mapper:
|
37
|
+
def map(self, items):
|
38
|
+
return sum(x * 2 for x in items)
|
39
|
+
|
40
|
+
def reduce(self, items):
|
41
|
+
return sum(items)
|
42
|
+
|
43
|
+
def finish(self, result):
|
44
|
+
return result
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import importlib.util
|
9
|
+
import sys
|
10
|
+
|
11
|
+
from monarch_supervisor import _FunctionCall, get_message_queue
|
12
|
+
|
13
|
+
if __name__ == "__main__":
|
14
|
+
q = get_message_queue()
|
15
|
+
_, call = q.recv()
|
16
|
+
assert isinstance(call, _FunctionCall)
|
17
|
+
filename, *rest = call.target.split(":", 1)
|
18
|
+
if not rest:
|
19
|
+
modulename, funcname = filename.rsplit(".", 1)
|
20
|
+
module = importlib.import_module(modulename)
|
21
|
+
else:
|
22
|
+
spec = importlib.util.spec_from_file_location("__entry__", filename)
|
23
|
+
assert spec is not None and spec.loader is not None
|
24
|
+
module = importlib.util.module_from_spec(spec)
|
25
|
+
# pyre-ignore[16]
|
26
|
+
spec.loader.exec_module(module)
|
27
|
+
sys.modules["__entry__"] = module
|
28
|
+
funcname = rest[0]
|
29
|
+
func = getattr(module, funcname)
|
30
|
+
func(*call.args, **call.kwargs)
|
@@ -0,0 +1,386 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import ctypes
|
9
|
+
import io
|
10
|
+
import logging
|
11
|
+
import os
|
12
|
+
import signal
|
13
|
+
import socket
|
14
|
+
import subprocess
|
15
|
+
import sys
|
16
|
+
import time
|
17
|
+
import traceback
|
18
|
+
import uuid
|
19
|
+
from contextlib import nullcontext
|
20
|
+
from pathlib import Path
|
21
|
+
from random import random
|
22
|
+
from string import Template
|
23
|
+
from typing import Any, Callable, Dict, List, Mapping, Optional
|
24
|
+
|
25
|
+
import zmq
|
26
|
+
from monarch_supervisor import (
|
27
|
+
_FunctionCall,
|
28
|
+
HEARTBEAT_INTERVAL,
|
29
|
+
HEARTBEAT_LIVENESS,
|
30
|
+
pickle_dumps,
|
31
|
+
pickle_loads,
|
32
|
+
ProcessFailedToStart,
|
33
|
+
)
|
34
|
+
from monarch_supervisor.logging import gethostname, initialize_logging
|
35
|
+
from monarch_supervisor.python_executable import PYTHON_EXECUTABLE
|
36
|
+
|
37
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
38
|
+
ABORT_INTERVAL = 5
|
39
|
+
LOG_PSTREE_INTERVAL: int = 60 * 10
|
40
|
+
__NR_pidfd_open = 434
|
41
|
+
libc = ctypes.CDLL(None)
|
42
|
+
|
43
|
+
|
44
|
+
# older libc do not have this syscall
|
45
|
+
def pidfd_open(pid: int) -> int:
|
46
|
+
return libc.syscall(__NR_pidfd_open, pid, 0)
|
47
|
+
|
48
|
+
|
49
|
+
# objects in this file represent Host/Process
|
50
|
+
# on the host machine itself.
|
51
|
+
|
52
|
+
# main package has Host/Process used by
|
53
|
+
# the supervisor.
|
54
|
+
|
55
|
+
|
56
|
+
class Process:
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
name: str,
|
60
|
+
logfilename: Optional[str],
|
61
|
+
proc_comm: zmq.Socket,
|
62
|
+
proc_id: int,
|
63
|
+
rank: int,
|
64
|
+
processes_per_host: int,
|
65
|
+
world_size: int,
|
66
|
+
popen: Mapping[str, Any],
|
67
|
+
proc_addr: str,
|
68
|
+
start_new_session: bool,
|
69
|
+
) -> None:
|
70
|
+
self.proc_id = proc_id
|
71
|
+
self.proc_comm = proc_comm
|
72
|
+
self.deferred_sends: Optional[List[bytes]] = []
|
73
|
+
local_config = {
|
74
|
+
"RANK": str(rank),
|
75
|
+
"WORLD_SIZE": str(world_size),
|
76
|
+
"LOCAL_RANK": str(rank % processes_per_host),
|
77
|
+
"LOCAL_WORLD_SIZE": str(processes_per_host),
|
78
|
+
"SUPERVISOR_PIPE": proc_addr,
|
79
|
+
"SUPERVISOR_IDENT": str(proc_id),
|
80
|
+
}
|
81
|
+
|
82
|
+
environ = dict(os.environ)
|
83
|
+
if popen["env"] is not None:
|
84
|
+
environ.update(
|
85
|
+
{
|
86
|
+
k: Template(v).safe_substitute(local_config)
|
87
|
+
for k, v in popen["env"].items()
|
88
|
+
}
|
89
|
+
)
|
90
|
+
args = popen["args"]
|
91
|
+
if isinstance(args, _FunctionCall):
|
92
|
+
self._send(pickle_dumps(args))
|
93
|
+
args = [PYTHON_EXECUTABLE, "-m", "monarch_supervisor.function_call"]
|
94
|
+
|
95
|
+
environ.update(local_config)
|
96
|
+
popen = {**popen, "env": environ, "args": args}
|
97
|
+
try:
|
98
|
+
if logfilename is None:
|
99
|
+
logcontext = nullcontext()
|
100
|
+
else:
|
101
|
+
try:
|
102
|
+
logcontext = open(logfilename, "a")
|
103
|
+
except FileNotFoundError:
|
104
|
+
Path(logfilename).parent.mkdir(exist_ok=True, parents=True)
|
105
|
+
logcontext = open(logfilename, "a")
|
106
|
+
with logcontext as logfile:
|
107
|
+
self.subprocess: subprocess.Popen[str] = subprocess.Popen(
|
108
|
+
**popen,
|
109
|
+
start_new_session=start_new_session,
|
110
|
+
stdout=logfile,
|
111
|
+
stderr=logfile,
|
112
|
+
)
|
113
|
+
except Exception:
|
114
|
+
s = io.StringIO()
|
115
|
+
traceback.print_exc(file=s)
|
116
|
+
logger.warning(f"Process failed to start: {s.getvalue()}\n")
|
117
|
+
raise ProcessFailedToStart(s.getvalue())
|
118
|
+
self.proc_id_bytes: bytes = proc_id.to_bytes(8, byteorder="little")
|
119
|
+
|
120
|
+
def _send(self, msg: bytes) -> None:
|
121
|
+
if self.deferred_sends is not None:
|
122
|
+
self.deferred_sends.append(msg)
|
123
|
+
else:
|
124
|
+
self.proc_comm.send_multipart([self.proc_id_bytes, msg])
|
125
|
+
|
126
|
+
def _notify_connected(self) -> None:
|
127
|
+
deferred_sends = self.deferred_sends
|
128
|
+
if deferred_sends is not None:
|
129
|
+
for msg in deferred_sends:
|
130
|
+
self.proc_comm.send_multipart([self.proc_id_bytes, msg])
|
131
|
+
self.deferred_sends = None
|
132
|
+
|
133
|
+
|
134
|
+
class Host:
|
135
|
+
"""
|
136
|
+
Represents a host (Host Manager) that can be supervised.
|
137
|
+
Starts an event loop listening for commands from the supervisor, including launching/killing processes.
|
138
|
+
"""
|
139
|
+
|
140
|
+
def __init__(self, supervisor_port: str, start_new_session: bool = True) -> None:
|
141
|
+
self.context: zmq.Context = zmq.Context(1)
|
142
|
+
self.supervisor_comm: zmq.Socket = self._socket(zmq.DEALER)
|
143
|
+
self.supervisor_comm.setsockopt(zmq.IPV6, True)
|
144
|
+
logger.info("Connecting to %s", supervisor_port)
|
145
|
+
self.supervisor_comm.connect(supervisor_port)
|
146
|
+
|
147
|
+
# tell the supervisor we exist, and provide
|
148
|
+
# hostname for debugging.
|
149
|
+
self.supervisor_comm.send(
|
150
|
+
pickle_dumps(("_hostname", None, socket.gethostname()))
|
151
|
+
)
|
152
|
+
|
153
|
+
self.poller = zmq.Poller()
|
154
|
+
self.poller.register(self.supervisor_comm, zmq.POLLIN)
|
155
|
+
|
156
|
+
# optional way to send messages to processes.
|
157
|
+
# all processes on this host will use the same
|
158
|
+
# socket.
|
159
|
+
self.proc_comm: zmq.Socket = self._socket(zmq.ROUTER)
|
160
|
+
|
161
|
+
self.proc_addr = f"ipc:///tmp/proc-{uuid.uuid4()}"
|
162
|
+
self.proc_comm.bind(self.proc_addr)
|
163
|
+
self.poller.register(self.proc_comm, zmq.POLLIN)
|
164
|
+
|
165
|
+
self.process_table: Dict[bytes, Process] = {}
|
166
|
+
self.fd_to_on_exit: Dict[int, Callable[[], None]] = {}
|
167
|
+
self._launches = 0
|
168
|
+
self.has_shutdown = False
|
169
|
+
self.exits: List[bytes] = []
|
170
|
+
self.start_new_session = start_new_session
|
171
|
+
|
172
|
+
def _socket(self, kind: int) -> zmq.Socket:
|
173
|
+
socket = self.context.socket(kind)
|
174
|
+
socket.setsockopt(zmq.SNDHWM, 0)
|
175
|
+
socket.setsockopt(zmq.RCVHWM, 0)
|
176
|
+
return socket
|
177
|
+
|
178
|
+
def heartbeat(self) -> None:
|
179
|
+
self.supervisor_comm.send(b"")
|
180
|
+
|
181
|
+
# TODO: validate these are valid messages to send
|
182
|
+
|
183
|
+
def launch(
|
184
|
+
self,
|
185
|
+
proc_id: int,
|
186
|
+
rank: int,
|
187
|
+
processes_per_rank: int,
|
188
|
+
world_size: int,
|
189
|
+
popen: Mapping[str, object],
|
190
|
+
name: str,
|
191
|
+
simulate: bool,
|
192
|
+
log_file: Optional[str],
|
193
|
+
) -> None:
|
194
|
+
self._launches += 1
|
195
|
+
if simulate:
|
196
|
+
self.supervisor_comm.send(pickle_dumps(("_started", proc_id, 2)))
|
197
|
+
self.supervisor_comm.send(pickle_dumps(("_exited", proc_id, 0)))
|
198
|
+
return
|
199
|
+
try:
|
200
|
+
logger.info(f"starting new process proc_id: {proc_id}")
|
201
|
+
process = Process(
|
202
|
+
name,
|
203
|
+
log_file,
|
204
|
+
self.proc_comm,
|
205
|
+
proc_id,
|
206
|
+
rank,
|
207
|
+
processes_per_rank,
|
208
|
+
world_size,
|
209
|
+
popen,
|
210
|
+
self.proc_addr,
|
211
|
+
self.start_new_session,
|
212
|
+
)
|
213
|
+
self.process_table[process.proc_id_bytes] = process
|
214
|
+
self.on_subprocess_exit(
|
215
|
+
process.subprocess, lambda: self.process_exit(process)
|
216
|
+
)
|
217
|
+
reply = process.subprocess.pid
|
218
|
+
except ProcessFailedToStart as e:
|
219
|
+
reply = str(e)
|
220
|
+
self.supervisor_comm.send(pickle_dumps(("_started", proc_id, reply)))
|
221
|
+
|
222
|
+
def process_exit(self, process: Process) -> None:
|
223
|
+
self.process_table.pop(process.proc_id_bytes)
|
224
|
+
# we do not allow descendents to outlive the parent
|
225
|
+
# if any remain this kill will clean them up
|
226
|
+
self.kill(process.subprocess.pid, signal.SIGKILL)
|
227
|
+
returncode = process.subprocess.wait()
|
228
|
+
if not self.has_shutdown:
|
229
|
+
self.exits.append(pickle_dumps(("_exited", process.proc_id, returncode)))
|
230
|
+
|
231
|
+
def kill(self, pid: int, sig: int) -> None:
|
232
|
+
if self.start_new_session:
|
233
|
+
os.killpg(pid, sig)
|
234
|
+
else:
|
235
|
+
os.kill(pid, sig)
|
236
|
+
|
237
|
+
def on_subprocess_exit(
|
238
|
+
self, subprocess: subprocess.Popen, on_exit: Callable[[], Any]
|
239
|
+
) -> None:
|
240
|
+
fd: int = pidfd_open(subprocess.pid)
|
241
|
+
self.fd_to_on_exit[fd] = on_exit
|
242
|
+
self.poller.register(fd, zmq.POLLIN)
|
243
|
+
|
244
|
+
def send(self, _proc_id: int, msg: bytes) -> None:
|
245
|
+
proc_id = _proc_id.to_bytes(8, byteorder="little")
|
246
|
+
if proc_id in self.process_table:
|
247
|
+
process = self.process_table[proc_id]
|
248
|
+
process._send(msg)
|
249
|
+
|
250
|
+
def signal(self, _proc_id: int, sig: int, group: bool) -> None:
|
251
|
+
proc_id = _proc_id.to_bytes(8, byteorder="little")
|
252
|
+
if proc_id in self.process_table:
|
253
|
+
process = self.process_table[proc_id]
|
254
|
+
if group and self.start_new_session:
|
255
|
+
os.killpg(process.subprocess.pid, sig)
|
256
|
+
else:
|
257
|
+
process.subprocess.send_signal(sig)
|
258
|
+
|
259
|
+
def _fd_exit(self, fd: int) -> None:
|
260
|
+
on_exit = self.fd_to_on_exit.pop(fd)
|
261
|
+
self.poller.unregister(fd)
|
262
|
+
os.close(fd)
|
263
|
+
on_exit()
|
264
|
+
|
265
|
+
def shutdown(self) -> None:
|
266
|
+
if self.has_shutdown:
|
267
|
+
return
|
268
|
+
self.has_shutdown = True
|
269
|
+
for proc in self.process_table.values():
|
270
|
+
self.kill(proc.subprocess.pid, signal.SIGTERM)
|
271
|
+
expiry = time.time() + ABORT_INTERVAL
|
272
|
+
ttl = ABORT_INTERVAL
|
273
|
+
while ttl > 0 and self.process_table:
|
274
|
+
for s, _ in self.poller.poll(timeout=int(1000 * ttl)):
|
275
|
+
if isinstance(s, int):
|
276
|
+
self._fd_exit(s)
|
277
|
+
ttl = time.time() - expiry
|
278
|
+
if self.process_table:
|
279
|
+
for proc in self.process_table.values():
|
280
|
+
self.kill(proc.subprocess.pid, signal.SIGKILL)
|
281
|
+
|
282
|
+
self.proc_comm.close(linger=0)
|
283
|
+
self.supervisor_comm.close(linger=0)
|
284
|
+
self.context.term()
|
285
|
+
|
286
|
+
def abort(self, with_error: Optional[str] = None) -> None:
|
287
|
+
self.shutdown()
|
288
|
+
if with_error:
|
289
|
+
logger.error("exiting with error: %s", with_error)
|
290
|
+
raise ConnectionAbortedError(with_error)
|
291
|
+
else:
|
292
|
+
logger.warning("exiting cleanly.")
|
293
|
+
sys.exit(0)
|
294
|
+
|
295
|
+
def run_event_loop_forever(self) -> None:
|
296
|
+
log_pstree_info_at = time.time() + LOG_PSTREE_INTERVAL
|
297
|
+
supervisor_expiry = None
|
298
|
+
heartbeat_at = 0
|
299
|
+
while True:
|
300
|
+
timeout = (
|
301
|
+
-1
|
302
|
+
if supervisor_expiry is None
|
303
|
+
else int(max(1000 * (heartbeat_at - time.time()) + 1, 0))
|
304
|
+
)
|
305
|
+
proc_comm_processed = False
|
306
|
+
for s, _ in self.poller.poll(timeout=timeout):
|
307
|
+
if isinstance(s, int):
|
308
|
+
# we register a file descriptor to the poller, which is a raw
|
309
|
+
# int file description that becomes ready when the a subprocess exits
|
310
|
+
# see pidfd_open.
|
311
|
+
self._fd_exit(s)
|
312
|
+
elif s is self.supervisor_comm:
|
313
|
+
if supervisor_expiry is None:
|
314
|
+
logging.info("Connected to supervisor")
|
315
|
+
# first heartbeat is set to between 0 to HEARTBEAT_INTERVAL
|
316
|
+
# to spread out the heartbeats from hosts that all start
|
317
|
+
# at the same time.
|
318
|
+
heartbeat_at = time.time() + HEARTBEAT_INTERVAL * random()
|
319
|
+
|
320
|
+
supervisor_expiry = (
|
321
|
+
time.time() + HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS
|
322
|
+
)
|
323
|
+
# pyre-ignore[29]
|
324
|
+
msg = self.supervisor_comm.recv()
|
325
|
+
if msg:
|
326
|
+
cmd, *args = pickle_loads(msg)
|
327
|
+
getattr(self, cmd)(*args)
|
328
|
+
elif s is self.proc_comm:
|
329
|
+
proc_comm_processed = True
|
330
|
+
proc_id_bytes, msg = self.proc_comm.recv_multipart()
|
331
|
+
process = self.process_table.get(proc_id_bytes)
|
332
|
+
# it is possible for the process to have already exited before
|
333
|
+
# we get its messages, so process_table will be empty
|
334
|
+
if process is not None:
|
335
|
+
process._notify_connected()
|
336
|
+
if len(msg):
|
337
|
+
proc_id = int.from_bytes(proc_id_bytes, byteorder="little")
|
338
|
+
self.supervisor_comm.send(
|
339
|
+
pickle_dumps(("_response", proc_id, msg))
|
340
|
+
)
|
341
|
+
if not proc_comm_processed and self.exits:
|
342
|
+
for exit in self.exits:
|
343
|
+
self.supervisor_comm.send(exit)
|
344
|
+
self.exits.clear()
|
345
|
+
|
346
|
+
if supervisor_expiry is not None:
|
347
|
+
t = time.time()
|
348
|
+
if t > heartbeat_at:
|
349
|
+
heartbeat_at = t + HEARTBEAT_INTERVAL
|
350
|
+
self.heartbeat()
|
351
|
+
if t > supervisor_expiry:
|
352
|
+
self.abort(
|
353
|
+
f"No messages from supervisor for {HEARTBEAT_INTERVAL*HEARTBEAT_LIVENESS} seconds, aborting."
|
354
|
+
)
|
355
|
+
if t > log_pstree_info_at:
|
356
|
+
log_pstree = subprocess.Popen(
|
357
|
+
[
|
358
|
+
os.getenv("FB_XAR_INVOKED_NAME", default=sys.executable),
|
359
|
+
"-m",
|
360
|
+
"monarch_supervisor.log_pstree",
|
361
|
+
str(os.getpid()),
|
362
|
+
]
|
363
|
+
)
|
364
|
+
self.on_subprocess_exit(log_pstree, log_pstree.wait)
|
365
|
+
log_pstree_info_at = t + LOG_PSTREE_INTERVAL
|
366
|
+
|
367
|
+
|
368
|
+
def main(addr: str) -> None:
|
369
|
+
manager: Host = Host(addr)
|
370
|
+
|
371
|
+
def handler(signal: int, frame: object) -> None:
|
372
|
+
manager.shutdown()
|
373
|
+
sys.exit(1)
|
374
|
+
|
375
|
+
signal.signal(signal.SIGINT, handler)
|
376
|
+
signal.signal(signal.SIGTERM, handler)
|
377
|
+
try:
|
378
|
+
manager.run_event_loop_forever()
|
379
|
+
finally:
|
380
|
+
manager.shutdown()
|
381
|
+
|
382
|
+
|
383
|
+
if __name__ == "__main__":
|
384
|
+
(addr,) = sys.argv[1:]
|
385
|
+
initialize_logging(f"{gethostname()} pid {os.getpid()} host-manager")
|
386
|
+
main(addr)
|
@@ -0,0 +1,145 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
import io
|
10
|
+
import json
|
11
|
+
import logging
|
12
|
+
import os
|
13
|
+
import signal
|
14
|
+
import socket
|
15
|
+
import subprocess
|
16
|
+
import sys
|
17
|
+
import time
|
18
|
+
import traceback
|
19
|
+
from typing import Callable, Optional, Sequence, Tuple
|
20
|
+
|
21
|
+
from . import Context, Host
|
22
|
+
|
23
|
+
from .host import main
|
24
|
+
from .logging import gethostname, initialize_logging
|
25
|
+
from .python_executable import PYTHON_EXECUTABLE
|
26
|
+
|
27
|
+
# Default port leveraging one from the reserved range for torchelastic
|
28
|
+
PORT: str = os.environ.get("SUPERVISOR_PORT", "29401")
|
29
|
+
|
30
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
|
33
|
+
NON_RETRYABLE_FAILURE: int = 100
|
34
|
+
JOB_RESTART_SCOPE_ESCALATION: int = 101
|
35
|
+
TW_USER_METADATA_HOSTNAMES_LIST_KEY: str = "TW_USER_METADATA_HOSTNAMES_LIST_KEY"
|
36
|
+
TW_USER_METADATA_FILE_PATH: str = "TW_USER_METADATA_FILE_PATH"
|
37
|
+
|
38
|
+
|
39
|
+
def _write_reply_file(msg: str, reply_file: Optional[str] = None) -> None:
|
40
|
+
if reply_file is None:
|
41
|
+
reply_file = os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
|
42
|
+
job_attempt = int(os.environ["MAST_HPC_JOB_ATTEMPT_INDEX"])
|
43
|
+
logger.info(
|
44
|
+
f"Supervisor writing a reply file with JOB_RESTART_SCOPE_ESCALATION to {reply_file} (attempt {job_attempt})."
|
45
|
+
)
|
46
|
+
with open(reply_file, "w") as f:
|
47
|
+
timestamp_ns = time.time_ns()
|
48
|
+
error_data = {
|
49
|
+
"message": msg,
|
50
|
+
"errorCode": JOB_RESTART_SCOPE_ESCALATION,
|
51
|
+
"timestamp": int(timestamp_ns // 1e9),
|
52
|
+
"timestamp_us": int(timestamp_ns // 1e3),
|
53
|
+
}
|
54
|
+
json.dump(error_data, f)
|
55
|
+
|
56
|
+
|
57
|
+
def mast(supervise: Callable[[Context, Sequence[Host]], None]) -> None:
|
58
|
+
"""
|
59
|
+
This function is the entrypoint for starting the supervisor when
|
60
|
+
running on MAST. Each host should call `mast(supervise)` where
|
61
|
+
`supervise` is the supervisor policy function for the job.
|
62
|
+
Supervisor will be called only on the supervisor machine with
|
63
|
+
`supervisor(n_hosts_in_task, port)` where `n_hosts_in_task` is
|
64
|
+
the number of hosts reserved in the task group, and `port` is the
|
65
|
+
port that supervisor should listen on.
|
66
|
+
|
67
|
+
The supervise function can then create a supervisor Context object,
|
68
|
+
request up to n_hosts_in_tasks hosts, and then
|
69
|
+
"""
|
70
|
+
|
71
|
+
hostnames = get_hostnames()
|
72
|
+
N = len(hostnames)
|
73
|
+
my_host_name = (os.environ.get("HOSTNAME") or socket.getfqdn()).removesuffix(
|
74
|
+
".facebook.com"
|
75
|
+
)
|
76
|
+
# Get first host in the task group
|
77
|
+
is_supervisor = my_host_name == hostnames[0]
|
78
|
+
initialize_logging(
|
79
|
+
"supervisor" if is_supervisor else f"{gethostname()} host-manager"
|
80
|
+
)
|
81
|
+
|
82
|
+
supervisor_addr = f"tcp://{socket.getfqdn(hostnames[0])}:{PORT}"
|
83
|
+
logger.info(
|
84
|
+
"hostname %s, supervisor host is %s, supervisor=%s",
|
85
|
+
my_host_name,
|
86
|
+
hostnames[0],
|
87
|
+
is_supervisor,
|
88
|
+
)
|
89
|
+
|
90
|
+
if is_supervisor:
|
91
|
+
_write_reply_file(
|
92
|
+
"Supervisor deadman's switch. "
|
93
|
+
"This reply file is written when the supervisor starts and deleted right before a successful exit. "
|
94
|
+
"It is used to cause the whole job to restart if for some reason the "
|
95
|
+
"supervisor host is unscheduled without it throwing an exception itself."
|
96
|
+
)
|
97
|
+
# local host manager on supervisor machine
|
98
|
+
host_process = subprocess.Popen(
|
99
|
+
[PYTHON_EXECUTABLE, "-m", "monarch_supervisor.host", supervisor_addr]
|
100
|
+
)
|
101
|
+
try:
|
102
|
+
ctx = Context(port=int(PORT))
|
103
|
+
hosts: Tuple[Host, ...] = ctx.request_hosts(n=N)
|
104
|
+
supervise(ctx, hosts)
|
105
|
+
ctx.shutdown()
|
106
|
+
logger.info("Supervisor shutdown complete.")
|
107
|
+
except BaseException:
|
108
|
+
ty, e, st = sys.exc_info()
|
109
|
+
s = io.StringIO()
|
110
|
+
traceback.print_tb(st, file=s)
|
111
|
+
_write_reply_file(
|
112
|
+
f"{ty.__name__ if ty is not None else 'None'}: {str(e)}\n{s.getvalue()}"
|
113
|
+
)
|
114
|
+
host_process.send_signal(signal.SIGINT)
|
115
|
+
raise
|
116
|
+
return_code = host_process.wait(timeout=10)
|
117
|
+
if return_code != 0:
|
118
|
+
# Host manager may have been instructed to write a reply file, so
|
119
|
+
# we do not write a reply file here which would clobber it.
|
120
|
+
logger.warning(
|
121
|
+
f"Host manager on supervisor returned non-zero code: {return_code}."
|
122
|
+
)
|
123
|
+
sys.exit(return_code)
|
124
|
+
else:
|
125
|
+
# successful exit, so we remove the deadman's switch reply file we wrote earlier.
|
126
|
+
reply_file = os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
|
127
|
+
os.unlink(reply_file)
|
128
|
+
else:
|
129
|
+
# host manager on non-supervisor machine
|
130
|
+
main(supervisor_addr)
|
131
|
+
|
132
|
+
|
133
|
+
def get_hostnames() -> Sequence[str]:
|
134
|
+
"""
|
135
|
+
Get the list of hostnames for the current task group.
|
136
|
+
"""
|
137
|
+
tw_metatdata_file = os.environ.get(TW_USER_METADATA_FILE_PATH)
|
138
|
+
hostnames_key = os.environ.get(TW_USER_METADATA_HOSTNAMES_LIST_KEY)
|
139
|
+
if tw_metatdata_file and hostnames_key:
|
140
|
+
with open(tw_metatdata_file, "r") as f:
|
141
|
+
data = json.load(f)
|
142
|
+
hostnames_str = data["userAttributes"][hostnames_key]
|
143
|
+
return hostnames_str.split(",")
|
144
|
+
|
145
|
+
return os.environ["MAST_HPC_TASK_GROUP_HOSTNAMES"].split(",")
|
@@ -0,0 +1,48 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
import logging
|
9
|
+
import subprocess
|
10
|
+
import sys
|
11
|
+
from typing import Optional
|
12
|
+
|
13
|
+
from monarch_supervisor.logging import gethostname, initialize_logging
|
14
|
+
|
15
|
+
pid: str
|
16
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
def extract_pss(pid: str) -> Optional[str]:
|
20
|
+
try:
|
21
|
+
with open(f"/proc/{pid}/smaps_rollup", "r") as f:
|
22
|
+
for line in f.readlines():
|
23
|
+
if line.startswith("Pss:"): # Check if the line starts with 'Pss:'
|
24
|
+
return " ".join(line.split()[1:3])
|
25
|
+
except Exception:
|
26
|
+
pass
|
27
|
+
return None
|
28
|
+
|
29
|
+
|
30
|
+
def log_pstree_output(pid: int) -> None:
|
31
|
+
pstree_output = subprocess.check_output(["pstree", "-Tap", str(pid)]).decode(
|
32
|
+
"utf-8"
|
33
|
+
)
|
34
|
+
lines = pstree_output.split("\n")
|
35
|
+
logger.info("Process Info")
|
36
|
+
for line in lines:
|
37
|
+
if not line.strip():
|
38
|
+
continue
|
39
|
+
parts = line.split(",")
|
40
|
+
pids = parts[1].split()[0]
|
41
|
+
mem = extract_pss(pids)
|
42
|
+
logger.info(f"{line} {mem}")
|
43
|
+
|
44
|
+
|
45
|
+
if __name__ == "__main__":
|
46
|
+
(pid,) = sys.argv[1:]
|
47
|
+
initialize_logging(f"{gethostname()} host-manager")
|
48
|
+
log_pstree_output(int(pid))
|