torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1402 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+
9
+ import contextlib
10
+ import importlib.resources
11
+ import logging
12
+ import os
13
+ import random
14
+ import re
15
+ import select
16
+ import socket
17
+ import string
18
+ import subprocess
19
+ import sys
20
+ import tempfile
21
+ import threading
22
+ import time
23
+ import uuid
24
+ from enum import Enum
25
+ from pathlib import Path
26
+ from types import TracebackType
27
+ from typing import (
28
+ Callable,
29
+ Collection,
30
+ Dict,
31
+ Generator,
32
+ List,
33
+ NamedTuple,
34
+ Optional,
35
+ TextIO,
36
+ Tuple,
37
+ Type,
38
+ TypeVar,
39
+ )
40
+
41
+ from monarch._rust_bindings.controller.bootstrap import (
42
+ ControllerCommand,
43
+ ControllerServerRequest,
44
+ ControllerServerResponse,
45
+ RunCommand,
46
+ )
47
+
48
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
49
+ ActorId,
50
+ )
51
+
52
+ from monarch._rust_bindings.monarch_tensor_worker.bootstrap import (
53
+ WorkerServerRequest,
54
+ WorkerServerResponse,
55
+ )
56
+
57
+ from monarch.common.device_mesh import DeviceMesh
58
+ from monarch.common.fake import fake_call
59
+ from monarch.common.invocation import DeviceException, RemoteException
60
+ from monarch.rust_backend_mesh import (
61
+ IBootstrap,
62
+ MeshWorld,
63
+ PoolDeviceMeshProvider,
64
+ rust_backend_mesh_provider,
65
+ rust_backend_meshes,
66
+ )
67
+
68
+ logger: logging.Logger = logging.getLogger(__name__)
69
+ _MONARCH_TENSOR_WORKER_MAIN = "monarch.tensor_worker_main"
70
+
71
+ try:
72
+ from __manifest__ import fbmake # noqa
73
+
74
+ IN_PAR = True
75
+ except ImportError:
76
+ IN_PAR = False
77
+
78
+
79
+ class SocketType(Enum):
80
+ """Enum representing socket types."""
81
+
82
+ TCP = "tcp"
83
+ UNIX = "unix"
84
+
85
+
86
+ class LoggingLocation(Enum):
87
+ """Enum representing where to flush stderr and stdout."""
88
+
89
+ DEFAULT = "default"
90
+ FILE = "file"
91
+
92
+
93
+ class SupervisionParams(NamedTuple):
94
+ # If system actor does not receive supervision update within this time,
95
+ # it will treate this proc as dead.
96
+ update_timeout_in_sec: int
97
+ # How often controller queries supervision status from system actor.
98
+ query_interval_in_sec: int
99
+ # How often proc actor sends supervision update to system actor.
100
+ update_interval_in_sec: int
101
+
102
+
103
+ class ControllerParams(NamedTuple):
104
+ # How often the controller will poll for operations that have not completed within a timeout duration
105
+ # indicating that it may be stuck.
106
+ worker_progress_check_interval_in_sec: int
107
+
108
+ # How long we will wait for an operation before letting the client know that it may be stuck.
109
+ operation_timeout_in_sec: int
110
+
111
+ # The number of operations invoked before we proactively check worker progress. If a large number
112
+ # of operations are invoked all at once, it is expected that it will take a while for all operations
113
+ # to complete so we want to inject progress requests at a higher frequency to check if we are making progress
114
+ operations_per_worker_progress_request: int
115
+
116
+ # If the controller should propagate a failure to the client if the workers become stuck.
117
+ fail_on_worker_timeout: bool
118
+
119
+
120
+ _PROC_ENV = {
121
+ "HYPERACTOR_MANAGED_SUBPROCESS": str(1),
122
+ }
123
+
124
+
125
+ def get_controller_main() -> tuple[Path, dict[str, str]]:
126
+ with (
127
+ importlib.resources.path("monarch", "monarch_controller") as controller_main,
128
+ ):
129
+ if not controller_main.exists():
130
+ if IN_PAR:
131
+ raise ImportError(
132
+ "Monarch env not found, please define a custom 'monarch_env' or "
133
+ "add '//monarch/python/monarch:default_env-library' to your binary dependencies "
134
+ "in TARGETS"
135
+ )
136
+ else:
137
+ raise ImportError(
138
+ "Monarch env not found, please re-run ./scripts/install.sh in fbcode/monarch"
139
+ )
140
+ env: dict[str, str] = {}
141
+
142
+ # Hack to make exploded wheel workflow work in the face of broken
143
+ # build-time RPATHs...
144
+ #
145
+ # If we're running under a conda env...
146
+ if not IN_PAR:
147
+ conda_prefix = os.environ.get("CONDA_PREFIX")
148
+ if conda_prefix is not None and sys.executable.startswith(
149
+ conda_prefix + "/"
150
+ ):
151
+ # .. and Monarch is coming from "outside" the env, via `PYTHONPATH`s ...
152
+ spec = importlib.util.find_spec("monarch")
153
+ assert spec is not None
154
+ origin = spec.origin
155
+ assert origin is not None
156
+ monarch_root = str(Path(origin).parent.parent)
157
+ if (
158
+ not monarch_root.startswith(conda_prefix + "/")
159
+ and monarch_root in sys.path
160
+ ):
161
+ import torch
162
+
163
+ # then assume we're running via exploded .whl, which means
164
+ # we need to manually set library paths to find the necessary
165
+ # native libs from the conda env.
166
+ env["LD_LIBRARY_PATH"] = ":".join(
167
+ [
168
+ os.path.join(os.path.dirname(torch.__file__), "lib"),
169
+ os.path.join(conda_prefix, "lib"),
170
+ ]
171
+ )
172
+
173
+ return controller_main, env
174
+
175
+
176
+ def _create_logging_locations(
177
+ logging_dir: str, name: str, logging_location: LoggingLocation
178
+ ) -> tuple[TextIO | None, TextIO | None]:
179
+ if logging_location == LoggingLocation.FILE:
180
+ stdout_file: TextIO = open(os.path.join(logging_dir, f"{name}.stdout"), "a+")
181
+ stderr_file: TextIO = open(os.path.join(logging_dir, f"{name}.stderr"), "a+")
182
+ return stdout_file, stderr_file
183
+ elif logging_location == LoggingLocation.DEFAULT:
184
+ return None, None
185
+ else:
186
+ raise ValueError(f"Unknown logging location: {logging_location}")
187
+
188
+
189
+ def _get_labels(flag_name: str, labels: Dict[str, str]) -> List[str]:
190
+ params = []
191
+ for k, v in labels.items():
192
+ assert k not in params, f"Duplicate label: {k}"
193
+ assert "=" not in k, f"Key cannot contain '=': {k}"
194
+ params.append(f"--{flag_name}")
195
+ params.append(f"{k}={v}")
196
+ return params
197
+
198
+
199
+ def _start_worker_cmd(
200
+ *,
201
+ world_uuid: str,
202
+ worker_rank: int,
203
+ gpus_per_host: int,
204
+ num_worker_procs: int,
205
+ args: list[str],
206
+ env: dict[str, str] | None = None,
207
+ stdout: TextIO | None = None,
208
+ stderr: TextIO | None = None,
209
+ stdin: TextIO | int | None = subprocess.DEVNULL,
210
+ pass_fds: Collection[int] = (),
211
+ ) -> subprocess.Popen[bytes]:
212
+ worker_cmd, worker_env = _get_worker_exec_info()
213
+ local_rank = worker_rank % gpus_per_host
214
+ process_env = {
215
+ **(_PROC_ENV | worker_env),
216
+ "CUDA_VISIBLE_DEVICES": str(local_rank),
217
+ "NCCL_HOSTID": f"{world_uuid}_host_{worker_rank // gpus_per_host}",
218
+ # This is needed to avoid a hard failure in ncclx when we do not
219
+ # have backend topology info (eg. on RE).
220
+ "NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
221
+ "LOCAL_RANK": str(local_rank),
222
+ "RANK": str(worker_rank),
223
+ "WORLD_SIZE": str(num_worker_procs),
224
+ "LOCAL_WORLD_SIZE": str(gpus_per_host),
225
+ **os.environ,
226
+ }
227
+ cmd = []
228
+ cmd.extend(worker_cmd)
229
+ cmd.extend(args)
230
+ if env is not None:
231
+ process_env.update(env)
232
+ return subprocess.Popen(
233
+ cmd,
234
+ env=process_env,
235
+ stdout=stdout,
236
+ stderr=stderr,
237
+ stdin=stdin,
238
+ pass_fds=pass_fds,
239
+ )
240
+
241
+
242
+ ServerT = TypeVar("ServerT")
243
+
244
+
245
+ class ServerInstance:
246
+ TIMEOUT = 10.0
247
+
248
+ def __init__(
249
+ self,
250
+ *,
251
+ server: "ServerBase[ServerT]",
252
+ ) -> None:
253
+ self._server = server
254
+ self._terminated: float = 0
255
+
256
+ # TODO
257
+ assert self._server._proc is not None
258
+ self.pid: int = self._server._proc.pid
259
+
260
+ def __enter__(self) -> "ServerInstance":
261
+ return self
262
+
263
+ def terminate(self) -> None:
264
+ # Start the timeout clock now.
265
+ self._terminated = time.time()
266
+
267
+ def kill(self) -> None:
268
+ pass
269
+
270
+ def __exit__(
271
+ self,
272
+ exc_type: Type[BaseException] | None,
273
+ exc_val: BaseException | None,
274
+ exc_tb: TracebackType | None,
275
+ ) -> None:
276
+ timeout = max(0, self._terminated + self.TIMEOUT - time.time())
277
+ try:
278
+ self._server._finish(timeout=timeout)
279
+ except Exception as exc:
280
+ if exc_type is None:
281
+ raise
282
+ else:
283
+ logger.warning(f"failed waiting for instance to finish: {exc}")
284
+
285
+
286
+ class ServerBase(contextlib.AbstractContextManager[ServerT, None]):
287
+ def __init__(
288
+ self,
289
+ *,
290
+ name: str,
291
+ response_cls: Type[WorkerServerResponse | ControllerServerResponse],
292
+ request_cls: Type[WorkerServerRequest | ControllerServerRequest],
293
+ ) -> None:
294
+ self._name = name
295
+ self._response_cls: Type[WorkerServerResponse | ControllerServerResponse] = (
296
+ response_cls
297
+ )
298
+ self._request_cls: Type[WorkerServerRequest | ControllerServerRequest] = (
299
+ request_cls
300
+ )
301
+
302
+ self._aborted = False
303
+ self._shutdown_started = False
304
+ self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
305
+ self._proc: subprocess.Popen[bytes] | None = None
306
+ self._pipe: Tuple[TextIO, TextIO] | None = None
307
+ self._lock: threading.Lock | None = None
308
+
309
+ def _send(self, msg: WorkerServerRequest | ControllerServerRequest) -> None:
310
+ logger.debug(f"{self._name}: sending server request: {msg}")
311
+ assert not self._aborted
312
+ assert self._lock is not None
313
+ if not self._lock.acquire(blocking=False):
314
+ raise Exception("server in use")
315
+ assert self._pipe is not None
316
+ self._pipe[1].write(msg.to_json() + "\n")
317
+ assert self._pipe is not None
318
+ self._pipe[1].flush()
319
+
320
+ def _recv(
321
+ self, timeout: float | None = None
322
+ ) -> WorkerServerResponse | ControllerServerResponse:
323
+ assert not self._aborted
324
+ assert self._lock is not None
325
+ assert self._lock.locked()
326
+ assert self._pipe is not None
327
+ ready, _, _ = select.select([self._pipe[0]], [], [], timeout)
328
+ if not ready:
329
+ assert self._proc is not None
330
+ assert timeout is not None
331
+ raise subprocess.TimeoutExpired(self._proc.args, timeout)
332
+ output = ready[0].readline()
333
+ logger.info(f"{self._name}: Got response: {output}")
334
+ response = self._response_cls.from_json(output)
335
+ assert self._lock is not None
336
+ self._lock.release()
337
+ logger.debug(f"{self._name}: received response: {response}")
338
+ return response
339
+
340
+ def _launch_server(
341
+ self,
342
+ read_fd: int,
343
+ write_fd: int,
344
+ ) -> subprocess.Popen[bytes]:
345
+ raise NotImplementedError()
346
+
347
+ def __enter__(self) -> ServerT:
348
+ assert self._proc is None, "already running"
349
+ logger.debug(f"{self._name}: launching worker server")
350
+ self._lock = threading.Lock()
351
+ send = os.pipe2(0)
352
+ recv = os.pipe2(0)
353
+ self._proc = self._contexts.enter_context(
354
+ self._launch_server(
355
+ read_fd=send[0],
356
+ write_fd=recv[1],
357
+ ),
358
+ )
359
+ self._pipe = (
360
+ self._contexts.enter_context(os.fdopen(recv[0], "r")),
361
+ self._contexts.enter_context(os.fdopen(send[1], "w")),
362
+ )
363
+ os.close(send[0])
364
+ os.close(recv[1])
365
+ # pyre-ignore: Incompatible return type [7]
366
+ return self
367
+
368
+ def initiate_shutdown(self) -> None:
369
+ if not self._shutdown_started and not self._aborted:
370
+ assert self._lock is not None
371
+ assert not self._lock.locked()
372
+ self._shutdown_started = True
373
+ self._send(self._request_cls.Exit())
374
+ assert self._pipe is not None
375
+ self._pipe[1].close()
376
+
377
+ def __exit__(
378
+ self,
379
+ exc_type: Type[BaseException] | None,
380
+ exc_val: BaseException | None,
381
+ exc_tb: TracebackType | None,
382
+ ) -> None:
383
+ if exc_type is not None or self._aborted:
384
+ assert self._proc is not None
385
+ self._proc.kill()
386
+ else:
387
+ # attempt a clean shutdown
388
+ self.initiate_shutdown()
389
+ assert self._proc is not None
390
+ assert self._proc.wait(timeout=5) == 0
391
+ self._contexts.__exit__(exc_type, exc_val, exc_tb)
392
+
393
+ def _finish(self, timeout: float | None = None) -> None:
394
+ try:
395
+ response = self._recv(timeout=timeout)
396
+ assert isinstance(response, self._response_cls.Finished), str(response)
397
+ # pyre-ignore: Undefined attribute [16]
398
+ assert response.error is None, response.error
399
+ except:
400
+ self._aborted = True
401
+ raise
402
+
403
+ def _launch_instance(
404
+ self,
405
+ *,
406
+ msg: WorkerServerRequest | ControllerServerRequest,
407
+ ) -> ServerInstance:
408
+ self._send(msg)
409
+ return ServerInstance(server=self)
410
+
411
+
412
+ class ISystemFactory:
413
+ def launch(
414
+ self,
415
+ *,
416
+ bootstrap_addr: str,
417
+ supervision_params: SupervisionParams,
418
+ ) -> ServerInstance | subprocess.Popen[bytes]:
419
+ raise NotImplementedError()
420
+
421
+
422
+ class IControllerFactory:
423
+ def launch(
424
+ self,
425
+ *,
426
+ worker_world: str,
427
+ bootstrap_addr: str,
428
+ controller_id: ActorId,
429
+ num_worker_procs: int,
430
+ gpus_per_host: int,
431
+ supervision_params: SupervisionParams,
432
+ controller_params: ControllerParams,
433
+ labels: Dict[str, str],
434
+ ) -> subprocess.Popen[bytes] | ServerInstance:
435
+ raise NotImplementedError()
436
+
437
+
438
+ class ControllerFactoryBase:
439
+ def __init__(
440
+ self,
441
+ *,
442
+ logging_location: LoggingLocation,
443
+ logging_dir: str,
444
+ ) -> None:
445
+ self.logging_location = logging_location
446
+ self.logging_dir = logging_dir
447
+
448
+ self.controller_main: Path
449
+ self.controller_env: dict[str, str]
450
+ self.controller_main, self.controller_env = get_controller_main()
451
+
452
+
453
+ class SystemFactory(ControllerFactoryBase, ISystemFactory):
454
+ def launch(
455
+ self,
456
+ *,
457
+ bootstrap_addr: str,
458
+ supervision_params: SupervisionParams,
459
+ ) -> subprocess.Popen[bytes]:
460
+ stdout_location, stderr_location = _create_logging_locations(
461
+ self.logging_dir,
462
+ "system",
463
+ self.logging_location,
464
+ )
465
+ return subprocess.Popen(
466
+ [
467
+ self.controller_main,
468
+ "system",
469
+ "--system-addr",
470
+ bootstrap_addr,
471
+ "--supervision-update-timeout-in-sec",
472
+ str(supervision_params.update_timeout_in_sec),
473
+ ],
474
+ stdout=stdout_location,
475
+ stderr=stderr_location,
476
+ stdin=subprocess.DEVNULL,
477
+ env=_PROC_ENV | self.controller_env,
478
+ )
479
+
480
+
481
+ class ControllerFactory(ControllerFactoryBase, IControllerFactory):
482
+ def launch(
483
+ self,
484
+ *,
485
+ worker_world: str,
486
+ bootstrap_addr: str,
487
+ controller_id: ActorId,
488
+ num_worker_procs: int,
489
+ gpus_per_host: int,
490
+ supervision_params: SupervisionParams,
491
+ controller_params: ControllerParams,
492
+ labels: Dict[str, str],
493
+ ) -> subprocess.Popen[bytes]:
494
+ stdout_location, stderr_location = _create_logging_locations(
495
+ self.logging_dir,
496
+ controller_id.world_name,
497
+ self.logging_location,
498
+ )
499
+ command = [
500
+ self.controller_main,
501
+ "controller",
502
+ "--worker-world",
503
+ worker_world,
504
+ "--system-addr",
505
+ bootstrap_addr,
506
+ "--controller-actor-id",
507
+ str(controller_id),
508
+ "--world-size",
509
+ str(num_worker_procs),
510
+ "--num-procs-per-host",
511
+ str(gpus_per_host),
512
+ "--supervision-query-interval-in-sec",
513
+ str(supervision_params.query_interval_in_sec),
514
+ "--supervision-update-interval-in-sec",
515
+ str(supervision_params.update_interval_in_sec),
516
+ "--worker-progress-check-interval-in-sec",
517
+ str(controller_params.worker_progress_check_interval_in_sec),
518
+ "--operation-timeout-in-sec",
519
+ str(controller_params.operation_timeout_in_sec),
520
+ "--operations-per-worker-progress-request",
521
+ str(controller_params.operations_per_worker_progress_request),
522
+ ]
523
+
524
+ if controller_params.fail_on_worker_timeout:
525
+ command.append("--fail-on-worker-timeout")
526
+
527
+ return subprocess.Popen(
528
+ command + _get_labels("extra-proc-labels", labels),
529
+ stdout=stdout_location,
530
+ stderr=stderr_location,
531
+ stdin=subprocess.DEVNULL,
532
+ env=_PROC_ENV | self.controller_env,
533
+ )
534
+
535
+
536
+ class ControllerServerBase(ServerBase[ServerT]):
537
+ def __init__(
538
+ self,
539
+ *,
540
+ uuid: str,
541
+ logging_location: LoggingLocation,
542
+ logging_dir: str,
543
+ ) -> None:
544
+ super().__init__(
545
+ name=uuid,
546
+ response_cls=ControllerServerResponse,
547
+ request_cls=ControllerServerRequest,
548
+ )
549
+ self.uuid = uuid
550
+ self.logging_location = logging_location
551
+ self.logging_dir = logging_dir
552
+
553
+ self.controller_main: Path
554
+ self.controller_env: dict[str, str]
555
+ self.controller_main, self.controller_env = get_controller_main()
556
+
557
+ def _launch_server(
558
+ self,
559
+ read_fd: int,
560
+ write_fd: int,
561
+ ) -> subprocess.Popen[bytes]:
562
+ stdout_location, stderr_location = _create_logging_locations(
563
+ self.logging_dir,
564
+ self.uuid,
565
+ self.logging_location,
566
+ )
567
+ return subprocess.Popen(
568
+ [
569
+ self.controller_main,
570
+ "serve",
571
+ str(read_fd),
572
+ str(write_fd),
573
+ ],
574
+ stdout=stdout_location,
575
+ pass_fds=(read_fd, write_fd),
576
+ stderr=stderr_location,
577
+ stdin=subprocess.DEVNULL,
578
+ env=_PROC_ENV | self.controller_env | dict(os.environ),
579
+ )
580
+
581
+
582
+ class SystemServer(ControllerServerBase["SystemServer"], ISystemFactory):
583
+ def launch(
584
+ self,
585
+ *,
586
+ bootstrap_addr: str,
587
+ supervision_params: SupervisionParams,
588
+ ) -> ServerInstance:
589
+ return self._launch_instance(
590
+ msg=ControllerServerRequest.Run(
591
+ RunCommand.System(
592
+ system_addr=bootstrap_addr,
593
+ supervision_update_timeout_in_sec=supervision_params.update_timeout_in_sec,
594
+ world_eviction_timeout_in_sec=10,
595
+ ),
596
+ ),
597
+ )
598
+
599
+
600
+ class ControllerServer(ControllerServerBase["ControllerServer"], IControllerFactory):
601
+ def launch(
602
+ self,
603
+ *,
604
+ worker_world: str,
605
+ bootstrap_addr: str,
606
+ controller_id: ActorId,
607
+ num_worker_procs: int,
608
+ gpus_per_host: int,
609
+ supervision_params: SupervisionParams,
610
+ controller_params: ControllerParams,
611
+ labels: Dict[str, str],
612
+ ) -> ServerInstance:
613
+ return self._launch_instance(
614
+ msg=ControllerServerRequest.Run(
615
+ RunCommand.Controller(
616
+ ControllerCommand(
617
+ worker_world=worker_world,
618
+ system_addr=bootstrap_addr,
619
+ controller_actor_id=str(controller_id),
620
+ world_size=num_worker_procs,
621
+ num_procs_per_host=gpus_per_host,
622
+ worker_name="worker",
623
+ program=None,
624
+ supervision_query_interval_in_sec=supervision_params.query_interval_in_sec,
625
+ supervision_update_interval_in_sec=supervision_params.update_interval_in_sec,
626
+ worker_progress_check_interval_in_sec=controller_params.worker_progress_check_interval_in_sec,
627
+ operation_timeout_in_sec=controller_params.operation_timeout_in_sec,
628
+ operations_per_worker_progress_request=controller_params.operations_per_worker_progress_request,
629
+ fail_on_worker_timeout=controller_params.fail_on_worker_timeout,
630
+ is_cpu_worker=False,
631
+ extra_proc_labels=list(labels.items()),
632
+ ),
633
+ ),
634
+ ),
635
+ )
636
+
637
+
638
+ class IWorkerFactory:
639
+ def launch(
640
+ self,
641
+ *,
642
+ worker_world: str,
643
+ worker_rank: int,
644
+ bootstrap_addr: str,
645
+ labels: Dict[str, str],
646
+ ) -> ServerInstance | subprocess.Popen[bytes]:
647
+ raise NotImplementedError()
648
+
649
+
650
+ class WorkerFactory(IWorkerFactory):
651
+ def __init__(
652
+ self,
653
+ *,
654
+ num_worker_procs: int,
655
+ gpus_per_host: int,
656
+ logging_location: LoggingLocation,
657
+ logging_dir: str,
658
+ ) -> None:
659
+ self.num_worker_procs = num_worker_procs
660
+ self.gpus_per_host = gpus_per_host
661
+ self.logging_location = logging_location
662
+ self.logging_dir = logging_dir
663
+
664
+ def launch(
665
+ self,
666
+ *,
667
+ worker_world: str,
668
+ worker_rank: int,
669
+ bootstrap_addr: str,
670
+ labels: Dict[str, str],
671
+ ) -> subprocess.Popen[bytes]:
672
+ stdout_location, stderr_location = _create_logging_locations(
673
+ self.logging_dir,
674
+ f"{worker_world}_{worker_rank}",
675
+ self.logging_location,
676
+ )
677
+ return _start_worker_cmd(
678
+ world_uuid=worker_world,
679
+ worker_rank=worker_rank,
680
+ gpus_per_host=self.gpus_per_host,
681
+ num_worker_procs=self.num_worker_procs,
682
+ args=[
683
+ "worker",
684
+ "--world-id",
685
+ worker_world,
686
+ "--proc-id",
687
+ f"{worker_world}[{worker_rank}]",
688
+ "--bootstrap-addr",
689
+ bootstrap_addr,
690
+ ]
691
+ + _get_labels("extra-proc-labels", labels),
692
+ stdout=stdout_location,
693
+ stderr=stderr_location,
694
+ )
695
+
696
+
697
+ class WorkerServer(ServerBase["WorkerServer"]):
698
+ def __init__(
699
+ self,
700
+ *,
701
+ uuid: str,
702
+ num_worker_procs: int,
703
+ gpus_per_host: int,
704
+ world_rank: int,
705
+ logging_location: LoggingLocation,
706
+ logging_dir: str,
707
+ ) -> None:
708
+ super().__init__(
709
+ name=uuid,
710
+ response_cls=WorkerServerResponse,
711
+ request_cls=WorkerServerRequest,
712
+ )
713
+ self.uuid = uuid
714
+ self.num_worker_procs = num_worker_procs
715
+ self.gpus_per_host = gpus_per_host
716
+ self.world_rank = world_rank
717
+ self.logging_location = logging_location
718
+ self.logging_dir = logging_dir
719
+
720
+ def _launch_server(
721
+ self,
722
+ read_fd: int,
723
+ write_fd: int,
724
+ ) -> subprocess.Popen[bytes]:
725
+ stdout_location, stderr_location = _create_logging_locations(
726
+ self.logging_dir,
727
+ f"{self.uuid}_{self.world_rank}",
728
+ self.logging_location,
729
+ )
730
+ return _start_worker_cmd(
731
+ world_uuid=self.uuid,
732
+ worker_rank=self.world_rank,
733
+ gpus_per_host=self.gpus_per_host,
734
+ num_worker_procs=self.num_worker_procs,
735
+ args=["worker-server", str(read_fd), str(write_fd)],
736
+ pass_fds=(read_fd, write_fd),
737
+ stdin=subprocess.PIPE,
738
+ stdout=stdout_location,
739
+ stderr=stderr_location,
740
+ )
741
+
742
+ def launch(
743
+ self,
744
+ *,
745
+ worker_world: str,
746
+ bootstrap_addr: str,
747
+ labels: Dict[str, str],
748
+ ) -> ServerInstance:
749
+ return self._launch_instance(
750
+ msg=WorkerServerRequest.Run(
751
+ world_id=worker_world,
752
+ proc_id=f"{worker_world}[{self.world_rank}]",
753
+ bootstrap_addr=bootstrap_addr,
754
+ labels=list(labels.items()),
755
+ )
756
+ )
757
+
758
+
759
+ class WorkerServers(IWorkerFactory):
760
+ def __init__(
761
+ self,
762
+ *,
763
+ workers: dict[int, WorkerServer],
764
+ ) -> None:
765
+ self._workers = workers
766
+ self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
767
+
768
+ @staticmethod
769
+ def create(
770
+ uuid: str,
771
+ num_worker_procs: int,
772
+ gpus_per_host: int,
773
+ logging_location: LoggingLocation,
774
+ logging_dir: str,
775
+ ) -> "WorkerServers":
776
+ return WorkerServers(
777
+ workers={
778
+ world_rank: WorkerServer(
779
+ uuid=uuid,
780
+ num_worker_procs=num_worker_procs,
781
+ gpus_per_host=gpus_per_host,
782
+ world_rank=world_rank,
783
+ logging_location=logging_location,
784
+ logging_dir=logging_dir,
785
+ )
786
+ for world_rank in range(num_worker_procs)
787
+ },
788
+ )
789
+
790
+ def initiate_shutdown(self) -> None:
791
+ for worker in self._workers.values():
792
+ worker.initiate_shutdown()
793
+
794
+ def __enter__(self) -> "WorkerServers":
795
+ for worker in self._workers.values():
796
+ self._contexts.enter_context(worker)
797
+ return self
798
+
799
+ def __exit__(
800
+ self,
801
+ exc_type: Type[BaseException] | None,
802
+ exc_val: BaseException | None,
803
+ exc_tb: TracebackType | None,
804
+ ) -> None:
805
+ self.initiate_shutdown()
806
+ self._contexts.__exit__(exc_type, exc_val, exc_tb)
807
+
808
+ def launch(
809
+ self,
810
+ *,
811
+ worker_world: str,
812
+ worker_rank: int,
813
+ bootstrap_addr: str,
814
+ labels: Dict[str, str],
815
+ ) -> ServerInstance | subprocess.Popen[bytes]:
816
+ return self._workers[worker_rank].launch(
817
+ worker_world=worker_world,
818
+ bootstrap_addr=bootstrap_addr,
819
+ labels=labels,
820
+ )
821
+
822
+
823
+ class ProcessCache:
824
+ def __init__(
825
+ self,
826
+ *,
827
+ logging_location: LoggingLocation,
828
+ logging_dir: str,
829
+ ) -> None:
830
+ self.logging_location: LoggingLocation = logging_location
831
+ self.logging_dir: str = logging_dir
832
+
833
+ self._system_cache: SystemServer | None = None
834
+ self._controller_cache: ControllerServer | None = None
835
+ self._worker_cache: dict[Tuple[int, int], WorkerServers] = {}
836
+ self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
837
+
838
+ def __enter__(self) -> "ProcessCache":
839
+ return self
840
+
841
+ def __exit__(
842
+ self,
843
+ exc_type: Type[BaseException] | None,
844
+ exc_val: BaseException | None,
845
+ exc_tb: TracebackType | None,
846
+ ) -> None:
847
+ if self._system_cache is not None:
848
+ self._system_cache.initiate_shutdown()
849
+ if self._controller_cache is not None:
850
+ self._controller_cache.initiate_shutdown()
851
+ for workers in self._worker_cache.values():
852
+ workers.initiate_shutdown()
853
+ self._contexts.__exit__(exc_type, exc_val, exc_tb)
854
+
855
+ def get_system_server(self) -> SystemServer:
856
+ if self._system_cache is None:
857
+ system = SystemServer(
858
+ uuid="cached_system",
859
+ logging_location=self.logging_location,
860
+ logging_dir=self.logging_dir,
861
+ )
862
+ self._system_cache = self._contexts.enter_context(system)
863
+ assert self._system_cache is not None
864
+ return self._system_cache
865
+
866
+ def get_controller_server(self) -> ControllerServer:
867
+ if self._controller_cache is None:
868
+ controller = ControllerServer(
869
+ uuid="cached_controller",
870
+ logging_location=self.logging_location,
871
+ logging_dir=self.logging_dir,
872
+ )
873
+ self._controller_cache = self._contexts.enter_context(controller)
874
+ assert self._controller_cache is not None
875
+ return self._controller_cache
876
+
877
+ def get_worker_servers(
878
+ self,
879
+ *,
880
+ num_worker_procs: int,
881
+ gpus_per_host: int,
882
+ ) -> WorkerServers:
883
+ key = (num_worker_procs, gpus_per_host)
884
+ workers = self._worker_cache.get(key)
885
+ if workers is None:
886
+ workers = WorkerServers.create(
887
+ uuid=f"cached_workers_{num_worker_procs}_{gpus_per_host}",
888
+ num_worker_procs=num_worker_procs,
889
+ gpus_per_host=gpus_per_host,
890
+ logging_location=self.logging_location,
891
+ logging_dir=self.logging_dir,
892
+ )
893
+ self._worker_cache[key] = self._contexts.enter_context(workers)
894
+ return workers
895
+
896
+
897
+ class Bootstrap:
898
+ def __init__(
899
+ self,
900
+ *,
901
+ meshes: int,
902
+ hosts_per_mesh: int,
903
+ gpus_per_host: int,
904
+ worker_factory: IWorkerFactory | None = None,
905
+ controller_factory: IControllerFactory | None = None,
906
+ system_factory: ISystemFactory | None = None,
907
+ socket_type: SocketType,
908
+ logging_location: LoggingLocation,
909
+ supervision_params: SupervisionParams | None,
910
+ controller_params: ControllerParams | None,
911
+ auto_epoch: bool,
912
+ controller_labels: Dict[str, str] | None = None,
913
+ worker_labels: Dict[str, str] | None = None,
914
+ ) -> None:
915
+ if supervision_params is None:
916
+ supervision_params = SupervisionParams(
917
+ update_timeout_in_sec=20,
918
+ query_interval_in_sec=2,
919
+ update_interval_in_sec=2,
920
+ )
921
+ self.supervision_params: SupervisionParams = supervision_params
922
+
923
+ if controller_params is None:
924
+ controller_params = ControllerParams(
925
+ worker_progress_check_interval_in_sec=10,
926
+ operation_timeout_in_sec=120,
927
+ operations_per_worker_progress_request=100,
928
+ fail_on_worker_timeout=False,
929
+ )
930
+ self.controller_params: ControllerParams = controller_params
931
+
932
+ self.epoch: int | None = 0 if auto_epoch else None
933
+
934
+ # hyperactor_telemetry will take the execution id and use it across all processes
935
+ execution_id = "rust_local_" + uuid.uuid4().hex
936
+ os.environ["HYPERACTOR_EXECUTION_ID"] = execution_id
937
+
938
+ # Create a temporary directory for logging
939
+ self.logging_dir: str = (
940
+ tempfile.mkdtemp(prefix="rust_local_mesh_")
941
+ if logging_location == LoggingLocation.FILE
942
+ else "N/A"
943
+ )
944
+ logger.info(
945
+ f"Creating Rust local mesh with {meshes} meshes X {hosts_per_mesh} hosts X {gpus_per_host} gpus.\n"
946
+ f"Logging directory: \033[92;1m{self.logging_dir}\033[0m\n"
947
+ f"Execution id: {execution_id}"
948
+ )
949
+ self.logging_location: LoggingLocation = logging_location
950
+
951
+ if controller_factory is None:
952
+ controller_factory = ControllerFactory(
953
+ logging_location=self.logging_location,
954
+ logging_dir=self.logging_dir,
955
+ )
956
+ self.controller_factory: IControllerFactory = controller_factory
957
+
958
+ if system_factory is None:
959
+ system_factory = SystemFactory(
960
+ logging_location=self.logging_location,
961
+ logging_dir=self.logging_dir,
962
+ )
963
+ self.system_factory: ISystemFactory = system_factory
964
+
965
+ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
966
+ if worker_factory is None:
967
+ worker_factory = WorkerFactory(
968
+ num_worker_procs=hosts_per_mesh * gpus_per_host,
969
+ gpus_per_host=gpus_per_host,
970
+ logging_location=self.logging_location,
971
+ logging_dir=self.logging_dir,
972
+ )
973
+ self.worker_factory: IWorkerFactory = worker_factory
974
+
975
+ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
976
+ fake_call(lambda: 0)
977
+
978
+ self.bootstrap_addr: str
979
+ if socket_type == SocketType.TCP:
980
+ with socket.socket() as sock:
981
+ sock.bind(("", 0))
982
+ port = sock.getsockname()[1]
983
+ self.bootstrap_addr = f"tcp![::1]:{port}"
984
+ elif socket_type == SocketType.UNIX:
985
+ # provide a random unix socket address
986
+ self.bootstrap_addr: str = f"unix!@{''.join(random.choice(string.ascii_lowercase) for _ in range(14))}-system"
987
+ else:
988
+ raise ValueError(f"Unknown socket type: {socket_type}")
989
+
990
+ env = os.environ.copy()
991
+ env["HYPERACTOR_MANAGED_SUBPROCESS"] = "1"
992
+ self.env: dict[str, str] = env
993
+
994
+ # Launch a single system globally
995
+ self.processes: list[subprocess.Popen[bytes] | ServerInstance] = []
996
+ self.processes.append(self._launch_system())
997
+
998
+ self.has_shutdown: bool = False
999
+ self.gpus_per_host: int = gpus_per_host
1000
+ self.num_worker_procs: int = hosts_per_mesh * gpus_per_host
1001
+ self.controller_ids: list[ActorId] = []
1002
+ self.mesh_worlds: dict[
1003
+ MeshWorld, list[subprocess.Popen[bytes] | ServerInstance]
1004
+ ] = {}
1005
+
1006
+ # Create meshes, each of which contains a single controller and multiple workers.
1007
+ # All of them will connect to the same system.
1008
+ pids: dict[str, list[int]] = {}
1009
+ for i in range(meshes):
1010
+ mesh_name: str = f"mesh_{i}"
1011
+ controller_world: str = f"{mesh_name}_controller"
1012
+ worker_world: str = f"{mesh_name}_worker"
1013
+ controller_id: ActorId = ActorId(
1014
+ world_name=controller_world,
1015
+ rank=0,
1016
+ actor_name="controller",
1017
+ )
1018
+ self.mesh_worlds[(worker_world, controller_id)] = []
1019
+ self.controller_ids.append(controller_id)
1020
+
1021
+ processes: list[subprocess.Popen[bytes] | ServerInstance] = (
1022
+ self.launch_mesh(
1023
+ controller_id,
1024
+ worker_world,
1025
+ controller_labels=controller_labels,
1026
+ worker_labels=worker_labels,
1027
+ )
1028
+ )
1029
+
1030
+ self.processes.extend(processes)
1031
+ pids[mesh_name] = [p.pid for p in processes]
1032
+
1033
+ log_message = (
1034
+ f"All processes started successfully:\n system: {self.processes[0].pid}\n"
1035
+ )
1036
+ for mesh, procs in pids.items():
1037
+ log_message += f"{mesh}: controller: {procs[0]}, "
1038
+ worker_messages = []
1039
+ for i in range(1, len(procs)):
1040
+ worker_messages.append(f"{i-1}: {procs[i]}")
1041
+ log_message += "workers: " + ", ".join(worker_messages)
1042
+ log_message += "\n"
1043
+
1044
+ self._contexts: contextlib.ExitStack[None] = contextlib.ExitStack()
1045
+
1046
+ logger.info(log_message)
1047
+
1048
+ def _launch_system(
1049
+ self,
1050
+ ) -> ServerInstance | subprocess.Popen[bytes]:
1051
+ logger.info("launching system")
1052
+ try:
1053
+ return self.system_factory.launch(
1054
+ bootstrap_addr=self.bootstrap_addr,
1055
+ supervision_params=self.supervision_params,
1056
+ )
1057
+ except Exception as e:
1058
+ logger.error(f"Failed to start system process: {e}")
1059
+ raise e
1060
+
1061
+ def _launch_controller(
1062
+ self,
1063
+ controller_id: ActorId,
1064
+ worker_world: str,
1065
+ epoch: str | None = None,
1066
+ labels: Dict[str, str] | None = None,
1067
+ ) -> subprocess.Popen[bytes] | ServerInstance:
1068
+ logger.info("launching controller")
1069
+ try:
1070
+ return self.controller_factory.launch(
1071
+ bootstrap_addr=self.bootstrap_addr,
1072
+ worker_world=worker_world
1073
+ if epoch is None
1074
+ else f"{worker_world}_{epoch}",
1075
+ controller_id=ActorId.from_string(
1076
+ (
1077
+ f"{controller_id.world_name + '_' + epoch if epoch else controller_id.world_name}"
1078
+ f"[{controller_id.rank}]."
1079
+ f"{controller_id.actor_name}[{controller_id.pid}]"
1080
+ )
1081
+ ),
1082
+ num_worker_procs=self.num_worker_procs,
1083
+ gpus_per_host=self.gpus_per_host,
1084
+ supervision_params=self.supervision_params,
1085
+ controller_params=self.controller_params,
1086
+ labels={} if labels is None else labels,
1087
+ )
1088
+ except Exception as e:
1089
+ logger.error(f"Failed to start controller process: {e}")
1090
+ raise e
1091
+
1092
+ def _launch_worker(
1093
+ self,
1094
+ worker_world: str,
1095
+ worker_rank: int,
1096
+ epoch: str | None = None,
1097
+ labels: Dict[str, str] | None = None,
1098
+ ) -> subprocess.Popen[bytes] | ServerInstance:
1099
+ logger.info("launching worker")
1100
+ try:
1101
+ return self.worker_factory.launch(
1102
+ worker_world=worker_world
1103
+ if epoch is None
1104
+ else f"{worker_world}_{epoch}",
1105
+ worker_rank=worker_rank,
1106
+ bootstrap_addr=self.bootstrap_addr,
1107
+ labels={} if labels is None else labels,
1108
+ )
1109
+ except Exception as e:
1110
+ logger.error(f"Failed to start worker process {worker_rank}: {e}")
1111
+ raise e
1112
+
1113
+ def get_mesh_worlds(self) -> list[MeshWorld]:
1114
+ return list(self.mesh_worlds.keys())
1115
+
1116
+ def kill_mesh(self, mesh_world: MeshWorld) -> None:
1117
+ logger.info(f"Killing mesh {mesh_world}")
1118
+ procs = self.mesh_worlds[mesh_world]
1119
+ procs[-1].kill()
1120
+
1121
+ def spawn_mesh(self, mesh_world: MeshWorld) -> None:
1122
+ self.launch_mesh(mesh_world[1], mesh_world[0])
1123
+
1124
+ def launch_mesh(
1125
+ self,
1126
+ controller_id: ActorId,
1127
+ worker_world: str,
1128
+ controller_labels: Dict[str, str] | None = None,
1129
+ worker_labels: Dict[str, str] | None = None,
1130
+ ) -> list[subprocess.Popen[bytes] | ServerInstance]:
1131
+ """
1132
+ Create a single controller and multiple workers for a mesh.
1133
+ The first process of the return is the controller.
1134
+ The remaining ones are workers.
1135
+ """
1136
+ logger.info(
1137
+ f"Launching mesh {worker_world} with controller {controller_id} epoch {self.epoch}"
1138
+ )
1139
+ epoch: str | None = None
1140
+ if self.epoch is not None:
1141
+ epoch = f"epoch_{self.epoch}"
1142
+ self.epoch += 1
1143
+
1144
+ processes: list[subprocess.Popen[bytes] | ServerInstance] = []
1145
+ controller_process = self._launch_controller(
1146
+ controller_id,
1147
+ worker_world,
1148
+ epoch,
1149
+ controller_labels,
1150
+ )
1151
+ processes.append(controller_process)
1152
+
1153
+ for i in range(self.num_worker_procs):
1154
+ worker_process = self._launch_worker(worker_world, i, epoch, worker_labels)
1155
+ processes.append(worker_process)
1156
+ self.mesh_worlds[(worker_world, controller_id)] = processes
1157
+ return processes
1158
+
1159
+ def __enter__(self) -> "Bootstrap":
1160
+ for process in self.processes:
1161
+ self._contexts.enter_context(process)
1162
+ return self
1163
+
1164
+ def __exit__(
1165
+ self,
1166
+ exc_type: Type[BaseException] | None,
1167
+ exc_val: BaseException | None,
1168
+ exc_tb: TracebackType | None,
1169
+ ) -> None:
1170
+ for process in self.processes:
1171
+ process.terminate()
1172
+ self._contexts.__exit__(exc_type, exc_val, exc_tb)
1173
+
1174
+
1175
+ def _local_device_count() -> int:
1176
+ dev_path = Path("/dev")
1177
+ pattern = re.compile(r"nvidia\d+$")
1178
+ nvidia_devices = [dev for dev in dev_path.iterdir() if pattern.match(dev.name)]
1179
+ return len(nvidia_devices)
1180
+
1181
+
1182
+ def _get_worker_exec_info() -> tuple[list[str], dict[str, str]]:
1183
+ if IN_PAR:
1184
+ cmd = [sys.argv[0]]
1185
+ env = {
1186
+ "PAR_MAIN_OVERRIDE": _MONARCH_TENSOR_WORKER_MAIN,
1187
+ }
1188
+ else:
1189
+ cmd = [sys.executable, "-m", _MONARCH_TENSOR_WORKER_MAIN]
1190
+ env = {}
1191
+
1192
+ env["MONARCH_TENSOR_WORKER_MAIN"] = _MONARCH_TENSOR_WORKER_MAIN
1193
+ env["MONARCH_TENSOR_WORKER_EXE"] = cmd[0]
1194
+ return cmd, env
1195
+
1196
+
1197
+ @contextlib.contextmanager
1198
+ def local_mesh(
1199
+ *,
1200
+ hosts: int = 1,
1201
+ gpus_per_host: int | None = None,
1202
+ socket_type: SocketType = SocketType.TCP,
1203
+ logging_location: LoggingLocation = LoggingLocation.FILE,
1204
+ supervision_params: SupervisionParams | None = None,
1205
+ controller_params: ControllerParams | None = None,
1206
+ worker_factory: IWorkerFactory | None = None,
1207
+ controller_factory: IControllerFactory | None = None,
1208
+ system_factory: ISystemFactory | None = None,
1209
+ ) -> Generator[DeviceMesh, None, None]:
1210
+ """
1211
+ Creates a single local device mesh with the given number of per host.
1212
+
1213
+ Args:
1214
+ hosts : number of hosts, primarily used for simulating multiple machines locally.
1215
+ Default: 1
1216
+ gpus_per_host : number of gpus per host.
1217
+ Default: the number of GPUs this machine has.
1218
+ socket_type : socket type to use for communication between processes.
1219
+ Default: TCP.
1220
+
1221
+ Example::
1222
+ with local_mesh().activate():
1223
+ x = torch.rand(3, 4)
1224
+ local_tensor = fetch_shard(x).result()
1225
+ """
1226
+ with local_meshes(
1227
+ meshes=1,
1228
+ hosts_per_mesh=hosts,
1229
+ gpus_per_host=gpus_per_host,
1230
+ socket_type=socket_type,
1231
+ logging_location=logging_location,
1232
+ supervision_params=supervision_params,
1233
+ controller_params=controller_params,
1234
+ worker_factory=worker_factory,
1235
+ controller_factory=controller_factory,
1236
+ system_factory=system_factory,
1237
+ ) as dms:
1238
+ assert len(dms) == 1
1239
+ yield dms[0]
1240
+
1241
+
1242
+ @contextlib.contextmanager
1243
+ def local_meshes(
1244
+ *,
1245
+ meshes: int = 1,
1246
+ hosts_per_mesh: int = 1,
1247
+ gpus_per_host: int | None = None,
1248
+ socket_type: SocketType = SocketType.TCP,
1249
+ logging_location: LoggingLocation = LoggingLocation.FILE,
1250
+ supervision_params: SupervisionParams | None = None,
1251
+ controller_params: ControllerParams | None = None,
1252
+ worker_factory: IWorkerFactory | None = None,
1253
+ controller_factory: IControllerFactory | None = None,
1254
+ system_factory: ISystemFactory | None = None,
1255
+ ) -> Generator[list[DeviceMesh], None, None]:
1256
+ """
1257
+ Creates multiple local device meshes.
1258
+
1259
+ Args:
1260
+ meshes : number of global meshes to create.
1261
+ Default: 1
1262
+ hosts_per_mesh : number of hosts per mesh, primarily used for simulating multiple machines locally.
1263
+ Default: 1
1264
+ gpus_per_host : number of gpus per host.
1265
+ Default: the number of GPUs this machine has.
1266
+ socket_type : socket type to use for communication between processes.
1267
+ Default: TCP.
1268
+ """
1269
+ (dms, bootstrap) = local_meshes_and_bootstraps(
1270
+ meshes=meshes,
1271
+ hosts_per_mesh=hosts_per_mesh,
1272
+ gpus_per_host=gpus_per_host,
1273
+ socket_type=socket_type,
1274
+ logging_location=logging_location,
1275
+ supervision_params=supervision_params,
1276
+ controller_params=controller_params,
1277
+ worker_factory=worker_factory,
1278
+ controller_factory=controller_factory,
1279
+ system_factory=system_factory,
1280
+ )
1281
+ with bootstrap:
1282
+ maybe_error = None
1283
+ try:
1284
+ yield dms
1285
+ except Exception as e:
1286
+ maybe_error = e
1287
+ raise
1288
+ finally:
1289
+ for dm in dms:
1290
+ dm.exit(maybe_error)
1291
+
1292
+
1293
+ def local_meshes_and_bootstraps(
1294
+ *,
1295
+ meshes: int = 1,
1296
+ hosts_per_mesh: int = 1,
1297
+ gpus_per_host: int | None = None,
1298
+ socket_type: SocketType = SocketType.TCP,
1299
+ logging_location: LoggingLocation = LoggingLocation.FILE,
1300
+ supervision_params: SupervisionParams | None = None,
1301
+ controller_params: ControllerParams | None = None,
1302
+ auto_epoch: bool = False,
1303
+ worker_factory: IWorkerFactory | None = None,
1304
+ controller_factory: IControllerFactory | None = None,
1305
+ system_factory: ISystemFactory | None = None,
1306
+ ) -> tuple[list[DeviceMesh], Bootstrap]:
1307
+ """
1308
+ Same as local_meshes, but also returns the bootstrap object. This is
1309
+ useful in tests where we want to maniputate the bootstrap object.
1310
+ """
1311
+
1312
+ if gpus_per_host is None:
1313
+ gpus_per_host = _local_device_count()
1314
+ assert (
1315
+ 0 < gpus_per_host <= 8
1316
+ ), "Number of GPUs must be greater than 0 and at most 8."
1317
+ bootstrap: Bootstrap = Bootstrap(
1318
+ meshes=meshes,
1319
+ hosts_per_mesh=hosts_per_mesh,
1320
+ gpus_per_host=gpus_per_host,
1321
+ socket_type=socket_type,
1322
+ logging_location=logging_location,
1323
+ supervision_params=supervision_params,
1324
+ controller_params=controller_params,
1325
+ auto_epoch=auto_epoch,
1326
+ worker_factory=worker_factory,
1327
+ controller_factory=controller_factory,
1328
+ system_factory=system_factory,
1329
+ )
1330
+
1331
+ def create_exit(
1332
+ dm: DeviceMesh, bootstrap: Bootstrap
1333
+ ) -> Callable[[Optional[RemoteException | DeviceException | Exception]], None]:
1334
+ def exit(
1335
+ error: Optional[RemoteException | DeviceException | Exception] = None,
1336
+ ) -> None:
1337
+ # We only support one single client proc.
1338
+ if not bootstrap.has_shutdown:
1339
+ dm.client.shutdown(True, error)
1340
+ bootstrap.has_shutdown = True
1341
+
1342
+ # We do not need to shutdown bootstrap and clean up the processes
1343
+ # as they will be cleaned up with the parent.
1344
+ return exit
1345
+
1346
+ dms = rust_backend_meshes(
1347
+ system_addr=bootstrap.bootstrap_addr,
1348
+ hosts=hosts_per_mesh,
1349
+ gpus=gpus_per_host,
1350
+ requested_meshes=meshes,
1351
+ )
1352
+
1353
+ for dm in dms:
1354
+ dm.exit = create_exit(dm, bootstrap)
1355
+
1356
+ return (dms, bootstrap)
1357
+
1358
+
1359
+ def local_mesh_provider(
1360
+ *,
1361
+ meshes: int = 1,
1362
+ hosts_per_mesh: int = 1,
1363
+ gpus_per_host: int | None = None,
1364
+ socket_type: SocketType = SocketType.TCP,
1365
+ logging_location: LoggingLocation = LoggingLocation.FILE,
1366
+ supervision_params: SupervisionParams | None = None,
1367
+ controller_params: ControllerParams | None = None,
1368
+ auto_epoch: bool = False,
1369
+ controller_labels: Dict[str, str] | None = None,
1370
+ worker_labels: Dict[str, str] | None = None,
1371
+ worker_factory: IWorkerFactory | None = None,
1372
+ controller_factory: IControllerFactory | None = None,
1373
+ system_factory: ISystemFactory | None = None,
1374
+ # pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type.
1375
+ ) -> tuple[PoolDeviceMeshProvider, Bootstrap]:
1376
+ if gpus_per_host is None:
1377
+ gpus_per_host = _local_device_count()
1378
+ assert (
1379
+ 0 < gpus_per_host <= 8
1380
+ ), "Number of GPUs must be greater than 0 and at most 8."
1381
+ bootstrap: Bootstrap = Bootstrap(
1382
+ meshes=meshes,
1383
+ hosts_per_mesh=hosts_per_mesh,
1384
+ gpus_per_host=gpus_per_host,
1385
+ socket_type=socket_type,
1386
+ logging_location=logging_location,
1387
+ supervision_params=supervision_params,
1388
+ controller_params=controller_params,
1389
+ auto_epoch=auto_epoch,
1390
+ controller_labels=controller_labels,
1391
+ worker_labels=worker_labels,
1392
+ worker_factory=worker_factory,
1393
+ controller_factory=controller_factory,
1394
+ system_factory=system_factory,
1395
+ )
1396
+
1397
+ provider = rust_backend_mesh_provider(
1398
+ system_addr=bootstrap.bootstrap_addr,
1399
+ hosts=hosts_per_mesh,
1400
+ gpus=gpus_per_host,
1401
+ )
1402
+ return (provider, bootstrap)