torchmonarch-nightly 2025.6.27__cp313-cp313-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,44 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from typing import NamedTuple
9
+
10
+ from monarch_supervisor import get_message_queue
11
+
12
+
13
+ class Reply(NamedTuple):
14
+ a: int
15
+ b: int
16
+ x: int
17
+
18
+
19
+ def reply_hello(a, b, x):
20
+ q = get_message_queue()
21
+ q.send(Reply(a, b, x))
22
+
23
+
24
+ def echo():
25
+ q = get_message_queue()
26
+ i = 0
27
+ while True:
28
+ sender, m = q.recv()
29
+ if m == "exit":
30
+ break
31
+ assert m == i
32
+ q.send(m)
33
+ i += 1
34
+
35
+
36
+ class Mapper:
37
+ def map(self, items):
38
+ return sum(x * 2 for x in items)
39
+
40
+ def reduce(self, items):
41
+ return sum(items)
42
+
43
+ def finish(self, result):
44
+ return result
@@ -0,0 +1,30 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import importlib.util
9
+ import sys
10
+
11
+ from monarch_supervisor import _FunctionCall, get_message_queue
12
+
13
+ if __name__ == "__main__":
14
+ q = get_message_queue()
15
+ _, call = q.recv()
16
+ assert isinstance(call, _FunctionCall)
17
+ filename, *rest = call.target.split(":", 1)
18
+ if not rest:
19
+ modulename, funcname = filename.rsplit(".", 1)
20
+ module = importlib.import_module(modulename)
21
+ else:
22
+ spec = importlib.util.spec_from_file_location("__entry__", filename)
23
+ assert spec is not None and spec.loader is not None
24
+ module = importlib.util.module_from_spec(spec)
25
+ # pyre-ignore[16]
26
+ spec.loader.exec_module(module)
27
+ sys.modules["__entry__"] = module
28
+ funcname = rest[0]
29
+ func = getattr(module, funcname)
30
+ func(*call.args, **call.kwargs)
@@ -0,0 +1,386 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import ctypes
9
+ import io
10
+ import logging
11
+ import os
12
+ import signal
13
+ import socket
14
+ import subprocess
15
+ import sys
16
+ import time
17
+ import traceback
18
+ import uuid
19
+ from contextlib import nullcontext
20
+ from pathlib import Path
21
+ from random import random
22
+ from string import Template
23
+ from typing import Any, Callable, Dict, List, Mapping, Optional
24
+
25
+ import zmq
26
+ from monarch_supervisor import (
27
+ _FunctionCall,
28
+ HEARTBEAT_INTERVAL,
29
+ HEARTBEAT_LIVENESS,
30
+ pickle_dumps,
31
+ pickle_loads,
32
+ ProcessFailedToStart,
33
+ )
34
+ from monarch_supervisor.logging import gethostname, initialize_logging
35
+ from monarch_supervisor.python_executable import PYTHON_EXECUTABLE
36
+
37
+ logger: logging.Logger = logging.getLogger(__name__)
38
+ ABORT_INTERVAL = 5
39
+ LOG_PSTREE_INTERVAL: int = 60 * 10
40
+ __NR_pidfd_open = 434
41
+ libc = ctypes.CDLL(None)
42
+
43
+
44
+ # older libc do not have this syscall
45
+ def pidfd_open(pid: int) -> int:
46
+ return libc.syscall(__NR_pidfd_open, pid, 0)
47
+
48
+
49
+ # objects in this file represent Host/Process
50
+ # on the host machine itself.
51
+
52
+ # main package has Host/Process used by
53
+ # the supervisor.
54
+
55
+
56
+ class Process:
57
+ def __init__(
58
+ self,
59
+ name: str,
60
+ logfilename: Optional[str],
61
+ proc_comm: zmq.Socket,
62
+ proc_id: int,
63
+ rank: int,
64
+ processes_per_host: int,
65
+ world_size: int,
66
+ popen: Mapping[str, Any],
67
+ proc_addr: str,
68
+ start_new_session: bool,
69
+ ) -> None:
70
+ self.proc_id = proc_id
71
+ self.proc_comm = proc_comm
72
+ self.deferred_sends: Optional[List[bytes]] = []
73
+ local_config = {
74
+ "RANK": str(rank),
75
+ "WORLD_SIZE": str(world_size),
76
+ "LOCAL_RANK": str(rank % processes_per_host),
77
+ "LOCAL_WORLD_SIZE": str(processes_per_host),
78
+ "SUPERVISOR_PIPE": proc_addr,
79
+ "SUPERVISOR_IDENT": str(proc_id),
80
+ }
81
+
82
+ environ = dict(os.environ)
83
+ if popen["env"] is not None:
84
+ environ.update(
85
+ {
86
+ k: Template(v).safe_substitute(local_config)
87
+ for k, v in popen["env"].items()
88
+ }
89
+ )
90
+ args = popen["args"]
91
+ if isinstance(args, _FunctionCall):
92
+ self._send(pickle_dumps(args))
93
+ args = [PYTHON_EXECUTABLE, "-m", "monarch_supervisor.function_call"]
94
+
95
+ environ.update(local_config)
96
+ popen = {**popen, "env": environ, "args": args}
97
+ try:
98
+ if logfilename is None:
99
+ logcontext = nullcontext()
100
+ else:
101
+ try:
102
+ logcontext = open(logfilename, "a")
103
+ except FileNotFoundError:
104
+ Path(logfilename).parent.mkdir(exist_ok=True, parents=True)
105
+ logcontext = open(logfilename, "a")
106
+ with logcontext as logfile:
107
+ self.subprocess: subprocess.Popen[str] = subprocess.Popen(
108
+ **popen,
109
+ start_new_session=start_new_session,
110
+ stdout=logfile,
111
+ stderr=logfile,
112
+ )
113
+ except Exception:
114
+ s = io.StringIO()
115
+ traceback.print_exc(file=s)
116
+ logger.warning(f"Process failed to start: {s.getvalue()}\n")
117
+ raise ProcessFailedToStart(s.getvalue())
118
+ self.proc_id_bytes: bytes = proc_id.to_bytes(8, byteorder="little")
119
+
120
+ def _send(self, msg: bytes) -> None:
121
+ if self.deferred_sends is not None:
122
+ self.deferred_sends.append(msg)
123
+ else:
124
+ self.proc_comm.send_multipart([self.proc_id_bytes, msg])
125
+
126
+ def _notify_connected(self) -> None:
127
+ deferred_sends = self.deferred_sends
128
+ if deferred_sends is not None:
129
+ for msg in deferred_sends:
130
+ self.proc_comm.send_multipart([self.proc_id_bytes, msg])
131
+ self.deferred_sends = None
132
+
133
+
134
+ class Host:
135
+ """
136
+ Represents a host (Host Manager) that can be supervised.
137
+ Starts an event loop listening for commands from the supervisor, including launching/killing processes.
138
+ """
139
+
140
+ def __init__(self, supervisor_port: str, start_new_session: bool = True) -> None:
141
+ self.context: zmq.Context = zmq.Context(1)
142
+ self.supervisor_comm: zmq.Socket = self._socket(zmq.DEALER)
143
+ self.supervisor_comm.setsockopt(zmq.IPV6, True)
144
+ logger.info("Connecting to %s", supervisor_port)
145
+ self.supervisor_comm.connect(supervisor_port)
146
+
147
+ # tell the supervisor we exist, and provide
148
+ # hostname for debugging.
149
+ self.supervisor_comm.send(
150
+ pickle_dumps(("_hostname", None, socket.gethostname()))
151
+ )
152
+
153
+ self.poller = zmq.Poller()
154
+ self.poller.register(self.supervisor_comm, zmq.POLLIN)
155
+
156
+ # optional way to send messages to processes.
157
+ # all processes on this host will use the same
158
+ # socket.
159
+ self.proc_comm: zmq.Socket = self._socket(zmq.ROUTER)
160
+
161
+ self.proc_addr = f"ipc:///tmp/proc-{uuid.uuid4()}"
162
+ self.proc_comm.bind(self.proc_addr)
163
+ self.poller.register(self.proc_comm, zmq.POLLIN)
164
+
165
+ self.process_table: Dict[bytes, Process] = {}
166
+ self.fd_to_on_exit: Dict[int, Callable[[], None]] = {}
167
+ self._launches = 0
168
+ self.has_shutdown = False
169
+ self.exits: List[bytes] = []
170
+ self.start_new_session = start_new_session
171
+
172
+ def _socket(self, kind: int) -> zmq.Socket:
173
+ socket = self.context.socket(kind)
174
+ socket.setsockopt(zmq.SNDHWM, 0)
175
+ socket.setsockopt(zmq.RCVHWM, 0)
176
+ return socket
177
+
178
+ def heartbeat(self) -> None:
179
+ self.supervisor_comm.send(b"")
180
+
181
+ # TODO: validate these are valid messages to send
182
+
183
+ def launch(
184
+ self,
185
+ proc_id: int,
186
+ rank: int,
187
+ processes_per_rank: int,
188
+ world_size: int,
189
+ popen: Mapping[str, object],
190
+ name: str,
191
+ simulate: bool,
192
+ log_file: Optional[str],
193
+ ) -> None:
194
+ self._launches += 1
195
+ if simulate:
196
+ self.supervisor_comm.send(pickle_dumps(("_started", proc_id, 2)))
197
+ self.supervisor_comm.send(pickle_dumps(("_exited", proc_id, 0)))
198
+ return
199
+ try:
200
+ logger.info(f"starting new process proc_id: {proc_id}")
201
+ process = Process(
202
+ name,
203
+ log_file,
204
+ self.proc_comm,
205
+ proc_id,
206
+ rank,
207
+ processes_per_rank,
208
+ world_size,
209
+ popen,
210
+ self.proc_addr,
211
+ self.start_new_session,
212
+ )
213
+ self.process_table[process.proc_id_bytes] = process
214
+ self.on_subprocess_exit(
215
+ process.subprocess, lambda: self.process_exit(process)
216
+ )
217
+ reply = process.subprocess.pid
218
+ except ProcessFailedToStart as e:
219
+ reply = str(e)
220
+ self.supervisor_comm.send(pickle_dumps(("_started", proc_id, reply)))
221
+
222
+ def process_exit(self, process: Process) -> None:
223
+ self.process_table.pop(process.proc_id_bytes)
224
+ # we do not allow descendents to outlive the parent
225
+ # if any remain this kill will clean them up
226
+ self.kill(process.subprocess.pid, signal.SIGKILL)
227
+ returncode = process.subprocess.wait()
228
+ if not self.has_shutdown:
229
+ self.exits.append(pickle_dumps(("_exited", process.proc_id, returncode)))
230
+
231
+ def kill(self, pid: int, sig: int) -> None:
232
+ if self.start_new_session:
233
+ os.killpg(pid, sig)
234
+ else:
235
+ os.kill(pid, sig)
236
+
237
+ def on_subprocess_exit(
238
+ self, subprocess: subprocess.Popen, on_exit: Callable[[], Any]
239
+ ) -> None:
240
+ fd: int = pidfd_open(subprocess.pid)
241
+ self.fd_to_on_exit[fd] = on_exit
242
+ self.poller.register(fd, zmq.POLLIN)
243
+
244
+ def send(self, _proc_id: int, msg: bytes) -> None:
245
+ proc_id = _proc_id.to_bytes(8, byteorder="little")
246
+ if proc_id in self.process_table:
247
+ process = self.process_table[proc_id]
248
+ process._send(msg)
249
+
250
+ def signal(self, _proc_id: int, sig: int, group: bool) -> None:
251
+ proc_id = _proc_id.to_bytes(8, byteorder="little")
252
+ if proc_id in self.process_table:
253
+ process = self.process_table[proc_id]
254
+ if group and self.start_new_session:
255
+ os.killpg(process.subprocess.pid, sig)
256
+ else:
257
+ process.subprocess.send_signal(sig)
258
+
259
+ def _fd_exit(self, fd: int) -> None:
260
+ on_exit = self.fd_to_on_exit.pop(fd)
261
+ self.poller.unregister(fd)
262
+ os.close(fd)
263
+ on_exit()
264
+
265
+ def shutdown(self) -> None:
266
+ if self.has_shutdown:
267
+ return
268
+ self.has_shutdown = True
269
+ for proc in self.process_table.values():
270
+ self.kill(proc.subprocess.pid, signal.SIGTERM)
271
+ expiry = time.time() + ABORT_INTERVAL
272
+ ttl = ABORT_INTERVAL
273
+ while ttl > 0 and self.process_table:
274
+ for s, _ in self.poller.poll(timeout=int(1000 * ttl)):
275
+ if isinstance(s, int):
276
+ self._fd_exit(s)
277
+ ttl = time.time() - expiry
278
+ if self.process_table:
279
+ for proc in self.process_table.values():
280
+ self.kill(proc.subprocess.pid, signal.SIGKILL)
281
+
282
+ self.proc_comm.close(linger=0)
283
+ self.supervisor_comm.close(linger=0)
284
+ self.context.term()
285
+
286
+ def abort(self, with_error: Optional[str] = None) -> None:
287
+ self.shutdown()
288
+ if with_error:
289
+ logger.error("exiting with error: %s", with_error)
290
+ raise ConnectionAbortedError(with_error)
291
+ else:
292
+ logger.warning("exiting cleanly.")
293
+ sys.exit(0)
294
+
295
+ def run_event_loop_forever(self) -> None:
296
+ log_pstree_info_at = time.time() + LOG_PSTREE_INTERVAL
297
+ supervisor_expiry = None
298
+ heartbeat_at = 0
299
+ while True:
300
+ timeout = (
301
+ -1
302
+ if supervisor_expiry is None
303
+ else int(max(1000 * (heartbeat_at - time.time()) + 1, 0))
304
+ )
305
+ proc_comm_processed = False
306
+ for s, _ in self.poller.poll(timeout=timeout):
307
+ if isinstance(s, int):
308
+ # we register a file descriptor to the poller, which is a raw
309
+ # int file description that becomes ready when the a subprocess exits
310
+ # see pidfd_open.
311
+ self._fd_exit(s)
312
+ elif s is self.supervisor_comm:
313
+ if supervisor_expiry is None:
314
+ logging.info("Connected to supervisor")
315
+ # first heartbeat is set to between 0 to HEARTBEAT_INTERVAL
316
+ # to spread out the heartbeats from hosts that all start
317
+ # at the same time.
318
+ heartbeat_at = time.time() + HEARTBEAT_INTERVAL * random()
319
+
320
+ supervisor_expiry = (
321
+ time.time() + HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS
322
+ )
323
+ # pyre-ignore[29]
324
+ msg = self.supervisor_comm.recv()
325
+ if msg:
326
+ cmd, *args = pickle_loads(msg)
327
+ getattr(self, cmd)(*args)
328
+ elif s is self.proc_comm:
329
+ proc_comm_processed = True
330
+ proc_id_bytes, msg = self.proc_comm.recv_multipart()
331
+ process = self.process_table.get(proc_id_bytes)
332
+ # it is possible for the process to have already exited before
333
+ # we get its messages, so process_table will be empty
334
+ if process is not None:
335
+ process._notify_connected()
336
+ if len(msg):
337
+ proc_id = int.from_bytes(proc_id_bytes, byteorder="little")
338
+ self.supervisor_comm.send(
339
+ pickle_dumps(("_response", proc_id, msg))
340
+ )
341
+ if not proc_comm_processed and self.exits:
342
+ for exit in self.exits:
343
+ self.supervisor_comm.send(exit)
344
+ self.exits.clear()
345
+
346
+ if supervisor_expiry is not None:
347
+ t = time.time()
348
+ if t > heartbeat_at:
349
+ heartbeat_at = t + HEARTBEAT_INTERVAL
350
+ self.heartbeat()
351
+ if t > supervisor_expiry:
352
+ self.abort(
353
+ f"No messages from supervisor for {HEARTBEAT_INTERVAL*HEARTBEAT_LIVENESS} seconds, aborting."
354
+ )
355
+ if t > log_pstree_info_at:
356
+ log_pstree = subprocess.Popen(
357
+ [
358
+ os.getenv("FB_XAR_INVOKED_NAME", default=sys.executable),
359
+ "-m",
360
+ "monarch_supervisor.log_pstree",
361
+ str(os.getpid()),
362
+ ]
363
+ )
364
+ self.on_subprocess_exit(log_pstree, log_pstree.wait)
365
+ log_pstree_info_at = t + LOG_PSTREE_INTERVAL
366
+
367
+
368
+ def main(addr: str) -> None:
369
+ manager: Host = Host(addr)
370
+
371
+ def handler(signal: int, frame: object) -> None:
372
+ manager.shutdown()
373
+ sys.exit(1)
374
+
375
+ signal.signal(signal.SIGINT, handler)
376
+ signal.signal(signal.SIGTERM, handler)
377
+ try:
378
+ manager.run_event_loop_forever()
379
+ finally:
380
+ manager.shutdown()
381
+
382
+
383
+ if __name__ == "__main__":
384
+ (addr,) = sys.argv[1:]
385
+ initialize_logging(f"{gethostname()} pid {os.getpid()} host-manager")
386
+ main(addr)
@@ -0,0 +1,145 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import io
10
+ import json
11
+ import logging
12
+ import os
13
+ import signal
14
+ import socket
15
+ import subprocess
16
+ import sys
17
+ import time
18
+ import traceback
19
+ from typing import Callable, Optional, Sequence, Tuple
20
+
21
+ from . import Context, Host
22
+
23
+ from .host import main
24
+ from .logging import gethostname, initialize_logging
25
+ from .python_executable import PYTHON_EXECUTABLE
26
+
27
+ # Default port leveraging one from the reserved range for torchelastic
28
+ PORT: str = os.environ.get("SUPERVISOR_PORT", "29401")
29
+
30
+ logger: logging.Logger = logging.getLogger(__name__)
31
+
32
+
33
+ NON_RETRYABLE_FAILURE: int = 100
34
+ JOB_RESTART_SCOPE_ESCALATION: int = 101
35
+ TW_USER_METADATA_HOSTNAMES_LIST_KEY: str = "TW_USER_METADATA_HOSTNAMES_LIST_KEY"
36
+ TW_USER_METADATA_FILE_PATH: str = "TW_USER_METADATA_FILE_PATH"
37
+
38
+
39
+ def _write_reply_file(msg: str, reply_file: Optional[str] = None) -> None:
40
+ if reply_file is None:
41
+ reply_file = os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
42
+ job_attempt = int(os.environ["MAST_HPC_JOB_ATTEMPT_INDEX"])
43
+ logger.info(
44
+ f"Supervisor writing a reply file with JOB_RESTART_SCOPE_ESCALATION to {reply_file} (attempt {job_attempt})."
45
+ )
46
+ with open(reply_file, "w") as f:
47
+ timestamp_ns = time.time_ns()
48
+ error_data = {
49
+ "message": msg,
50
+ "errorCode": JOB_RESTART_SCOPE_ESCALATION,
51
+ "timestamp": int(timestamp_ns // 1e9),
52
+ "timestamp_us": int(timestamp_ns // 1e3),
53
+ }
54
+ json.dump(error_data, f)
55
+
56
+
57
+ def mast(supervise: Callable[[Context, Sequence[Host]], None]) -> None:
58
+ """
59
+ This function is the entrypoint for starting the supervisor when
60
+ running on MAST. Each host should call `mast(supervise)` where
61
+ `supervise` is the supervisor policy function for the job.
62
+ Supervisor will be called only on the supervisor machine with
63
+ `supervisor(n_hosts_in_task, port)` where `n_hosts_in_task` is
64
+ the number of hosts reserved in the task group, and `port` is the
65
+ port that supervisor should listen on.
66
+
67
+ The supervise function can then create a supervisor Context object,
68
+ request up to n_hosts_in_tasks hosts, and then
69
+ """
70
+
71
+ hostnames = get_hostnames()
72
+ N = len(hostnames)
73
+ my_host_name = (os.environ.get("HOSTNAME") or socket.getfqdn()).removesuffix(
74
+ ".facebook.com"
75
+ )
76
+ # Get first host in the task group
77
+ is_supervisor = my_host_name == hostnames[0]
78
+ initialize_logging(
79
+ "supervisor" if is_supervisor else f"{gethostname()} host-manager"
80
+ )
81
+
82
+ supervisor_addr = f"tcp://{socket.getfqdn(hostnames[0])}:{PORT}"
83
+ logger.info(
84
+ "hostname %s, supervisor host is %s, supervisor=%s",
85
+ my_host_name,
86
+ hostnames[0],
87
+ is_supervisor,
88
+ )
89
+
90
+ if is_supervisor:
91
+ _write_reply_file(
92
+ "Supervisor deadman's switch. "
93
+ "This reply file is written when the supervisor starts and deleted right before a successful exit. "
94
+ "It is used to cause the whole job to restart if for some reason the "
95
+ "supervisor host is unscheduled without it throwing an exception itself."
96
+ )
97
+ # local host manager on supervisor machine
98
+ host_process = subprocess.Popen(
99
+ [PYTHON_EXECUTABLE, "-m", "monarch_supervisor.host", supervisor_addr]
100
+ )
101
+ try:
102
+ ctx = Context(port=int(PORT))
103
+ hosts: Tuple[Host, ...] = ctx.request_hosts(n=N)
104
+ supervise(ctx, hosts)
105
+ ctx.shutdown()
106
+ logger.info("Supervisor shutdown complete.")
107
+ except BaseException:
108
+ ty, e, st = sys.exc_info()
109
+ s = io.StringIO()
110
+ traceback.print_tb(st, file=s)
111
+ _write_reply_file(
112
+ f"{ty.__name__ if ty is not None else 'None'}: {str(e)}\n{s.getvalue()}"
113
+ )
114
+ host_process.send_signal(signal.SIGINT)
115
+ raise
116
+ return_code = host_process.wait(timeout=10)
117
+ if return_code != 0:
118
+ # Host manager may have been instructed to write a reply file, so
119
+ # we do not write a reply file here which would clobber it.
120
+ logger.warning(
121
+ f"Host manager on supervisor returned non-zero code: {return_code}."
122
+ )
123
+ sys.exit(return_code)
124
+ else:
125
+ # successful exit, so we remove the deadman's switch reply file we wrote earlier.
126
+ reply_file = os.environ["MAST_HPC_TASK_FAILURE_REPLY_FILE"]
127
+ os.unlink(reply_file)
128
+ else:
129
+ # host manager on non-supervisor machine
130
+ main(supervisor_addr)
131
+
132
+
133
+ def get_hostnames() -> Sequence[str]:
134
+ """
135
+ Get the list of hostnames for the current task group.
136
+ """
137
+ tw_metatdata_file = os.environ.get(TW_USER_METADATA_FILE_PATH)
138
+ hostnames_key = os.environ.get(TW_USER_METADATA_HOSTNAMES_LIST_KEY)
139
+ if tw_metatdata_file and hostnames_key:
140
+ with open(tw_metatdata_file, "r") as f:
141
+ data = json.load(f)
142
+ hostnames_str = data["userAttributes"][hostnames_key]
143
+ return hostnames_str.split(",")
144
+
145
+ return os.environ["MAST_HPC_TASK_GROUP_HOSTNAMES"].split(",")
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import logging
9
+ import subprocess
10
+ import sys
11
+ from typing import Optional
12
+
13
+ from monarch_supervisor.logging import gethostname, initialize_logging
14
+
15
+ pid: str
16
+ logger: logging.Logger = logging.getLogger(__name__)
17
+
18
+
19
+ def extract_pss(pid: str) -> Optional[str]:
20
+ try:
21
+ with open(f"/proc/{pid}/smaps_rollup", "r") as f:
22
+ for line in f.readlines():
23
+ if line.startswith("Pss:"): # Check if the line starts with 'Pss:'
24
+ return " ".join(line.split()[1:3])
25
+ except Exception:
26
+ pass
27
+ return None
28
+
29
+
30
+ def log_pstree_output(pid: int) -> None:
31
+ pstree_output = subprocess.check_output(["pstree", "-Tap", str(pid)]).decode(
32
+ "utf-8"
33
+ )
34
+ lines = pstree_output.split("\n")
35
+ logger.info("Process Info")
36
+ for line in lines:
37
+ if not line.strip():
38
+ continue
39
+ parts = line.split(",")
40
+ pids = parts[1].split()[0]
41
+ mem = extract_pss(pids)
42
+ logger.info(f"{line} {mem}")
43
+
44
+
45
+ if __name__ == "__main__":
46
+ (pid,) = sys.argv[1:]
47
+ initialize_logging(f"{gethostname()} host-manager")
48
+ log_pstree_output(int(pid))