torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1044 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import io
|
9
|
+
import logging
|
10
|
+
import math
|
11
|
+
import os
|
12
|
+
import pickle
|
13
|
+
import signal
|
14
|
+
import sys
|
15
|
+
import time
|
16
|
+
from abc import ABC, abstractmethod
|
17
|
+
from collections import deque
|
18
|
+
from enum import Enum
|
19
|
+
from functools import cache
|
20
|
+
from logging import Logger
|
21
|
+
from pathlib import Path
|
22
|
+
from threading import Thread
|
23
|
+
from typing import (
|
24
|
+
Any,
|
25
|
+
Callable,
|
26
|
+
Dict,
|
27
|
+
List,
|
28
|
+
Mapping,
|
29
|
+
NamedTuple,
|
30
|
+
Optional,
|
31
|
+
Sequence,
|
32
|
+
Tuple,
|
33
|
+
TypeVar,
|
34
|
+
Union,
|
35
|
+
)
|
36
|
+
|
37
|
+
import torch
|
38
|
+
|
39
|
+
import zmq
|
40
|
+
import zmq.asyncio
|
41
|
+
|
42
|
+
|
43
|
+
logger: Logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
T = TypeVar("T")
|
46
|
+
|
47
|
+
# multiplier (how many heartbeats do we miss before lost)
|
48
|
+
HEARTBEAT_LIVENESS = float(os.getenv("TORCH_SUPERVISOR_HEARTBEAT_LIVENESS", "5.0"))
|
49
|
+
# frequency (in seconds) how often to send heartbeat
|
50
|
+
HEARTBEAT_INTERVAL = float(os.getenv("TORCH_SUPERVISOR_HEARTBEAT_INTERVAL", "1.0"))
|
51
|
+
TTL_REPORT_INTERVAL = float(os.getenv("TORCH_SUPERVISOR_TTL_REPORT_INTERVAL", "60"))
|
52
|
+
LOG_INTERVAL = float(os.getenv("TORCH_SUPERVISOR_LOG_INTERVAL", "60"))
|
53
|
+
DEFAULT_LOGGER_FORMAT = (
|
54
|
+
"%(levelname).1s%(asctime)s.%(msecs)03d000 %(process)d "
|
55
|
+
"%(pathname)s:%(lineno)d] supervisor: %(message)s"
|
56
|
+
)
|
57
|
+
DEFAULT_LOGGER_DATEFORMAT = "%m%d %H:%M:%S"
|
58
|
+
|
59
|
+
_State = Enum("_State", ["UNATTACHED", "ATTACHED", "LOST"])
|
60
|
+
_UNATTACHED: _State = _State.UNATTACHED
|
61
|
+
_ATTACHED: _State = _State.ATTACHED
|
62
|
+
_LOST: _State = _State.LOST
|
63
|
+
|
64
|
+
|
65
|
+
def pickle_loads(*args, **kwargs) -> Any:
|
66
|
+
# Ensure that any tensors load from CPU via monkeypatching how Storages are
|
67
|
+
# loaded.
|
68
|
+
old = torch.storage._load_from_bytes
|
69
|
+
try:
|
70
|
+
torch.storage._load_from_bytes = lambda b: torch.load(
|
71
|
+
io.BytesIO(b), map_location="cpu", weights_only=False
|
72
|
+
)
|
73
|
+
# @lint-ignore PYTHONPICKLEISBAD
|
74
|
+
return pickle.loads(*args, **kwargs)
|
75
|
+
finally:
|
76
|
+
torch.storage._load_from_bytes = old
|
77
|
+
|
78
|
+
|
79
|
+
def pickle_dumps(*args, **kwargs) -> Any:
|
80
|
+
# @lint-ignore PYTHONPICKLEISBAD
|
81
|
+
return pickle.dumps(*args, **kwargs)
|
82
|
+
|
83
|
+
|
84
|
+
# Hosts vs Connection objects:
|
85
|
+
# Connections get created when a host manager registers with the supervisor.
|
86
|
+
# They represent a live socket between the supervisor and the host manager.
|
87
|
+
# Hosts get created when the supervisor requests a new Host object.
|
88
|
+
# They are what the policy script uses as handles to launch jobs.
|
89
|
+
# The supervisor then brokers a match between an existing Host and an existing Connection,
|
90
|
+
# fulfilling the Host's request. Because either object could get created first
|
91
|
+
# (supervisor is slow to create the Host, or host manager is slow to establish a Connection),
|
92
|
+
# it is easier to keep them as separate concepts then try to fold it into a single Host object.
|
93
|
+
|
94
|
+
|
95
|
+
class Connection:
|
96
|
+
"""
|
97
|
+
Connections get created when a host manager registers with the supervisor.
|
98
|
+
They represent a live socket between the supervisor and the host manager.
|
99
|
+
"""
|
100
|
+
|
101
|
+
def __init__(self, ctx: "Context", name: bytes, hostname: Optional[str]) -> None:
|
102
|
+
self.state: _State = _UNATTACHED
|
103
|
+
self.name = name
|
104
|
+
self.hostname = hostname
|
105
|
+
self.host: "Optional[Host]" = None
|
106
|
+
# expiration timestamp when the host will be considered lost
|
107
|
+
self.expiry: float = time.time() + HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS
|
108
|
+
if hostname is None:
|
109
|
+
self.lost(ctx, "Connection did not start with a hostname")
|
110
|
+
else:
|
111
|
+
# let the connection know we exist
|
112
|
+
ctx._backend.send_multipart([name, b""])
|
113
|
+
|
114
|
+
def heartbeat(self) -> float:
|
115
|
+
"""
|
116
|
+
Sets new heartbeart. Updates expiry timestamp
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
float: the old ttl (timestamp in seconds) for record keeping
|
120
|
+
"""
|
121
|
+
now = time.time()
|
122
|
+
ttl = self.expiry - now
|
123
|
+
self.expiry = now + HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS
|
124
|
+
return ttl
|
125
|
+
|
126
|
+
def check_alive_at(self, ctx: "Context", t: float) -> None:
|
127
|
+
"""Checks if host manager alive. if not, mark host as lost and send abort"""
|
128
|
+
if self.state is not _LOST and self.expiry < t:
|
129
|
+
# host timeout
|
130
|
+
elapsed = t - self.expiry + HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS
|
131
|
+
logger.warning(
|
132
|
+
"Host %s (%s) has not heartbeated in %s seconds, disconnecting it",
|
133
|
+
self.hostname,
|
134
|
+
self.name,
|
135
|
+
elapsed,
|
136
|
+
)
|
137
|
+
self.lost(ctx, "Host did not heartbeat")
|
138
|
+
|
139
|
+
def handle_message(self, ctx: "Context", msg: bytes) -> None:
|
140
|
+
ctx._heartbeat_ttl(self.heartbeat())
|
141
|
+
if self.state is _LOST:
|
142
|
+
# got a message from a host that expired, but
|
143
|
+
# eventually came back to life
|
144
|
+
# At this point we've marked its processes as dead
|
145
|
+
# so we are going to tell it to abort so that it gets
|
146
|
+
# restarted and can become a new connection.
|
147
|
+
logger.info("Host %s that was lost reconnected, sending abort", self.name)
|
148
|
+
self.send_abort(ctx, "Supervisor thought host timed out")
|
149
|
+
return
|
150
|
+
|
151
|
+
if not len(msg):
|
152
|
+
# heartbeat msg, respond with our own
|
153
|
+
ctx._backend.send_multipart([self.name, b""])
|
154
|
+
return
|
155
|
+
|
156
|
+
if self.state is _UNATTACHED:
|
157
|
+
logger.warning(
|
158
|
+
"Got message from host %s manager before it was attached.", self.name
|
159
|
+
)
|
160
|
+
self.lost(ctx, "Host manager sent messages before attached.")
|
161
|
+
return
|
162
|
+
|
163
|
+
cmd, proc_id, *args = pickle_loads(msg)
|
164
|
+
assert self.host is not None
|
165
|
+
receiver = self.host if proc_id is None else self.host._proc_table.get(proc_id)
|
166
|
+
if receiver is None:
|
167
|
+
# messages from a process might arrive after the user
|
168
|
+
# no longer has a handle to the Process object
|
169
|
+
# in which case they are ok to just drop
|
170
|
+
assert proc_id >= 0 and proc_id < ctx._next_id, "unexpected proc_id"
|
171
|
+
logger.warning(
|
172
|
+
"Received message %s from process %s after local handle deleted",
|
173
|
+
cmd,
|
174
|
+
proc_id,
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
getattr(receiver, cmd)(*args)
|
178
|
+
receiver = None
|
179
|
+
|
180
|
+
def lost(self, ctx: "Context", with_error: Optional[str]) -> None:
|
181
|
+
orig_state = self.state
|
182
|
+
if orig_state is _LOST:
|
183
|
+
return
|
184
|
+
self.state = _LOST
|
185
|
+
if orig_state is _ATTACHED:
|
186
|
+
assert self.host is not None
|
187
|
+
self.host._lost(with_error)
|
188
|
+
self.host = None
|
189
|
+
self.send_abort(ctx, with_error)
|
190
|
+
|
191
|
+
def send_abort(self, ctx: "Context", with_error: Optional[str]) -> None:
|
192
|
+
ctx._backend.send_multipart([self.name, pickle_dumps(("abort", with_error))])
|
193
|
+
|
194
|
+
|
195
|
+
class HostDisconnected(NamedTuple):
|
196
|
+
time: float
|
197
|
+
|
198
|
+
|
199
|
+
# TODO: rename to HostHandle to disambiguate with supervisor.host.Host?
|
200
|
+
class Host:
|
201
|
+
"""
|
202
|
+
Hosts get created when the supervisor requests a new Host object.
|
203
|
+
They are what the policy script uses as handles to launch jobs.
|
204
|
+
"""
|
205
|
+
|
206
|
+
def __init__(self, context: "Context") -> None:
|
207
|
+
self._context = context
|
208
|
+
self._state: _State = _UNATTACHED
|
209
|
+
self._name: Optional[bytes] = None
|
210
|
+
self._deferred_sends: List[bytes] = []
|
211
|
+
self._proc_table: Dict[int, Process] = {}
|
212
|
+
self._hostname: Optional[str] = None
|
213
|
+
self._is_lost = False
|
214
|
+
|
215
|
+
def __repr__(self) -> str:
|
216
|
+
return f"Host[{self._hostname}]"
|
217
|
+
|
218
|
+
@property
|
219
|
+
def hostname(self) -> Optional[str]:
|
220
|
+
return self._hostname
|
221
|
+
|
222
|
+
def _lost(self, msg: Optional[str]) -> None:
|
223
|
+
orig_state = self._state
|
224
|
+
if orig_state is _LOST:
|
225
|
+
return
|
226
|
+
self._state = _LOST
|
227
|
+
if orig_state is _ATTACHED:
|
228
|
+
assert self._name is not None
|
229
|
+
self._context._name_to_connection[self._name].lost(self._context, msg)
|
230
|
+
self._name = None
|
231
|
+
self._deferred_sends.clear()
|
232
|
+
for p in list(self._proc_table.values()):
|
233
|
+
p._lost_host()
|
234
|
+
# should be cleared by aborting the processes
|
235
|
+
assert len(self._proc_table) == 0
|
236
|
+
self._context._produce_message(self, HostDisconnected(time.time()))
|
237
|
+
self._is_lost = True
|
238
|
+
|
239
|
+
def _send(self, msg: bytes) -> None:
|
240
|
+
if self._state is _ATTACHED:
|
241
|
+
self._context._backend.send_multipart([self._name, msg])
|
242
|
+
elif self._state is _UNATTACHED:
|
243
|
+
self._deferred_sends.append(msg)
|
244
|
+
|
245
|
+
def _launch(self, p: "Process") -> None:
|
246
|
+
self._proc_table[p._id] = p
|
247
|
+
if self._state is _LOST:
|
248
|
+
# launch after we lost connection to this host.
|
249
|
+
p._lost_host()
|
250
|
+
return
|
251
|
+
self._send(
|
252
|
+
pickle_dumps(
|
253
|
+
(
|
254
|
+
"launch",
|
255
|
+
p._id,
|
256
|
+
p.rank,
|
257
|
+
p.processes_per_host,
|
258
|
+
p.world_size,
|
259
|
+
p.popen,
|
260
|
+
p.name,
|
261
|
+
p.simulate,
|
262
|
+
p.logfile,
|
263
|
+
)
|
264
|
+
)
|
265
|
+
)
|
266
|
+
self._context._launches += 1
|
267
|
+
|
268
|
+
@property
|
269
|
+
def disconnected(self) -> bool:
|
270
|
+
return self._is_lost
|
271
|
+
|
272
|
+
def create_process(
|
273
|
+
self,
|
274
|
+
args: Sequence[str],
|
275
|
+
env: Optional[Dict[str, str]] = None,
|
276
|
+
cwd: Optional[str] = None,
|
277
|
+
name: Optional[str] = None,
|
278
|
+
simulate: bool = False,
|
279
|
+
) -> "ProcessList":
|
280
|
+
return self._context.create_process_group(
|
281
|
+
[self], args=args, env=env, cwd=cwd, name=name, simulate=simulate
|
282
|
+
)[0]
|
283
|
+
|
284
|
+
|
285
|
+
class ProcessFailedToStart(Exception):
|
286
|
+
pass
|
287
|
+
|
288
|
+
|
289
|
+
class ProcessStarted(NamedTuple):
|
290
|
+
pid: int
|
291
|
+
|
292
|
+
|
293
|
+
class ProcessExited(NamedTuple):
|
294
|
+
result: Union[int, Exception]
|
295
|
+
|
296
|
+
|
297
|
+
class Process:
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
context: "Context",
|
301
|
+
host: "Host",
|
302
|
+
rank: int,
|
303
|
+
processes_per_host: int,
|
304
|
+
world_size: int,
|
305
|
+
popen: Mapping[str, object],
|
306
|
+
name: str,
|
307
|
+
simulate: bool,
|
308
|
+
) -> None:
|
309
|
+
self._id: int = context._next_id
|
310
|
+
context._next_id += 1
|
311
|
+
self._context = context
|
312
|
+
self.host = host
|
313
|
+
self.rank = rank
|
314
|
+
self.processes_per_host = processes_per_host
|
315
|
+
self.world_size = world_size
|
316
|
+
self.popen = popen
|
317
|
+
self.simulate = simulate
|
318
|
+
self.name: str = name.format(rank=str(rank).zfill(len(str(world_size))))
|
319
|
+
self.logfile: Optional[str] = (
|
320
|
+
None
|
321
|
+
if context.log_format is None
|
322
|
+
else context.log_format.format(name=self.name)
|
323
|
+
)
|
324
|
+
self._pid = None
|
325
|
+
self._returncode = None
|
326
|
+
self._state = "launched"
|
327
|
+
self._filter_obj = None
|
328
|
+
|
329
|
+
@property
|
330
|
+
def returncode(self) -> Optional[int]:
|
331
|
+
return self._returncode
|
332
|
+
|
333
|
+
@property
|
334
|
+
def pid(self) -> Optional[int]:
|
335
|
+
return self._pid
|
336
|
+
|
337
|
+
def __repr__(self) -> str:
|
338
|
+
return f"Process(rank={self.rank}, host={self.host}, pid={self.pid})"
|
339
|
+
|
340
|
+
def _lost_host(self) -> None:
|
341
|
+
self._abort(ConnectionAbortedError("Lost connection to process host"))
|
342
|
+
|
343
|
+
def _abort(self, e: Exception) -> None:
|
344
|
+
if self._state in ["launched", "running"]:
|
345
|
+
self._exit_message(e)
|
346
|
+
self._state = "aborted"
|
347
|
+
|
348
|
+
def send(self, msg: object) -> None:
|
349
|
+
msg = pickle_dumps(msg)
|
350
|
+
self._context._schedule(lambda: self._send(msg))
|
351
|
+
|
352
|
+
def _send(self, msg: bytes) -> None:
|
353
|
+
if self._state != "aborted":
|
354
|
+
self._context._sends += 1
|
355
|
+
self.host._send(pickle_dumps(("send", self._id, msg)))
|
356
|
+
|
357
|
+
def signal(self, signal: int = signal.SIGTERM, group: bool = True) -> None:
|
358
|
+
self._context._schedule(lambda: self._signal(signal, group))
|
359
|
+
|
360
|
+
def _signal(self, signal: int, group: bool) -> None:
|
361
|
+
if self._state != "aborted":
|
362
|
+
self.host._send(pickle_dumps(("signal", self._id, signal, group)))
|
363
|
+
|
364
|
+
def _exited(self, returncode: int) -> None:
|
365
|
+
self._state = "exited"
|
366
|
+
self._returncode = returncode
|
367
|
+
self._exit_message(returncode)
|
368
|
+
self._context._exits += 1
|
369
|
+
|
370
|
+
def _exit_message(self, returncode: Union[int, Exception]) -> None:
|
371
|
+
self.host._proc_table.pop(self._id)
|
372
|
+
self._context._produce_message(self, ProcessExited(returncode))
|
373
|
+
|
374
|
+
def _started(self, pid: Union[str, int]) -> None:
|
375
|
+
if isinstance(pid, int):
|
376
|
+
self._state = "running"
|
377
|
+
self._context._produce_message(self, ProcessStarted(pid))
|
378
|
+
self._pid = pid
|
379
|
+
self._context._starts += 1
|
380
|
+
else:
|
381
|
+
self._abort(ProcessFailedToStart(pid))
|
382
|
+
|
383
|
+
def _response(self, msg: bytes) -> None:
|
384
|
+
unpickled: NamedTuple = pickle_loads(msg)
|
385
|
+
self._context._produce_message(self, unpickled)
|
386
|
+
|
387
|
+
def __del__(self) -> None:
|
388
|
+
self._context._proc_deletes += 1
|
389
|
+
|
390
|
+
|
391
|
+
def _get_hostname_if_exists(msg: bytes) -> Optional[str]:
|
392
|
+
"""
|
393
|
+
Get's hostname from zmq message if it exists for logging to Connection
|
394
|
+
"""
|
395
|
+
if not len(msg):
|
396
|
+
return None
|
397
|
+
try:
|
398
|
+
cmd, _, hostname = pickle_loads(msg)
|
399
|
+
if cmd != "_hostname" or not isinstance(hostname, str):
|
400
|
+
return None
|
401
|
+
return hostname
|
402
|
+
except Exception:
|
403
|
+
return None
|
404
|
+
|
405
|
+
|
406
|
+
class Status(NamedTuple):
|
407
|
+
launches: int
|
408
|
+
starts: int
|
409
|
+
exits: int
|
410
|
+
sends: int
|
411
|
+
responses: int
|
412
|
+
process_deletes: int
|
413
|
+
unassigned_hosts: int
|
414
|
+
unassigned_connections: int
|
415
|
+
poll_percentage: float
|
416
|
+
active_percentage: float
|
417
|
+
heartbeats: int
|
418
|
+
heartbeat_average_ttl: float
|
419
|
+
heartbeat_min_ttl: float
|
420
|
+
connection_histogram: Dict[str, int]
|
421
|
+
avg_event_loop_time: float
|
422
|
+
max_event_loop_time: float
|
423
|
+
|
424
|
+
|
425
|
+
class Letter(NamedTuple):
|
426
|
+
sender: Union[Host, Process, None]
|
427
|
+
message: Any
|
428
|
+
|
429
|
+
|
430
|
+
class HostConnected(NamedTuple):
|
431
|
+
hostname: str
|
432
|
+
|
433
|
+
|
434
|
+
class FilteredMessageQueue(ABC):
|
435
|
+
def __init__(self) -> None:
|
436
|
+
self._client_queue: deque[Letter] = deque()
|
437
|
+
self._filter: Optional["Filter"] = None
|
438
|
+
|
439
|
+
@abstractmethod
|
440
|
+
def _read_messages(self, timeout: Optional[float]) -> List[Letter]: ...
|
441
|
+
|
442
|
+
def _set_filter_to(self, new) -> None:
|
443
|
+
old = self._filter
|
444
|
+
if new is old: # None None or f f
|
445
|
+
return
|
446
|
+
self._filter = new
|
447
|
+
if old is not None:
|
448
|
+
self._client_queue.rotate(old._cursor)
|
449
|
+
if new is not None: # None f, or f None, or f f'
|
450
|
+
new._cursor = 0
|
451
|
+
|
452
|
+
def _next_message(self, timeout) -> Optional[Letter]:
|
453
|
+
# return the first message that passes self._filter, starting at self._filter._cursor
|
454
|
+
queue = self._client_queue
|
455
|
+
if self._filter is None:
|
456
|
+
if queue:
|
457
|
+
return queue.popleft()
|
458
|
+
messages = self._read_messages(timeout)
|
459
|
+
if not messages:
|
460
|
+
return None
|
461
|
+
head, *rest = messages
|
462
|
+
queue.extend(rest)
|
463
|
+
return head
|
464
|
+
else:
|
465
|
+
filter = self._filter
|
466
|
+
filter_fn = filter._fn
|
467
|
+
for i in range(filter._cursor, len(queue)):
|
468
|
+
if filter_fn(queue[0]):
|
469
|
+
filter._cursor = i
|
470
|
+
return queue.popleft()
|
471
|
+
queue.rotate(-1)
|
472
|
+
if timeout is None:
|
473
|
+
while True:
|
474
|
+
messages = self._read_messages(None)
|
475
|
+
for i, msg in enumerate(messages):
|
476
|
+
if filter_fn(msg):
|
477
|
+
filter._cursor = len(queue)
|
478
|
+
queue.extendleft(messages[-1:i:-1])
|
479
|
+
return msg
|
480
|
+
queue.append(msg)
|
481
|
+
else:
|
482
|
+
t = time.time()
|
483
|
+
expiry = t + timeout
|
484
|
+
while t <= expiry:
|
485
|
+
messages = self._read_messages(expiry - t)
|
486
|
+
if not messages:
|
487
|
+
break
|
488
|
+
for i, msg in enumerate(messages):
|
489
|
+
if filter_fn(msg):
|
490
|
+
filter._cursor = len(queue)
|
491
|
+
queue.extendleft(messages[-1:i:-1])
|
492
|
+
return msg
|
493
|
+
queue.append(msg)
|
494
|
+
t = time.time()
|
495
|
+
filter._cursor = len(queue)
|
496
|
+
return None
|
497
|
+
|
498
|
+
def recv(self, timeout: Optional[float] = None, _filter=None) -> Letter:
|
499
|
+
self._set_filter_to(_filter)
|
500
|
+
msg = self._next_message(timeout)
|
501
|
+
if msg is None:
|
502
|
+
raise TimeoutError()
|
503
|
+
return msg
|
504
|
+
|
505
|
+
def recvloop(self, timeout=None):
|
506
|
+
while True:
|
507
|
+
yield self.recv(timeout)
|
508
|
+
|
509
|
+
def recvready(self, timeout: Optional[float] = 0, _filter=None) -> List[Letter]:
|
510
|
+
self._set_filter_to(_filter)
|
511
|
+
result = []
|
512
|
+
append = result.append
|
513
|
+
next_message = self._next_message
|
514
|
+
msg = next_message(timeout)
|
515
|
+
while msg is not None:
|
516
|
+
append(msg)
|
517
|
+
msg = next_message(0)
|
518
|
+
return result
|
519
|
+
|
520
|
+
def messagefilter(self, fn):
|
521
|
+
if isinstance(fn, (tuple, type)):
|
522
|
+
|
523
|
+
def wrapped(msg):
|
524
|
+
return isinstance(msg.message, fn)
|
525
|
+
|
526
|
+
else:
|
527
|
+
wrapped = fn
|
528
|
+
return Filter(self, wrapped)
|
529
|
+
|
530
|
+
|
531
|
+
class Filter:
|
532
|
+
def __init__(
|
533
|
+
self, context: FilteredMessageQueue, fn: Callable[[Letter], bool]
|
534
|
+
) -> None:
|
535
|
+
self._context = context
|
536
|
+
self._fn = fn
|
537
|
+
self._cursor = 0
|
538
|
+
|
539
|
+
def recv(self, timeout=None):
|
540
|
+
return self._context.recv(timeout, _filter=self)
|
541
|
+
|
542
|
+
def recvloop(self, timeout=None):
|
543
|
+
while True:
|
544
|
+
yield self.recv(timeout)
|
545
|
+
|
546
|
+
def recvready(self, timeout=0):
|
547
|
+
return self._context.recvready(timeout, _filter=self)
|
548
|
+
|
549
|
+
|
550
|
+
def TTL(timeout: Optional[float]) -> Callable[[], float]:
|
551
|
+
if timeout is None:
|
552
|
+
return lambda: math.inf
|
553
|
+
expiry = time.time() + timeout
|
554
|
+
return lambda: max(expiry - time.time(), 0)
|
555
|
+
|
556
|
+
|
557
|
+
class ProcessList(tuple):
|
558
|
+
def send(self, msg: Any) -> None:
|
559
|
+
if not self:
|
560
|
+
return
|
561
|
+
ctx = self[0]._context
|
562
|
+
msg = pickle_dumps(msg)
|
563
|
+
ctx._schedule(lambda: self._send(msg))
|
564
|
+
|
565
|
+
def _send(self, msg: bytes) -> None:
|
566
|
+
for p in self:
|
567
|
+
p._send(msg)
|
568
|
+
|
569
|
+
def __getitem__(self, index):
|
570
|
+
result = super().__getitem__(index)
|
571
|
+
# If the index is a slice, convert the result to MyTuple
|
572
|
+
if isinstance(index, slice):
|
573
|
+
return ProcessList(result)
|
574
|
+
return result
|
575
|
+
|
576
|
+
|
577
|
+
class Context(FilteredMessageQueue):
|
578
|
+
def __init__(
|
579
|
+
self,
|
580
|
+
port: Optional[int] = None,
|
581
|
+
log_format: Optional[str] = None,
|
582
|
+
log_interval: float = LOG_INTERVAL,
|
583
|
+
) -> None:
|
584
|
+
super().__init__()
|
585
|
+
if log_format is not None:
|
586
|
+
path = log_format.format(name="supervisor")
|
587
|
+
logger.warning(f"Redirect logging to {path}")
|
588
|
+
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
589
|
+
with open(path, "w") as f:
|
590
|
+
os.dup2(f.fileno(), sys.stdout.fileno())
|
591
|
+
os.dup2(f.fileno(), sys.stderr.fileno())
|
592
|
+
|
593
|
+
self._log_interval: float = log_interval
|
594
|
+
self._context: zmq.Context = zmq.Context(1)
|
595
|
+
|
596
|
+
# to talk to python clients in this process
|
597
|
+
self._requests: deque[Callable[[], None]] = deque()
|
598
|
+
self._delivered_messages: deque[List[Letter]] = deque()
|
599
|
+
self._delivered_messages_entry: List[Letter] = []
|
600
|
+
self._requests_ready: zmq.Socket = self._socket(zmq.PAIR)
|
601
|
+
|
602
|
+
self._requests_ready.bind("inproc://doorbell")
|
603
|
+
self._doorbell: zmq.Socket = self._socket(zmq.PAIR)
|
604
|
+
self._doorbell.connect("inproc://doorbell")
|
605
|
+
self._doorbell_poller = zmq.Poller()
|
606
|
+
self._doorbell_poller.register(self._doorbell, zmq.POLLIN)
|
607
|
+
|
608
|
+
self._backend: zmq.Socket = self._socket(zmq.ROUTER)
|
609
|
+
self._backend.setsockopt(zmq.IPV6, True)
|
610
|
+
if port is None:
|
611
|
+
# Specify a min and max port range; the default min/max triggers a
|
612
|
+
# codepath in zmq that is vulnerable to races between ephemeral port
|
613
|
+
# acqusition and last_endpoint being available.
|
614
|
+
self.port = self._backend.bind_to_random_port("tcp://*", 49153, 65536)
|
615
|
+
else:
|
616
|
+
self._backend.bind(f"tcp://*:{port}")
|
617
|
+
self.port = port
|
618
|
+
|
619
|
+
self._poller = zmq.Poller()
|
620
|
+
self._poller.register(self._backend, zmq.POLLIN)
|
621
|
+
self._poller.register(self._requests_ready, zmq.POLLIN)
|
622
|
+
|
623
|
+
self._unassigned_hosts: deque[Host] = deque()
|
624
|
+
self._unassigned_connections: deque[Connection] = deque()
|
625
|
+
self._name_to_connection: Dict[bytes, Connection] = {}
|
626
|
+
self._last_heartbeat_check: float = time.time()
|
627
|
+
self._last_logstatus: float = self._last_heartbeat_check
|
628
|
+
self._next_id = 0
|
629
|
+
self._exits = 0
|
630
|
+
self._sends = 0
|
631
|
+
self._responses = 0
|
632
|
+
self._launches = 0
|
633
|
+
self._starts = 0
|
634
|
+
self._proc_deletes = 0
|
635
|
+
self._reset_heartbeat_stats()
|
636
|
+
|
637
|
+
self._exit_event_loop = False
|
638
|
+
self._pg_name = 0
|
639
|
+
self.log_format = log_format
|
640
|
+
self.log_status = lambda status: None
|
641
|
+
|
642
|
+
self._thread = Thread(target=self._event_loop, daemon=True)
|
643
|
+
self._thread.start()
|
644
|
+
|
645
|
+
def _socket(self, kind: int) -> zmq.Socket:
|
646
|
+
socket = self._context.socket(kind)
|
647
|
+
socket.setsockopt(zmq.SNDHWM, 0)
|
648
|
+
socket.setsockopt(zmq.RCVHWM, 0)
|
649
|
+
return socket
|
650
|
+
|
651
|
+
def _attach(self) -> None:
|
652
|
+
while self._unassigned_connections and self._unassigned_hosts:
|
653
|
+
c = self._unassigned_connections[0]
|
654
|
+
h = self._unassigned_hosts[0]
|
655
|
+
if c.state is _LOST:
|
656
|
+
self._unassigned_connections.popleft()
|
657
|
+
elif h._state is _LOST:
|
658
|
+
self._unassigned_hosts.popleft()
|
659
|
+
else:
|
660
|
+
self._unassigned_connections.popleft()
|
661
|
+
self._unassigned_hosts.popleft()
|
662
|
+
c.host = h
|
663
|
+
h._name = c.name
|
664
|
+
assert c.hostname is not None
|
665
|
+
h._context._produce_message(h, HostConnected(c.hostname))
|
666
|
+
h._hostname = c.hostname
|
667
|
+
h._state = c.state = _ATTACHED
|
668
|
+
for msg in h._deferred_sends:
|
669
|
+
self._backend.send_multipart([h._name, msg])
|
670
|
+
h._deferred_sends.clear()
|
671
|
+
|
672
|
+
def _event_loop(self) -> None:
|
673
|
+
_time_poll = 0
|
674
|
+
_time_process = 0
|
675
|
+
_time_loop = 0
|
676
|
+
_max_time_loop = -1
|
677
|
+
_num_loops = 0
|
678
|
+
while True:
|
679
|
+
time_begin = time.time()
|
680
|
+
poll_result = self._poller.poll(timeout=int(HEARTBEAT_INTERVAL * 1000))
|
681
|
+
time_poll = time.time()
|
682
|
+
|
683
|
+
for sock, _ in poll_result:
|
684
|
+
# known host managers
|
685
|
+
if sock is self._backend:
|
686
|
+
f, msg = self._backend.recv_multipart()
|
687
|
+
if f not in self._name_to_connection:
|
688
|
+
hostname = _get_hostname_if_exists(msg)
|
689
|
+
connection = self._name_to_connection[f] = Connection(
|
690
|
+
self, f, hostname
|
691
|
+
)
|
692
|
+
self._unassigned_connections.append(connection)
|
693
|
+
self._attach()
|
694
|
+
else:
|
695
|
+
self._name_to_connection[f].handle_message(self, msg)
|
696
|
+
elif sock is self._requests_ready:
|
697
|
+
ttl = TTL(HEARTBEAT_INTERVAL / 2)
|
698
|
+
while self._requests and ttl() > 0:
|
699
|
+
# pyre-ignore[29]
|
700
|
+
self._requests_ready.recv()
|
701
|
+
fn = self._requests.popleft()
|
702
|
+
fn()
|
703
|
+
del fn # otherwise we hold a handle until
|
704
|
+
# the next time we run a command
|
705
|
+
if self._exit_event_loop:
|
706
|
+
return
|
707
|
+
t = time.time()
|
708
|
+
if t - time_begin > HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS:
|
709
|
+
logger.warning(
|
710
|
+
f"Main poll took too long! ({t - time_begin} > {HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS} seconds). Host managers will think we are dead."
|
711
|
+
)
|
712
|
+
elapsed = t - self._last_heartbeat_check
|
713
|
+
should_check_heartbeat = elapsed > HEARTBEAT_INTERVAL * HEARTBEAT_LIVENESS
|
714
|
+
if should_check_heartbeat:
|
715
|
+
self._last_heartbeat_check = t
|
716
|
+
# priority queue would be log(N)
|
717
|
+
for connection in self._name_to_connection.values():
|
718
|
+
connection.check_alive_at(self, t)
|
719
|
+
|
720
|
+
# Marking futures ready should always happen at the end of processing events above
|
721
|
+
# to unblock anything processing the futures, before we start waiting for more events.
|
722
|
+
if self._delivered_messages_entry:
|
723
|
+
self._delivered_messages.append(self._delivered_messages_entry)
|
724
|
+
self._delivered_messages_entry = []
|
725
|
+
self._requests_ready.send(b"")
|
726
|
+
time_end = time.time()
|
727
|
+
_time_poll += time_poll - time_begin
|
728
|
+
_time_process += time_end - time_poll
|
729
|
+
_time_loop += time_end - time_begin
|
730
|
+
_max_time_loop = max(_max_time_loop, time_end - time_begin)
|
731
|
+
_num_loops += 1
|
732
|
+
|
733
|
+
elapsed = t - self._last_logstatus
|
734
|
+
if elapsed > self._log_interval:
|
735
|
+
self._last_logstatus = t
|
736
|
+
self._logstatus(
|
737
|
+
_time_poll / elapsed,
|
738
|
+
_time_process / elapsed,
|
739
|
+
_time_loop / _num_loops,
|
740
|
+
_max_time_loop,
|
741
|
+
)
|
742
|
+
_time_poll = 0
|
743
|
+
_time_process = 0
|
744
|
+
_time_loop = 0
|
745
|
+
_max_time_loop = -1
|
746
|
+
_num_loops = 0
|
747
|
+
|
748
|
+
def _logstatus(
|
749
|
+
self,
|
750
|
+
poll_fraction: float,
|
751
|
+
active_fraction: float,
|
752
|
+
avg_event_loop_time: float,
|
753
|
+
max_event_loop_time: float,
|
754
|
+
) -> None:
|
755
|
+
connection_histogram = {}
|
756
|
+
for connection in self._name_to_connection.values():
|
757
|
+
state = connection.state.name
|
758
|
+
connection_histogram[state] = connection_histogram.setdefault(state, 0) + 1
|
759
|
+
|
760
|
+
status = Status(
|
761
|
+
self._launches,
|
762
|
+
self._starts,
|
763
|
+
self._exits,
|
764
|
+
self._sends,
|
765
|
+
self._responses,
|
766
|
+
self._proc_deletes,
|
767
|
+
len(self._unassigned_hosts),
|
768
|
+
len(self._unassigned_connections),
|
769
|
+
poll_fraction * 100,
|
770
|
+
active_fraction * 100,
|
771
|
+
self._heartbeats,
|
772
|
+
self._heartbeat_ttl_sum / self._heartbeats if self._heartbeats else 0,
|
773
|
+
self._heartbeat_min_ttl,
|
774
|
+
connection_histogram,
|
775
|
+
avg_event_loop_time,
|
776
|
+
max_event_loop_time,
|
777
|
+
)
|
778
|
+
self._reset_heartbeat_stats()
|
779
|
+
|
780
|
+
logger.info(
|
781
|
+
"supervisor status: %s process launches, %s starts, %s exits, %s message sends, %s message responses,"
|
782
|
+
" %s process __del__, %s hosts waiting for connections, %s connections waiting for handles,"
|
783
|
+
" time is %.2f%% polling and %.2f%% active, heartbeats %s, heartbeat_avg_ttl %.4f,"
|
784
|
+
" heartbeat_min_ttl %.4f, connections %s, avg_event_loop_time %.4f seconds, max_event_loop_time %.4f seconds",
|
785
|
+
*status,
|
786
|
+
)
|
787
|
+
self.log_status(status)
|
788
|
+
|
789
|
+
def _heartbeat_ttl(self, ttl: float) -> None:
|
790
|
+
# Updates heartbeat stats with most recent ttl
|
791
|
+
self._heartbeats += 1
|
792
|
+
self._heartbeat_ttl_sum += ttl
|
793
|
+
self._heartbeat_min_ttl = min(self._heartbeat_min_ttl, ttl)
|
794
|
+
|
795
|
+
def _reset_heartbeat_stats(self) -> None:
|
796
|
+
self._heartbeats = 0
|
797
|
+
self._heartbeat_ttl_sum = 0
|
798
|
+
self._heartbeat_min_ttl = sys.maxsize
|
799
|
+
|
800
|
+
def _schedule(self, fn: Callable[[], None]) -> None:
|
801
|
+
self._requests.append(fn)
|
802
|
+
self._doorbell.send(b"")
|
803
|
+
|
804
|
+
def request_hosts(self, n: int) -> "Tuple[Host, ...]":
|
805
|
+
"""
|
806
|
+
Request from the scheduler n hosts to run processes on.
|
807
|
+
The future is fulfilled when the reservation is made, but
|
808
|
+
potenially before all the hosts check in with this API.
|
809
|
+
|
810
|
+
Note: implementations that use existing slurm-like schedulers,
|
811
|
+
will immediately full the future because the reservation was
|
812
|
+
already made.
|
813
|
+
"""
|
814
|
+
hosts = tuple(Host(self) for i in range(n))
|
815
|
+
self._schedule(lambda: self._request_hosts(hosts))
|
816
|
+
return hosts
|
817
|
+
|
818
|
+
def _request_host(self, h: Host) -> None:
|
819
|
+
self._unassigned_hosts.append(h)
|
820
|
+
self._attach()
|
821
|
+
|
822
|
+
def _request_hosts(self, hosts: Sequence[Host]) -> None:
|
823
|
+
for h in hosts:
|
824
|
+
self._request_host(h)
|
825
|
+
|
826
|
+
def return_hosts(self, hosts: Sequence[Host], error: Optional[str] = None) -> None:
|
827
|
+
"""
|
828
|
+
Processes on the returned hosts will be killed,
|
829
|
+
and future processes launches with the host will fail.
|
830
|
+
"""
|
831
|
+
self._schedule(lambda: self._return_hosts(hosts, error))
|
832
|
+
|
833
|
+
def _return_hosts(self, hosts: Sequence[Host], error: Optional[str]) -> None:
|
834
|
+
for h in hosts:
|
835
|
+
h._lost(error)
|
836
|
+
|
837
|
+
def replace_hosts(self, hosts: Sequence[Host]) -> "Tuple[Host, ...]":
|
838
|
+
"""
|
839
|
+
Request that these hosts be replaced with new hosts.
|
840
|
+
Processes on the host will be killed, and future processes
|
841
|
+
launches will be launched on the new hosts.
|
842
|
+
"""
|
843
|
+
# if the host is disconnected, return it to the pool of unused hosts
|
844
|
+
# and we hope that scheduler has replaced the job
|
845
|
+
# if the host is still connected, then send the host a message
|
846
|
+
# then cancel is processes and abort with an error to get the
|
847
|
+
# the scheduler to reassign the host
|
848
|
+
hosts = list(hosts)
|
849
|
+
self.return_hosts(hosts, "supervisor requested replacement")
|
850
|
+
return self.request_hosts(len(hosts))
|
851
|
+
|
852
|
+
def _shutdown(self) -> None:
|
853
|
+
self._exit_event_loop = True
|
854
|
+
for connection in self._name_to_connection.values():
|
855
|
+
connection.lost(self, None)
|
856
|
+
|
857
|
+
def shutdown(self) -> None:
|
858
|
+
self._schedule(self._shutdown)
|
859
|
+
self._thread.join()
|
860
|
+
self._backend.close()
|
861
|
+
self._requests_ready.close()
|
862
|
+
self._doorbell.close()
|
863
|
+
self._context.term()
|
864
|
+
|
865
|
+
# TODO: other arguments like environment, etc.
|
866
|
+
def create_process_group(
|
867
|
+
self,
|
868
|
+
hosts: Sequence[Host],
|
869
|
+
args: Union["_FunctionCall", Sequence[str]],
|
870
|
+
processes_per_host: int = 1,
|
871
|
+
env: Optional[Dict[str, str]] = None,
|
872
|
+
cwd: Optional[str] = None,
|
873
|
+
name: Optional[str] = None,
|
874
|
+
simulate: bool = False,
|
875
|
+
) -> ProcessList:
|
876
|
+
world_size = processes_per_host * len(hosts)
|
877
|
+
if name is None:
|
878
|
+
name = f"pg{self._pg_name}"
|
879
|
+
self._pg_name += 1
|
880
|
+
logger.info(
|
881
|
+
"Starting process group %r with %d processes (%s hosts * %s processes per host)",
|
882
|
+
name,
|
883
|
+
world_size,
|
884
|
+
len(hosts),
|
885
|
+
processes_per_host,
|
886
|
+
)
|
887
|
+
popen = {"args": args, "env": env, "cwd": cwd}
|
888
|
+
procs = ProcessList(
|
889
|
+
Process(
|
890
|
+
self,
|
891
|
+
h,
|
892
|
+
i * processes_per_host + j,
|
893
|
+
processes_per_host,
|
894
|
+
world_size,
|
895
|
+
popen,
|
896
|
+
name,
|
897
|
+
simulate,
|
898
|
+
)
|
899
|
+
for i, h in enumerate(hosts)
|
900
|
+
for j in range(processes_per_host)
|
901
|
+
)
|
902
|
+
self._schedule(lambda: self._launch_processes(procs))
|
903
|
+
return procs
|
904
|
+
|
905
|
+
def _launch_processes(self, procs: Sequence[Process]) -> None:
|
906
|
+
for p in procs:
|
907
|
+
p.host._launch(p)
|
908
|
+
|
909
|
+
def _produce_message(
|
910
|
+
self, sender: Union[Host, Process], message: NamedTuple
|
911
|
+
) -> None:
|
912
|
+
self._delivered_messages_entry.append(Letter(sender, message))
|
913
|
+
|
914
|
+
def _read_messages(self, timeout: Optional[float]) -> List[Letter]:
|
915
|
+
if timeout is not None and not self._doorbell_poller.poll(
|
916
|
+
timeout=int(1000 * timeout)
|
917
|
+
):
|
918
|
+
return []
|
919
|
+
# pyre-ignore[29]
|
920
|
+
self._doorbell.recv()
|
921
|
+
return self._delivered_messages.popleft()
|
922
|
+
|
923
|
+
|
924
|
+
@cache
|
925
|
+
def get_message_queue(
|
926
|
+
supervisor_ident: Optional[int] = None, supervisor_pipe: Optional[str] = None
|
927
|
+
) -> "LocalMessageQueue":
|
928
|
+
"""
|
929
|
+
Processes launched on the hosts can use this function to connect
|
930
|
+
to the messaging queue of the supervisor.
|
931
|
+
|
932
|
+
Messages send from here can be received by the supervisor using
|
933
|
+
`proc.recv()` and messages from proc.send() will appear in this queue.
|
934
|
+
"""
|
935
|
+
if supervisor_ident is None:
|
936
|
+
supervisor_ident = int(os.environ["SUPERVISOR_IDENT"])
|
937
|
+
if supervisor_pipe is None:
|
938
|
+
supervisor_pipe = os.environ["SUPERVISOR_PIPE"]
|
939
|
+
|
940
|
+
return LocalMessageQueue(supervisor_ident, supervisor_pipe)
|
941
|
+
|
942
|
+
|
943
|
+
class LocalMessageQueue(FilteredMessageQueue):
|
944
|
+
"""
|
945
|
+
Used by processes launched on the host to communicate with the supervisor.
|
946
|
+
Also used as the pipe between main worker process and pipe process with worker pipes.
|
947
|
+
"""
|
948
|
+
|
949
|
+
def __init__(self, supervisor_ident: int, supervisor_pipe: str) -> None:
|
950
|
+
super().__init__()
|
951
|
+
self._ctx = zmq.Context(1)
|
952
|
+
self._sock = self._socket(zmq.DEALER)
|
953
|
+
proc_id = supervisor_ident.to_bytes(8, byteorder="little")
|
954
|
+
self._sock.setsockopt(zmq.IDENTITY, proc_id)
|
955
|
+
self._sock.connect(supervisor_pipe)
|
956
|
+
self._sock.send(b"")
|
957
|
+
self._poller = zmq.Poller()
|
958
|
+
self._poller.register(self._sock, zmq.POLLIN)
|
959
|
+
self._async_socket: Optional[zmq.asyncio.Socket] = None
|
960
|
+
|
961
|
+
def _socket(self, kind) -> zmq.Socket:
|
962
|
+
sock = self._ctx.socket(kind)
|
963
|
+
sock.setsockopt(zmq.SNDHWM, 0)
|
964
|
+
sock.setsockopt(zmq.RCVHWM, 0)
|
965
|
+
return sock
|
966
|
+
|
967
|
+
def _read_messages(self, timeout: Optional[float]) -> List[Letter]:
|
968
|
+
if timeout is not None and not self._poller.poll(timeout=int(1000 * timeout)):
|
969
|
+
return []
|
970
|
+
return [Letter(None, self._sock.recv_pyobj())]
|
971
|
+
|
972
|
+
async def recv_async(self) -> Letter:
|
973
|
+
if self._async_socket is None:
|
974
|
+
self._async_socket = zmq.asyncio.Socket.from_socket(self._sock)
|
975
|
+
return Letter(None, await self._async_socket.recv_pyobj())
|
976
|
+
|
977
|
+
def send(self, message: Any) -> None:
|
978
|
+
self._sock.send_pyobj(message)
|
979
|
+
|
980
|
+
def close(self) -> None:
|
981
|
+
self._sock.close()
|
982
|
+
self._ctx.term()
|
983
|
+
|
984
|
+
|
985
|
+
class _FunctionCall(NamedTuple):
|
986
|
+
target: str
|
987
|
+
args: Tuple[str]
|
988
|
+
kwargs: Dict[str, str]
|
989
|
+
|
990
|
+
|
991
|
+
def FunctionCall(target: str, *args, **kwargs) -> _FunctionCall:
|
992
|
+
if target.startswith("__main__."):
|
993
|
+
file = sys.modules["__main__"].__file__
|
994
|
+
sys.modules["__entry__"] = sys.modules["__main__"]
|
995
|
+
target = f'{file}:{target.split(".", 1)[1]}'
|
996
|
+
return _FunctionCall(target, args, kwargs)
|
997
|
+
|
998
|
+
|
999
|
+
# [Threading Model]
|
1000
|
+
# The supervisor policy script runs in the main thread,
|
1001
|
+
# and there is a separate _event_loop thread launched by the
|
1002
|
+
# Context object for managing messages from host managers.
|
1003
|
+
# Context, Host, and Process objects get created on the
|
1004
|
+
# main thread, and their public APIs contain read
|
1005
|
+
# only parameters. Private members should only be read/
|
1006
|
+
# written from the event loop for these objects.
|
1007
|
+
# The context._schedule provides a way to schedule
|
1008
|
+
# a function to run on the event loop from the main thread.
|
1009
|
+
# The only observable changes from the main thread go through
|
1010
|
+
# future objects.
|
1011
|
+
# The _event_loop maintains a list of
|
1012
|
+
# futures to be marked finished (_finished_futures_entry) which it will
|
1013
|
+
# be sent to the main thread at the end of one event loop iteration to
|
1014
|
+
# actually mutate the future to be completed, and run callbacks.
|
1015
|
+
# _finished_futures, _request_ready, and Future instances are the only
|
1016
|
+
# objects the main thread should mutate after they are created.
|
1017
|
+
|
1018
|
+
# [zeromq Background]
|
1019
|
+
# We use zeromq to make high-throughput messaging possible despite
|
1020
|
+
# using Python for the event loop. There are a few differences from traditional
|
1021
|
+
# tcp sockets that are important:
|
1022
|
+
# * A traditional 'socket' is a connection that can be read or written
|
1023
|
+
# and is connected to exactly one other socket. In zeromq sockets
|
1024
|
+
# can be connected to multiple other sockets. For instance,
|
1025
|
+
# context._backend is connect to _all_ host managers.
|
1026
|
+
# * To connect sockets traditionally, one side listens on a listener socket
|
1027
|
+
# for a new connection with 'bind'/'listen', and the other side 'connect's its socket.
|
1028
|
+
# This creates another socket on the 'bind' side (three total sockets). Bind and listen
|
1029
|
+
# must happen before 'connect' or the connection will be refused. In zeromq, any socket can
|
1030
|
+
# bind or connect. A socket that binds can be connected to many others if multiple other sockets connect to it.
|
1031
|
+
# A socket can also connect itself to multiple other sockets by calling connect multiple times
|
1032
|
+
# (we do not use this here).
|
1033
|
+
# Connect can come before bind, zeromq will retry if the bind-ing process is not yet there.
|
1034
|
+
# * When sockets are connected to multiple others, we have to define what it means to send
|
1035
|
+
# or receive. This is configured when creating a socket. zmq.PAIR asserts there will only
|
1036
|
+
# be a single other sockets, and behaves like a traditional socket. zmq.DEALER sends data
|
1037
|
+
# by round robining (dealing) each message to the sockets its connected to. A receive will
|
1038
|
+
# get a message from any of the incoming sockets. zmq.ROUTER receives a message from one of its
|
1039
|
+
# connections and _prefixes_ it with the identity of the socket that sent it (recv_multipart) is
|
1040
|
+
# used to get a list of message parts and find this identity. When sending a message with zmq.ROUTER,
|
1041
|
+
# the first part must be an identity, and it will send (route) the message to the connection with
|
1042
|
+
# that identity.
|
1043
|
+
# The zeromq guide has more information, but this implementation is intentially only use the above
|
1044
|
+
# features to make it easier to use a different message broker if needed.
|