torchmonarch-nightly 2025.6.27__cp311-cp311-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1402 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
|
9
|
+
import contextlib
|
10
|
+
import importlib.resources
|
11
|
+
import logging
|
12
|
+
import os
|
13
|
+
import random
|
14
|
+
import re
|
15
|
+
import select
|
16
|
+
import socket
|
17
|
+
import string
|
18
|
+
import subprocess
|
19
|
+
import sys
|
20
|
+
import tempfile
|
21
|
+
import threading
|
22
|
+
import time
|
23
|
+
import uuid
|
24
|
+
from enum import Enum
|
25
|
+
from pathlib import Path
|
26
|
+
from types import TracebackType
|
27
|
+
from typing import (
|
28
|
+
Callable,
|
29
|
+
Collection,
|
30
|
+
Dict,
|
31
|
+
Generator,
|
32
|
+
List,
|
33
|
+
NamedTuple,
|
34
|
+
Optional,
|
35
|
+
TextIO,
|
36
|
+
Tuple,
|
37
|
+
Type,
|
38
|
+
TypeVar,
|
39
|
+
)
|
40
|
+
|
41
|
+
from monarch._rust_bindings.controller.bootstrap import (
|
42
|
+
ControllerCommand,
|
43
|
+
ControllerServerRequest,
|
44
|
+
ControllerServerResponse,
|
45
|
+
RunCommand,
|
46
|
+
)
|
47
|
+
|
48
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
49
|
+
ActorId,
|
50
|
+
)
|
51
|
+
|
52
|
+
from monarch._rust_bindings.monarch_tensor_worker.bootstrap import (
|
53
|
+
WorkerServerRequest,
|
54
|
+
WorkerServerResponse,
|
55
|
+
)
|
56
|
+
|
57
|
+
from monarch.common.device_mesh import DeviceMesh
|
58
|
+
from monarch.common.fake import fake_call
|
59
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
60
|
+
from monarch.rust_backend_mesh import (
|
61
|
+
IBootstrap,
|
62
|
+
MeshWorld,
|
63
|
+
PoolDeviceMeshProvider,
|
64
|
+
rust_backend_mesh_provider,
|
65
|
+
rust_backend_meshes,
|
66
|
+
)
|
67
|
+
|
68
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
69
|
+
_MONARCH_TENSOR_WORKER_MAIN = "monarch.tensor_worker_main"
|
70
|
+
|
71
|
+
try:
|
72
|
+
from __manifest__ import fbmake # noqa
|
73
|
+
|
74
|
+
IN_PAR = True
|
75
|
+
except ImportError:
|
76
|
+
IN_PAR = False
|
77
|
+
|
78
|
+
|
79
|
+
class SocketType(Enum):
|
80
|
+
"""Enum representing socket types."""
|
81
|
+
|
82
|
+
TCP = "tcp"
|
83
|
+
UNIX = "unix"
|
84
|
+
|
85
|
+
|
86
|
+
class LoggingLocation(Enum):
|
87
|
+
"""Enum representing where to flush stderr and stdout."""
|
88
|
+
|
89
|
+
DEFAULT = "default"
|
90
|
+
FILE = "file"
|
91
|
+
|
92
|
+
|
93
|
+
class SupervisionParams(NamedTuple):
|
94
|
+
# If system actor does not receive supervision update within this time,
|
95
|
+
# it will treate this proc as dead.
|
96
|
+
update_timeout_in_sec: int
|
97
|
+
# How often controller queries supervision status from system actor.
|
98
|
+
query_interval_in_sec: int
|
99
|
+
# How often proc actor sends supervision update to system actor.
|
100
|
+
update_interval_in_sec: int
|
101
|
+
|
102
|
+
|
103
|
+
class ControllerParams(NamedTuple):
|
104
|
+
# How often the controller will poll for operations that have not completed within a timeout duration
|
105
|
+
# indicating that it may be stuck.
|
106
|
+
worker_progress_check_interval_in_sec: int
|
107
|
+
|
108
|
+
# How long we will wait for an operation before letting the client know that it may be stuck.
|
109
|
+
operation_timeout_in_sec: int
|
110
|
+
|
111
|
+
# The number of operations invoked before we proactively check worker progress. If a large number
|
112
|
+
# of operations are invoked all at once, it is expected that it will take a while for all operations
|
113
|
+
# to complete so we want to inject progress requests at a higher frequency to check if we are making progress
|
114
|
+
operations_per_worker_progress_request: int
|
115
|
+
|
116
|
+
# If the controller should propagate a failure to the client if the workers become stuck.
|
117
|
+
fail_on_worker_timeout: bool
|
118
|
+
|
119
|
+
|
120
|
+
_PROC_ENV = {
|
121
|
+
"HYPERACTOR_MANAGED_SUBPROCESS": str(1),
|
122
|
+
}
|
123
|
+
|
124
|
+
|
125
|
+
def get_controller_main() -> tuple[Path, dict[str, str]]:
|
126
|
+
with (
|
127
|
+
importlib.resources.path("monarch", "monarch_controller") as controller_main,
|
128
|
+
):
|
129
|
+
if not controller_main.exists():
|
130
|
+
if IN_PAR:
|
131
|
+
raise ImportError(
|
132
|
+
"Monarch env not found, please define a custom 'monarch_env' or "
|
133
|
+
"add '//monarch/python/monarch:default_env-library' to your binary dependencies "
|
134
|
+
"in TARGETS"
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
raise ImportError(
|
138
|
+
"Monarch env not found, please re-run ./scripts/install.sh in fbcode/monarch"
|
139
|
+
)
|
140
|
+
env: dict[str, str] = {}
|
141
|
+
|
142
|
+
# Hack to make exploded wheel workflow work in the face of broken
|
143
|
+
# build-time RPATHs...
|
144
|
+
#
|
145
|
+
# If we're running under a conda env...
|
146
|
+
if not IN_PAR:
|
147
|
+
conda_prefix = os.environ.get("CONDA_PREFIX")
|
148
|
+
if conda_prefix is not None and sys.executable.startswith(
|
149
|
+
conda_prefix + "/"
|
150
|
+
):
|
151
|
+
# .. and Monarch is coming from "outside" the env, via `PYTHONPATH`s ...
|
152
|
+
spec = importlib.util.find_spec("monarch")
|
153
|
+
assert spec is not None
|
154
|
+
origin = spec.origin
|
155
|
+
assert origin is not None
|
156
|
+
monarch_root = str(Path(origin).parent.parent)
|
157
|
+
if (
|
158
|
+
not monarch_root.startswith(conda_prefix + "/")
|
159
|
+
and monarch_root in sys.path
|
160
|
+
):
|
161
|
+
import torch
|
162
|
+
|
163
|
+
# then assume we're running via exploded .whl, which means
|
164
|
+
# we need to manually set library paths to find the necessary
|
165
|
+
# native libs from the conda env.
|
166
|
+
env["LD_LIBRARY_PATH"] = ":".join(
|
167
|
+
[
|
168
|
+
os.path.join(os.path.dirname(torch.__file__), "lib"),
|
169
|
+
os.path.join(conda_prefix, "lib"),
|
170
|
+
]
|
171
|
+
)
|
172
|
+
|
173
|
+
return controller_main, env
|
174
|
+
|
175
|
+
|
176
|
+
def _create_logging_locations(
|
177
|
+
logging_dir: str, name: str, logging_location: LoggingLocation
|
178
|
+
) -> tuple[TextIO | None, TextIO | None]:
|
179
|
+
if logging_location == LoggingLocation.FILE:
|
180
|
+
stdout_file: TextIO = open(os.path.join(logging_dir, f"{name}.stdout"), "a+")
|
181
|
+
stderr_file: TextIO = open(os.path.join(logging_dir, f"{name}.stderr"), "a+")
|
182
|
+
return stdout_file, stderr_file
|
183
|
+
elif logging_location == LoggingLocation.DEFAULT:
|
184
|
+
return None, None
|
185
|
+
else:
|
186
|
+
raise ValueError(f"Unknown logging location: {logging_location}")
|
187
|
+
|
188
|
+
|
189
|
+
def _get_labels(flag_name: str, labels: Dict[str, str]) -> List[str]:
|
190
|
+
params = []
|
191
|
+
for k, v in labels.items():
|
192
|
+
assert k not in params, f"Duplicate label: {k}"
|
193
|
+
assert "=" not in k, f"Key cannot contain '=': {k}"
|
194
|
+
params.append(f"--{flag_name}")
|
195
|
+
params.append(f"{k}={v}")
|
196
|
+
return params
|
197
|
+
|
198
|
+
|
199
|
+
def _start_worker_cmd(
|
200
|
+
*,
|
201
|
+
world_uuid: str,
|
202
|
+
worker_rank: int,
|
203
|
+
gpus_per_host: int,
|
204
|
+
num_worker_procs: int,
|
205
|
+
args: list[str],
|
206
|
+
env: dict[str, str] | None = None,
|
207
|
+
stdout: TextIO | None = None,
|
208
|
+
stderr: TextIO | None = None,
|
209
|
+
stdin: TextIO | int | None = subprocess.DEVNULL,
|
210
|
+
pass_fds: Collection[int] = (),
|
211
|
+
) -> subprocess.Popen[bytes]:
|
212
|
+
worker_cmd, worker_env = _get_worker_exec_info()
|
213
|
+
local_rank = worker_rank % gpus_per_host
|
214
|
+
process_env = {
|
215
|
+
**(_PROC_ENV | worker_env),
|
216
|
+
"CUDA_VISIBLE_DEVICES": str(local_rank),
|
217
|
+
"NCCL_HOSTID": f"{world_uuid}_host_{worker_rank // gpus_per_host}",
|
218
|
+
# This is needed to avoid a hard failure in ncclx when we do not
|
219
|
+
# have backend topology info (eg. on RE).
|
220
|
+
"NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
|
221
|
+
"LOCAL_RANK": str(local_rank),
|
222
|
+
"RANK": str(worker_rank),
|
223
|
+
"WORLD_SIZE": str(num_worker_procs),
|
224
|
+
"LOCAL_WORLD_SIZE": str(gpus_per_host),
|
225
|
+
**os.environ,
|
226
|
+
}
|
227
|
+
cmd = []
|
228
|
+
cmd.extend(worker_cmd)
|
229
|
+
cmd.extend(args)
|
230
|
+
if env is not None:
|
231
|
+
process_env.update(env)
|
232
|
+
return subprocess.Popen(
|
233
|
+
cmd,
|
234
|
+
env=process_env,
|
235
|
+
stdout=stdout,
|
236
|
+
stderr=stderr,
|
237
|
+
stdin=stdin,
|
238
|
+
pass_fds=pass_fds,
|
239
|
+
)
|
240
|
+
|
241
|
+
|
242
|
+
ServerT = TypeVar("ServerT")
|
243
|
+
|
244
|
+
|
245
|
+
class ServerInstance:
|
246
|
+
TIMEOUT = 10.0
|
247
|
+
|
248
|
+
def __init__(
|
249
|
+
self,
|
250
|
+
*,
|
251
|
+
server: "ServerBase[ServerT]",
|
252
|
+
) -> None:
|
253
|
+
self._server = server
|
254
|
+
self._terminated: float = 0
|
255
|
+
|
256
|
+
# TODO
|
257
|
+
assert self._server._proc is not None
|
258
|
+
self.pid: int = self._server._proc.pid
|
259
|
+
|
260
|
+
def __enter__(self) -> "ServerInstance":
|
261
|
+
return self
|
262
|
+
|
263
|
+
def terminate(self) -> None:
|
264
|
+
# Start the timeout clock now.
|
265
|
+
self._terminated = time.time()
|
266
|
+
|
267
|
+
def kill(self) -> None:
|
268
|
+
pass
|
269
|
+
|
270
|
+
def __exit__(
|
271
|
+
self,
|
272
|
+
exc_type: Type[BaseException] | None,
|
273
|
+
exc_val: BaseException | None,
|
274
|
+
exc_tb: TracebackType | None,
|
275
|
+
) -> None:
|
276
|
+
timeout = max(0, self._terminated + self.TIMEOUT - time.time())
|
277
|
+
try:
|
278
|
+
self._server._finish(timeout=timeout)
|
279
|
+
except Exception as exc:
|
280
|
+
if exc_type is None:
|
281
|
+
raise
|
282
|
+
else:
|
283
|
+
logger.warning(f"failed waiting for instance to finish: {exc}")
|
284
|
+
|
285
|
+
|
286
|
+
class ServerBase(contextlib.AbstractContextManager[ServerT, None]):
|
287
|
+
def __init__(
|
288
|
+
self,
|
289
|
+
*,
|
290
|
+
name: str,
|
291
|
+
response_cls: Type[WorkerServerResponse | ControllerServerResponse],
|
292
|
+
request_cls: Type[WorkerServerRequest | ControllerServerRequest],
|
293
|
+
) -> None:
|
294
|
+
self._name = name
|
295
|
+
self._response_cls: Type[WorkerServerResponse | ControllerServerResponse] = (
|
296
|
+
response_cls
|
297
|
+
)
|
298
|
+
self._request_cls: Type[WorkerServerRequest | ControllerServerRequest] = (
|
299
|
+
request_cls
|
300
|
+
)
|
301
|
+
|
302
|
+
self._aborted = False
|
303
|
+
self._shutdown_started = False
|
304
|
+
self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
|
305
|
+
self._proc: subprocess.Popen[bytes] | None = None
|
306
|
+
self._pipe: Tuple[TextIO, TextIO] | None = None
|
307
|
+
self._lock: threading.Lock | None = None
|
308
|
+
|
309
|
+
def _send(self, msg: WorkerServerRequest | ControllerServerRequest) -> None:
|
310
|
+
logger.debug(f"{self._name}: sending server request: {msg}")
|
311
|
+
assert not self._aborted
|
312
|
+
assert self._lock is not None
|
313
|
+
if not self._lock.acquire(blocking=False):
|
314
|
+
raise Exception("server in use")
|
315
|
+
assert self._pipe is not None
|
316
|
+
self._pipe[1].write(msg.to_json() + "\n")
|
317
|
+
assert self._pipe is not None
|
318
|
+
self._pipe[1].flush()
|
319
|
+
|
320
|
+
def _recv(
|
321
|
+
self, timeout: float | None = None
|
322
|
+
) -> WorkerServerResponse | ControllerServerResponse:
|
323
|
+
assert not self._aborted
|
324
|
+
assert self._lock is not None
|
325
|
+
assert self._lock.locked()
|
326
|
+
assert self._pipe is not None
|
327
|
+
ready, _, _ = select.select([self._pipe[0]], [], [], timeout)
|
328
|
+
if not ready:
|
329
|
+
assert self._proc is not None
|
330
|
+
assert timeout is not None
|
331
|
+
raise subprocess.TimeoutExpired(self._proc.args, timeout)
|
332
|
+
output = ready[0].readline()
|
333
|
+
logger.info(f"{self._name}: Got response: {output}")
|
334
|
+
response = self._response_cls.from_json(output)
|
335
|
+
assert self._lock is not None
|
336
|
+
self._lock.release()
|
337
|
+
logger.debug(f"{self._name}: received response: {response}")
|
338
|
+
return response
|
339
|
+
|
340
|
+
def _launch_server(
|
341
|
+
self,
|
342
|
+
read_fd: int,
|
343
|
+
write_fd: int,
|
344
|
+
) -> subprocess.Popen[bytes]:
|
345
|
+
raise NotImplementedError()
|
346
|
+
|
347
|
+
def __enter__(self) -> ServerT:
|
348
|
+
assert self._proc is None, "already running"
|
349
|
+
logger.debug(f"{self._name}: launching worker server")
|
350
|
+
self._lock = threading.Lock()
|
351
|
+
send = os.pipe2(0)
|
352
|
+
recv = os.pipe2(0)
|
353
|
+
self._proc = self._contexts.enter_context(
|
354
|
+
self._launch_server(
|
355
|
+
read_fd=send[0],
|
356
|
+
write_fd=recv[1],
|
357
|
+
),
|
358
|
+
)
|
359
|
+
self._pipe = (
|
360
|
+
self._contexts.enter_context(os.fdopen(recv[0], "r")),
|
361
|
+
self._contexts.enter_context(os.fdopen(send[1], "w")),
|
362
|
+
)
|
363
|
+
os.close(send[0])
|
364
|
+
os.close(recv[1])
|
365
|
+
# pyre-ignore: Incompatible return type [7]
|
366
|
+
return self
|
367
|
+
|
368
|
+
def initiate_shutdown(self) -> None:
|
369
|
+
if not self._shutdown_started and not self._aborted:
|
370
|
+
assert self._lock is not None
|
371
|
+
assert not self._lock.locked()
|
372
|
+
self._shutdown_started = True
|
373
|
+
self._send(self._request_cls.Exit())
|
374
|
+
assert self._pipe is not None
|
375
|
+
self._pipe[1].close()
|
376
|
+
|
377
|
+
def __exit__(
|
378
|
+
self,
|
379
|
+
exc_type: Type[BaseException] | None,
|
380
|
+
exc_val: BaseException | None,
|
381
|
+
exc_tb: TracebackType | None,
|
382
|
+
) -> None:
|
383
|
+
if exc_type is not None or self._aborted:
|
384
|
+
assert self._proc is not None
|
385
|
+
self._proc.kill()
|
386
|
+
else:
|
387
|
+
# attempt a clean shutdown
|
388
|
+
self.initiate_shutdown()
|
389
|
+
assert self._proc is not None
|
390
|
+
assert self._proc.wait(timeout=5) == 0
|
391
|
+
self._contexts.__exit__(exc_type, exc_val, exc_tb)
|
392
|
+
|
393
|
+
def _finish(self, timeout: float | None = None) -> None:
|
394
|
+
try:
|
395
|
+
response = self._recv(timeout=timeout)
|
396
|
+
assert isinstance(response, self._response_cls.Finished), str(response)
|
397
|
+
# pyre-ignore: Undefined attribute [16]
|
398
|
+
assert response.error is None, response.error
|
399
|
+
except:
|
400
|
+
self._aborted = True
|
401
|
+
raise
|
402
|
+
|
403
|
+
def _launch_instance(
|
404
|
+
self,
|
405
|
+
*,
|
406
|
+
msg: WorkerServerRequest | ControllerServerRequest,
|
407
|
+
) -> ServerInstance:
|
408
|
+
self._send(msg)
|
409
|
+
return ServerInstance(server=self)
|
410
|
+
|
411
|
+
|
412
|
+
class ISystemFactory:
|
413
|
+
def launch(
|
414
|
+
self,
|
415
|
+
*,
|
416
|
+
bootstrap_addr: str,
|
417
|
+
supervision_params: SupervisionParams,
|
418
|
+
) -> ServerInstance | subprocess.Popen[bytes]:
|
419
|
+
raise NotImplementedError()
|
420
|
+
|
421
|
+
|
422
|
+
class IControllerFactory:
|
423
|
+
def launch(
|
424
|
+
self,
|
425
|
+
*,
|
426
|
+
worker_world: str,
|
427
|
+
bootstrap_addr: str,
|
428
|
+
controller_id: ActorId,
|
429
|
+
num_worker_procs: int,
|
430
|
+
gpus_per_host: int,
|
431
|
+
supervision_params: SupervisionParams,
|
432
|
+
controller_params: ControllerParams,
|
433
|
+
labels: Dict[str, str],
|
434
|
+
) -> subprocess.Popen[bytes] | ServerInstance:
|
435
|
+
raise NotImplementedError()
|
436
|
+
|
437
|
+
|
438
|
+
class ControllerFactoryBase:
|
439
|
+
def __init__(
|
440
|
+
self,
|
441
|
+
*,
|
442
|
+
logging_location: LoggingLocation,
|
443
|
+
logging_dir: str,
|
444
|
+
) -> None:
|
445
|
+
self.logging_location = logging_location
|
446
|
+
self.logging_dir = logging_dir
|
447
|
+
|
448
|
+
self.controller_main: Path
|
449
|
+
self.controller_env: dict[str, str]
|
450
|
+
self.controller_main, self.controller_env = get_controller_main()
|
451
|
+
|
452
|
+
|
453
|
+
class SystemFactory(ControllerFactoryBase, ISystemFactory):
|
454
|
+
def launch(
|
455
|
+
self,
|
456
|
+
*,
|
457
|
+
bootstrap_addr: str,
|
458
|
+
supervision_params: SupervisionParams,
|
459
|
+
) -> subprocess.Popen[bytes]:
|
460
|
+
stdout_location, stderr_location = _create_logging_locations(
|
461
|
+
self.logging_dir,
|
462
|
+
"system",
|
463
|
+
self.logging_location,
|
464
|
+
)
|
465
|
+
return subprocess.Popen(
|
466
|
+
[
|
467
|
+
self.controller_main,
|
468
|
+
"system",
|
469
|
+
"--system-addr",
|
470
|
+
bootstrap_addr,
|
471
|
+
"--supervision-update-timeout-in-sec",
|
472
|
+
str(supervision_params.update_timeout_in_sec),
|
473
|
+
],
|
474
|
+
stdout=stdout_location,
|
475
|
+
stderr=stderr_location,
|
476
|
+
stdin=subprocess.DEVNULL,
|
477
|
+
env=_PROC_ENV | self.controller_env,
|
478
|
+
)
|
479
|
+
|
480
|
+
|
481
|
+
class ControllerFactory(ControllerFactoryBase, IControllerFactory):
|
482
|
+
def launch(
|
483
|
+
self,
|
484
|
+
*,
|
485
|
+
worker_world: str,
|
486
|
+
bootstrap_addr: str,
|
487
|
+
controller_id: ActorId,
|
488
|
+
num_worker_procs: int,
|
489
|
+
gpus_per_host: int,
|
490
|
+
supervision_params: SupervisionParams,
|
491
|
+
controller_params: ControllerParams,
|
492
|
+
labels: Dict[str, str],
|
493
|
+
) -> subprocess.Popen[bytes]:
|
494
|
+
stdout_location, stderr_location = _create_logging_locations(
|
495
|
+
self.logging_dir,
|
496
|
+
controller_id.world_name,
|
497
|
+
self.logging_location,
|
498
|
+
)
|
499
|
+
command = [
|
500
|
+
self.controller_main,
|
501
|
+
"controller",
|
502
|
+
"--worker-world",
|
503
|
+
worker_world,
|
504
|
+
"--system-addr",
|
505
|
+
bootstrap_addr,
|
506
|
+
"--controller-actor-id",
|
507
|
+
str(controller_id),
|
508
|
+
"--world-size",
|
509
|
+
str(num_worker_procs),
|
510
|
+
"--num-procs-per-host",
|
511
|
+
str(gpus_per_host),
|
512
|
+
"--supervision-query-interval-in-sec",
|
513
|
+
str(supervision_params.query_interval_in_sec),
|
514
|
+
"--supervision-update-interval-in-sec",
|
515
|
+
str(supervision_params.update_interval_in_sec),
|
516
|
+
"--worker-progress-check-interval-in-sec",
|
517
|
+
str(controller_params.worker_progress_check_interval_in_sec),
|
518
|
+
"--operation-timeout-in-sec",
|
519
|
+
str(controller_params.operation_timeout_in_sec),
|
520
|
+
"--operations-per-worker-progress-request",
|
521
|
+
str(controller_params.operations_per_worker_progress_request),
|
522
|
+
]
|
523
|
+
|
524
|
+
if controller_params.fail_on_worker_timeout:
|
525
|
+
command.append("--fail-on-worker-timeout")
|
526
|
+
|
527
|
+
return subprocess.Popen(
|
528
|
+
command + _get_labels("extra-proc-labels", labels),
|
529
|
+
stdout=stdout_location,
|
530
|
+
stderr=stderr_location,
|
531
|
+
stdin=subprocess.DEVNULL,
|
532
|
+
env=_PROC_ENV | self.controller_env,
|
533
|
+
)
|
534
|
+
|
535
|
+
|
536
|
+
class ControllerServerBase(ServerBase[ServerT]):
|
537
|
+
def __init__(
|
538
|
+
self,
|
539
|
+
*,
|
540
|
+
uuid: str,
|
541
|
+
logging_location: LoggingLocation,
|
542
|
+
logging_dir: str,
|
543
|
+
) -> None:
|
544
|
+
super().__init__(
|
545
|
+
name=uuid,
|
546
|
+
response_cls=ControllerServerResponse,
|
547
|
+
request_cls=ControllerServerRequest,
|
548
|
+
)
|
549
|
+
self.uuid = uuid
|
550
|
+
self.logging_location = logging_location
|
551
|
+
self.logging_dir = logging_dir
|
552
|
+
|
553
|
+
self.controller_main: Path
|
554
|
+
self.controller_env: dict[str, str]
|
555
|
+
self.controller_main, self.controller_env = get_controller_main()
|
556
|
+
|
557
|
+
def _launch_server(
|
558
|
+
self,
|
559
|
+
read_fd: int,
|
560
|
+
write_fd: int,
|
561
|
+
) -> subprocess.Popen[bytes]:
|
562
|
+
stdout_location, stderr_location = _create_logging_locations(
|
563
|
+
self.logging_dir,
|
564
|
+
self.uuid,
|
565
|
+
self.logging_location,
|
566
|
+
)
|
567
|
+
return subprocess.Popen(
|
568
|
+
[
|
569
|
+
self.controller_main,
|
570
|
+
"serve",
|
571
|
+
str(read_fd),
|
572
|
+
str(write_fd),
|
573
|
+
],
|
574
|
+
stdout=stdout_location,
|
575
|
+
pass_fds=(read_fd, write_fd),
|
576
|
+
stderr=stderr_location,
|
577
|
+
stdin=subprocess.DEVNULL,
|
578
|
+
env=_PROC_ENV | self.controller_env | dict(os.environ),
|
579
|
+
)
|
580
|
+
|
581
|
+
|
582
|
+
class SystemServer(ControllerServerBase["SystemServer"], ISystemFactory):
|
583
|
+
def launch(
|
584
|
+
self,
|
585
|
+
*,
|
586
|
+
bootstrap_addr: str,
|
587
|
+
supervision_params: SupervisionParams,
|
588
|
+
) -> ServerInstance:
|
589
|
+
return self._launch_instance(
|
590
|
+
msg=ControllerServerRequest.Run(
|
591
|
+
RunCommand.System(
|
592
|
+
system_addr=bootstrap_addr,
|
593
|
+
supervision_update_timeout_in_sec=supervision_params.update_timeout_in_sec,
|
594
|
+
world_eviction_timeout_in_sec=10,
|
595
|
+
),
|
596
|
+
),
|
597
|
+
)
|
598
|
+
|
599
|
+
|
600
|
+
class ControllerServer(ControllerServerBase["ControllerServer"], IControllerFactory):
|
601
|
+
def launch(
|
602
|
+
self,
|
603
|
+
*,
|
604
|
+
worker_world: str,
|
605
|
+
bootstrap_addr: str,
|
606
|
+
controller_id: ActorId,
|
607
|
+
num_worker_procs: int,
|
608
|
+
gpus_per_host: int,
|
609
|
+
supervision_params: SupervisionParams,
|
610
|
+
controller_params: ControllerParams,
|
611
|
+
labels: Dict[str, str],
|
612
|
+
) -> ServerInstance:
|
613
|
+
return self._launch_instance(
|
614
|
+
msg=ControllerServerRequest.Run(
|
615
|
+
RunCommand.Controller(
|
616
|
+
ControllerCommand(
|
617
|
+
worker_world=worker_world,
|
618
|
+
system_addr=bootstrap_addr,
|
619
|
+
controller_actor_id=str(controller_id),
|
620
|
+
world_size=num_worker_procs,
|
621
|
+
num_procs_per_host=gpus_per_host,
|
622
|
+
worker_name="worker",
|
623
|
+
program=None,
|
624
|
+
supervision_query_interval_in_sec=supervision_params.query_interval_in_sec,
|
625
|
+
supervision_update_interval_in_sec=supervision_params.update_interval_in_sec,
|
626
|
+
worker_progress_check_interval_in_sec=controller_params.worker_progress_check_interval_in_sec,
|
627
|
+
operation_timeout_in_sec=controller_params.operation_timeout_in_sec,
|
628
|
+
operations_per_worker_progress_request=controller_params.operations_per_worker_progress_request,
|
629
|
+
fail_on_worker_timeout=controller_params.fail_on_worker_timeout,
|
630
|
+
is_cpu_worker=False,
|
631
|
+
extra_proc_labels=list(labels.items()),
|
632
|
+
),
|
633
|
+
),
|
634
|
+
),
|
635
|
+
)
|
636
|
+
|
637
|
+
|
638
|
+
class IWorkerFactory:
|
639
|
+
def launch(
|
640
|
+
self,
|
641
|
+
*,
|
642
|
+
worker_world: str,
|
643
|
+
worker_rank: int,
|
644
|
+
bootstrap_addr: str,
|
645
|
+
labels: Dict[str, str],
|
646
|
+
) -> ServerInstance | subprocess.Popen[bytes]:
|
647
|
+
raise NotImplementedError()
|
648
|
+
|
649
|
+
|
650
|
+
class WorkerFactory(IWorkerFactory):
|
651
|
+
def __init__(
|
652
|
+
self,
|
653
|
+
*,
|
654
|
+
num_worker_procs: int,
|
655
|
+
gpus_per_host: int,
|
656
|
+
logging_location: LoggingLocation,
|
657
|
+
logging_dir: str,
|
658
|
+
) -> None:
|
659
|
+
self.num_worker_procs = num_worker_procs
|
660
|
+
self.gpus_per_host = gpus_per_host
|
661
|
+
self.logging_location = logging_location
|
662
|
+
self.logging_dir = logging_dir
|
663
|
+
|
664
|
+
def launch(
|
665
|
+
self,
|
666
|
+
*,
|
667
|
+
worker_world: str,
|
668
|
+
worker_rank: int,
|
669
|
+
bootstrap_addr: str,
|
670
|
+
labels: Dict[str, str],
|
671
|
+
) -> subprocess.Popen[bytes]:
|
672
|
+
stdout_location, stderr_location = _create_logging_locations(
|
673
|
+
self.logging_dir,
|
674
|
+
f"{worker_world}_{worker_rank}",
|
675
|
+
self.logging_location,
|
676
|
+
)
|
677
|
+
return _start_worker_cmd(
|
678
|
+
world_uuid=worker_world,
|
679
|
+
worker_rank=worker_rank,
|
680
|
+
gpus_per_host=self.gpus_per_host,
|
681
|
+
num_worker_procs=self.num_worker_procs,
|
682
|
+
args=[
|
683
|
+
"worker",
|
684
|
+
"--world-id",
|
685
|
+
worker_world,
|
686
|
+
"--proc-id",
|
687
|
+
f"{worker_world}[{worker_rank}]",
|
688
|
+
"--bootstrap-addr",
|
689
|
+
bootstrap_addr,
|
690
|
+
]
|
691
|
+
+ _get_labels("extra-proc-labels", labels),
|
692
|
+
stdout=stdout_location,
|
693
|
+
stderr=stderr_location,
|
694
|
+
)
|
695
|
+
|
696
|
+
|
697
|
+
class WorkerServer(ServerBase["WorkerServer"]):
|
698
|
+
def __init__(
|
699
|
+
self,
|
700
|
+
*,
|
701
|
+
uuid: str,
|
702
|
+
num_worker_procs: int,
|
703
|
+
gpus_per_host: int,
|
704
|
+
world_rank: int,
|
705
|
+
logging_location: LoggingLocation,
|
706
|
+
logging_dir: str,
|
707
|
+
) -> None:
|
708
|
+
super().__init__(
|
709
|
+
name=uuid,
|
710
|
+
response_cls=WorkerServerResponse,
|
711
|
+
request_cls=WorkerServerRequest,
|
712
|
+
)
|
713
|
+
self.uuid = uuid
|
714
|
+
self.num_worker_procs = num_worker_procs
|
715
|
+
self.gpus_per_host = gpus_per_host
|
716
|
+
self.world_rank = world_rank
|
717
|
+
self.logging_location = logging_location
|
718
|
+
self.logging_dir = logging_dir
|
719
|
+
|
720
|
+
def _launch_server(
|
721
|
+
self,
|
722
|
+
read_fd: int,
|
723
|
+
write_fd: int,
|
724
|
+
) -> subprocess.Popen[bytes]:
|
725
|
+
stdout_location, stderr_location = _create_logging_locations(
|
726
|
+
self.logging_dir,
|
727
|
+
f"{self.uuid}_{self.world_rank}",
|
728
|
+
self.logging_location,
|
729
|
+
)
|
730
|
+
return _start_worker_cmd(
|
731
|
+
world_uuid=self.uuid,
|
732
|
+
worker_rank=self.world_rank,
|
733
|
+
gpus_per_host=self.gpus_per_host,
|
734
|
+
num_worker_procs=self.num_worker_procs,
|
735
|
+
args=["worker-server", str(read_fd), str(write_fd)],
|
736
|
+
pass_fds=(read_fd, write_fd),
|
737
|
+
stdin=subprocess.PIPE,
|
738
|
+
stdout=stdout_location,
|
739
|
+
stderr=stderr_location,
|
740
|
+
)
|
741
|
+
|
742
|
+
def launch(
|
743
|
+
self,
|
744
|
+
*,
|
745
|
+
worker_world: str,
|
746
|
+
bootstrap_addr: str,
|
747
|
+
labels: Dict[str, str],
|
748
|
+
) -> ServerInstance:
|
749
|
+
return self._launch_instance(
|
750
|
+
msg=WorkerServerRequest.Run(
|
751
|
+
world_id=worker_world,
|
752
|
+
proc_id=f"{worker_world}[{self.world_rank}]",
|
753
|
+
bootstrap_addr=bootstrap_addr,
|
754
|
+
labels=list(labels.items()),
|
755
|
+
)
|
756
|
+
)
|
757
|
+
|
758
|
+
|
759
|
+
class WorkerServers(IWorkerFactory):
|
760
|
+
def __init__(
|
761
|
+
self,
|
762
|
+
*,
|
763
|
+
workers: dict[int, WorkerServer],
|
764
|
+
) -> None:
|
765
|
+
self._workers = workers
|
766
|
+
self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
|
767
|
+
|
768
|
+
@staticmethod
|
769
|
+
def create(
|
770
|
+
uuid: str,
|
771
|
+
num_worker_procs: int,
|
772
|
+
gpus_per_host: int,
|
773
|
+
logging_location: LoggingLocation,
|
774
|
+
logging_dir: str,
|
775
|
+
) -> "WorkerServers":
|
776
|
+
return WorkerServers(
|
777
|
+
workers={
|
778
|
+
world_rank: WorkerServer(
|
779
|
+
uuid=uuid,
|
780
|
+
num_worker_procs=num_worker_procs,
|
781
|
+
gpus_per_host=gpus_per_host,
|
782
|
+
world_rank=world_rank,
|
783
|
+
logging_location=logging_location,
|
784
|
+
logging_dir=logging_dir,
|
785
|
+
)
|
786
|
+
for world_rank in range(num_worker_procs)
|
787
|
+
},
|
788
|
+
)
|
789
|
+
|
790
|
+
def initiate_shutdown(self) -> None:
|
791
|
+
for worker in self._workers.values():
|
792
|
+
worker.initiate_shutdown()
|
793
|
+
|
794
|
+
def __enter__(self) -> "WorkerServers":
|
795
|
+
for worker in self._workers.values():
|
796
|
+
self._contexts.enter_context(worker)
|
797
|
+
return self
|
798
|
+
|
799
|
+
def __exit__(
|
800
|
+
self,
|
801
|
+
exc_type: Type[BaseException] | None,
|
802
|
+
exc_val: BaseException | None,
|
803
|
+
exc_tb: TracebackType | None,
|
804
|
+
) -> None:
|
805
|
+
self.initiate_shutdown()
|
806
|
+
self._contexts.__exit__(exc_type, exc_val, exc_tb)
|
807
|
+
|
808
|
+
def launch(
|
809
|
+
self,
|
810
|
+
*,
|
811
|
+
worker_world: str,
|
812
|
+
worker_rank: int,
|
813
|
+
bootstrap_addr: str,
|
814
|
+
labels: Dict[str, str],
|
815
|
+
) -> ServerInstance | subprocess.Popen[bytes]:
|
816
|
+
return self._workers[worker_rank].launch(
|
817
|
+
worker_world=worker_world,
|
818
|
+
bootstrap_addr=bootstrap_addr,
|
819
|
+
labels=labels,
|
820
|
+
)
|
821
|
+
|
822
|
+
|
823
|
+
class ProcessCache:
|
824
|
+
def __init__(
|
825
|
+
self,
|
826
|
+
*,
|
827
|
+
logging_location: LoggingLocation,
|
828
|
+
logging_dir: str,
|
829
|
+
) -> None:
|
830
|
+
self.logging_location: LoggingLocation = logging_location
|
831
|
+
self.logging_dir: str = logging_dir
|
832
|
+
|
833
|
+
self._system_cache: SystemServer | None = None
|
834
|
+
self._controller_cache: ControllerServer | None = None
|
835
|
+
self._worker_cache: dict[Tuple[int, int], WorkerServers] = {}
|
836
|
+
self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
|
837
|
+
|
838
|
+
def __enter__(self) -> "ProcessCache":
|
839
|
+
return self
|
840
|
+
|
841
|
+
def __exit__(
|
842
|
+
self,
|
843
|
+
exc_type: Type[BaseException] | None,
|
844
|
+
exc_val: BaseException | None,
|
845
|
+
exc_tb: TracebackType | None,
|
846
|
+
) -> None:
|
847
|
+
if self._system_cache is not None:
|
848
|
+
self._system_cache.initiate_shutdown()
|
849
|
+
if self._controller_cache is not None:
|
850
|
+
self._controller_cache.initiate_shutdown()
|
851
|
+
for workers in self._worker_cache.values():
|
852
|
+
workers.initiate_shutdown()
|
853
|
+
self._contexts.__exit__(exc_type, exc_val, exc_tb)
|
854
|
+
|
855
|
+
def get_system_server(self) -> SystemServer:
|
856
|
+
if self._system_cache is None:
|
857
|
+
system = SystemServer(
|
858
|
+
uuid="cached_system",
|
859
|
+
logging_location=self.logging_location,
|
860
|
+
logging_dir=self.logging_dir,
|
861
|
+
)
|
862
|
+
self._system_cache = self._contexts.enter_context(system)
|
863
|
+
assert self._system_cache is not None
|
864
|
+
return self._system_cache
|
865
|
+
|
866
|
+
def get_controller_server(self) -> ControllerServer:
|
867
|
+
if self._controller_cache is None:
|
868
|
+
controller = ControllerServer(
|
869
|
+
uuid="cached_controller",
|
870
|
+
logging_location=self.logging_location,
|
871
|
+
logging_dir=self.logging_dir,
|
872
|
+
)
|
873
|
+
self._controller_cache = self._contexts.enter_context(controller)
|
874
|
+
assert self._controller_cache is not None
|
875
|
+
return self._controller_cache
|
876
|
+
|
877
|
+
def get_worker_servers(
|
878
|
+
self,
|
879
|
+
*,
|
880
|
+
num_worker_procs: int,
|
881
|
+
gpus_per_host: int,
|
882
|
+
) -> WorkerServers:
|
883
|
+
key = (num_worker_procs, gpus_per_host)
|
884
|
+
workers = self._worker_cache.get(key)
|
885
|
+
if workers is None:
|
886
|
+
workers = WorkerServers.create(
|
887
|
+
uuid=f"cached_workers_{num_worker_procs}_{gpus_per_host}",
|
888
|
+
num_worker_procs=num_worker_procs,
|
889
|
+
gpus_per_host=gpus_per_host,
|
890
|
+
logging_location=self.logging_location,
|
891
|
+
logging_dir=self.logging_dir,
|
892
|
+
)
|
893
|
+
self._worker_cache[key] = self._contexts.enter_context(workers)
|
894
|
+
return workers
|
895
|
+
|
896
|
+
|
897
|
+
class Bootstrap:
|
898
|
+
def __init__(
|
899
|
+
self,
|
900
|
+
*,
|
901
|
+
meshes: int,
|
902
|
+
hosts_per_mesh: int,
|
903
|
+
gpus_per_host: int,
|
904
|
+
worker_factory: IWorkerFactory | None = None,
|
905
|
+
controller_factory: IControllerFactory | None = None,
|
906
|
+
system_factory: ISystemFactory | None = None,
|
907
|
+
socket_type: SocketType,
|
908
|
+
logging_location: LoggingLocation,
|
909
|
+
supervision_params: SupervisionParams | None,
|
910
|
+
controller_params: ControllerParams | None,
|
911
|
+
auto_epoch: bool,
|
912
|
+
controller_labels: Dict[str, str] | None = None,
|
913
|
+
worker_labels: Dict[str, str] | None = None,
|
914
|
+
) -> None:
|
915
|
+
if supervision_params is None:
|
916
|
+
supervision_params = SupervisionParams(
|
917
|
+
update_timeout_in_sec=20,
|
918
|
+
query_interval_in_sec=2,
|
919
|
+
update_interval_in_sec=2,
|
920
|
+
)
|
921
|
+
self.supervision_params: SupervisionParams = supervision_params
|
922
|
+
|
923
|
+
if controller_params is None:
|
924
|
+
controller_params = ControllerParams(
|
925
|
+
worker_progress_check_interval_in_sec=10,
|
926
|
+
operation_timeout_in_sec=120,
|
927
|
+
operations_per_worker_progress_request=100,
|
928
|
+
fail_on_worker_timeout=False,
|
929
|
+
)
|
930
|
+
self.controller_params: ControllerParams = controller_params
|
931
|
+
|
932
|
+
self.epoch: int | None = 0 if auto_epoch else None
|
933
|
+
|
934
|
+
# hyperactor_telemetry will take the execution id and use it across all processes
|
935
|
+
execution_id = "rust_local_" + uuid.uuid4().hex
|
936
|
+
os.environ["HYPERACTOR_EXECUTION_ID"] = execution_id
|
937
|
+
|
938
|
+
# Create a temporary directory for logging
|
939
|
+
self.logging_dir: str = (
|
940
|
+
tempfile.mkdtemp(prefix="rust_local_mesh_")
|
941
|
+
if logging_location == LoggingLocation.FILE
|
942
|
+
else "N/A"
|
943
|
+
)
|
944
|
+
logger.info(
|
945
|
+
f"Creating Rust local mesh with {meshes} meshes X {hosts_per_mesh} hosts X {gpus_per_host} gpus.\n"
|
946
|
+
f"Logging directory: \033[92;1m{self.logging_dir}\033[0m\n"
|
947
|
+
f"Execution id: {execution_id}"
|
948
|
+
)
|
949
|
+
self.logging_location: LoggingLocation = logging_location
|
950
|
+
|
951
|
+
if controller_factory is None:
|
952
|
+
controller_factory = ControllerFactory(
|
953
|
+
logging_location=self.logging_location,
|
954
|
+
logging_dir=self.logging_dir,
|
955
|
+
)
|
956
|
+
self.controller_factory: IControllerFactory = controller_factory
|
957
|
+
|
958
|
+
if system_factory is None:
|
959
|
+
system_factory = SystemFactory(
|
960
|
+
logging_location=self.logging_location,
|
961
|
+
logging_dir=self.logging_dir,
|
962
|
+
)
|
963
|
+
self.system_factory: ISystemFactory = system_factory
|
964
|
+
|
965
|
+
# do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
|
966
|
+
if worker_factory is None:
|
967
|
+
worker_factory = WorkerFactory(
|
968
|
+
num_worker_procs=hosts_per_mesh * gpus_per_host,
|
969
|
+
gpus_per_host=gpus_per_host,
|
970
|
+
logging_location=self.logging_location,
|
971
|
+
logging_dir=self.logging_dir,
|
972
|
+
)
|
973
|
+
self.worker_factory: IWorkerFactory = worker_factory
|
974
|
+
|
975
|
+
# do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
|
976
|
+
fake_call(lambda: 0)
|
977
|
+
|
978
|
+
self.bootstrap_addr: str
|
979
|
+
if socket_type == SocketType.TCP:
|
980
|
+
with socket.socket() as sock:
|
981
|
+
sock.bind(("", 0))
|
982
|
+
port = sock.getsockname()[1]
|
983
|
+
self.bootstrap_addr = f"tcp![::1]:{port}"
|
984
|
+
elif socket_type == SocketType.UNIX:
|
985
|
+
# provide a random unix socket address
|
986
|
+
self.bootstrap_addr: str = f"unix!@{''.join(random.choice(string.ascii_lowercase) for _ in range(14))}-system"
|
987
|
+
else:
|
988
|
+
raise ValueError(f"Unknown socket type: {socket_type}")
|
989
|
+
|
990
|
+
env = os.environ.copy()
|
991
|
+
env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
|
992
|
+
self.env: dict[str, str] = env
|
993
|
+
|
994
|
+
# Launch a single system globally
|
995
|
+
self.processes: list[subprocess.Popen[bytes] | ServerInstance] = []
|
996
|
+
self.processes.append(self._launch_system())
|
997
|
+
|
998
|
+
self.has_shutdown: bool = False
|
999
|
+
self.gpus_per_host: int = gpus_per_host
|
1000
|
+
self.num_worker_procs: int = hosts_per_mesh * gpus_per_host
|
1001
|
+
self.controller_ids: list[ActorId] = []
|
1002
|
+
self.mesh_worlds: dict[
|
1003
|
+
MeshWorld, list[subprocess.Popen[bytes] | ServerInstance]
|
1004
|
+
] = {}
|
1005
|
+
|
1006
|
+
# Create meshes, each of which contains a single controller and multiple workers.
|
1007
|
+
# All of them will connect to the same system.
|
1008
|
+
pids: dict[str, list[int]] = {}
|
1009
|
+
for i in range(meshes):
|
1010
|
+
mesh_name: str = f"mesh_{i}"
|
1011
|
+
controller_world: str = f"{mesh_name}_controller"
|
1012
|
+
worker_world: str = f"{mesh_name}_worker"
|
1013
|
+
controller_id: ActorId = ActorId(
|
1014
|
+
world_name=controller_world,
|
1015
|
+
rank=0,
|
1016
|
+
actor_name="controller",
|
1017
|
+
)
|
1018
|
+
self.mesh_worlds[(worker_world, controller_id)] = []
|
1019
|
+
self.controller_ids.append(controller_id)
|
1020
|
+
|
1021
|
+
processes: list[subprocess.Popen[bytes] | ServerInstance] = (
|
1022
|
+
self.launch_mesh(
|
1023
|
+
controller_id,
|
1024
|
+
worker_world,
|
1025
|
+
controller_labels=controller_labels,
|
1026
|
+
worker_labels=worker_labels,
|
1027
|
+
)
|
1028
|
+
)
|
1029
|
+
|
1030
|
+
self.processes.extend(processes)
|
1031
|
+
pids[mesh_name] = [p.pid for p in processes]
|
1032
|
+
|
1033
|
+
log_message = (
|
1034
|
+
f"All processes started successfully:\n system: {self.processes[0].pid}\n"
|
1035
|
+
)
|
1036
|
+
for mesh, procs in pids.items():
|
1037
|
+
log_message += f"{mesh}: controller: {procs[0]}, "
|
1038
|
+
worker_messages = []
|
1039
|
+
for i in range(1, len(procs)):
|
1040
|
+
worker_messages.append(f"{i-1}: {procs[i]}")
|
1041
|
+
log_message += "workers: " + ", ".join(worker_messages)
|
1042
|
+
log_message += "\n"
|
1043
|
+
|
1044
|
+
self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
|
1045
|
+
|
1046
|
+
logger.info(log_message)
|
1047
|
+
|
1048
|
+
def _launch_system(
|
1049
|
+
self,
|
1050
|
+
) -> ServerInstance | subprocess.Popen[bytes]:
|
1051
|
+
logger.info("launching system")
|
1052
|
+
try:
|
1053
|
+
return self.system_factory.launch(
|
1054
|
+
bootstrap_addr=self.bootstrap_addr,
|
1055
|
+
supervision_params=self.supervision_params,
|
1056
|
+
)
|
1057
|
+
except Exception as e:
|
1058
|
+
logger.error(f"Failed to start system process: {e}")
|
1059
|
+
raise e
|
1060
|
+
|
1061
|
+
def _launch_controller(
|
1062
|
+
self,
|
1063
|
+
controller_id: ActorId,
|
1064
|
+
worker_world: str,
|
1065
|
+
epoch: str | None = None,
|
1066
|
+
labels: Dict[str, str] | None = None,
|
1067
|
+
) -> subprocess.Popen[bytes] | ServerInstance:
|
1068
|
+
logger.info("launching controller")
|
1069
|
+
try:
|
1070
|
+
return self.controller_factory.launch(
|
1071
|
+
bootstrap_addr=self.bootstrap_addr,
|
1072
|
+
worker_world=worker_world
|
1073
|
+
if epoch is None
|
1074
|
+
else f"{worker_world}_{epoch}",
|
1075
|
+
controller_id=ActorId.from_string(
|
1076
|
+
(
|
1077
|
+
f"{controller_id.world_name + '_' + epoch if epoch else controller_id.world_name}"
|
1078
|
+
f"[{controller_id.rank}]."
|
1079
|
+
f"{controller_id.actor_name}[{controller_id.pid}]"
|
1080
|
+
)
|
1081
|
+
),
|
1082
|
+
num_worker_procs=self.num_worker_procs,
|
1083
|
+
gpus_per_host=self.gpus_per_host,
|
1084
|
+
supervision_params=self.supervision_params,
|
1085
|
+
controller_params=self.controller_params,
|
1086
|
+
labels={} if labels is None else labels,
|
1087
|
+
)
|
1088
|
+
except Exception as e:
|
1089
|
+
logger.error(f"Failed to start controller process: {e}")
|
1090
|
+
raise e
|
1091
|
+
|
1092
|
+
def _launch_worker(
|
1093
|
+
self,
|
1094
|
+
worker_world: str,
|
1095
|
+
worker_rank: int,
|
1096
|
+
epoch: str | None = None,
|
1097
|
+
labels: Dict[str, str] | None = None,
|
1098
|
+
) -> subprocess.Popen[bytes] | ServerInstance:
|
1099
|
+
logger.info("launching worker")
|
1100
|
+
try:
|
1101
|
+
return self.worker_factory.launch(
|
1102
|
+
worker_world=worker_world
|
1103
|
+
if epoch is None
|
1104
|
+
else f"{worker_world}_{epoch}",
|
1105
|
+
worker_rank=worker_rank,
|
1106
|
+
bootstrap_addr=self.bootstrap_addr,
|
1107
|
+
labels={} if labels is None else labels,
|
1108
|
+
)
|
1109
|
+
except Exception as e:
|
1110
|
+
logger.error(f"Failed to start worker process {worker_rank}: {e}")
|
1111
|
+
raise e
|
1112
|
+
|
1113
|
+
def get_mesh_worlds(self) -> list[MeshWorld]:
|
1114
|
+
return list(self.mesh_worlds.keys())
|
1115
|
+
|
1116
|
+
def kill_mesh(self, mesh_world: MeshWorld) -> None:
|
1117
|
+
logger.info(f"Killing mesh {mesh_world}")
|
1118
|
+
procs = self.mesh_worlds[mesh_world]
|
1119
|
+
procs[-1].kill()
|
1120
|
+
|
1121
|
+
def spawn_mesh(self, mesh_world: MeshWorld) -> None:
|
1122
|
+
self.launch_mesh(mesh_world[1], mesh_world[0])
|
1123
|
+
|
1124
|
+
def launch_mesh(
|
1125
|
+
self,
|
1126
|
+
controller_id: ActorId,
|
1127
|
+
worker_world: str,
|
1128
|
+
controller_labels: Dict[str, str] | None = None,
|
1129
|
+
worker_labels: Dict[str, str] | None = None,
|
1130
|
+
) -> list[subprocess.Popen[bytes] | ServerInstance]:
|
1131
|
+
"""
|
1132
|
+
Create a single controller and multiple workers for a mesh.
|
1133
|
+
The first process of the return is the controller.
|
1134
|
+
The remaining ones are workers.
|
1135
|
+
"""
|
1136
|
+
logger.info(
|
1137
|
+
f"Launching mesh {worker_world} with controller {controller_id} epoch {self.epoch}"
|
1138
|
+
)
|
1139
|
+
epoch: str | None = None
|
1140
|
+
if self.epoch is not None:
|
1141
|
+
epoch = f"epoch_{self.epoch}"
|
1142
|
+
self.epoch += 1
|
1143
|
+
|
1144
|
+
processes: list[subprocess.Popen[bytes] | ServerInstance] = []
|
1145
|
+
controller_process = self._launch_controller(
|
1146
|
+
controller_id,
|
1147
|
+
worker_world,
|
1148
|
+
epoch,
|
1149
|
+
controller_labels,
|
1150
|
+
)
|
1151
|
+
processes.append(controller_process)
|
1152
|
+
|
1153
|
+
for i in range(self.num_worker_procs):
|
1154
|
+
worker_process = self._launch_worker(worker_world, i, epoch, worker_labels)
|
1155
|
+
processes.append(worker_process)
|
1156
|
+
self.mesh_worlds[(worker_world, controller_id)] = processes
|
1157
|
+
return processes
|
1158
|
+
|
1159
|
+
def __enter__(self) -> "Bootstrap":
|
1160
|
+
for process in self.processes:
|
1161
|
+
self._contexts.enter_context(process)
|
1162
|
+
return self
|
1163
|
+
|
1164
|
+
def __exit__(
|
1165
|
+
self,
|
1166
|
+
exc_type: Type[BaseException] | None,
|
1167
|
+
exc_val: BaseException | None,
|
1168
|
+
exc_tb: TracebackType | None,
|
1169
|
+
) -> None:
|
1170
|
+
for process in self.processes:
|
1171
|
+
process.terminate()
|
1172
|
+
self._contexts.__exit__(exc_type, exc_val, exc_tb)
|
1173
|
+
|
1174
|
+
|
1175
|
+
def _local_device_count() -> int:
|
1176
|
+
dev_path = Path("/dev")
|
1177
|
+
pattern = re.compile(r"nvidia\d+$")
|
1178
|
+
nvidia_devices = [dev for dev in dev_path.iterdir() if pattern.match(dev.name)]
|
1179
|
+
return len(nvidia_devices)
|
1180
|
+
|
1181
|
+
|
1182
|
+
def _get_worker_exec_info() -> tuple[list[str], dict[str, str]]:
|
1183
|
+
if IN_PAR:
|
1184
|
+
cmd = [sys.argv[0]]
|
1185
|
+
env = {
|
1186
|
+
"PAR_MAIN_OVERRIDE": _MONARCH_TENSOR_WORKER_MAIN,
|
1187
|
+
}
|
1188
|
+
else:
|
1189
|
+
cmd = [sys.executable, "-m", _MONARCH_TENSOR_WORKER_MAIN]
|
1190
|
+
env = {}
|
1191
|
+
|
1192
|
+
env["MONARCH_TENSOR_WORKER_MAIN"] = _MONARCH_TENSOR_WORKER_MAIN
|
1193
|
+
env["MONARCH_TENSOR_WORKER_EXE"] = cmd[0]
|
1194
|
+
return cmd, env
|
1195
|
+
|
1196
|
+
|
1197
|
+
@contextlib.contextmanager
|
1198
|
+
def local_mesh(
|
1199
|
+
*,
|
1200
|
+
hosts: int = 1,
|
1201
|
+
gpus_per_host: int | None = None,
|
1202
|
+
socket_type: SocketType = SocketType.TCP,
|
1203
|
+
logging_location: LoggingLocation = LoggingLocation.FILE,
|
1204
|
+
supervision_params: SupervisionParams | None = None,
|
1205
|
+
controller_params: ControllerParams | None = None,
|
1206
|
+
worker_factory: IWorkerFactory | None = None,
|
1207
|
+
controller_factory: IControllerFactory | None = None,
|
1208
|
+
system_factory: ISystemFactory | None = None,
|
1209
|
+
) -> Generator[DeviceMesh, None, None]:
|
1210
|
+
"""
|
1211
|
+
Creates a single local device mesh with the given number of per host.
|
1212
|
+
|
1213
|
+
Args:
|
1214
|
+
hosts : number of hosts, primarily used for simulating multiple machines locally.
|
1215
|
+
Default: 1
|
1216
|
+
gpus_per_host : number of gpus per host.
|
1217
|
+
Default: the number of GPUs this machine has.
|
1218
|
+
socket_type : socket type to use for communication between processes.
|
1219
|
+
Default: TCP.
|
1220
|
+
|
1221
|
+
Example::
|
1222
|
+
with local_mesh().activate():
|
1223
|
+
x = torch.rand(3, 4)
|
1224
|
+
local_tensor = fetch_shard(x).result()
|
1225
|
+
"""
|
1226
|
+
with local_meshes(
|
1227
|
+
meshes=1,
|
1228
|
+
hosts_per_mesh=hosts,
|
1229
|
+
gpus_per_host=gpus_per_host,
|
1230
|
+
socket_type=socket_type,
|
1231
|
+
logging_location=logging_location,
|
1232
|
+
supervision_params=supervision_params,
|
1233
|
+
controller_params=controller_params,
|
1234
|
+
worker_factory=worker_factory,
|
1235
|
+
controller_factory=controller_factory,
|
1236
|
+
system_factory=system_factory,
|
1237
|
+
) as dms:
|
1238
|
+
assert len(dms) == 1
|
1239
|
+
yield dms[0]
|
1240
|
+
|
1241
|
+
|
1242
|
+
@contextlib.contextmanager
|
1243
|
+
def local_meshes(
|
1244
|
+
*,
|
1245
|
+
meshes: int = 1,
|
1246
|
+
hosts_per_mesh: int = 1,
|
1247
|
+
gpus_per_host: int | None = None,
|
1248
|
+
socket_type: SocketType = SocketType.TCP,
|
1249
|
+
logging_location: LoggingLocation = LoggingLocation.FILE,
|
1250
|
+
supervision_params: SupervisionParams | None = None,
|
1251
|
+
controller_params: ControllerParams | None = None,
|
1252
|
+
worker_factory: IWorkerFactory | None = None,
|
1253
|
+
controller_factory: IControllerFactory | None = None,
|
1254
|
+
system_factory: ISystemFactory | None = None,
|
1255
|
+
) -> Generator[list[DeviceMesh], None, None]:
|
1256
|
+
"""
|
1257
|
+
Creates multiple local device meshes.
|
1258
|
+
|
1259
|
+
Args:
|
1260
|
+
meshes : number of global meshes to create.
|
1261
|
+
Default: 1
|
1262
|
+
hosts_per_mesh : number of hosts per mesh, primarily used for simulating multiple machines locally.
|
1263
|
+
Default: 1
|
1264
|
+
gpus_per_host : number of gpus per host.
|
1265
|
+
Default: the number of GPUs this machine has.
|
1266
|
+
socket_type : socket type to use for communication between processes.
|
1267
|
+
Default: TCP.
|
1268
|
+
"""
|
1269
|
+
(dms, bootstrap) = local_meshes_and_bootstraps(
|
1270
|
+
meshes=meshes,
|
1271
|
+
hosts_per_mesh=hosts_per_mesh,
|
1272
|
+
gpus_per_host=gpus_per_host,
|
1273
|
+
socket_type=socket_type,
|
1274
|
+
logging_location=logging_location,
|
1275
|
+
supervision_params=supervision_params,
|
1276
|
+
controller_params=controller_params,
|
1277
|
+
worker_factory=worker_factory,
|
1278
|
+
controller_factory=controller_factory,
|
1279
|
+
system_factory=system_factory,
|
1280
|
+
)
|
1281
|
+
with bootstrap:
|
1282
|
+
maybe_error = None
|
1283
|
+
try:
|
1284
|
+
yield dms
|
1285
|
+
except Exception as e:
|
1286
|
+
maybe_error = e
|
1287
|
+
raise
|
1288
|
+
finally:
|
1289
|
+
for dm in dms:
|
1290
|
+
dm.exit(maybe_error)
|
1291
|
+
|
1292
|
+
|
1293
|
+
def local_meshes_and_bootstraps(
|
1294
|
+
*,
|
1295
|
+
meshes: int = 1,
|
1296
|
+
hosts_per_mesh: int = 1,
|
1297
|
+
gpus_per_host: int | None = None,
|
1298
|
+
socket_type: SocketType = SocketType.TCP,
|
1299
|
+
logging_location: LoggingLocation = LoggingLocation.FILE,
|
1300
|
+
supervision_params: SupervisionParams | None = None,
|
1301
|
+
controller_params: ControllerParams | None = None,
|
1302
|
+
auto_epoch: bool = False,
|
1303
|
+
worker_factory: IWorkerFactory | None = None,
|
1304
|
+
controller_factory: IControllerFactory | None = None,
|
1305
|
+
system_factory: ISystemFactory | None = None,
|
1306
|
+
) -> tuple[list[DeviceMesh], Bootstrap]:
|
1307
|
+
"""
|
1308
|
+
Same as local_meshes, but also returns the bootstrap object. This is
|
1309
|
+
useful in tests where we want to maniputate the bootstrap object.
|
1310
|
+
"""
|
1311
|
+
|
1312
|
+
if gpus_per_host is None:
|
1313
|
+
gpus_per_host = _local_device_count()
|
1314
|
+
assert (
|
1315
|
+
0 < gpus_per_host <= 8
|
1316
|
+
), "Number of GPUs must be greater than 0 and at most 8."
|
1317
|
+
bootstrap: Bootstrap = Bootstrap(
|
1318
|
+
meshes=meshes,
|
1319
|
+
hosts_per_mesh=hosts_per_mesh,
|
1320
|
+
gpus_per_host=gpus_per_host,
|
1321
|
+
socket_type=socket_type,
|
1322
|
+
logging_location=logging_location,
|
1323
|
+
supervision_params=supervision_params,
|
1324
|
+
controller_params=controller_params,
|
1325
|
+
auto_epoch=auto_epoch,
|
1326
|
+
worker_factory=worker_factory,
|
1327
|
+
controller_factory=controller_factory,
|
1328
|
+
system_factory=system_factory,
|
1329
|
+
)
|
1330
|
+
|
1331
|
+
def create_exit(
|
1332
|
+
dm: DeviceMesh, bootstrap: Bootstrap
|
1333
|
+
) -> Callable[[Optional[RemoteException | DeviceException | Exception]], None]:
|
1334
|
+
def exit(
|
1335
|
+
error: Optional[RemoteException | DeviceException | Exception] = None,
|
1336
|
+
) -> None:
|
1337
|
+
# We only support one single client proc.
|
1338
|
+
if not bootstrap.has_shutdown:
|
1339
|
+
dm.client.shutdown(True, error)
|
1340
|
+
bootstrap.has_shutdown = True
|
1341
|
+
|
1342
|
+
# We do not need to shutdown bootstrap and clean up the processes
|
1343
|
+
# as they will be cleaned up with the parent.
|
1344
|
+
return exit
|
1345
|
+
|
1346
|
+
dms = rust_backend_meshes(
|
1347
|
+
system_addr=bootstrap.bootstrap_addr,
|
1348
|
+
hosts=hosts_per_mesh,
|
1349
|
+
gpus=gpus_per_host,
|
1350
|
+
requested_meshes=meshes,
|
1351
|
+
)
|
1352
|
+
|
1353
|
+
for dm in dms:
|
1354
|
+
dm.exit = create_exit(dm, bootstrap)
|
1355
|
+
|
1356
|
+
return (dms, bootstrap)
|
1357
|
+
|
1358
|
+
|
1359
|
+
def local_mesh_provider(
|
1360
|
+
*,
|
1361
|
+
meshes: int = 1,
|
1362
|
+
hosts_per_mesh: int = 1,
|
1363
|
+
gpus_per_host: int | None = None,
|
1364
|
+
socket_type: SocketType = SocketType.TCP,
|
1365
|
+
logging_location: LoggingLocation = LoggingLocation.FILE,
|
1366
|
+
supervision_params: SupervisionParams | None = None,
|
1367
|
+
controller_params: ControllerParams | None = None,
|
1368
|
+
auto_epoch: bool = False,
|
1369
|
+
controller_labels: Dict[str, str] | None = None,
|
1370
|
+
worker_labels: Dict[str, str] | None = None,
|
1371
|
+
worker_factory: IWorkerFactory | None = None,
|
1372
|
+
controller_factory: IControllerFactory | None = None,
|
1373
|
+
system_factory: ISystemFactory | None = None,
|
1374
|
+
# pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type.
|
1375
|
+
) -> tuple[PoolDeviceMeshProvider, Bootstrap]:
|
1376
|
+
if gpus_per_host is None:
|
1377
|
+
gpus_per_host = _local_device_count()
|
1378
|
+
assert (
|
1379
|
+
0 < gpus_per_host <= 8
|
1380
|
+
), "Number of GPUs must be greater than 0 and at most 8."
|
1381
|
+
bootstrap: Bootstrap = Bootstrap(
|
1382
|
+
meshes=meshes,
|
1383
|
+
hosts_per_mesh=hosts_per_mesh,
|
1384
|
+
gpus_per_host=gpus_per_host,
|
1385
|
+
socket_type=socket_type,
|
1386
|
+
logging_location=logging_location,
|
1387
|
+
supervision_params=supervision_params,
|
1388
|
+
controller_params=controller_params,
|
1389
|
+
auto_epoch=auto_epoch,
|
1390
|
+
controller_labels=controller_labels,
|
1391
|
+
worker_labels=worker_labels,
|
1392
|
+
worker_factory=worker_factory,
|
1393
|
+
controller_factory=controller_factory,
|
1394
|
+
system_factory=system_factory,
|
1395
|
+
)
|
1396
|
+
|
1397
|
+
provider = rust_backend_mesh_provider(
|
1398
|
+
system_addr=bootstrap.bootstrap_addr,
|
1399
|
+
hosts=hosts_per_mesh,
|
1400
|
+
gpus=gpus_per_host,
|
1401
|
+
)
|
1402
|
+
return (provider, bootstrap)
|