torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
monarch/random.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import os
|
8
|
+
from typing import NamedTuple, Tuple
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from monarch.common.remote import remote
|
12
|
+
from monarch.common.tensor import Tensor
|
13
|
+
|
14
|
+
|
15
|
+
class State(NamedTuple):
|
16
|
+
cpu: Tensor
|
17
|
+
cuda: Tensor
|
18
|
+
|
19
|
+
|
20
|
+
@remote(
|
21
|
+
propagate=lambda: (
|
22
|
+
torch.empty(5056, dtype=torch.uint8),
|
23
|
+
torch.empty(16, dtype=torch.uint8),
|
24
|
+
)
|
25
|
+
)
|
26
|
+
def _get_state() -> Tuple[torch.Tensor, torch.Tensor]:
|
27
|
+
return (torch.get_rng_state(), torch.cuda.get_rng_state())
|
28
|
+
|
29
|
+
|
30
|
+
@remote(propagate=lambda state: None)
|
31
|
+
def set_state(state: Tuple[Tensor, Tensor]):
|
32
|
+
cpu, device = state
|
33
|
+
torch.set_rng_state(cpu)
|
34
|
+
torch.cuda.set_rng_state(device)
|
35
|
+
|
36
|
+
|
37
|
+
@remote(propagate=lambda _: None)
|
38
|
+
def _manual_seed(seed: torch.Tensor):
|
39
|
+
torch.manual_seed(seed.item())
|
40
|
+
|
41
|
+
|
42
|
+
@remote(propagate=lambda: None)
|
43
|
+
def make_deterministic():
|
44
|
+
torch.use_deterministic_algorithms(True)
|
45
|
+
torch.backends.cudnn.deterministic = True
|
46
|
+
torch.backends.cudnn.benchmark = False
|
47
|
+
# env var for deterministic CuBLAS
|
48
|
+
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
|
49
|
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
50
|
+
|
51
|
+
|
52
|
+
def get_state() -> State:
|
53
|
+
return State(*_get_state())
|
54
|
+
|
55
|
+
|
56
|
+
def new_state(seed: Tensor) -> State:
|
57
|
+
orig = get_state()
|
58
|
+
_manual_seed(seed)
|
59
|
+
mine = get_state()
|
60
|
+
set_state(orig)
|
61
|
+
return mine
|
monarch/rdma.py
ADDED
@@ -0,0 +1,190 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import ctypes
|
8
|
+
|
9
|
+
import traceback
|
10
|
+
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from traceback import extract_tb, StackSummary
|
13
|
+
from typing import cast, Dict, Optional, Tuple
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
18
|
+
|
19
|
+
from monarch.actor_mesh import (
|
20
|
+
_ActorMeshRefImpl,
|
21
|
+
Actor,
|
22
|
+
ActorMeshRef,
|
23
|
+
endpoint,
|
24
|
+
MonarchContext,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class LocalRDMARecord:
|
30
|
+
data: torch.Tensor
|
31
|
+
|
32
|
+
|
33
|
+
_local_buffers: Dict[int, "LocalRDMARecord"] = {}
|
34
|
+
|
35
|
+
|
36
|
+
def _get_bytes(storage: torch.Tensor, offset: int, size: int) -> bytearray:
|
37
|
+
"""Extracts a bytearray from a 1D, 1byte per item tensor."""
|
38
|
+
if offset + size > storage.numel():
|
39
|
+
raise ValueError(f"Read out of range: {offset + size} > {storage.size()}")
|
40
|
+
addr = storage.data_ptr()
|
41
|
+
if storage.device.type != "cpu":
|
42
|
+
result = bytearray(size)
|
43
|
+
result_tensor = torch.frombuffer(
|
44
|
+
result,
|
45
|
+
dtype=torch.uint8,
|
46
|
+
)
|
47
|
+
source_tensor = storage[offset:]
|
48
|
+
result_tensor.copy_(source_tensor)
|
49
|
+
else:
|
50
|
+
ctypes_array = (ctypes.c_byte * size).from_address(addr)
|
51
|
+
result = bytearray(ctypes_array)
|
52
|
+
return result
|
53
|
+
|
54
|
+
|
55
|
+
class RDMAManager(Actor):
|
56
|
+
@staticmethod
|
57
|
+
def on_proc(proc_id: str) -> "RDMAManager":
|
58
|
+
ctx = MonarchContext.get()
|
59
|
+
return cast(
|
60
|
+
RDMAManager,
|
61
|
+
ActorMeshRef(
|
62
|
+
RDMAManager,
|
63
|
+
_ActorMeshRefImpl.from_actor_id(
|
64
|
+
ctx.mailbox,
|
65
|
+
ActorId.from_string(f"{proc_id}.rdma_manager[0]"),
|
66
|
+
),
|
67
|
+
ctx.mailbox,
|
68
|
+
),
|
69
|
+
)
|
70
|
+
|
71
|
+
@endpoint
|
72
|
+
async def drop(self, addr: int) -> None:
|
73
|
+
if addr in _local_buffers:
|
74
|
+
del _local_buffers[addr]
|
75
|
+
|
76
|
+
@endpoint
|
77
|
+
async def fetch(self, addr: int, offset: int, nbytes: int) -> bytearray:
|
78
|
+
if addr not in _local_buffers:
|
79
|
+
raise ValueError(f"Unknown buffer {addr}")
|
80
|
+
storage = _local_buffers[addr].data
|
81
|
+
return _get_bytes(storage, offset, nbytes)
|
82
|
+
|
83
|
+
@endpoint
|
84
|
+
async def put(self, addr: int, offset: int, bytes: bytearray) -> None:
|
85
|
+
if addr not in _local_buffers:
|
86
|
+
raise ValueError(f"Unknown buffer {addr}")
|
87
|
+
storage = _local_buffers[addr].data
|
88
|
+
storage[offset : offset + len(bytes)] = torch.frombuffer(
|
89
|
+
bytes, dtype=storage.dtype
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None:
|
94
|
+
if t.ndim != 1:
|
95
|
+
raise ValueError(f"Tensor must be 1D, got {t.ndim}D")
|
96
|
+
if t.dtype != torch.uint8:
|
97
|
+
raise ValueError(f"Tensor must be uint8, got {t.dtype}")
|
98
|
+
if not t.is_contiguous():
|
99
|
+
raise ValueError("Tensor must be contiguous")
|
100
|
+
|
101
|
+
|
102
|
+
class RDMABuffer:
|
103
|
+
def __init__(self, data: torch.Tensor) -> None:
|
104
|
+
"""
|
105
|
+
RDMABuffer only supports 1D contiguous tensors that are 1 byte per item.
|
106
|
+
|
107
|
+
To create a 1 byte, 1D view, use t.view(torch.uint8).flatten()
|
108
|
+
|
109
|
+
TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors
|
110
|
+
"""
|
111
|
+
_assert_tensor_is_1d_contiguous_uint8(data)
|
112
|
+
assert data.storage_offset() == 0
|
113
|
+
storage = data.untyped_storage()
|
114
|
+
self.addr: int = storage.data_ptr()
|
115
|
+
self.begin = 0
|
116
|
+
self.end: int = storage.size()
|
117
|
+
self.proc_id: str = MonarchContext.get().proc_id
|
118
|
+
self.local_data: object = None
|
119
|
+
_local_buffers[self.addr] = LocalRDMARecord(data)
|
120
|
+
|
121
|
+
def drop(self) -> None:
|
122
|
+
if self.proc_id is None:
|
123
|
+
del _local_buffers[self.addr]
|
124
|
+
return
|
125
|
+
rmda_actor = RDMAManager.on_proc(self.proc_id)
|
126
|
+
# pyre-ignore[16]: Undefined attribute [16]: `Endpoint` has no attribute `cast`.
|
127
|
+
rmda_actor.drop.cast(self.addr)
|
128
|
+
|
129
|
+
def __getstate__(self) -> Tuple[int, int, int, Optional[str]]:
|
130
|
+
proc_id = self.proc_id
|
131
|
+
# locally created RDMABuffer being set remotely,
|
132
|
+
# record its proc_id so we know how to establish connections to it
|
133
|
+
if proc_id is None:
|
134
|
+
proc_id = MonarchContext.get().proc_id
|
135
|
+
return (self.addr, self.begin, self.end, proc_id)
|
136
|
+
|
137
|
+
def __setstate__(self, state: Tuple[int, int, int, str]) -> None:
|
138
|
+
self.local_data = None
|
139
|
+
self.addr, self.begin, self.end, self.proc_id = state
|
140
|
+
|
141
|
+
async def read_into(self, dst: torch.Tensor, offset: int = 0) -> None:
|
142
|
+
"""
|
143
|
+
Read data from the RDMABuffer into a destination tensor.
|
144
|
+
|
145
|
+
The destination tensor must be contiguous and 1 byte per item.
|
146
|
+
"""
|
147
|
+
_assert_tensor_is_1d_contiguous_uint8(dst)
|
148
|
+
bytes = await RDMAManager.on_proc(self.proc_id).fetch.call_one(
|
149
|
+
self.addr, offset, dst.numel()
|
150
|
+
)
|
151
|
+
dst.copy_(torch.frombuffer(bytes, dtype=torch.uint8))
|
152
|
+
|
153
|
+
async def write(self, src: torch.Tensor, offset: int = 0) -> None:
|
154
|
+
"""
|
155
|
+
Write data from a source tensor into the RDMABuffer.
|
156
|
+
|
157
|
+
The source tensor must be contiguous and 1 byte per item.
|
158
|
+
"""
|
159
|
+
_assert_tensor_is_1d_contiguous_uint8(src)
|
160
|
+
bytes = _get_bytes(
|
161
|
+
src,
|
162
|
+
cast(int, src.storage_offset()),
|
163
|
+
src.numel(),
|
164
|
+
)
|
165
|
+
await RDMAManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes)
|
166
|
+
|
167
|
+
|
168
|
+
class ActorMeshRefCallFailedException(Exception):
|
169
|
+
"""
|
170
|
+
Deterministic problem with the user's code.
|
171
|
+
For example, an OOM resulting in trying to allocate too much GPU memory, or violating
|
172
|
+
some invariant enforced by the various APIs.
|
173
|
+
"""
|
174
|
+
|
175
|
+
def __init__(
|
176
|
+
self,
|
177
|
+
exception: Exception,
|
178
|
+
message: str = "A remote service call has failed asynchronously.",
|
179
|
+
) -> None:
|
180
|
+
self.exception = exception
|
181
|
+
self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__)
|
182
|
+
self.message = message
|
183
|
+
|
184
|
+
def __str__(self) -> str:
|
185
|
+
exe = str(self.exception)
|
186
|
+
actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames))
|
187
|
+
return (
|
188
|
+
f"{self.message}\n"
|
189
|
+
f"Traceback of where the service call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}"
|
190
|
+
)
|
monarch/remote_class.py
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import importlib
|
9
|
+
import itertools
|
10
|
+
from typing import Any, Dict
|
11
|
+
|
12
|
+
from monarch.common import device_mesh
|
13
|
+
from monarch.common.remote import remote
|
14
|
+
|
15
|
+
|
16
|
+
class ControllerRemoteClass:
|
17
|
+
"""
|
18
|
+
This class simplifies the creation and management of remote classes. It serves as
|
19
|
+
the controller side of a remote class architecture. Classes that are intended to be
|
20
|
+
controlled remotely should inherit from this class. The constructor of the inheriting
|
21
|
+
class must invoke `super().__init__()` with the path to the remote class that will be
|
22
|
+
used on the worker nodes. Methods that are intended for remote execution must be
|
23
|
+
decorated with `ControllerRemoteClass.remote_method`.
|
24
|
+
|
25
|
+
Note: This class is designed for use by the controller developer only and should
|
26
|
+
not be directly used in model code.
|
27
|
+
|
28
|
+
Example usage:
|
29
|
+
|
30
|
+
class ControllerMyClass(ControllerRemoteClass):
|
31
|
+
def __init__(self, *args, **kwargs) -> None:
|
32
|
+
super().__init__("my_package.my_class", *args, **kwargs)
|
33
|
+
|
34
|
+
@ControllerRemoteClass.remote_method
|
35
|
+
def some_method(self, *args, **kwargs) -> None:
|
36
|
+
# This method is intended for remote execution and does nothing locally.
|
37
|
+
pass
|
38
|
+
"""
|
39
|
+
|
40
|
+
_counter = itertools.count()
|
41
|
+
|
42
|
+
def __init__(self, cls_path: str, *args, **kwargs) -> None:
|
43
|
+
self.ident = next(ControllerRemoteClass._counter)
|
44
|
+
self.cls_path = cls_path
|
45
|
+
self.mesh = device_mesh._active
|
46
|
+
_controller_remote_class_method(
|
47
|
+
cls_path, "remote_init", self.ident, *args, **kwargs
|
48
|
+
)
|
49
|
+
|
50
|
+
def __del__(self) -> None:
|
51
|
+
mesh = getattr(self, "mesh", None)
|
52
|
+
if mesh is not None and not mesh.client._shutdown:
|
53
|
+
with self.mesh.activate():
|
54
|
+
_controller_remote_class_method(
|
55
|
+
self.cls_path,
|
56
|
+
"remote_del",
|
57
|
+
self.ident,
|
58
|
+
)
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def remote_method(fn):
|
62
|
+
def wrapper(self, *args, **kwargs) -> None:
|
63
|
+
_controller_remote_class_method(
|
64
|
+
self.cls_path, "remote_method", self.ident, fn.__name__, *args, **kwargs
|
65
|
+
)
|
66
|
+
|
67
|
+
return wrapper
|
68
|
+
|
69
|
+
|
70
|
+
# Add the logic as a separate private function instead of adding ita to
|
71
|
+
# ResolvableFunctionFromPath. This avoids users to using this directly.
|
72
|
+
_controller_remote_class_method = remote(
|
73
|
+
"monarch.remote_class._remote_class_method", propagate="inspect"
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
def _remote_class_method(cls_path: str, method_name: str, *args, **kwargs) -> None:
|
78
|
+
modulename, classname = cls_path.rsplit(".", 1)
|
79
|
+
module = importlib.import_module(modulename)
|
80
|
+
cls = getattr(module, classname)
|
81
|
+
method = getattr(cls, method_name)
|
82
|
+
method(*args, **kwargs)
|
83
|
+
|
84
|
+
|
85
|
+
class WorkerRemoteClass:
|
86
|
+
"""
|
87
|
+
This class is designed to be used alongside ``ControllerRemoteClass`` and represents
|
88
|
+
the worker-side of a remote class architecture. Instances of this class should just
|
89
|
+
mimic standard Python classes, with the notable exception that all methods must
|
90
|
+
return None -- the current RemoteClass architecture does not support methods that
|
91
|
+
return values.
|
92
|
+
|
93
|
+
The `ident` attribute is used for tracking object instances created via `remote_init`.
|
94
|
+
This tracking is necessary because the remote function would otherwise lose the
|
95
|
+
reference to the object.
|
96
|
+
"""
|
97
|
+
|
98
|
+
_objects: Dict[int, Any] = {}
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def remote_init(cls, ident: int, *args, **kwargs) -> None:
|
102
|
+
WorkerRemoteClass._objects[ident] = cls(*args, **kwargs)
|
103
|
+
|
104
|
+
@classmethod
|
105
|
+
def remote_del(cls, ident) -> None:
|
106
|
+
WorkerRemoteClass._objects.pop(ident)
|
107
|
+
|
108
|
+
@classmethod
|
109
|
+
def remote_method(cls, ident: int, method_name, *args, **kwargs) -> None:
|
110
|
+
instance = WorkerRemoteClass._objects[ident]
|
111
|
+
assert (
|
112
|
+
cls == instance.__class__
|
113
|
+
), "Mismatched class type {cls} {instance.__class__}"
|
114
|
+
getattr(instance, method_name)(*args, **kwargs)
|
@@ -0,0 +1,280 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
import logging
|
9
|
+
import time
|
10
|
+
from logging import Logger
|
11
|
+
from typing import Any, Callable, Optional, Protocol
|
12
|
+
|
13
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
14
|
+
ClientActor,
|
15
|
+
SystemSnapshotFilter,
|
16
|
+
)
|
17
|
+
|
18
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
19
|
+
ActorId,
|
20
|
+
init_proc,
|
21
|
+
Proc,
|
22
|
+
)
|
23
|
+
from monarch.common.client import Client
|
24
|
+
from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus
|
25
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
26
|
+
from monarch.common.mast import MastJob
|
27
|
+
from monarch.common.shape import NDSlice
|
28
|
+
from monarch.controller.rust_backend.controller import RustController
|
29
|
+
|
30
|
+
TORCHX_MAST_TASK_GROUP_NAME = "script"
|
31
|
+
|
32
|
+
logger: Logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
# A world tuple contains a worker world name and a controller actor id
|
35
|
+
# The pair forms a functional world that can be used to create a device mesh
|
36
|
+
MeshWorld = tuple[str, ActorId]
|
37
|
+
|
38
|
+
# Taken from //monarch/controller/src/bootstrap.rs
|
39
|
+
WORLD_WORKER_LABEL = "world.monarch.meta.com/worker"
|
40
|
+
WORLD_CONTROLLER_LABEL = "world.monarch.meta.com/controllerActorId"
|
41
|
+
WORLD_CONTROLLER_IP = "world.monarch.meta.com/ip_addr"
|
42
|
+
|
43
|
+
|
44
|
+
class IBootstrap(Protocol):
|
45
|
+
def get_mesh_worlds(self) -> list[MeshWorld]:
|
46
|
+
"""Returns the list of mesh worlds."""
|
47
|
+
...
|
48
|
+
|
49
|
+
def kill_mesh(self, mesh_world: MeshWorld) -> None:
|
50
|
+
"""Kills a mesh in a bootstrap instance."""
|
51
|
+
...
|
52
|
+
|
53
|
+
def spawn_mesh(self, mesh_world: MeshWorld) -> None:
|
54
|
+
"""Spawns a mesh in a bootstrap instance."""
|
55
|
+
...
|
56
|
+
|
57
|
+
|
58
|
+
class IPoolDeviceMeshProvider(Protocol):
|
59
|
+
def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh:
|
60
|
+
raise NotImplementedError()
|
61
|
+
|
62
|
+
|
63
|
+
class PoolDeviceMeshProvider:
|
64
|
+
"""
|
65
|
+
Given a client actor, the device mesh provider discovers and keeps track of
|
66
|
+
the world status and provides a device mesh given a healthy world.
|
67
|
+
"""
|
68
|
+
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
hosts: int,
|
72
|
+
gpus: int,
|
73
|
+
proc: Proc,
|
74
|
+
) -> None:
|
75
|
+
self._hosts = hosts
|
76
|
+
self._gpus = gpus
|
77
|
+
self._mesh_map: dict[MeshWorld, DeviceMesh | None] = {}
|
78
|
+
self._proc = proc
|
79
|
+
# Root client is not used to create device meshes.
|
80
|
+
# It is only used to pull the world status.
|
81
|
+
self._root_client: ClientActor = ClientActor(
|
82
|
+
proc=self._proc,
|
83
|
+
actor_name="root_client", # The client name really doesn't matter
|
84
|
+
)
|
85
|
+
|
86
|
+
def new_mesh(self, timeout_in_sec: Optional[int] = None) -> DeviceMesh:
|
87
|
+
"""
|
88
|
+
Creates a new device mesh based on the current world status.
|
89
|
+
If no healthy world is found, the call will block until a healthy world is found
|
90
|
+
or timeout_in_sec is reached.xtimeout_in_sec being None indicates no timeout.
|
91
|
+
"""
|
92
|
+
|
93
|
+
logger.info("Trying to allocate a new mesh in its desired world...")
|
94
|
+
|
95
|
+
def _create_exit(
|
96
|
+
client: Client,
|
97
|
+
) -> Callable[[Optional[RemoteException | DeviceException | Exception]], None]:
|
98
|
+
def _exit(
|
99
|
+
error: Optional[RemoteException | DeviceException | Exception] = None,
|
100
|
+
) -> None:
|
101
|
+
client.shutdown(True, error)
|
102
|
+
|
103
|
+
return _exit
|
104
|
+
|
105
|
+
def _is_world_healthy(world_status: dict[str, str], target_world: str) -> bool:
|
106
|
+
return (
|
107
|
+
target_world in world_status
|
108
|
+
and DeviceMeshStatus(world_status[target_world])
|
109
|
+
== DeviceMeshStatus.LIVE
|
110
|
+
)
|
111
|
+
|
112
|
+
now = time.time()
|
113
|
+
while timeout_in_sec is None or time.time() - now < timeout_in_sec:
|
114
|
+
# Pull the fresh world status
|
115
|
+
self._refresh_worlds()
|
116
|
+
world_status = self._root_client.world_status()
|
117
|
+
self._remove_evicted_worlds(world_status)
|
118
|
+
|
119
|
+
# Find the next available world
|
120
|
+
for mesh_world, mesh in self._mesh_map.items():
|
121
|
+
if mesh is not None:
|
122
|
+
# Mesh has been allocated to this world, skip
|
123
|
+
continue
|
124
|
+
|
125
|
+
worker_world, controller_id = mesh_world
|
126
|
+
controller_world = controller_id.world_name
|
127
|
+
|
128
|
+
if (not _is_world_healthy(world_status, worker_world)) or (
|
129
|
+
not _is_world_healthy(world_status, controller_world)
|
130
|
+
):
|
131
|
+
# Either controller world is not ready or worker world is not ready
|
132
|
+
continue
|
133
|
+
|
134
|
+
# Create a new device mesh
|
135
|
+
backend_ctrl = RustController(
|
136
|
+
proc=self._proc,
|
137
|
+
client_actor=ClientActor.new_with_parent(
|
138
|
+
self._proc, self._root_client.actor_id
|
139
|
+
),
|
140
|
+
controller_id=controller_id,
|
141
|
+
worker_world_name=worker_world,
|
142
|
+
)
|
143
|
+
client = Client(backend_ctrl, self._hosts * self._gpus, self._gpus)
|
144
|
+
|
145
|
+
# TODO: we need to consider hosts and gpus constraints as well
|
146
|
+
dm = DeviceMesh(
|
147
|
+
client,
|
148
|
+
NDSlice(
|
149
|
+
offset=0,
|
150
|
+
sizes=[self._hosts, self._gpus],
|
151
|
+
strides=[self._gpus, 1],
|
152
|
+
),
|
153
|
+
("host", "gpu"),
|
154
|
+
worker_world,
|
155
|
+
)
|
156
|
+
dm.exit = _create_exit(client)
|
157
|
+
self._mesh_map[mesh_world] = dm
|
158
|
+
|
159
|
+
logger.info("Mesh successfully allocated in world: %s", worker_world)
|
160
|
+
|
161
|
+
return dm
|
162
|
+
|
163
|
+
# TODO(T216841374): Change to healthy world push based checks
|
164
|
+
sleep_sec = 0.05
|
165
|
+
logger.debug(f"No healthy world found, sleeping for {sleep_sec}s...")
|
166
|
+
time.sleep(sleep_sec)
|
167
|
+
|
168
|
+
raise TimeoutError(f"Could not find a healthy world in {timeout_in_sec}s!")
|
169
|
+
|
170
|
+
def _refresh_worlds(self) -> None:
|
171
|
+
system_snapshot = self._root_client.world_state(
|
172
|
+
filter=SystemSnapshotFilter(world_labels={WORLD_WORKER_LABEL: "1"})
|
173
|
+
)
|
174
|
+
for world_id, world_snapshot in system_snapshot.items():
|
175
|
+
if WORLD_CONTROLLER_LABEL not in world_snapshot.labels:
|
176
|
+
continue
|
177
|
+
controller_actor_id = ActorId.from_string(
|
178
|
+
world_snapshot.labels[WORLD_CONTROLLER_LABEL]
|
179
|
+
)
|
180
|
+
world_tuple = (world_id, controller_actor_id)
|
181
|
+
if world_tuple not in self._mesh_map:
|
182
|
+
logger.debug(f"Discovered new worker world {world_id}")
|
183
|
+
self._mesh_map[world_tuple] = None
|
184
|
+
|
185
|
+
def _remove_evicted_worlds(self, world_status: dict[str, str]) -> None:
|
186
|
+
"""
|
187
|
+
Go through the mesh map and remove the world that has already been evicted by the system.
|
188
|
+
"""
|
189
|
+
mesh_worlds_to_remove = []
|
190
|
+
for mesh_world, _ in self._mesh_map.items():
|
191
|
+
worker_world, controller_id = mesh_world
|
192
|
+
controller_world = controller_id.world_name
|
193
|
+
|
194
|
+
if (
|
195
|
+
world_status.get(worker_world) is None
|
196
|
+
or world_status.get(controller_world) is None
|
197
|
+
):
|
198
|
+
logger.debug(f"Removing Evicted world {mesh_world}")
|
199
|
+
mesh_worlds_to_remove.append(mesh_world)
|
200
|
+
|
201
|
+
for mesh_world in mesh_worlds_to_remove:
|
202
|
+
self._mesh_map.pop(mesh_world)
|
203
|
+
|
204
|
+
|
205
|
+
def rust_mast_mesh(
|
206
|
+
job_name: str, system_port: int = 29500, **kwargs: Any
|
207
|
+
) -> DeviceMesh:
|
208
|
+
job = MastJob(job_name, TORCHX_MAST_TASK_GROUP_NAME)
|
209
|
+
if not job.is_running():
|
210
|
+
job.wait_for_running(10 * 60)
|
211
|
+
hostnames = job.get_hostnames()
|
212
|
+
system_addr = f"metatls!{hostnames[0]}.facebook.com:{system_port}"
|
213
|
+
return rust_backend_mesh(
|
214
|
+
system_addr,
|
215
|
+
**kwargs,
|
216
|
+
)
|
217
|
+
|
218
|
+
|
219
|
+
def rust_backend_mesh(
|
220
|
+
system_addr: str,
|
221
|
+
hosts: int,
|
222
|
+
gpus: int,
|
223
|
+
) -> DeviceMesh:
|
224
|
+
dms = rust_backend_meshes(
|
225
|
+
system_addr,
|
226
|
+
hosts,
|
227
|
+
gpus,
|
228
|
+
requested_meshes=1,
|
229
|
+
)
|
230
|
+
assert len(dms) == 1
|
231
|
+
return dms[0]
|
232
|
+
|
233
|
+
|
234
|
+
def rust_backend_meshes(
|
235
|
+
system_addr: str,
|
236
|
+
hosts: int,
|
237
|
+
gpus: int,
|
238
|
+
requested_meshes: int = 1,
|
239
|
+
) -> list[DeviceMesh]:
|
240
|
+
"""
|
241
|
+
Given system system_addr, discover worlds registered and create a device mesh per
|
242
|
+
world with hosts and gpus. The call will block until requested_meshes
|
243
|
+
are discovered and created, or 1200s timeout is reached.
|
244
|
+
Args:
|
245
|
+
system_addr: the system address to connect to.
|
246
|
+
hosts: number of hosts to create the device mesh with.
|
247
|
+
gpus: number of gpus to create the device mesh with.
|
248
|
+
requested_meshes: the minimum number of meshes to create.
|
249
|
+
"""
|
250
|
+
mesh_provider = rust_backend_mesh_provider(system_addr, hosts, gpus)
|
251
|
+
dms: list[DeviceMesh] = []
|
252
|
+
|
253
|
+
# Given a client actor and a list of world names, wait for all the worlds to be ready.
|
254
|
+
max_timeout_in_sec = 1200
|
255
|
+
start_time = time.time()
|
256
|
+
while True:
|
257
|
+
if time.time() - start_time > max_timeout_in_sec:
|
258
|
+
raise TimeoutError(
|
259
|
+
f"Timeout ({max_timeout_in_sec} sec) waiting for all worlds to be ready."
|
260
|
+
)
|
261
|
+
mesh = mesh_provider.new_mesh()
|
262
|
+
dms.append(mesh)
|
263
|
+
if len(dms) == requested_meshes:
|
264
|
+
return dms
|
265
|
+
|
266
|
+
|
267
|
+
def rust_backend_mesh_provider(
|
268
|
+
system_addr: str,
|
269
|
+
hosts: int,
|
270
|
+
gpus: int,
|
271
|
+
client_proc_id: str = "client[0]",
|
272
|
+
# pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type.
|
273
|
+
) -> PoolDeviceMeshProvider:
|
274
|
+
proc: Proc = init_proc(
|
275
|
+
proc_id=client_proc_id,
|
276
|
+
bootstrap_addr=system_addr,
|
277
|
+
timeout=5,
|
278
|
+
supervision_update_interval=5,
|
279
|
+
)
|
280
|
+
return PoolDeviceMeshProvider(hosts, gpus, proc)
|