torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,121 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from unittest import main, TestCase
9
+
10
+ import torch
11
+ from monarch.gradient._gradient_generator import GradientGenerator
12
+ from monarch.gradient_generator import gradient_execution_order
13
+
14
+
15
+ class TestGradIter(TestCase):
16
+ def checkEqual(self, r, r2):
17
+ self.assertEqual(len(r), len(r2))
18
+ for i, i2 in zip(r, r2):
19
+ self.assertTrue((i is None and i2 is None) or torch.allclose(i, i2))
20
+
21
+ def test_simple(self):
22
+ t = torch.rand(2, requires_grad=True)
23
+ t2 = torch.rand(2, requires_grad=True)
24
+
25
+ _ = t + t2
26
+ a, b = torch.std_mean(t + t2)
27
+
28
+ r2 = torch.autograd.grad([a, b], [t2, t], retain_graph=True)
29
+ r = list(GradientGenerator([a, b], [t2, t]))
30
+ print(a, b)
31
+ print(a.grad_fn, b.grad_fn)
32
+
33
+ print(r)
34
+ self.checkEqual(r, r2)
35
+
36
+ def test_pipeline_like(self):
37
+ t = torch.rand(3, 3, requires_grad=True)
38
+
39
+ w1 = torch.rand(3, 2, requires_grad=True)
40
+ w2 = torch.rand(3, 2, requires_grad=True)
41
+ w3 = torch.rand(3, 2, requires_grad=True)
42
+
43
+ u = torch.rand(3, 2, requires_grad=True)
44
+
45
+ _ = u * u
46
+
47
+ w4 = torch.rand(2, 3, requires_grad=True)
48
+ w5 = torch.rand(2, 3, requires_grad=True)
49
+ w6 = torch.rand(2, 3, requires_grad=True)
50
+
51
+ from torch.nn.functional import relu
52
+
53
+ a = relu(t @ (w1 @ w4))
54
+ a = relu(a @ (w2 @ w5))
55
+ a = relu(a @ (w3 @ w6))
56
+
57
+ std, mean = torch.std_mean(a)
58
+ loss = std + std
59
+
60
+ cgrads = torch.autograd.grad(
61
+ [loss], [t, w3, w6, u, w2, w5], allow_unused=True, retain_graph=True
62
+ )
63
+ gi = GradientGenerator([loss], [t, w3, w6, u, w2, w5])
64
+ grads = [*gi]
65
+ self.checkEqual(grads, cgrads)
66
+
67
+ def test_tree(self):
68
+ t = torch.rand(3, 3, requires_grad=True)
69
+
70
+ t2 = t + t
71
+ t3 = t * t
72
+ t4 = t / t
73
+ t5 = t - t
74
+
75
+ t6 = t2 * t3
76
+ t7 = t4 * t5
77
+ t8 = t2 * t4
78
+ t9 = t3 * t5
79
+ t10 = t6 + t7 + t8 + t9
80
+
81
+ t11 = t10.sum()
82
+
83
+ cgrads = torch.autograd.grad([t11], [t2, t], retain_graph=True)
84
+ gi = GradientGenerator([t11], [t2, t])
85
+ grads = [*gi]
86
+ self.checkEqual(grads, cgrads)
87
+
88
+ def test_broadcast(self):
89
+ t = torch.rand(3, 3, requires_grad=True)
90
+ t2 = torch.rand(3, requires_grad=True)
91
+ t3 = t2 / t2
92
+
93
+ r = (t * t3).sum()
94
+ cgrads = torch.autograd.grad([r], [t, t2], retain_graph=True)
95
+ gi = GradientGenerator([r], [t, t2])
96
+ grads = [*gi]
97
+ self.checkEqual(grads, cgrads)
98
+
99
+ def test_grad_order(self):
100
+ t = torch.rand(3, 3, requires_grad=True)
101
+ w1 = torch.rand(3, 3, requires_grad=True)
102
+ w2 = torch.rand(3, 3, requires_grad=True)
103
+ w3 = torch.rand(3, 3, requires_grad=True)
104
+
105
+ u = torch.rand(3, 2, requires_grad=True)
106
+ _ = u * u
107
+ from torch.nn.functional import relu
108
+
109
+ a = relu(t @ w1)
110
+ a = relu(a @ w2)
111
+ a = relu(a @ w3)
112
+
113
+ std, mean = torch.std_mean(a)
114
+ loss = std + std
115
+
116
+ order = gradient_execution_order([loss], [w2, w3, w1, a])
117
+ self.assertEqual(order, [3, 1, 0, 2])
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
@@ -0,0 +1,74 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ from unittest import main, TestCase
9
+
10
+ import pytest
11
+ import torch
12
+ import monarch.common.mock_cuda # usort: skip
13
+
14
+
15
+ def simple_forward_backward(device: str) -> None:
16
+ torch.manual_seed(123)
17
+ m = torch.nn.Sequential(torch.nn.Linear(3, 3), torch.nn.ReLU()).to(device)
18
+ x = torch.rand(10, 3).to(device)
19
+ y = m(x)
20
+ loss_fn = torch.nn.CrossEntropyLoss()
21
+ loss = loss_fn(y, torch.randint(3, (10,)).to(device))
22
+ # Under the hood, enabling/disabling CUDA mocking is done with a thread-local
23
+ # flag. By default, backward() executes ops on a different thread than the one
24
+ # we enabled mocking on, which would lead to an invalid memory access. So we need
25
+ # to disable multithreading for backward.
26
+ with torch.autograd.set_multithreading_enabled(False):
27
+ loss.backward()
28
+ # pyre-ignore: Incompatible return type [7]: Expected `None` but got `Tuple[typing.Any, Union[None, Tensor, Module], Union[None, Tensor, Module]]`.
29
+ return y, m[0].weight.grad, m[0].bias.grad
30
+
31
+
32
+ # Mock cuda depends on initialization load order
33
+ # For OSS, run this test separately until it can be run in a subprocess.
34
+ @pytest.mark.oss_skip
35
+ class TestMockCuda(TestCase):
36
+ def setUp(self) -> None:
37
+ return super().setUp()
38
+
39
+ def test_output_is_garbage(self):
40
+ with monarch.common.mock_cuda.mock_cuda_guard():
41
+ x = torch.arange(9, device="cuda", dtype=torch.float32).reshape(3, 3)
42
+ y = 2 * torch.eye(3, device="cuda")
43
+ true_output = torch.tensor(
44
+ [[0, 2, 4], [6, 8, 10], [12, 14, 16]], dtype=torch.float32
45
+ )
46
+ self.assertFalse(torch.equal((x @ y).cpu(), true_output))
47
+
48
+ def test_simple_forward_backward(self):
49
+ # This test just makes sure that the forward and backward pass work
50
+ # and don't crash.
51
+ simple_forward_backward("cuda")
52
+
53
+ def test_turn_mock_on_and_off(self):
54
+ cpu_y, cpu_dw, cpu_db = simple_forward_backward("cpu")
55
+
56
+ real_y, real_dw, real_db = simple_forward_backward("cuda")
57
+ self.assertTrue(torch.allclose(cpu_y, real_y.cpu()))
58
+ self.assertTrue(torch.allclose(cpu_dw, real_dw.cpu()))
59
+ self.assertTrue(torch.allclose(cpu_db, real_db.cpu()))
60
+
61
+ with monarch.common.mock_cuda.mock_cuda_guard():
62
+ mocked_y, mocked_dw, mocked_db = simple_forward_backward("cuda")
63
+ self.assertFalse(torch.allclose(cpu_y, mocked_y.cpu()))
64
+ self.assertFalse(torch.allclose(cpu_dw, mocked_dw.cpu()))
65
+ self.assertFalse(torch.allclose(cpu_db, mocked_db.cpu()))
66
+
67
+ real_y, real_dw, real_db = simple_forward_backward("cuda")
68
+ self.assertTrue(torch.allclose(cpu_y, real_y.cpu()))
69
+ self.assertTrue(torch.allclose(cpu_dw, real_dw.cpu()))
70
+ self.assertTrue(torch.allclose(cpu_db, real_db.cpu()))
71
+
72
+
73
+ if __name__ == "__main__":
74
+ main()
@@ -0,0 +1,110 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-unsafe
8
+ import sys
9
+ import traceback
10
+ from contextlib import contextmanager
11
+ from typing import Generator
12
+
13
+ import pytest
14
+
15
+ import torch
16
+
17
+ from monarch import DeviceMesh, fetch_shard, remote, rust_local_mesh
18
+ from monarch._rust_bindings.monarch_extension.client import ( # @manual=//monarch/monarch_extension:monarch_extension
19
+ ClientActor,
20
+ DebuggerMessage as ClientDebuggerMessage,
21
+ )
22
+
23
+ from monarch._rust_bindings.monarch_extension.debugger import (
24
+ DebuggerMessage as PdbDebuggerMessage,
25
+ get_bytes_from_write_action,
26
+ )
27
+ from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction
28
+ from monarch.rust_local_mesh import LoggingLocation, SocketType
29
+ from monarch_supervisor.logging import fix_exception_lines
30
+
31
+
32
+ def custom_excepthook(exc_type, exc_value, exc_traceback):
33
+ tb_lines = fix_exception_lines(
34
+ traceback.format_exception(exc_type, exc_value, exc_traceback)
35
+ )
36
+ print("\n".join(tb_lines), file=sys.stderr)
37
+
38
+
39
+ sys.excepthook = custom_excepthook
40
+
41
+
42
+ @contextmanager
43
+ def local_mesh(
44
+ hosts: int = 1, gpu_per_host: int = 2, activate: bool = True
45
+ ) -> Generator[DeviceMesh, None, None]:
46
+ with rust_local_mesh.local_mesh(
47
+ hosts=hosts,
48
+ gpus_per_host=gpu_per_host,
49
+ socket_type=SocketType.UNIX,
50
+ logging_location=LoggingLocation.DEFAULT,
51
+ ) as dm:
52
+ try:
53
+ if activate:
54
+ with dm.activate():
55
+ yield dm
56
+ else:
57
+ yield dm
58
+ dm.exit()
59
+ except Exception:
60
+ dm.client._shutdown = True
61
+ raise
62
+
63
+
64
+ remote_test_pdb_actor = remote(
65
+ "monarch.worker._testing_function.test_pdb_actor",
66
+ propagate=lambda: torch.zeros(1),
67
+ )
68
+
69
+
70
+ @pytest.mark.skipif(
71
+ torch.cuda.device_count() < 2,
72
+ reason="Not enough GPUs, this test requires at least 2 GPUs",
73
+ )
74
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
75
+ # out is not counted as a failure, so we set a more restrictive timeout to
76
+ # ensure we see a hard failure in CI.
77
+ @pytest.mark.timeout(120)
78
+ class TestPdbActor:
79
+ def test_pdb_actor(self):
80
+ with local_mesh(1, 1) as dm:
81
+ with dm.activate():
82
+ client = dm.client.inner._actor
83
+ assert isinstance(client, ClientActor)
84
+ fut = fetch_shard(remote_test_pdb_actor())
85
+ msg = client.get_next_message(timeout_msec=None)
86
+ assert isinstance(msg, ClientDebuggerMessage)
87
+ assert isinstance(msg.action, DebuggerAction.Paused)
88
+ client.send(
89
+ msg.debugger_actor_id,
90
+ PdbDebuggerMessage(action=DebuggerAction.Attach()).serialize(),
91
+ )
92
+ msg = client.get_next_message(timeout_msec=None)
93
+ assert isinstance(msg, ClientDebuggerMessage)
94
+ assert isinstance(msg.action, DebuggerAction.Read)
95
+ assert msg.action.requested_size == 4
96
+ client.send(
97
+ msg.debugger_actor_id,
98
+ PdbDebuggerMessage(
99
+ action=DebuggerAction.Write(b"1234")
100
+ ).serialize(),
101
+ )
102
+ msg = client.get_next_message(timeout_msec=None)
103
+ assert isinstance(msg, ClientDebuggerMessage)
104
+ assert isinstance(msg.action, DebuggerAction.Write)
105
+ assert get_bytes_from_write_action(msg.action) == b"5678"
106
+ client.send(
107
+ msg.debugger_actor_id,
108
+ PdbDebuggerMessage(action=DebuggerAction.Detach()).serialize(),
109
+ )
110
+ fut.result()