torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +74 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +198 -0
  10. monarch/actor_mesh.py +692 -0
  11. monarch/allocator.py +62 -0
  12. monarch/bootstrap_main.py +75 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +69 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/common/_C.pyi +11 -0
  18. monarch/common/_C.so +0 -0
  19. monarch/common/__init__.py +0 -0
  20. monarch/common/_coalescing.py +308 -0
  21. monarch/common/_device_utils.py +18 -0
  22. monarch/common/_tensor_to_table.py +172 -0
  23. monarch/common/base_tensor.py +28 -0
  24. monarch/common/borrows.py +143 -0
  25. monarch/common/client.py +646 -0
  26. monarch/common/constants.py +10 -0
  27. monarch/common/context_manager.py +40 -0
  28. monarch/common/controller_api.py +104 -0
  29. monarch/common/device_mesh.py +443 -0
  30. monarch/common/fake.py +55 -0
  31. monarch/common/function.py +160 -0
  32. monarch/common/function_caching.py +164 -0
  33. monarch/common/future.py +168 -0
  34. monarch/common/invocation.py +125 -0
  35. monarch/common/mast.py +221 -0
  36. monarch/common/messages.py +572 -0
  37. monarch/common/mock_cuda.py +41 -0
  38. monarch/common/opaque_ref.py +98 -0
  39. monarch/common/pickle_flatten.py +48 -0
  40. monarch/common/pipe.py +152 -0
  41. monarch/common/process_group.py +55 -0
  42. monarch/common/recording.py +127 -0
  43. monarch/common/reference.py +33 -0
  44. monarch/common/remote.py +304 -0
  45. monarch/common/selection.py +9 -0
  46. monarch/common/shape.py +204 -0
  47. monarch/common/stream.py +111 -0
  48. monarch/common/tensor.py +793 -0
  49. monarch/common/tensor_factory.py +31 -0
  50. monarch/common/tree.py +73 -0
  51. monarch/controller/__init__.py +7 -0
  52. monarch/controller/backend.py +223 -0
  53. monarch/controller/controller.py +223 -0
  54. monarch/controller/debugger.py +47 -0
  55. monarch/controller/history.py +90 -0
  56. monarch/controller/rust_backend/__init__.py +7 -0
  57. monarch/controller/rust_backend/controller.py +245 -0
  58. monarch/fetch.py +55 -0
  59. monarch/future.py +25 -0
  60. monarch/gradient/__init__.py +11 -0
  61. monarch/gradient/_gradient_generator.pyi +22 -0
  62. monarch/gradient/_gradient_generator.so +0 -0
  63. monarch/gradient_generator.py +185 -0
  64. monarch/memory.py +43 -0
  65. monarch/monarch_controller +0 -0
  66. monarch/notebook.py +761 -0
  67. monarch/opaque_module.py +235 -0
  68. monarch/opaque_object.py +88 -0
  69. monarch/parallel/__init__.py +9 -0
  70. monarch/parallel/pipelining/__init__.py +7 -0
  71. monarch/parallel/pipelining/runtime.py +847 -0
  72. monarch/parallel/pipelining/schedule_ir.py +692 -0
  73. monarch/parallel/pipelining/scheduler.py +249 -0
  74. monarch/proc_mesh.py +188 -0
  75. monarch/profiler.py +160 -0
  76. monarch/python_local_mesh.py +107 -0
  77. monarch/random.py +61 -0
  78. monarch/rdma.py +190 -0
  79. monarch/remote_class.py +114 -0
  80. monarch/rust_backend_mesh.py +280 -0
  81. monarch/rust_local_mesh.py +1402 -0
  82. monarch/sim_mesh.py +357 -0
  83. monarch/simulator/__init__.py +7 -0
  84. monarch/simulator/command_history.py +424 -0
  85. monarch/simulator/config.py +21 -0
  86. monarch/simulator/interface.py +59 -0
  87. monarch/simulator/ir.py +770 -0
  88. monarch/simulator/mock_controller.py +214 -0
  89. monarch/simulator/profiling.py +424 -0
  90. monarch/simulator/simulator.py +1052 -0
  91. monarch/simulator/task.py +255 -0
  92. monarch/simulator/tensor.py +373 -0
  93. monarch/simulator/trace.py +395 -0
  94. monarch/simulator/utils.py +41 -0
  95. monarch/simulator/worker.py +389 -0
  96. monarch/tensor_worker_main.py +260 -0
  97. monarch/tensorboard.py +84 -0
  98. monarch/timer/__init__.py +21 -0
  99. monarch/timer/example_monarch.py +78 -0
  100. monarch/timer/example_spmd.py +55 -0
  101. monarch/timer/execution_timer.py +199 -0
  102. monarch/timer/execution_timer_test.py +131 -0
  103. monarch/tools/__init__.py +7 -0
  104. monarch/tools/cli.py +167 -0
  105. monarch/tools/commands.py +189 -0
  106. monarch/tools/components/__init__.py +7 -0
  107. monarch/tools/components/hyperactor.py +57 -0
  108. monarch/tools/config/__init__.py +20 -0
  109. monarch/tools/config/defaults.py +54 -0
  110. monarch/tools/mesh_spec.py +121 -0
  111. monarch/worker/__init__.py +7 -0
  112. monarch/worker/_testing_function.py +481 -0
  113. monarch/worker/compiled_block.py +270 -0
  114. monarch/worker/debugger.py +125 -0
  115. monarch/worker/lines.py +47 -0
  116. monarch/worker/monitor.py +53 -0
  117. monarch/worker/worker.py +1191 -0
  118. monarch/world_mesh.py +34 -0
  119. monarch_supervisor/__init__.py +1044 -0
  120. monarch_supervisor/_testing.py +44 -0
  121. monarch_supervisor/function_call.py +30 -0
  122. monarch_supervisor/host.py +386 -0
  123. monarch_supervisor/launchers.py +145 -0
  124. monarch_supervisor/log_pstree.py +48 -0
  125. monarch_supervisor/logging.py +103 -0
  126. monarch_supervisor/python_executable.py +42 -0
  127. tests/__init__.py +0 -0
  128. tests/dispatch_bench.py +124 -0
  129. tests/dispatch_bench_helper.py +25 -0
  130. tests/error_test_binary.py +139 -0
  131. tests/simulator/__init__.py +0 -0
  132. tests/simulator/test_profiling.py +136 -0
  133. tests/simulator/test_simulator.py +411 -0
  134. tests/simulator/test_task.py +64 -0
  135. tests/simulator/test_worker.py +102 -0
  136. tests/sleep_binary.py +35 -0
  137. tests/test_actor_error.py +112 -0
  138. tests/test_alloc.py +25 -0
  139. tests/test_coalescing.py +492 -0
  140. tests/test_controller.py +835 -0
  141. tests/test_device_mesh.py +132 -0
  142. tests/test_fault_tolerance.py +398 -0
  143. tests/test_future.py +94 -0
  144. tests/test_grad_generator.py +121 -0
  145. tests/test_mock_cuda.py +74 -0
  146. tests/test_pdb_actor.py +110 -0
  147. tests/test_python_actors.py +372 -0
  148. tests/test_remote_functions.py +1271 -0
  149. tests/test_rust_backend.py +182 -0
  150. tests/test_signal_safe_block_on.py +103 -0
  151. tests/test_sim_backend.py +54 -0
  152. torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
  153. torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
  154. torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
  155. torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
  156. torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
  157. torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,121 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from unittest import main, TestCase
