torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,121 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from unittest import main, TestCase
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from monarch.gradient._gradient_generator import GradientGenerator
|
12
|
+
from monarch.gradient_generator import gradient_execution_order
|
13
|
+
|
14
|
+
|
15
|
+
class TestGradIter(TestCase):
|
16
|
+
def checkEqual(self, r, r2):
|
17
|
+
self.assertEqual(len(r), len(r2))
|
18
|
+
for i, i2 in zip(r, r2):
|
19
|
+
self.assertTrue((i is None and i2 is None) or torch.allclose(i, i2))
|
20
|
+
|
21
|
+
def test_simple(self):
|
22
|
+
t = torch.rand(2, requires_grad=True)
|
23
|
+
t2 = torch.rand(2, requires_grad=True)
|
24
|
+
|
25
|
+
_ = t + t2
|
26
|
+
a, b = torch.std_mean(t + t2)
|
27
|
+
|
28
|
+
r2 = torch.autograd.grad([a, b], [t2, t], retain_graph=True)
|
29
|
+
r = list(GradientGenerator([a, b], [t2, t]))
|
30
|
+
print(a, b)
|
31
|
+
print(a.grad_fn, b.grad_fn)
|
32
|
+
|
33
|
+
print(r)
|
34
|
+
self.checkEqual(r, r2)
|
35
|
+
|
36
|
+
def test_pipeline_like(self):
|
37
|
+
t = torch.rand(3, 3, requires_grad=True)
|
38
|
+
|
39
|
+
w1 = torch.rand(3, 2, requires_grad=True)
|
40
|
+
w2 = torch.rand(3, 2, requires_grad=True)
|
41
|
+
w3 = torch.rand(3, 2, requires_grad=True)
|
42
|
+
|
43
|
+
u = torch.rand(3, 2, requires_grad=True)
|
44
|
+
|
45
|
+
_ = u * u
|
46
|
+
|
47
|
+
w4 = torch.rand(2, 3, requires_grad=True)
|
48
|
+
w5 = torch.rand(2, 3, requires_grad=True)
|
49
|
+
w6 = torch.rand(2, 3, requires_grad=True)
|
50
|
+
|
51
|
+
from torch.nn.functional import relu
|
52
|
+
|
53
|
+
a = relu(t @ (w1 @ w4))
|
54
|
+
a = relu(a @ (w2 @ w5))
|
55
|
+
a = relu(a @ (w3 @ w6))
|
56
|
+
|
57
|
+
std, mean = torch.std_mean(a)
|
58
|
+
loss = std + std
|
59
|
+
|
60
|
+
cgrads = torch.autograd.grad(
|
61
|
+
[loss], [t, w3, w6, u, w2, w5], allow_unused=True, retain_graph=True
|
62
|
+
)
|
63
|
+
gi = GradientGenerator([loss], [t, w3, w6, u, w2, w5])
|
64
|
+
grads = [*gi]
|
65
|
+
self.checkEqual(grads, cgrads)
|
66
|
+
|
67
|
+
def test_tree(self):
|
68
|
+
t = torch.rand(3, 3, requires_grad=True)
|
69
|
+
|
70
|
+
t2 = t + t
|
71
|
+
t3 = t * t
|
72
|
+
t4 = t / t
|
73
|
+
t5 = t - t
|
74
|
+
|
75
|
+
t6 = t2 * t3
|
76
|
+
t7 = t4 * t5
|
77
|
+
t8 = t2 * t4
|
78
|
+
t9 = t3 * t5
|
79
|
+
t10 = t6 + t7 + t8 + t9
|
80
|
+
|
81
|
+
t11 = t10.sum()
|
82
|
+
|
83
|
+
cgrads = torch.autograd.grad([t11], [t2, t], retain_graph=True)
|
84
|
+
gi = GradientGenerator([t11], [t2, t])
|
85
|
+
grads = [*gi]
|
86
|
+
self.checkEqual(grads, cgrads)
|
87
|
+
|
88
|
+
def test_broadcast(self):
|
89
|
+
t = torch.rand(3, 3, requires_grad=True)
|
90
|
+
t2 = torch.rand(3, requires_grad=True)
|
91
|
+
t3 = t2 / t2
|
92
|
+
|
93
|
+
r = (t * t3).sum()
|
94
|
+
cgrads = torch.autograd.grad([r], [t, t2], retain_graph=True)
|
95
|
+
gi = GradientGenerator([r], [t, t2])
|
96
|
+
grads = [*gi]
|
97
|
+
self.checkEqual(grads, cgrads)
|
98
|
+
|
99
|
+
def test_grad_order(self):
|
100
|
+
t = torch.rand(3, 3, requires_grad=True)
|
101
|
+
w1 = torch.rand(3, 3, requires_grad=True)
|
102
|
+
w2 = torch.rand(3, 3, requires_grad=True)
|
103
|
+
w3 = torch.rand(3, 3, requires_grad=True)
|
104
|
+
|
105
|
+
u = torch.rand(3, 2, requires_grad=True)
|
106
|
+
_ = u * u
|
107
|
+
from torch.nn.functional import relu
|
108
|
+
|
109
|
+
a = relu(t @ w1)
|
110
|
+
a = relu(a @ w2)
|
111
|
+
a = relu(a @ w3)
|
112
|
+
|
113
|
+
std, mean = torch.std_mean(a)
|
114
|
+
loss = std + std
|
115
|
+
|
116
|
+
order = gradient_execution_order([loss], [w2, w3, w1, a])
|
117
|
+
self.assertEqual(order, [3, 1, 0, 2])
|
118
|
+
|
119
|
+
|
120
|
+
if __name__ == "__main__":
|
121
|
+
main()
|
tests/test_mock_cuda.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
from unittest import main, TestCase
|
9
|
+
|
10
|
+
import pytest
|
11
|
+
import torch
|
12
|
+
import monarch.common.mock_cuda # usort: skip
|
13
|
+
|
14
|
+
|
15
|
+
def simple_forward_backward(device: str) -> None:
|
16
|
+
torch.manual_seed(123)
|
17
|
+
m = torch.nn.Sequential(torch.nn.Linear(3, 3), torch.nn.ReLU()).to(device)
|
18
|
+
x = torch.rand(10, 3).to(device)
|
19
|
+
y = m(x)
|
20
|
+
loss_fn = torch.nn.CrossEntropyLoss()
|
21
|
+
loss = loss_fn(y, torch.randint(3, (10,)).to(device))
|
22
|
+
# Under the hood, enabling/disabling CUDA mocking is done with a thread-local
|
23
|
+
# flag. By default, backward() executes ops on a different thread than the one
|
24
|
+
# we enabled mocking on, which would lead to an invalid memory access. So we need
|
25
|
+
# to disable multithreading for backward.
|
26
|
+
with torch.autograd.set_multithreading_enabled(False):
|
27
|
+
loss.backward()
|
28
|
+
# pyre-ignore: Incompatible return type [7]: Expected `None` but got `Tuple[typing.Any, Union[None, Tensor, Module], Union[None, Tensor, Module]]`.
|
29
|
+
return y, m[0].weight.grad, m[0].bias.grad
|
30
|
+
|
31
|
+
|
32
|
+
# Mock cuda depends on initialization load order
|
33
|
+
# For OSS, run this test separately until it can be run in a subprocess.
|
34
|
+
@pytest.mark.oss_skip
|
35
|
+
class TestMockCuda(TestCase):
|
36
|
+
def setUp(self) -> None:
|
37
|
+
return super().setUp()
|
38
|
+
|
39
|
+
def test_output_is_garbage(self):
|
40
|
+
with monarch.common.mock_cuda.mock_cuda_guard():
|
41
|
+
x = torch.arange(9, device="cuda", dtype=torch.float32).reshape(3, 3)
|
42
|
+
y = 2 * torch.eye(3, device="cuda")
|
43
|
+
true_output = torch.tensor(
|
44
|
+
[[0, 2, 4], [6, 8, 10], [12, 14, 16]], dtype=torch.float32
|
45
|
+
)
|
46
|
+
self.assertFalse(torch.equal((x @ y).cpu(), true_output))
|
47
|
+
|
48
|
+
def test_simple_forward_backward(self):
|
49
|
+
# This test just makes sure that the forward and backward pass work
|
50
|
+
# and don't crash.
|
51
|
+
simple_forward_backward("cuda")
|
52
|
+
|
53
|
+
def test_turn_mock_on_and_off(self):
|
54
|
+
cpu_y, cpu_dw, cpu_db = simple_forward_backward("cpu")
|
55
|
+
|
56
|
+
real_y, real_dw, real_db = simple_forward_backward("cuda")
|
57
|
+
self.assertTrue(torch.allclose(cpu_y, real_y.cpu()))
|
58
|
+
self.assertTrue(torch.allclose(cpu_dw, real_dw.cpu()))
|
59
|
+
self.assertTrue(torch.allclose(cpu_db, real_db.cpu()))
|
60
|
+
|
61
|
+
with monarch.common.mock_cuda.mock_cuda_guard():
|
62
|
+
mocked_y, mocked_dw, mocked_db = simple_forward_backward("cuda")
|
63
|
+
self.assertFalse(torch.allclose(cpu_y, mocked_y.cpu()))
|
64
|
+
self.assertFalse(torch.allclose(cpu_dw, mocked_dw.cpu()))
|
65
|
+
self.assertFalse(torch.allclose(cpu_db, mocked_db.cpu()))
|
66
|
+
|
67
|
+
real_y, real_dw, real_db = simple_forward_backward("cuda")
|
68
|
+
self.assertTrue(torch.allclose(cpu_y, real_y.cpu()))
|
69
|
+
self.assertTrue(torch.allclose(cpu_dw, real_dw.cpu()))
|
70
|
+
self.assertTrue(torch.allclose(cpu_db, real_db.cpu()))
|
71
|
+
|
72
|
+
|
73
|
+
if __name__ == "__main__":
|
74
|
+
main()
|
tests/test_pdb_actor.py
ADDED
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import sys
|
9
|
+
import traceback
|
10
|
+
from contextlib import contextmanager
|
11
|
+
from typing import Generator
|
12
|
+
|
13
|
+
import pytest
|
14
|
+
|
15
|
+
import torch
|
16
|
+
|
17
|
+
from monarch import DeviceMesh, fetch_shard, remote, rust_local_mesh
|
18
|
+
from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
|
19
|
+
ClientActor,
|
20
|
+
DebuggerMessage as ClientDebuggerMessage,
|
21
|
+
)
|
22
|
+
|
23
|
+
from monarch._rust_bindings.monarch_extension.debugger import (
|
24
|
+
DebuggerMessage as PdbDebuggerMessage,
|
25
|
+
get_bytes_from_write_action,
|
26
|
+
)
|
27
|
+
from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
|
28
|
+
from monarch.rust_local_mesh import LoggingLocation, SocketType
|
29
|
+
from monarch_supervisor.logging import fix_exception_lines
|
30
|
+
|
31
|
+
|
32
|
+
def custom_excepthook(exc_type, exc_value, exc_traceback):
|
33
|
+
tb_lines = fix_exception_lines(
|
34
|
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
35
|
+
)
|
36
|
+
print("\n".join(tb_lines), file=sys.stderr)
|
37
|
+
|
38
|
+
|
39
|
+
sys.excepthook = custom_excepthook
|
40
|
+
|
41
|
+
|
42
|
+
@contextmanager
|
43
|
+
def local_mesh(
|
44
|
+
hosts: int = 1, gpu_per_host: int = 2, activate: bool = True
|
45
|
+
) -> Generator[DeviceMesh, None, None]:
|
46
|
+
with rust_local_mesh.local_mesh(
|
47
|
+
hosts=hosts,
|
48
|
+
gpus_per_host=gpu_per_host,
|
49
|
+
socket_type=SocketType.UNIX,
|
50
|
+
logging_location=LoggingLocation.DEFAULT,
|
51
|
+
) as dm:
|
52
|
+
try:
|
53
|
+
if activate:
|
54
|
+
with dm.activate():
|
55
|
+
yield dm
|
56
|
+
else:
|
57
|
+
yield dm
|
58
|
+
dm.exit()
|
59
|
+
except Exception:
|
60
|
+
dm.client._shutdown = True
|
61
|
+
raise
|
62
|
+
|
63
|
+
|
64
|
+
remote_test_pdb_actor = remote(
|
65
|
+
"monarch.worker._testing_function.test_pdb_actor",
|
66
|
+
propagate=lambda: torch.zeros(1),
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
@pytest.mark.skipif(
|
71
|
+
torch.cuda.device_count() < 2,
|
72
|
+
reason="Not enough GPUs, this test requires at least 2 GPUs",
|
73
|
+
)
|
74
|
+
# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
|
75
|
+
# out is not counted as a failure, so we set a more restrictive timeout to
|
76
|
+
# ensure we see a hard failure in CI.
|
77
|
+
@pytest.mark.timeout(120)
|
78
|
+
class TestPdbActor:
|
79
|
+
def test_pdb_actor(self):
|
80
|
+
with local_mesh(1, 1) as dm:
|
81
|
+
with dm.activate():
|
82
|
+
client = dm.client.inner._actor
|
83
|
+
assert isinstance(client, ClientActor)
|
84
|
+
fut = fetch_shard(remote_test_pdb_actor())
|
85
|
+
msg = client.get_next_message(timeout_msec=None)
|
86
|
+
assert isinstance(msg, ClientDebuggerMessage)
|
87
|
+
assert isinstance(msg.action, DebuggerAction.Paused)
|
88
|
+
client.send(
|
89
|
+
msg.debugger_actor_id,
|
90
|
+
PdbDebuggerMessage(action=DebuggerAction.Attach()).serialize(),
|
91
|
+
)
|
92
|
+
msg = client.get_next_message(timeout_msec=None)
|
93
|
+
assert isinstance(msg, ClientDebuggerMessage)
|
94
|
+
assert isinstance(msg.action, DebuggerAction.Read)
|
95
|
+
assert msg.action.requested_size == 4
|
96
|
+
client.send(
|
97
|
+
msg.debugger_actor_id,
|
98
|
+
PdbDebuggerMessage(
|
99
|
+
action=DebuggerAction.Write(b"1234")
|
100
|
+
).serialize(),
|
101
|
+
)
|
102
|
+
msg = client.get_next_message(timeout_msec=None)
|
103
|
+
assert isinstance(msg, ClientDebuggerMessage)
|
104
|
+
assert isinstance(msg.action, DebuggerAction.Write)
|
105
|
+
assert get_bytes_from_write_action(msg.action) == b"5678"
|
106
|
+
client.send(
|
107
|
+
msg.debugger_actor_id,
|
108
|
+
PdbDebuggerMessage(action=DebuggerAction.Detach()).serialize(),
|
109
|
+
)
|
110
|
+
fut.result()
|
@@ -0,0 +1,372 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import operator
|
8
|
+
from types import ModuleType
|
9
|
+
|
10
|
+
import torch
|
11
|
+
from monarch.actor_mesh import (
|
12
|
+
Accumulator,
|
13
|
+
Actor,
|
14
|
+
current_actor_name,
|
15
|
+
current_rank,
|
16
|
+
current_size,
|
17
|
+
endpoint,
|
18
|
+
)
|
19
|
+
|
20
|
+
from monarch.proc_mesh import local_proc_mesh, proc_mesh
|
21
|
+
from monarch.rdma import RDMABuffer
|
22
|
+
|
23
|
+
|
24
|
+
class Counter(Actor):
|
25
|
+
def __init__(self, v: int):
|
26
|
+
self.v = v
|
27
|
+
|
28
|
+
@endpoint
|
29
|
+
async def incr(self):
|
30
|
+
self.v += 1
|
31
|
+
|
32
|
+
@endpoint
|
33
|
+
async def value(self) -> int:
|
34
|
+
return self.v
|
35
|
+
|
36
|
+
|
37
|
+
class Indirect(Actor):
|
38
|
+
@endpoint
|
39
|
+
async def call_value(self, c: Counter) -> int:
|
40
|
+
return await c.value.choose()
|
41
|
+
|
42
|
+
|
43
|
+
class ParameterServer(Actor):
|
44
|
+
def __init__(self):
|
45
|
+
self.params = torch.rand(10, 10)
|
46
|
+
self.grad_buffer = torch.rand(10, 10)
|
47
|
+
|
48
|
+
@endpoint
|
49
|
+
async def grad_handle(self) -> RDMABuffer:
|
50
|
+
byte_tensor = self.grad_buffer.view(torch.uint8).flatten()
|
51
|
+
return RDMABuffer(byte_tensor)
|
52
|
+
|
53
|
+
@endpoint
|
54
|
+
async def update(self):
|
55
|
+
self.params += 0.01 * self.grad_buffer
|
56
|
+
|
57
|
+
@endpoint
|
58
|
+
async def get_grad_buffer(self) -> torch.Tensor:
|
59
|
+
# just used for testing
|
60
|
+
return self.grad_buffer
|
61
|
+
|
62
|
+
|
63
|
+
async def test_choose():
|
64
|
+
proc = await local_proc_mesh(gpus=2)
|
65
|
+
v = await proc.spawn("counter", Counter, 3)
|
66
|
+
i = await proc.spawn("indirect", Indirect)
|
67
|
+
v.incr.broadcast()
|
68
|
+
result = await v.value.choose()
|
69
|
+
result2 = await i.call_value.choose(v)
|
70
|
+
|
71
|
+
assert result == result2
|
72
|
+
|
73
|
+
|
74
|
+
async def test_stream():
|
75
|
+
proc = await local_proc_mesh(gpus=2)
|
76
|
+
v = await proc.spawn("counter2", Counter, 3)
|
77
|
+
v.incr.broadcast()
|
78
|
+
|
79
|
+
assert 8 == sum([x async for x in v.value.stream()])
|
80
|
+
|
81
|
+
|
82
|
+
class ParameterClient(Actor):
|
83
|
+
def __init__(self, server, buffer):
|
84
|
+
self.server = server
|
85
|
+
byte_tensor = buffer.view(torch.uint8).flatten()
|
86
|
+
self.buffer = byte_tensor
|
87
|
+
|
88
|
+
@endpoint
|
89
|
+
async def upload(self, tensor):
|
90
|
+
gh = await self.server.grad_handle.call_one()
|
91
|
+
await gh.write(tensor)
|
92
|
+
|
93
|
+
@endpoint
|
94
|
+
async def download(self):
|
95
|
+
gh = await self.server.grad_handle.call_one()
|
96
|
+
await gh.read_into(self.buffer)
|
97
|
+
|
98
|
+
@endpoint
|
99
|
+
async def get_buffer(self):
|
100
|
+
return self.buffer
|
101
|
+
|
102
|
+
|
103
|
+
async def test_proc_mesh_rdma():
|
104
|
+
proc = await proc_mesh(gpus=1)
|
105
|
+
server = await proc.spawn("server", ParameterServer)
|
106
|
+
|
107
|
+
# --- CPU TESTS ---
|
108
|
+
client_cpu = await proc.spawn(
|
109
|
+
"client_cpu", ParameterClient, server, torch.ones(10, 10)
|
110
|
+
)
|
111
|
+
x = await client_cpu.get_buffer.call_one()
|
112
|
+
assert torch.sum(x.view(torch.float32).view(10, 10)) == 100
|
113
|
+
zeros = torch.zeros(10, 10)
|
114
|
+
await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten())
|
115
|
+
await client_cpu.download.call_one()
|
116
|
+
x = await client_cpu.get_buffer.call_one()
|
117
|
+
assert torch.sum(x.view(torch.float32).view(10, 10)) == 0
|
118
|
+
|
119
|
+
# --- Modify server's backing buffer directly ---
|
120
|
+
await server.update.call_one()
|
121
|
+
|
122
|
+
# Should reflect updated values
|
123
|
+
await client_cpu.download.call_one()
|
124
|
+
|
125
|
+
buffer = await client_cpu.get_buffer.call_one()
|
126
|
+
remote_grad = await server.get_grad_buffer.call_one()
|
127
|
+
assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad)
|
128
|
+
|
129
|
+
# --- GPU TESTS ---
|
130
|
+
client_gpu = await proc.spawn(
|
131
|
+
"client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda")
|
132
|
+
)
|
133
|
+
x = await client_gpu.get_buffer.call_one()
|
134
|
+
buffer = x.view(torch.float32).view(10, 10)
|
135
|
+
assert torch.sum(buffer) == 100
|
136
|
+
zeros = torch.zeros(10, 10, device="cuda")
|
137
|
+
await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten())
|
138
|
+
await client_gpu.download.call_one()
|
139
|
+
x = await client_gpu.get_buffer.call_one()
|
140
|
+
buffer_gpu = x.view(torch.float32).view(10, 10)
|
141
|
+
assert torch.sum(buffer_gpu) == 0
|
142
|
+
assert buffer_gpu.device.type == "cuda"
|
143
|
+
|
144
|
+
# Modify server state again
|
145
|
+
await server.update.call_one()
|
146
|
+
await client_gpu.download.call_one()
|
147
|
+
x = await client_gpu.get_buffer.call_one()
|
148
|
+
buffer_gpu = x.view(torch.float32).view(10, 10)
|
149
|
+
remote_grad = await server.get_grad_buffer.call_one()
|
150
|
+
assert torch.allclose(buffer_gpu.cpu(), remote_grad)
|
151
|
+
|
152
|
+
|
153
|
+
class To(Actor):
|
154
|
+
@endpoint
|
155
|
+
async def whoami(self):
|
156
|
+
return current_actor_name()
|
157
|
+
|
158
|
+
|
159
|
+
class From(Actor):
|
160
|
+
@endpoint
|
161
|
+
async def get(self, to: To):
|
162
|
+
return [x async for x in to.whoami.stream()]
|
163
|
+
|
164
|
+
|
165
|
+
async def test_mesh_passed_to_mesh():
|
166
|
+
proc = await local_proc_mesh(gpus=2)
|
167
|
+
f = await proc.spawn("from", From)
|
168
|
+
t = await proc.spawn("to", To)
|
169
|
+
all = [y async for x in f.get.stream(t) for y in x]
|
170
|
+
assert len(all) == 4
|
171
|
+
assert all[0] != all[1]
|
172
|
+
|
173
|
+
|
174
|
+
async def test_mesh_passed_to_mesh_on_different_proc_mesh():
|
175
|
+
proc = await local_proc_mesh(gpus=2)
|
176
|
+
proc2 = await local_proc_mesh(gpus=2)
|
177
|
+
f = await proc.spawn("from", From)
|
178
|
+
t = await proc2.spawn("to", To)
|
179
|
+
all = [y async for x in f.get.stream(t) for y in x]
|
180
|
+
assert len(all) == 4
|
181
|
+
assert all[0] != all[1]
|
182
|
+
|
183
|
+
|
184
|
+
async def test_actor_slicing():
|
185
|
+
proc = await local_proc_mesh(gpus=2)
|
186
|
+
proc2 = await local_proc_mesh(gpus=2)
|
187
|
+
|
188
|
+
f = await proc.spawn("from", From)
|
189
|
+
t = await proc2.spawn("to", To)
|
190
|
+
|
191
|
+
assert await t.slice(gpus=0).whoami.call() != await t.slice(gpus=1).whoami.call()
|
192
|
+
|
193
|
+
result = [y async for x in f.get.stream(t.slice(gpus=0)) for y in x]
|
194
|
+
assert len(result) == 2
|
195
|
+
|
196
|
+
assert result[0] == result[1]
|
197
|
+
|
198
|
+
|
199
|
+
async def test_aggregate():
|
200
|
+
proc = await local_proc_mesh(gpus=2)
|
201
|
+
counter = await proc.spawn("counter", Counter, 1)
|
202
|
+
counter.incr.broadcast()
|
203
|
+
acc = Accumulator(counter.value, 0, operator.add)
|
204
|
+
r = await acc.accumulate()
|
205
|
+
assert r == 4
|
206
|
+
|
207
|
+
|
208
|
+
class RunIt(Actor):
|
209
|
+
@endpoint
|
210
|
+
async def run(self, fn):
|
211
|
+
return fn()
|
212
|
+
|
213
|
+
|
214
|
+
async def test_rank_size():
|
215
|
+
proc = await local_proc_mesh(gpus=2)
|
216
|
+
r = await proc.spawn("runit", RunIt)
|
217
|
+
|
218
|
+
acc = Accumulator(r.run, 0, operator.add)
|
219
|
+
|
220
|
+
assert 1 == await acc.accumulate(lambda: current_rank()["gpus"])
|
221
|
+
assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
|
222
|
+
|
223
|
+
|
224
|
+
class TrainerActor(Actor):
|
225
|
+
def __init__(self):
|
226
|
+
super().__init__()
|
227
|
+
self.trainer = torch.nn.Linear(10, 10).to("cuda")
|
228
|
+
self.trainer.weight.data.zero_()
|
229
|
+
|
230
|
+
@endpoint
|
231
|
+
async def init(self, gen):
|
232
|
+
ranks = current_rank()
|
233
|
+
self.gen = gen.slice(**ranks)
|
234
|
+
|
235
|
+
@endpoint
|
236
|
+
async def exchange_metadata(self):
|
237
|
+
byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten()
|
238
|
+
self.handle = RDMABuffer(byte_tensor)
|
239
|
+
await self.gen.attach_weight_buffer.call(self.handle)
|
240
|
+
|
241
|
+
@endpoint
|
242
|
+
async def weights_ready(self):
|
243
|
+
self.trainer.weight.data.add_(1.0)
|
244
|
+
|
245
|
+
|
246
|
+
class GeneratorActor(Actor):
|
247
|
+
def __init__(self):
|
248
|
+
super().__init__()
|
249
|
+
self.generator = torch.nn.Linear(10, 10).to("cuda")
|
250
|
+
self.step = 0
|
251
|
+
|
252
|
+
@endpoint
|
253
|
+
async def init(self, trainer):
|
254
|
+
ranks = current_rank()
|
255
|
+
self.trainer = trainer.slice(**ranks)
|
256
|
+
|
257
|
+
@endpoint
|
258
|
+
async def attach_weight_buffer(self, handle):
|
259
|
+
self.handle = handle
|
260
|
+
|
261
|
+
@endpoint
|
262
|
+
async def update_weights(self):
|
263
|
+
self.step += 1
|
264
|
+
byte_tensor = self.generator.weight.data.view(torch.uint8).flatten()
|
265
|
+
await self.handle.read_into(byte_tensor)
|
266
|
+
assert (
|
267
|
+
torch.sum(self.generator.weight.data) == self.step * 100
|
268
|
+
), f"{torch.sum(self.generator.weight.data)=}, {self.step=}"
|
269
|
+
|
270
|
+
|
271
|
+
async def test_gpu_trainer_generator():
|
272
|
+
trainer_proc = await proc_mesh(gpus=1)
|
273
|
+
gen_proc = await proc_mesh(gpus=1)
|
274
|
+
trainer = await trainer_proc.spawn("trainer", TrainerActor)
|
275
|
+
generator = await gen_proc.spawn("gen", GeneratorActor)
|
276
|
+
|
277
|
+
await generator.init.call(trainer)
|
278
|
+
await trainer.init.call(generator)
|
279
|
+
await trainer.exchange_metadata.call()
|
280
|
+
|
281
|
+
for _ in range(3):
|
282
|
+
await trainer.weights_ready.call()
|
283
|
+
await generator.update_weights.call()
|
284
|
+
|
285
|
+
|
286
|
+
class SyncActor(Actor):
|
287
|
+
@endpoint
|
288
|
+
def sync_endpoint(self, a_counter: Counter):
|
289
|
+
return a_counter.value.choose().get()
|
290
|
+
|
291
|
+
|
292
|
+
async def test_sync_actor():
|
293
|
+
proc = await local_proc_mesh(gpus=2)
|
294
|
+
a = await proc.spawn("actor", SyncActor)
|
295
|
+
c = await proc.spawn("counter", Counter, 5)
|
296
|
+
r = await a.sync_endpoint.choose(c)
|
297
|
+
assert r == 5
|
298
|
+
|
299
|
+
|
300
|
+
def test_gpu_trainer_generator_sync() -> None:
|
301
|
+
trainer_proc = proc_mesh(gpus=1).get()
|
302
|
+
gen_proc = proc_mesh(gpus=1).get()
|
303
|
+
trainer = trainer_proc.spawn("trainer", TrainerActor).get()
|
304
|
+
generator = gen_proc.spawn("gen", GeneratorActor).get()
|
305
|
+
|
306
|
+
generator.init.call(trainer).get()
|
307
|
+
trainer.init.call(generator).get()
|
308
|
+
trainer.exchange_metadata.call().get()
|
309
|
+
|
310
|
+
for _ in range(3):
|
311
|
+
trainer.weights_ready.call().get()
|
312
|
+
generator.update_weights.call().get()
|
313
|
+
|
314
|
+
|
315
|
+
def test_sync_actor_sync_client():
|
316
|
+
proc = local_proc_mesh(gpus=2).get()
|
317
|
+
a = proc.spawn("actor", SyncActor).get()
|
318
|
+
c = proc.spawn("counter", Counter, 5).get()
|
319
|
+
r = a.sync_endpoint.choose(c).get()
|
320
|
+
assert r == 5
|
321
|
+
|
322
|
+
|
323
|
+
def test_rank_size_sync() -> None:
|
324
|
+
proc = local_proc_mesh(gpus=2).get()
|
325
|
+
r = proc.spawn("runit", RunIt).get()
|
326
|
+
|
327
|
+
acc = Accumulator(r.run, 0, operator.add)
|
328
|
+
assert 1 == acc.accumulate(lambda: current_rank()["gpus"]).get()
|
329
|
+
assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
|
330
|
+
|
331
|
+
|
332
|
+
def test_accumulate_sync() -> None:
|
333
|
+
proc = local_proc_mesh(gpus=2).get()
|
334
|
+
counter = proc.spawn("counter", Counter, 1).get()
|
335
|
+
counter.incr.broadcast()
|
336
|
+
acc = Accumulator(counter.value, 0, operator.add)
|
337
|
+
r = acc.accumulate().get()
|
338
|
+
assert r == 4
|
339
|
+
|
340
|
+
|
341
|
+
class CastToCounter(Actor):
|
342
|
+
@endpoint
|
343
|
+
def doit(self, c: Counter):
|
344
|
+
return list(c.value.call().get())
|
345
|
+
|
346
|
+
|
347
|
+
def test_value_mesh() -> None:
|
348
|
+
proc = local_proc_mesh(gpus=2).get()
|
349
|
+
counter = proc.spawn("counter", Counter, 0).get()
|
350
|
+
counter.slice(hosts=0, gpus=1).incr.broadcast()
|
351
|
+
x = counter.value.call().get()
|
352
|
+
assert 0 == x.item(hosts=0, gpus=0)
|
353
|
+
assert 1 == x.item(hosts=0, gpus=1)
|
354
|
+
assert 1 == x.slice(hosts=0, gpus=1).item()
|
355
|
+
n = proc.spawn("ctc", CastToCounter).get()
|
356
|
+
assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
|
357
|
+
|
358
|
+
|
359
|
+
def test_rust_binding_modules_correct() -> None:
|
360
|
+
import monarch._rust_bindings as bindings
|
361
|
+
|
362
|
+
def check(module, path):
|
363
|
+
for name, value in module.__dict__.items():
|
364
|
+
if name.startswith("__"):
|
365
|
+
continue
|
366
|
+
if isinstance(value, ModuleType):
|
367
|
+
check(value, f"{path}.{name}")
|
368
|
+
elif hasattr(value, "__module__"):
|
369
|
+
assert value.__name__ == name
|
370
|
+
assert value.__module__ == path
|
371
|
+
|
372
|
+
check(bindings, "monarch._rust_bindings")
|