torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
monarch/profiler.py
ADDED
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import itertools
|
9
|
+
import os
|
10
|
+
from dataclasses import dataclass
|
11
|
+
from functools import partial
|
12
|
+
from pathlib import Path
|
13
|
+
from typing import Any, Dict, NamedTuple, Optional, Tuple
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from monarch.common.remote import remote
|
17
|
+
from monarch.remote_class import ControllerRemoteClass, WorkerRemoteClass
|
18
|
+
|
19
|
+
|
20
|
+
class Schedule(NamedTuple):
|
21
|
+
wait: int
|
22
|
+
warmup: int
|
23
|
+
active: int
|
24
|
+
repeat: int = 0
|
25
|
+
skip_first: int = 0
|
26
|
+
|
27
|
+
|
28
|
+
class profile:
|
29
|
+
"""
|
30
|
+
The class wraps `torch.profiler.profile()` to allow invoking the profiler remotely.
|
31
|
+
There are two main differences:
|
32
|
+
1) `on_trace_ready` can only be a string, indicating the folder where the traces
|
33
|
+
will be saved.
|
34
|
+
2) `schedule` must be of type `monarch.profiler.Schedule`.
|
35
|
+
"""
|
36
|
+
|
37
|
+
PATH_KEY = "on_trace_ready"
|
38
|
+
_counter = itertools.count()
|
39
|
+
|
40
|
+
def __init__(self, *args, **kwargs) -> None:
|
41
|
+
assert isinstance(kwargs.get(self.PATH_KEY, None), str), (
|
42
|
+
f"{self.PATH_KEY} must be passed and must be a string to represent the "
|
43
|
+
"path to save the profiler."
|
44
|
+
)
|
45
|
+
schedule = kwargs.get("schedule", None)
|
46
|
+
assert (
|
47
|
+
isinstance(schedule, Schedule) or schedule is None
|
48
|
+
), "schedule can only be monarch.profiler.Schedule or None."
|
49
|
+
self.id = next(self._counter)
|
50
|
+
_profiler_controller_init(self.id, *args, **kwargs)
|
51
|
+
|
52
|
+
def __enter__(self) -> "profile":
|
53
|
+
_profiler_controller_enter(self.id)
|
54
|
+
return self
|
55
|
+
|
56
|
+
def __exit__(self, *args, **kwargs) -> None:
|
57
|
+
_profiler_controller_exit(self.id)
|
58
|
+
|
59
|
+
def step(self) -> None:
|
60
|
+
_profiler_controller_step(self.id)
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass
|
64
|
+
class _Profiler:
|
65
|
+
args: Tuple[Any, ...]
|
66
|
+
kwargs: Dict[str, Any]
|
67
|
+
profiler: Optional[torch.profiler.profile] = None
|
68
|
+
|
69
|
+
|
70
|
+
_profilers: Dict[int, _Profiler] = {}
|
71
|
+
|
72
|
+
|
73
|
+
def _profiler_init(ident, *args, **kwargs) -> None:
|
74
|
+
global _profilers
|
75
|
+
assert (
|
76
|
+
ident not in _profilers
|
77
|
+
), f"Initializing an already existing profiler, {ident=}"
|
78
|
+
_profilers[ident] = _Profiler(args, kwargs)
|
79
|
+
# It's unclear why we cannot create the profiler here. Even though
|
80
|
+
# the thread is the same, profiler complains thread id mismatch.
|
81
|
+
|
82
|
+
|
83
|
+
def _profiler_enter(ident, *args, **kwargs) -> None:
|
84
|
+
def on_trace_ready(prof, dir_path):
|
85
|
+
dir_path = Path(dir_path).absolute()
|
86
|
+
os.makedirs(dir_path, exist_ok=True)
|
87
|
+
# This is not a synchronized call, so it is okay to call without
|
88
|
+
# device mesh.
|
89
|
+
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
|
90
|
+
prof.export_chrome_trace(f"{dir_path}/trace_{rank}.json")
|
91
|
+
|
92
|
+
profiler = _profilers[ident]
|
93
|
+
profiler.kwargs[profile.PATH_KEY] = partial(
|
94
|
+
on_trace_ready, dir_path=profiler.kwargs[profile.PATH_KEY]
|
95
|
+
)
|
96
|
+
schedule = profiler.kwargs.get("schedule", None)
|
97
|
+
if schedule is not None:
|
98
|
+
profiler.kwargs["schedule"] = torch.profiler.schedule(**schedule._asdict())
|
99
|
+
profiler.profiler = torch.profiler.profile(*profiler.args, **profiler.kwargs)
|
100
|
+
|
101
|
+
profiler.profiler.__enter__()
|
102
|
+
|
103
|
+
|
104
|
+
def _profiler_exit(ident, *args, **kwargs) -> None:
|
105
|
+
profiler = _profilers[ident].profiler
|
106
|
+
assert profiler is not None
|
107
|
+
profiler.__exit__(None, None, None)
|
108
|
+
_profilers.pop(ident)
|
109
|
+
|
110
|
+
|
111
|
+
def _profiler_step(ident, *args, **kwargs) -> None:
|
112
|
+
profiler = _profilers[ident].profiler
|
113
|
+
assert profiler is not None
|
114
|
+
profiler.step()
|
115
|
+
|
116
|
+
|
117
|
+
_profiler_controller_init = remote(
|
118
|
+
"monarch.profiler._profiler_init", propagate="inspect"
|
119
|
+
)
|
120
|
+
|
121
|
+
_profiler_controller_enter = remote(
|
122
|
+
"monarch.profiler._profiler_enter", propagate="inspect"
|
123
|
+
)
|
124
|
+
|
125
|
+
_profiler_controller_exit = remote(
|
126
|
+
"monarch.profiler._profiler_exit", propagate="inspect"
|
127
|
+
)
|
128
|
+
|
129
|
+
_profiler_controller_step = remote(
|
130
|
+
"monarch.profiler._profiler_step", propagate="inspect"
|
131
|
+
)
|
132
|
+
|
133
|
+
|
134
|
+
class record_function(ControllerRemoteClass):
|
135
|
+
"""
|
136
|
+
The class wraps `torch.profiler.record_function()` to allow invoking the
|
137
|
+
record_function remotely.
|
138
|
+
"""
|
139
|
+
|
140
|
+
def __init__(self, name: str, args: Optional[str] = None) -> None:
|
141
|
+
super().__init__("monarch.profiler.WorkerRecordFunction", name, args)
|
142
|
+
|
143
|
+
@ControllerRemoteClass.remote_method
|
144
|
+
def __enter__(self) -> "record_function":
|
145
|
+
return self
|
146
|
+
|
147
|
+
@ControllerRemoteClass.remote_method
|
148
|
+
def __exit__(self, *args, **kwargs) -> None:
|
149
|
+
return
|
150
|
+
|
151
|
+
|
152
|
+
class WorkerRecordFunction(WorkerRemoteClass):
|
153
|
+
def __init__(self, *args, **kwargs) -> None:
|
154
|
+
self._record_function = torch.profiler.record_function(*args, **kwargs)
|
155
|
+
|
156
|
+
def __enter__(self) -> None:
|
157
|
+
self._record_function.__enter__()
|
158
|
+
|
159
|
+
def __exit__(self, *args, **kwargs) -> None:
|
160
|
+
self._record_function.__exit__(*args, **kwargs)
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import os
|
9
|
+
import subprocess
|
10
|
+
from time import sleep
|
11
|
+
from typing import Optional, TYPE_CHECKING
|
12
|
+
|
13
|
+
import monarch_supervisor
|
14
|
+
from monarch.common._device_utils import _local_device_count
|
15
|
+
from monarch.common.fake import fake_call
|
16
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
17
|
+
from monarch.world_mesh import world_mesh
|
18
|
+
from monarch_supervisor import Context, HostConnected
|
19
|
+
from monarch_supervisor.python_executable import PYTHON_EXECUTABLE
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from monarch.common.device_mesh import DeviceMesh
|
23
|
+
|
24
|
+
|
25
|
+
class PythonLocalContext:
|
26
|
+
def __init__(self, N: int):
|
27
|
+
# do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later
|
28
|
+
fake_call(lambda: 0)
|
29
|
+
|
30
|
+
self.ctx = ctx = Context()
|
31
|
+
ctx.request_hosts(N)
|
32
|
+
|
33
|
+
# we want ctx to start its listener threads
|
34
|
+
# before creating the hosts because
|
35
|
+
# initialization will happen faster in this case
|
36
|
+
sleep(0)
|
37
|
+
supervisor_addr = f"tcp://127.0.0.1:{ctx.port}"
|
38
|
+
|
39
|
+
env = {
|
40
|
+
**os.environ,
|
41
|
+
"TORCH_SUPERVISOR_HEARTBEAT_INTERVAL": str(
|
42
|
+
monarch_supervisor.HEARTBEAT_INTERVAL
|
43
|
+
),
|
44
|
+
# This is needed to avoid a hard failure in ncclx when we do not
|
45
|
+
# have backend topology info (eg. on RE).
|
46
|
+
"NCCL_IGNORE_TOPO_LOAD_FAILURE": "true",
|
47
|
+
}
|
48
|
+
|
49
|
+
# start_new_session=True, because we want the host managers to be able to kill
|
50
|
+
# any worker processes before they exit, even if the supervisor crashes, or we ctrl-c
|
51
|
+
# it in testing.
|
52
|
+
self.host_managers = [
|
53
|
+
subprocess.Popen(
|
54
|
+
[
|
55
|
+
PYTHON_EXECUTABLE,
|
56
|
+
"-m",
|
57
|
+
"monarch_supervisor.host",
|
58
|
+
supervisor_addr,
|
59
|
+
],
|
60
|
+
env=env,
|
61
|
+
start_new_session=True,
|
62
|
+
)
|
63
|
+
for _ in range(N)
|
64
|
+
]
|
65
|
+
connections = ctx.messagefilter(HostConnected)
|
66
|
+
self.hosts = [connections.recv(timeout=30).sender for _ in range(N)]
|
67
|
+
|
68
|
+
def shutdown(self):
|
69
|
+
self.ctx.shutdown()
|
70
|
+
for host_manager in self.host_managers:
|
71
|
+
host_manager.wait(timeout=10)
|
72
|
+
|
73
|
+
|
74
|
+
def python_local_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> "DeviceMesh":
|
75
|
+
"""
|
76
|
+
Creates a local device mesh with the given number of hosts and gpus per host.
|
77
|
+
Easy way to use PythonLocalContext.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
gpus (Optional[int]): number of gpus per host.
|
81
|
+
Default: the number of GPUs this machine has.
|
82
|
+
|
83
|
+
hosts (int): number of hosts, primarily used for simulating multiple machines locally.
|
84
|
+
Default: 1
|
85
|
+
|
86
|
+
Example::
|
87
|
+
local_mesh = python_local_mesh(gpus=2)
|
88
|
+
with local_mesh.activate():
|
89
|
+
x = torch.rand(3, 4)
|
90
|
+
local_tensor = fetch_shard(x).result()
|
91
|
+
|
92
|
+
# Cleanly shut down the local mesh and exit.
|
93
|
+
local_mesh.exit()
|
94
|
+
"""
|
95
|
+
ctx = PythonLocalContext(hosts)
|
96
|
+
if gpus is None:
|
97
|
+
gpus = _local_device_count()
|
98
|
+
dm = world_mesh(ctx.ctx, ctx.hosts, gpus)
|
99
|
+
|
100
|
+
def exit(
|
101
|
+
error: Optional[RemoteException | DeviceException | Exception] = None,
|
102
|
+
) -> None:
|
103
|
+
dm.client.shutdown(True, error)
|
104
|
+
ctx.shutdown()
|
105
|
+
|
106
|
+
dm.exit = exit
|
107
|
+
return dm
|
monarch/random.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import os
|
8
|
+
from typing import NamedTuple, Tuple
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from monarch.common.remote import remote
|
12
|
+
from monarch.common.tensor import Tensor
|
13
|
+
|
14
|
+
|
15
|
+
class State(NamedTuple):
|
16
|
+
cpu: Tensor
|
17
|
+
cuda: Tensor
|
18
|
+
|
19
|
+
|
20
|
+
@remote(
|
21
|
+
propagate=lambda: (
|
22
|
+
torch.empty(5056, dtype=torch.uint8),
|
23
|
+
torch.empty(16, dtype=torch.uint8),
|
24
|
+
)
|
25
|
+
)
|
26
|
+
def _get_state() -> Tuple[torch.Tensor, torch.Tensor]:
|
27
|
+
return (torch.get_rng_state(), torch.cuda.get_rng_state())
|
28
|
+
|
29
|
+
|
30
|
+
@remote(propagate=lambda state: None)
|
31
|
+
def set_state(state: Tuple[Tensor, Tensor]):
|
32
|
+
cpu, device = state
|
33
|
+
torch.set_rng_state(cpu)
|
34
|
+
torch.cuda.set_rng_state(device)
|
35
|
+
|
36
|
+
|
37
|
+
@remote(propagate=lambda _: None)
|
38
|
+
def _manual_seed(seed: torch.Tensor):
|
39
|
+
torch.manual_seed(seed.item())
|
40
|
+
|
41
|
+
|
42
|
+
@remote(propagate=lambda: None)
|
43
|
+
def make_deterministic():
|
44
|
+
torch.use_deterministic_algorithms(True)
|
45
|
+
torch.backends.cudnn.deterministic = True
|
46
|
+
torch.backends.cudnn.benchmark = False
|
47
|
+
# env var for deterministic CuBLAS
|
48
|
+
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
|
49
|
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
50
|
+
|
51
|
+
|
52
|
+
def get_state() -> State:
|
53
|
+
return State(*_get_state())
|
54
|
+
|
55
|
+
|
56
|
+
def new_state(seed: Tensor) -> State:
|
57
|
+
orig = get_state()
|
58
|
+
_manual_seed(seed)
|
59
|
+
mine = get_state()
|
60
|
+
set_state(orig)
|
61
|
+
return mine
|
monarch/rdma.py
ADDED
@@ -0,0 +1,162 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import ctypes
|
8
|
+
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from typing import cast, Dict, Optional, Tuple
|
11
|
+
|
12
|
+
import torch
|
13
|
+
|
14
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ActorId
|
15
|
+
|
16
|
+
from monarch.actor_mesh import (
|
17
|
+
_ActorMeshRefImpl,
|
18
|
+
Actor,
|
19
|
+
ActorMeshRef,
|
20
|
+
endpoint,
|
21
|
+
MonarchContext,
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class LocalRDMARecord:
|
27
|
+
data: torch.Tensor
|
28
|
+
|
29
|
+
|
30
|
+
_local_buffers: Dict[int, "LocalRDMARecord"] = {}
|
31
|
+
|
32
|
+
|
33
|
+
def _get_bytes(storage: torch.Tensor, offset: int, size: int) -> bytearray:
|
34
|
+
"""Extracts a bytearray from a 1D, 1byte per item tensor."""
|
35
|
+
if offset + size > storage.numel():
|
36
|
+
raise ValueError(f"Read out of range: {offset + size} > {storage.size()}")
|
37
|
+
addr = storage.data_ptr()
|
38
|
+
if storage.device.type != "cpu":
|
39
|
+
result = bytearray(size)
|
40
|
+
result_tensor = torch.frombuffer(
|
41
|
+
result,
|
42
|
+
dtype=torch.uint8,
|
43
|
+
)
|
44
|
+
source_tensor = storage[offset:]
|
45
|
+
result_tensor.copy_(source_tensor)
|
46
|
+
else:
|
47
|
+
ctypes_array = (ctypes.c_byte * size).from_address(addr)
|
48
|
+
result = bytearray(ctypes_array)
|
49
|
+
return result
|
50
|
+
|
51
|
+
|
52
|
+
class RDMAManager(Actor):
|
53
|
+
@staticmethod
|
54
|
+
def on_proc(proc_id: str) -> "RDMAManager":
|
55
|
+
ctx = MonarchContext.get()
|
56
|
+
return cast(
|
57
|
+
RDMAManager,
|
58
|
+
ActorMeshRef(
|
59
|
+
RDMAManager,
|
60
|
+
_ActorMeshRefImpl.from_actor_id(
|
61
|
+
ctx.mailbox,
|
62
|
+
ActorId.from_string(f"{proc_id}.rdma_manager[0]"),
|
63
|
+
),
|
64
|
+
ctx.mailbox,
|
65
|
+
),
|
66
|
+
)
|
67
|
+
|
68
|
+
@endpoint
|
69
|
+
async def drop(self, addr: int) -> None:
|
70
|
+
if addr in _local_buffers:
|
71
|
+
del _local_buffers[addr]
|
72
|
+
|
73
|
+
@endpoint
|
74
|
+
async def fetch(self, addr: int, offset: int, nbytes: int) -> bytearray:
|
75
|
+
if addr not in _local_buffers:
|
76
|
+
raise ValueError(f"Unknown buffer {addr}")
|
77
|
+
storage = _local_buffers[addr].data
|
78
|
+
return _get_bytes(storage, offset, nbytes)
|
79
|
+
|
80
|
+
@endpoint
|
81
|
+
async def put(self, addr: int, offset: int, bytes: bytearray) -> None:
|
82
|
+
if addr not in _local_buffers:
|
83
|
+
raise ValueError(f"Unknown buffer {addr}")
|
84
|
+
storage = _local_buffers[addr].data
|
85
|
+
storage[offset : offset + len(bytes)] = torch.frombuffer(
|
86
|
+
bytes, dtype=storage.dtype
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def _assert_tensor_is_1d_contiguous_uint8(t: torch.Tensor) -> None:
|
91
|
+
if t.ndim != 1:
|
92
|
+
raise ValueError(f"Tensor must be 1D, got {t.ndim}D")
|
93
|
+
if t.dtype != torch.uint8:
|
94
|
+
raise ValueError(f"Tensor must be uint8, got {t.dtype}")
|
95
|
+
if not t.is_contiguous():
|
96
|
+
raise ValueError("Tensor must be contiguous")
|
97
|
+
|
98
|
+
|
99
|
+
class RDMABuffer:
|
100
|
+
def __init__(self, data: torch.Tensor) -> None:
|
101
|
+
"""
|
102
|
+
RDMABuffer only supports 1D contiguous tensors that are 1 byte per item.
|
103
|
+
|
104
|
+
To create a 1 byte, 1D view, use t.view(torch.uint8).flatten()
|
105
|
+
|
106
|
+
TODO: Create TensorBuffer, which will be main user API supporting non-contiguous , multi-byte-per-elment tensors
|
107
|
+
"""
|
108
|
+
_assert_tensor_is_1d_contiguous_uint8(data)
|
109
|
+
assert data.storage_offset() == 0
|
110
|
+
storage = data.untyped_storage()
|
111
|
+
self.addr: int = storage.data_ptr()
|
112
|
+
self.begin = 0
|
113
|
+
self.end: int = storage.size()
|
114
|
+
self.proc_id: str = MonarchContext.get().proc_id
|
115
|
+
self.local_data: object = None
|
116
|
+
_local_buffers[self.addr] = LocalRDMARecord(data)
|
117
|
+
|
118
|
+
def drop(self) -> None:
|
119
|
+
if self.proc_id is None:
|
120
|
+
del _local_buffers[self.addr]
|
121
|
+
return
|
122
|
+
rmda_actor = RDMAManager.on_proc(self.proc_id)
|
123
|
+
# pyre-ignore[16]: Undefined attribute [16]: `Endpoint` has no attribute `cast`.
|
124
|
+
rmda_actor.drop.cast(self.addr)
|
125
|
+
|
126
|
+
def __getstate__(self) -> Tuple[int, int, int, Optional[str]]:
|
127
|
+
proc_id = self.proc_id
|
128
|
+
# locally created RDMABuffer being set remotely,
|
129
|
+
# record its proc_id so we know how to establish connections to it
|
130
|
+
if proc_id is None:
|
131
|
+
proc_id = MonarchContext.get().proc_id
|
132
|
+
return (self.addr, self.begin, self.end, proc_id)
|
133
|
+
|
134
|
+
def __setstate__(self, state: Tuple[int, int, int, str]) -> None:
|
135
|
+
self.local_data = None
|
136
|
+
self.addr, self.begin, self.end, self.proc_id = state
|
137
|
+
|
138
|
+
async def read_into(self, dst: torch.Tensor, offset: int = 0) -> None:
|
139
|
+
"""
|
140
|
+
Read data from the RDMABuffer into a destination tensor.
|
141
|
+
|
142
|
+
The destination tensor must be contiguous and 1 byte per item.
|
143
|
+
"""
|
144
|
+
_assert_tensor_is_1d_contiguous_uint8(dst)
|
145
|
+
bytes = await RDMAManager.on_proc(self.proc_id).fetch.call_one(
|
146
|
+
self.addr, offset, dst.numel()
|
147
|
+
)
|
148
|
+
dst.copy_(torch.frombuffer(bytes, dtype=torch.uint8))
|
149
|
+
|
150
|
+
async def write(self, src: torch.Tensor, offset: int = 0) -> None:
|
151
|
+
"""
|
152
|
+
Write data from a source tensor into the RDMABuffer.
|
153
|
+
|
154
|
+
The source tensor must be contiguous and 1 byte per item.
|
155
|
+
"""
|
156
|
+
_assert_tensor_is_1d_contiguous_uint8(src)
|
157
|
+
bytes = _get_bytes(
|
158
|
+
src,
|
159
|
+
cast(int, src.storage_offset()),
|
160
|
+
src.numel(),
|
161
|
+
)
|
162
|
+
await RDMAManager.on_proc(self.proc_id).put.call_one(self.addr, offset, bytes)
|
monarch/remote_class.py
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import importlib
|
9
|
+
import itertools
|
10
|
+
from typing import Any, Dict
|
11
|
+
|
12
|
+
from monarch.common import device_mesh
|
13
|
+
from monarch.common.remote import remote
|
14
|
+
|
15
|
+
|
16
|
+
class ControllerRemoteClass:
|
17
|
+
"""
|
18
|
+
This class simplifies the creation and management of remote classes. It serves as
|
19
|
+
the controller side of a remote class architecture. Classes that are intended to be
|
20
|
+
controlled remotely should inherit from this class. The constructor of the inheriting
|
21
|
+
class must invoke `super().__init__()` with the path to the remote class that will be
|
22
|
+
used on the worker nodes. Methods that are intended for remote execution must be
|
23
|
+
decorated with `ControllerRemoteClass.remote_method`.
|
24
|
+
|
25
|
+
Note: This class is designed for use by the controller developer only and should
|
26
|
+
not be directly used in model code.
|
27
|
+
|
28
|
+
Example usage:
|
29
|
+
|
30
|
+
class ControllerMyClass(ControllerRemoteClass):
|
31
|
+
def __init__(self, *args, **kwargs) -> None:
|
32
|
+
super().__init__("my_package.my_class", *args, **kwargs)
|
33
|
+
|
34
|
+
@ControllerRemoteClass.remote_method
|
35
|
+
def some_method(self, *args, **kwargs) -> None:
|
36
|
+
# This method is intended for remote execution and does nothing locally.
|
37
|
+
pass
|
38
|
+
"""
|
39
|
+
|
40
|
+
_counter = itertools.count()
|
41
|
+
|
42
|
+
def __init__(self, cls_path: str, *args, **kwargs) -> None:
|
43
|
+
self.ident = next(ControllerRemoteClass._counter)
|
44
|
+
self.cls_path = cls_path
|
45
|
+
self.mesh = device_mesh._active
|
46
|
+
_controller_remote_class_method(
|
47
|
+
cls_path, "remote_init", self.ident, *args, **kwargs
|
48
|
+
)
|
49
|
+
|
50
|
+
def __del__(self) -> None:
|
51
|
+
mesh = getattr(self, "mesh", None)
|
52
|
+
if mesh is not None and not mesh.client._shutdown:
|
53
|
+
with self.mesh.activate():
|
54
|
+
_controller_remote_class_method(
|
55
|
+
self.cls_path,
|
56
|
+
"remote_del",
|
57
|
+
self.ident,
|
58
|
+
)
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def remote_method(fn):
|
62
|
+
def wrapper(self, *args, **kwargs) -> None:
|
63
|
+
_controller_remote_class_method(
|
64
|
+
self.cls_path, "remote_method", self.ident, fn.__name__, *args, **kwargs
|
65
|
+
)
|
66
|
+
|
67
|
+
return wrapper
|
68
|
+
|
69
|
+
|
70
|
+
# Add the logic as a separate private function instead of adding ita to
|
71
|
+
# ResolvableFunctionFromPath. This avoids users to using this directly.
|
72
|
+
_controller_remote_class_method = remote(
|
73
|
+
"monarch.remote_class._remote_class_method", propagate="inspect"
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
def _remote_class_method(cls_path: str, method_name: str, *args, **kwargs) -> None:
|
78
|
+
modulename, classname = cls_path.rsplit(".", 1)
|
79
|
+
module = importlib.import_module(modulename)
|
80
|
+
cls = getattr(module, classname)
|
81
|
+
method = getattr(cls, method_name)
|
82
|
+
method(*args, **kwargs)
|
83
|
+
|
84
|
+
|
85
|
+
class WorkerRemoteClass:
|
86
|
+
"""
|
87
|
+
This class is designed to be used alongside ``ControllerRemoteClass`` and represents
|
88
|
+
the worker-side of a remote class architecture. Instances of this class should just
|
89
|
+
mimic standard Python classes, with the notable exception that all methods must
|
90
|
+
return None -- the current RemoteClass architecture does not support methods that
|
91
|
+
return values.
|
92
|
+
|
93
|
+
The `ident` attribute is used for tracking object instances created via `remote_init`.
|
94
|
+
This tracking is necessary because the remote function would otherwise lose the
|
95
|
+
reference to the object.
|
96
|
+
"""
|
97
|
+
|
98
|
+
_objects: Dict[int, Any] = {}
|
99
|
+
|
100
|
+
@classmethod
|
101
|
+
def remote_init(cls, ident: int, *args, **kwargs) -> None:
|
102
|
+
WorkerRemoteClass._objects[ident] = cls(*args, **kwargs)
|
103
|
+
|
104
|
+
@classmethod
|
105
|
+
def remote_del(cls, ident) -> None:
|
106
|
+
WorkerRemoteClass._objects.pop(ident)
|
107
|
+
|
108
|
+
@classmethod
|
109
|
+
def remote_method(cls, ident: int, method_name, *args, **kwargs) -> None:
|
110
|
+
instance = WorkerRemoteClass._objects[ident]
|
111
|
+
assert (
|
112
|
+
cls == instance.__class__
|
113
|
+
), "Mismatched class type {cls} {instance.__class__}"
|
114
|
+
getattr(instance, method_name)(*args, **kwargs)
|