9
+
10
+ import torch
11
+ from monarch.gradient._gradient_generator import GradientGenerator
12
+ from monarch.gradient_generator import gradient_execution_order
13
+
14
+
15
+ class TestGradIter(TestCase):
16
+ def checkEqual(self, r, r2):
17
+ self.assertEqual(len(r), len(r2))
18
+ for i, i2 in zip(r, r2):
19
+ self.assertTrue((i is None and i2 is None) or torch.allclose(i, i2))
20
+
21
+ def test_simple(self):
22
+ t = torch.rand(2, requires_grad=True)
23
+ t2 = torch.rand(2, requires_grad=True)
24
+
25
+ _ = t + t2
26
+ a, b = torch.std_mean(t + t2)
27
+
28
+ r2 = torch.autograd.grad([a, b], [t2, t], retain_graph=True)
29
+ r = list(GradientGenerator([a, b], [t2, t]))
30
+ print(a, b)
31
+ print(a.grad_fn, b.grad_fn)
32
+
33
+ print(r)
34
+ self.checkEqual(r, r2)
35
+
36
+ def test_pipeline_like(self):
37
+ t = torch.rand(3, 3, requires_grad=True)
38
+
39
+ w1 = torch.rand(3, 2, requires_grad=True)
40
+ w2 = torch.rand(3, 2, requires_grad=True)
41
+ w3 = torch.rand(3, 2, requires_grad=True)
42
+
43
+ u = torch.rand(3, 2, requires_grad=True)
44
+
45
+ _ = u * u
46
+
47
+ w4 = torch.rand(2, 3, requires_grad=True)
48
+ w5 = torch.rand(2, 3, requires_grad=True)
49
+ w6 = torch.rand(2, 3, requires_grad=True)
50
+
51
+ from torch.nn.functional import relu
52
+
53
+ a = relu(t @ (w1 @ w4))
54
+ a = relu(a @ (w2 @ w5))
55
+ a = relu(a @ (w3 @ w6))
56
+
57
+ std, mean = torch.std_mean(a)
58
+ loss = std + std
59
+
60
+ cgrads = torch.autograd.grad(
61
+ [loss], [t, w3, w6, u, w2, w5], allow_unused=True, retain_graph=True
62
+ )
63
+ gi = GradientGenerator([loss], [t, w3, w6, u, w2, w5])
64
+ grads = [*gi]
65
+ self.checkEqual(grads, cgrads)
66
+
67
+ def test_tree(self):
68
+ t = torch.rand(3, 3, requires_grad=True)
69
+
70
+ t2 = t + t
71
+ t3 = t * t
72
+ t4 = t / t
73
+ t5 = t - t
74
+
75
+ t6 = t2 * t3
76
+ t7 = t4 * t5
77
+ t8 = t2 * t4
78
+ t9 = t3 * t5
79
+ t10 = t6 + t7 + t8 + t9
80
+
81
+ t11 = t10.sum()
82
+
83
+ cgrads = torch.autograd.grad([t11], [t2, t], retain_graph=True)
84
+ gi = GradientGenerator([t11], [t2, t])
85
+ grads = [*gi]
86
+ self.checkEqual(grads, cgrads)
87
+
88
+ def test_broadcast(self):
89
+ t = torch.rand(3, 3, requires_grad=True)
90
+ t2 = torch.rand(3, requires_grad=True)
91
+ t3 = t2 / t2
92
+
93
+ r = (t * t3).sum()
94
+ cgrads = torch.autograd.grad([r], [t, t2], retain_graph=True)
95
+ gi = GradientGenerator([r], [t, t2])
96
+ grads = [*gi]
97
+ self.checkEqual(grads, cgrads)
98
+
99
+ def test_grad_order(self):
100
+ t = torch.rand(3, 3, requires_grad=True)
101
+ w1 = torch.rand(3, 3, requires_grad=True)
102
+ w2 = torch.rand(3, 3, requires_grad=True)
103
+ w3 = torch.rand(3, 3, requires_grad=True)
104
+
105
+ u = torch.rand(3, 2, requires_grad=True)
106
+ _ = u * u
107
+ from torch.nn.functional import relu
108
+
109
+ a = relu(t @ w1)
110
+ a = relu(a @ w2)
111
+ a = relu(a @ w3)
112
+
113
+ std, mean = torch.std_mean(a)
114
+ loss = std + std
115
+
116
+ order = gradient_execution_order([loss], [w2, w3, w1, a])
117
+ self.assertEqual(order, [3, 1, 0, 2])
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
@@ -0,0 +1,74 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from unittest import main, TestCase
9
+
10
+ import pytest
11
+ import torch
12
+ import monarch.common.mock_cuda # usort: skip
13
+
14
+
15
+ def simple_forward_backward(device: str) -> None:
16
+ torch.manual_seed(123)
17
+ m = torch.nn.Sequential(torch.nn.Linear(3, 3), torch.nn.ReLU()).to(device)
18
+ x = torch.rand(10, 3).to(device)
19
+ y = m(x)
20
+ loss_fn = torch.nn.CrossEntropyLoss()
21
+ loss = loss_fn(y, torch.randint(3, (10,)).to(device))
22
+ # Under the hood, enabling/disabling CUDA mocking is done with a thread-local
23
+ # flag. By default, backward() executes ops on a different thread than the one
24
+ # we enabled mocking on, which would lead to an invalid memory access. So we need
25
+ # to disable multithreading for backward.
26
+ with torch.autograd.set_multithreading_enabled(False):
27
+ loss.backward()
28
+ # pyre-ignore: Incompatible return type [7]: Expected `None` but got `Tuple[typing.Any, Union[None, Tensor, Module], Union[None, Tensor, Module]]`.
29
+ return y, m[0].weight.grad, m[0].bias.grad
30
+
31
+
32
+ # Mock cuda depends on initialization load order
33
+ # For OSS, run this test separately until it can be run in a subprocess.
34
+ @pytest.mark.oss_skip
35
+ class TestMockCuda(TestCase):
36
+ def setUp(self) -> None:
37
+ return super().setUp()
38
+
39
+ def test_output_is_garbage(self):
40
+ with monarch.common.mock_cuda.mock_cuda_guard():
41
+ x = torch.arange(9, device="cuda", dtype=torch.float32).reshape(3, 3)
42
+ y = 2 * torch.eye(3, device="cuda")
43
+ true_output = torch.tensor(
44
+ [[0, 2, 4], [6, 8, 10], [12, 14, 16]], dtype=torch.float32
45
+ )
46
+ self.assertFalse(torch.equal((x @ y).cpu(), true_output))
47
+
48
+ def test_simple_forward_backward(self):
49
+ # This test just makes sure that the forward and backward pass work
50
+ # and don't crash.
51
+ simple_forward_backward("cuda")
52
+
53
+ def test_turn_mock_on_and_off(self):
54
+ cpu_y, cpu_dw, cpu_db = simple_forward_backward("cpu")
55
+
56
+ real_y, real_dw, real_db = simple_forward_backward("cuda")
57
+ self.assertTrue(torch.allclose(cpu_y, real_y.cpu()))
58
+ self.assertTrue(torch.allclose(cpu_dw, real_dw.cpu()))
59
+ self.assertTrue(torch.allclose(cpu_db, real_db.cpu()))
60
+
61
+ with monarch.common.mock_cuda.mock_cuda_guard():
62
+ mocked_y, mocked_dw, mocked_db = simple_forward_backward("cuda")
63
+ self.assertFalse(torch.allclose(cpu_y, mocked_y.cpu()))
64
+ self.assertFalse(torch.allclose(cpu_dw, mocked_dw.cpu()))
65
+ self.assertFalse(torch.allclose(cpu_db, mocked_db.cpu()))
66
+
67
+ real_y, real_dw, real_db = simple_forward_backward("cuda")
68
+ self.assertTrue(torch.allclose(cpu_y, real_y.cpu()))
69
+ self.assertTrue(torch.allclose(cpu_dw, real_dw.cpu()))
70
+ self.assertTrue(torch.allclose(cpu_db, real_db.cpu()))
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
@@ -0,0 +1,110 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import sys
9
+ import traceback
10
+ from contextlib import contextmanager
11
+ from typing import Generator
12
+
13
+ import pytest
14
+
15
+ import torch
16
+
17
+ from monarch import DeviceMesh, fetch_shard, remote, rust_local_mesh
18
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
19
+ ClientActor,
20
+ DebuggerMessage as ClientDebuggerMessage,
21
+ )
22
+
23
+ from monarch._rust_bindings.monarch_extension.debugger import (
24
+ DebuggerMessage as PdbDebuggerMessage,
25
+ get_bytes_from_write_action,
26
+ )
27
+ from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
28
+ from monarch.rust_local_mesh import LoggingLocation, SocketType
29
+ from monarch_supervisor.logging import fix_exception_lines
30
+
31
+
32
+ def custom_excepthook(exc_type, exc_value, exc_traceback):
33
+ tb_lines = fix_exception_lines(
34
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
35
+ )
36
+ print("\n".join(tb_lines), file=sys.stderr)
37
+
38
+
39
+ sys.excepthook = custom_excepthook
40
+
41
+
42
+ @contextmanager
43
+ def local_mesh(
44
+ hosts: int = 1, gpu_per_host: int = 2, activate: bool = True
45
+ ) -> Generator[DeviceMesh, None, None]:
46
+ with rust_local_mesh.local_mesh(
47
+ hosts=hosts,
48
+ gpus_per_host=gpu_per_host,
49
+ socket_type=SocketType.UNIX,
50
+ logging_location=LoggingLocation.DEFAULT,
51
+ ) as dm:
52
+ try:
53
+ if activate:
54
+ with dm.activate():
55
+ yield dm
56
+ else:
57
+ yield dm
58
+ dm.exit()
59
+ except Exception:
60
+ dm.client._shutdown = True
61
+ raise
62
+
63
+
64
+ remote_test_pdb_actor = remote(
65
+ "monarch.worker._testing_function.test_pdb_actor",
66
+ propagate=lambda: torch.zeros(1),
67
+ )
68
+
69
+
70
+ @pytest.mark.skipif(
71
+ torch.cuda.device_count() < 2,
72
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
73
+ )
74
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
75
+ # out is not counted as a failure, so we set a more restrictive timeout to
76
+ # ensure we see a hard failure in CI.
77
+ @pytest.mark.timeout(120)
78
+ class TestPdbActor:
79
+ def test_pdb_actor(self):
80
+ with local_mesh(1, 1) as dm:
81
+ with dm.activate():
82
+ client = dm.client.inner._actor
83
+ assert isinstance(client, ClientActor)
84
+ fut = fetch_shard(remote_test_pdb_actor())
85
+ msg = client.get_next_message(timeout_msec=None)
86
+ assert isinstance(msg, ClientDebuggerMessage)
87
+ assert isinstance(msg.action, DebuggerAction.Paused)
88
+ client.send(
89
+ msg.debugger_actor_id,
90
+ PdbDebuggerMessage(action=DebuggerAction.Attach()).serialize(),
91
+ )
92
+ msg = client.get_next_message(timeout_msec=None)
93
+ assert isinstance(msg, ClientDebuggerMessage)
94
+ assert isinstance(msg.action, DebuggerAction.Read)
95
+ assert msg.action.requested_size == 4
96
+ client.send(
97
+ msg.debugger_actor_id,
98
+ PdbDebuggerMessage(
99
+ action=DebuggerAction.Write(b"1234")
100
+ ).serialize(),
101
+ )
102
+ msg = client.get_next_message(timeout_msec=None)
103
+ assert isinstance(msg, ClientDebuggerMessage)
104
+ assert isinstance(msg.action, DebuggerAction.Write)
105
+ assert get_bytes_from_write_action(msg.action) == b"5678"
106
+ client.send(
107
+ msg.debugger_actor_id,
108
+ PdbDebuggerMessage(action=DebuggerAction.Detach()).serialize(),
109
+ )
110
+ fut.result()
@@ -0,0 +1,372 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import operator
8
+ from types import ModuleType
9
+
10
+ import torch
11
+ from monarch.actor_mesh import (
12
+ Accumulator,
13
+ Actor,
14
+ current_actor_name,
15
+ current_rank,
16
+ current_size,
17
+ endpoint,
18
+ )
19
+
20
+ from monarch.proc_mesh import local_proc_mesh, proc_mesh
21
+ from monarch.rdma import RDMABuffer
22
+
23
+
24
+ class Counter(Actor):
25
+ def __init__(self, v: int):
26
+ self.v = v
27
+
28
+ @endpoint
29
+ async def incr(self):
30
+ self.v += 1
31
+
32
+ @endpoint
33
+ async def value(self) -> int:
34
+ return self.v
35
+
36
+
37
+ class Indirect(Actor):
38
+ @endpoint
39
+ async def call_value(self, c: Counter) -> int:
40
+ return await c.value.choose()
41
+
42
+
43
+ class ParameterServer(Actor):
44
+ def __init__(self):
45
+ self.params = torch.rand(10, 10)
46
+ self.grad_buffer = torch.rand(10, 10)
47
+
48
+ @endpoint
49
+ async def grad_handle(self) -> RDMABuffer:
50
+ byte_tensor = self.grad_buffer.view(torch.uint8).flatten()
51
+ return RDMABuffer(byte_tensor)
52
+
53
+ @endpoint
54
+ async def update(self):
55
+ self.params += 0.01 * self.grad_buffer
56
+
57
+ @endpoint
58
+ async def get_grad_buffer(self) -> torch.Tensor:
59
+ # just used for testing
60
+ return self.grad_buffer
61
+
62
+
63
+ async def test_choose():
64
+ proc = await local_proc_mesh(gpus=2)
65
+ v = await proc.spawn("counter", Counter, 3)
66
+ i = await proc.spawn("indirect", Indirect)
67
+ v.incr.broadcast()
68
+ result = await v.value.choose()
69
+ result2 = await i.call_value.choose(v)
70
+
71
+ assert result == result2
72
+
73
+
74
+ async def test_stream():
75
+ proc = await local_proc_mesh(gpus=2)
76
+ v = await proc.spawn("counter2", Counter, 3)
77
+ v.incr.broadcast()
78
+
79
+ assert 8 == sum([x async for x in v.value.stream()])
80
+
81
+
82
+ class ParameterClient(Actor):
83
+ def __init__(self, server, buffer):
84
+ self.server = server
85
+ byte_tensor = buffer.view(torch.uint8).flatten()
86
+ self.buffer = byte_tensor
87
+
88
+ @endpoint
89
+ async def upload(self, tensor):
90
+ gh = await self.server.grad_handle.call_one()
91
+ await gh.write(tensor)
92
+
93
+ @endpoint
94
+ async def download(self):
95
+ gh = await self.server.grad_handle.call_one()
96
+ await gh.read_into(self.buffer)
97
+
98
+ @endpoint
99
+ async def get_buffer(self):
100
+ return self.buffer
101
+
102
+
103
+ async def test_proc_mesh_rdma():
104
+ proc = await proc_mesh(gpus=1)
105
+ server = await proc.spawn("server", ParameterServer)
106
+
107
+ # --- CPU TESTS ---
108
+ client_cpu = await proc.spawn(
109
+ "client_cpu", ParameterClient, server, torch.ones(10, 10)
110
+ )
111
+ x = await client_cpu.get_buffer.call_one()
112
+ assert torch.sum(x.view(torch.float32).view(10, 10)) == 100
113
+ zeros = torch.zeros(10, 10)
114
+ await client_cpu.upload.call_one(zeros.view(torch.uint8).flatten())
115
+ await client_cpu.download.call_one()
116
+ x = await client_cpu.get_buffer.call_one()
117
+ assert torch.sum(x.view(torch.float32).view(10, 10)) == 0
118
+
119
+ # --- Modify server's backing buffer directly ---
120
+ await server.update.call_one()
121
+
122
+ # Should reflect updated values
123
+ await client_cpu.download.call_one()
124
+
125
+ buffer = await client_cpu.get_buffer.call_one()
126
+ remote_grad = await server.get_grad_buffer.call_one()
127
+ assert torch.allclose(buffer.view(torch.float32).view(10, 10), remote_grad)
128
+
129
+ # --- GPU TESTS ---
130
+ client_gpu = await proc.spawn(
131
+ "client_gpu", ParameterClient, server, torch.ones(10, 10, device="cuda")
132
+ )
133
+ x = await client_gpu.get_buffer.call_one()
134
+ buffer = x.view(torch.float32).view(10, 10)
135
+ assert torch.sum(buffer) == 100
136
+ zeros = torch.zeros(10, 10, device="cuda")
137
+ await client_gpu.upload.call_one(zeros.view(torch.uint8).flatten())
138
+ await client_gpu.download.call_one()
139
+ x = await client_gpu.get_buffer.call_one()
140
+ buffer_gpu = x.view(torch.float32).view(10, 10)
141
+ assert torch.sum(buffer_gpu) == 0
142
+ assert buffer_gpu.device.type == "cuda"
143
+
144
+ # Modify server state again
145
+ await server.update.call_one()
146
+ await client_gpu.download.call_one()
147
+ x = await client_gpu.get_buffer.call_one()
148
+ buffer_gpu = x.view(torch.float32).view(10, 10)
149
+ remote_grad = await server.get_grad_buffer.call_one()
150
+ assert torch.allclose(buffer_gpu.cpu(), remote_grad)
151
+
152
+
153
+ class To(Actor):
154
+ @endpoint
155
+ async def whoami(self):
156
+ return current_actor_name()
157
+
158
+
159
+ class From(Actor):
160
+ @endpoint
161
+ async def get(self, to: To):
162
+ return [x async for x in to.whoami.stream()]
163
+
164
+
165
+ async def test_mesh_passed_to_mesh():
166
+ proc = await local_proc_mesh(gpus=2)
167
+ f = await proc.spawn("from", From)
168
+ t = await proc.spawn("to", To)
169
+ all = [y async for x in f.get.stream(t) for y in x]
170
+ assert len(all) == 4
171
+ assert all[0] != all[1]
172
+
173
+
174
+ async def test_mesh_passed_to_mesh_on_different_proc_mesh():
175
+ proc = await local_proc_mesh(gpus=2)
176
+ proc2 = await local_proc_mesh(gpus=2)
177
+ f = await proc.spawn("from", From)
178
+ t = await proc2.spawn("to", To)
179
+ all = [y async for x in f.get.stream(t) for y in x]
180
+ assert len(all) == 4
181
+ assert all[0] != all[1]
182
+
183
+
184
+ async def test_actor_slicing():
185
+ proc = await local_proc_mesh(gpus=2)
186
+ proc2 = await local_proc_mesh(gpus=2)
187
+
188
+ f = await proc.spawn("from", From)
189
+ t = await proc2.spawn("to", To)
190
+
191
+ assert await t.slice(gpus=0).whoami.call() != await t.slice(gpus=1).whoami.call()
192
+
193
+ result = [y async for x in f.get.stream(t.slice(gpus=0)) for y in x]
194
+ assert len(result) == 2
195
+
196
+ assert result[0] == result[1]
197
+
198
+
199
+ async def test_aggregate():
200
+ proc = await local_proc_mesh(gpus=2)
201
+ counter = await proc.spawn("counter", Counter, 1)
202
+ counter.incr.broadcast()
203
+ acc = Accumulator(counter.value, 0, operator.add)
204
+ r = await acc.accumulate()
205
+ assert r == 4
206
+
207
+
208
+ class RunIt(Actor):
209
+ @endpoint
210
+ async def run(self, fn):
211
+ return fn()
212
+
213
+
214
+ async def test_rank_size():
215
+ proc = await local_proc_mesh(gpus=2)
216
+ r = await proc.spawn("runit", RunIt)
217
+
218
+ acc = Accumulator(r.run, 0, operator.add)
219
+
220
+ assert 1 == await acc.accumulate(lambda: current_rank()["gpus"])
221
+ assert 4 == await acc.accumulate(lambda: current_size()["gpus"])
222
+
223
+
224
+ class TrainerActor(Actor):
225
+ def __init__(self):
226
+ super().__init__()
227
+ self.trainer = torch.nn.Linear(10, 10).to("cuda")
228
+ self.trainer.weight.data.zero_()
229
+
230
+ @endpoint
231
+ async def init(self, gen):
232
+ ranks = current_rank()
233
+ self.gen = gen.slice(**ranks)
234
+
235
+ @endpoint
236
+ async def exchange_metadata(self):
237
+ byte_tensor = self.trainer.weight.data.view(torch.uint8).flatten()
238
+ self.handle = RDMABuffer(byte_tensor)
239
+ await self.gen.attach_weight_buffer.call(self.handle)
240
+
241
+ @endpoint
242
+ async def weights_ready(self):
243
+ self.trainer.weight.data.add_(1.0)
244
+
245
+
246
+ class GeneratorActor(Actor):
247
+ def __init__(self):
248
+ super().__init__()
249
+ self.generator = torch.nn.Linear(10, 10).to("cuda")
250
+ self.step = 0
251
+
252
+ @endpoint
253
+ async def init(self, trainer):
254
+ ranks = current_rank()
255
+ self.trainer = trainer.slice(**ranks)
256
+
257
+ @endpoint
258
+ async def attach_weight_buffer(self, handle):
259
+ self.handle = handle
260
+
261
+ @endpoint
262
+ async def update_weights(self):
263
+ self.step += 1
264
+ byte_tensor = self.generator.weight.data.view(torch.uint8).flatten()
265
+ await self.handle.read_into(byte_tensor)
266
+ assert (
267
+ torch.sum(self.generator.weight.data) == self.step * 100
268
+ ), f"{torch.sum(self.generator.weight.data)=}, {self.step=}"
269
+
270
+
271
+ async def test_gpu_trainer_generator():
272
+ trainer_proc = await proc_mesh(gpus=1)
273
+ gen_proc = await proc_mesh(gpus=1)
274
+ trainer = await trainer_proc.spawn("trainer", TrainerActor)
275
+ generator = await gen_proc.spawn("gen", GeneratorActor)
276
+
277
+ await generator.init.call(trainer)
278
+ await trainer.init.call(generator)
279
+ await trainer.exchange_metadata.call()
280
+
281
+ for _ in range(3):
282
+ await trainer.weights_ready.call()
283
+ await generator.update_weights.call()
284
+
285
+
286
+ class SyncActor(Actor):
287
+ @endpoint
288
+ def sync_endpoint(self, a_counter: Counter):
289
+ return a_counter.value.choose().get()
290
+
291
+
292
+ async def test_sync_actor():
293
+ proc = await local_proc_mesh(gpus=2)
294
+ a = await proc.spawn("actor", SyncActor)
295
+ c = await proc.spawn("counter", Counter, 5)
296
+ r = await a.sync_endpoint.choose(c)
297
+ assert r == 5
298
+
299
+
300
+ def test_gpu_trainer_generator_sync() -> None:
301
+ trainer_proc = proc_mesh(gpus=1).get()
302
+ gen_proc = proc_mesh(gpus=1).get()
303
+ trainer = trainer_proc.spawn("trainer", TrainerActor).get()
304
+ generator = gen_proc.spawn("gen", GeneratorActor).get()
305
+
306
+ generator.init.call(trainer).get()
307
+ trainer.init.call(generator).get()
308
+ trainer.exchange_metadata.call().get()
309
+
310
+ for _ in range(3):
311
+ trainer.weights_ready.call().get()
312
+ generator.update_weights.call().get()
313
+
314
+
315
+ def test_sync_actor_sync_client():
316
+ proc = local_proc_mesh(gpus=2).get()
317
+ a = proc.spawn("actor", SyncActor).get()
318
+ c = proc.spawn("counter", Counter, 5).get()
319
+ r = a.sync_endpoint.choose(c).get()
320
+ assert r == 5
321
+
322
+
323
+ def test_rank_size_sync() -> None:
324
+ proc = local_proc_mesh(gpus=2).get()
325
+ r = proc.spawn("runit", RunIt).get()
326
+
327
+ acc = Accumulator(r.run, 0, operator.add)
328
+ assert 1 == acc.accumulate(lambda: current_rank()["gpus"]).get()
329
+ assert 4 == acc.accumulate(lambda: current_size()["gpus"]).get()
330
+
331
+
332
+ def test_accumulate_sync() -> None:
333
+ proc = local_proc_mesh(gpus=2).get()
334
+ counter = proc.spawn("counter", Counter, 1).get()
335
+ counter.incr.broadcast()
336
+ acc = Accumulator(counter.value, 0, operator.add)
337
+ r = acc.accumulate().get()
338
+ assert r == 4
339
+
340
+
341
+ class CastToCounter(Actor):
342
+ @endpoint
343
+ def doit(self, c: Counter):
344
+ return list(c.value.call().get())
345
+
346
+
347
+ def test_value_mesh() -> None:
348
+ proc = local_proc_mesh(gpus=2).get()
349
+ counter = proc.spawn("counter", Counter, 0).get()
350
+ counter.slice(hosts=0, gpus=1).incr.broadcast()
351
+ x = counter.value.call().get()
352
+ assert 0 == x.item(hosts=0, gpus=0)
353
+ assert 1 == x.item(hosts=0, gpus=1)
354
+ assert 1 == x.slice(hosts=0, gpus=1).item()
355
+ n = proc.spawn("ctc", CastToCounter).get()
356
+ assert list(x) == n.slice(gpus=0).doit.call_one(counter).get()
357
+
358
+
359
+ def test_rust_binding_modules_correct() -> None:
360
+ import monarch._rust_bindings as bindings
361
+
362
+ def check(module, path):
363
+ for name, value in module.__dict__.items():
364
+ if name.startswith("__"):
365
+ continue
366
+ if isinstance(value, ModuleType):
367
+ check(value, f"{path}.{name}")
368
+ elif hasattr(value, "__module__"):
369
+ assert value.__name__ == name
370
+ assert value.__module__ == path
371
+
372
+ check(bindings, "monarch._rust_bindings")