torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. monarch/__init__.py +189 -0
  2. monarch/_monarch/__init__.py +5 -0
  3. monarch/_monarch/hyperactor/__init__.py +58 -0
  4. monarch/_monarch/selection/__init__.py +13 -0
  5. monarch/_monarch/worker/__init__.py +0 -0
  6. monarch/_monarch/worker/debugger.py +117 -0
  7. monarch/_monarch/worker/logging.py +107 -0
  8. monarch/_rust_bindings.so +0 -0
  9. monarch/_testing.py +230 -0
  10. monarch/actor_mesh.py +761 -0
  11. monarch/allocator.py +220 -0
  12. monarch/bootstrap_main.py +59 -0
  13. monarch/builtins/__init__.py +14 -0
  14. monarch/builtins/log.py +22 -0
  15. monarch/builtins/random.py +68 -0
  16. monarch/cached_remote_function.py +257 -0
  17. monarch/code_sync.py +10 -0
  18. monarch/common/_C.pyi +11 -0
  19. monarch/common/_C.so +0 -0
  20. monarch/common/__init__.py +0 -0
  21. monarch/common/_coalescing.py +308 -0
  22. monarch/common/_device_utils.py +18 -0
  23. monarch/common/_tensor_to_table.py +172 -0
  24. monarch/common/base_tensor.py +28 -0
  25. monarch/common/borrows.py +143 -0
  26. monarch/common/client.py +690 -0
  27. monarch/common/constants.py +10 -0
  28. monarch/common/context_manager.py +40 -0
  29. monarch/common/controller_api.py +104 -0
  30. monarch/common/device_mesh.py +417 -0
  31. monarch/common/fake.py +55 -0
  32. monarch/common/function.py +160 -0
  33. monarch/common/function_caching.py +164 -0
  34. monarch/common/future.py +168 -0
  35. monarch/common/invocation.py +125 -0
  36. monarch/common/mast.py +221 -0
  37. monarch/common/messages.py +573 -0
  38. monarch/common/mock_cuda.py +41 -0
  39. monarch/common/opaque_ref.py +98 -0
  40. monarch/common/pickle_flatten.py +48 -0
  41. monarch/common/pipe.py +152 -0
  42. monarch/common/process_group.py +55 -0
  43. monarch/common/recording.py +127 -0
  44. monarch/common/reference.py +33 -0
  45. monarch/common/remote.py +297 -0
  46. monarch/common/selection.py +9 -0
  47. monarch/common/shape.py +229 -0
  48. monarch/common/stream.py +114 -0
  49. monarch/common/tensor.py +814 -0
  50. monarch/common/tensor_factory.py +31 -0
  51. monarch/common/tree.py +73 -0
  52. monarch/controller/__init__.py +7 -0
  53. monarch/controller/backend.py +223 -0
  54. monarch/controller/controller.py +223 -0
  55. monarch/controller/debugger.py +47 -0
  56. monarch/controller/history.py +90 -0
  57. monarch/controller/rust_backend/__init__.py +7 -0
  58. monarch/controller/rust_backend/controller.py +245 -0
  59. monarch/debugger.py +379 -0
  60. monarch/fetch.py +55 -0
  61. monarch/future.py +76 -0
  62. monarch/gradient/__init__.py +11 -0
  63. monarch/gradient/_gradient_generator.pyi +22 -0
  64. monarch/gradient/_gradient_generator.so +0 -0
  65. monarch/gradient_generator.py +185 -0
  66. monarch/memory.py +43 -0
  67. monarch/mesh_controller.py +271 -0
  68. monarch/monarch_controller +0 -0
  69. monarch/notebook.py +761 -0
  70. monarch/opaque_module.py +235 -0
  71. monarch/opaque_object.py +88 -0
  72. monarch/parallel/__init__.py +9 -0
  73. monarch/parallel/pipelining/__init__.py +7 -0
  74. monarch/parallel/pipelining/runtime.py +847 -0
  75. monarch/parallel/pipelining/schedule_ir.py +692 -0
  76. monarch/parallel/pipelining/scheduler.py +249 -0
  77. monarch/pdb_wrapper.py +135 -0
  78. monarch/proc_mesh.py +299 -0
  79. monarch/profiler.py +160 -0
  80. monarch/python_local_mesh.py +107 -0
  81. monarch/random.py +61 -0
  82. monarch/rdma.py +162 -0
  83. monarch/remote_class.py +114 -0
  84. monarch/rust_backend_mesh.py +280 -0
  85. monarch/rust_local_mesh.py +1402 -0
  86. monarch/sim_mesh.py +359 -0
  87. monarch/simulator/__init__.py +7 -0
  88. monarch/simulator/command_history.py +424 -0
  89. monarch/simulator/config.py +21 -0
  90. monarch/simulator/interface.py +59 -0
  91. monarch/simulator/ir.py +770 -0
  92. monarch/simulator/mock_controller.py +214 -0
  93. monarch/simulator/profiling.py +424 -0
  94. monarch/simulator/simulator.py +1052 -0
  95. monarch/simulator/task.py +255 -0
  96. monarch/simulator/tensor.py +373 -0
  97. monarch/simulator/trace.py +395 -0
  98. monarch/simulator/utils.py +41 -0
  99. monarch/simulator/worker.py +389 -0
  100. monarch/telemetry.py +19 -0
  101. monarch/tensor_worker_main.py +260 -0
  102. monarch/tensorboard.py +84 -0
  103. monarch/timer/__init__.py +21 -0
  104. monarch/timer/example_monarch.py +78 -0
  105. monarch/timer/example_spmd.py +55 -0
  106. monarch/timer/execution_timer.py +199 -0
  107. monarch/timer/execution_timer_test.py +131 -0
  108. monarch/tools/__init__.py +7 -0
  109. monarch/tools/cli.py +167 -0
  110. monarch/tools/commands.py +251 -0
  111. monarch/tools/components/__init__.py +7 -0
  112. monarch/tools/components/hyperactor.py +58 -0
  113. monarch/tools/config/__init__.py +20 -0
  114. monarch/tools/config/defaults.py +54 -0
  115. monarch/tools/mesh_spec.py +165 -0
  116. monarch/tools/network.py +69 -0
  117. monarch/worker/__init__.py +7 -0
  118. monarch/worker/_testing_function.py +481 -0
  119. monarch/worker/compiled_block.py +270 -0
  120. monarch/worker/debugger.py +125 -0
  121. monarch/worker/lines.py +47 -0
  122. monarch/worker/monitor.py +53 -0
  123. monarch/worker/worker.py +1191 -0
  124. monarch/world_mesh.py +34 -0
  125. monarch_supervisor/__init__.py +1044 -0
  126. monarch_supervisor/_testing.py +44 -0
  127. monarch_supervisor/function_call.py +30 -0
  128. monarch_supervisor/host.py +386 -0
  129. monarch_supervisor/launchers.py +145 -0
  130. monarch_supervisor/log_pstree.py +48 -0
  131. monarch_supervisor/logging.py +103 -0
  132. monarch_supervisor/python_executable.py +42 -0
  133. tests/__init__.py +0 -0
  134. tests/dispatch_bench.py +124 -0
  135. tests/dispatch_bench_helper.py +25 -0
  136. tests/error_test_binary.py +180 -0
  137. tests/simulator/__init__.py +0 -0
  138. tests/simulator/test_profiling.py +136 -0
  139. tests/simulator/test_simulator.py +411 -0
  140. tests/simulator/test_task.py +64 -0
  141. tests/simulator/test_worker.py +102 -0
  142. tests/sleep_binary.py +35 -0
  143. tests/test_actor_error.py +240 -0
  144. tests/test_alloc.py +25 -0
  145. tests/test_allocator.py +365 -0
  146. tests/test_coalescing.py +492 -0
  147. tests/test_controller.py +845 -0
  148. tests/test_device_mesh.py +132 -0
  149. tests/test_fault_tolerance.py +398 -0
  150. tests/test_future.py +94 -0
  151. tests/test_grad_generator.py +121 -0
  152. tests/test_mock_cuda.py +74 -0
  153. tests/test_pdb_actor.py +110 -0
  154. tests/test_python_actors.py +736 -0
  155. tests/test_remote_functions.py +1271 -0
  156. tests/test_rust_backend.py +217 -0
  157. tests/test_signal_safe_block_on.py +103 -0
  158. tests/test_sim_backend.py +54 -0
  159. tests/test_tensor_engine.py +52 -0
  160. torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
  161. torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
  162. torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
  163. torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
  164. torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
  165. torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,132 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import pytest
9
+
10
+ from monarch import DeviceMesh, NDSlice
11
+ from monarch.common.client import Client
12
+ from monarch.simulator.mock_controller import MockController
13
+
14
+
15
+ class TestDeviceMesh:
16
+ def test_mesh_index(self) -> None:
17
+ fake_processes = NDSlice(offset=0, sizes=[2, 3, 4], strides=[12, 4, 1])
18
+ ctrl = MockController(1, False)
19
+ client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
20
+ dm = DeviceMesh(client, fake_processes, ("a", "b", "c"))
21
+ assert 0 == dm(a=0, b=0, c=0).processes[0]
22
+ x = dm(a=0, b=0)
23
+ assert x.processes[:] == fake_processes[0:4]
24
+ assert x.names == ("c",)
25
+ assert x.processes.sizes[0] == 4
26
+ x = dm(c=slice(None, None, 2))
27
+ assert x.processes[:] == fake_processes[::2]
28
+ x = dm(b=2, c=3)
29
+ assert x.processes[:] == (11, 23)
30
+ client.shutdown()
31
+
32
+ def test_mesh_reshape(self) -> None:
33
+ fake_processes = NDSlice(offset=0, sizes=[60, 24], strides=[24, 1])
34
+ ctrl = MockController(1, False)
35
+ client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
36
+ dm = DeviceMesh(client, fake_processes, ("host", "gpu"))
37
+ dm2 = dm.split(host=("dp", "pp"), gpu=("tp",), pp=4)
38
+ assert dm2.names == ("dp", "pp", "tp")
39
+ assert dm2.processes.sizes == [15, 4, 24]
40
+ assert dm2.processes.strides == [4 * 24, 24, 1]
41
+
42
+ dm3 = dm.rename(host="dp", gpu="tp")
43
+ assert dm3.names == ("dp", "tp")
44
+ assert dm.processes.strides == dm3.processes.strides
45
+ dm4 = dm.split(host=("dp", "pp"), gpu=("tp",), dp=4)
46
+ assert dm4.processes.sizes == [4, 15, 24]
47
+ dm5 = dm.split(host=("dp", "pp"), dp=60)
48
+ assert dm5.processes.sizes == [60, 1, 24]
49
+ dm6 = dm.split(host=("dp", "pp"), pp=60)
50
+ assert dm6.processes.sizes == [1, 60, 24]
51
+
52
+ with pytest.raises(ValueError, match="Cannot infer size"):
53
+ dm2 = dm.split(host=("dp", "pp"))
54
+
55
+ with pytest.raises(ValueError, match="unused size constraints"):
56
+ dm2 = dm.split(host=("dp", "pp"), pp=4, ddp=3)
57
+
58
+ dm2 = dm.rename(host="dp")
59
+ assert dm2.names == ("dp", "gpu")
60
+
61
+ with pytest.raises(ValueError, match="Duplicate dimension name"):
62
+ dm2 = dm.split(host=("dp", "pp"), gpu=("pp",), dp=3)
63
+
64
+ with pytest.raises(ValueError, match="evenly divided"):
65
+ dm2 = dm.split(host=("dp", "pp"), dp=7)
66
+
67
+ client.shutdown()
68
+
69
+ def test_flatten(self) -> None:
70
+ fake_processes = NDSlice(offset=0, sizes=[60, 24], strides=[24, 1])
71
+ ctrl = MockController(1, False)
72
+ client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
73
+ dm = DeviceMesh(client, fake_processes, ("host", "gpu"))
74
+ dm2 = dm.flatten("gpu")
75
+ assert dm2.names == ("gpu",)
76
+ assert dm2.processes.sizes == [60 * 24]
77
+ assert dm2.processes.strides == [1]
78
+ client.shutdown()
79
+
80
+ good_cases = [
81
+ NDSlice(offset=0, sizes=[100], strides=[1]),
82
+ NDSlice(offset=100, sizes=[8, 4, 2], strides=[8, 2, 1]),
83
+ NDSlice(offset=1, sizes=[4, 2], strides=[2, 1]),
84
+ ]
85
+ for slice in good_cases:
86
+ ctrl = MockController(1, False)
87
+ client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
88
+ dm = DeviceMesh(
89
+ client, slice, tuple(f"dim{d}" for d in range(len(slice.sizes)))
90
+ )
91
+ dm2 = dm.flatten("outer")
92
+ assert dm2.names == ("outer",)
93
+ assert dm2.processes.strides == [1]
94
+ assert list(slice) == list(dm2.processes)
95
+ client.shutdown()
96
+
97
+ # Test some bad ones (sparse slices).
98
+ bad_cases = [
99
+ NDSlice(offset=0, sizes=[100], strides=[2]),
100
+ NDSlice(offset=0, sizes=[64, 32], strides=[64, 1]),
101
+ ]
102
+ for slice in bad_cases:
103
+ with pytest.raises(ValueError, match="cannot flatten sparse mesh"):
104
+ ctrl = MockController(1, False)
105
+ client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
106
+ dm = DeviceMesh(
107
+ client, slice, tuple(f"dim{d}" for d in range(len(slice.sizes)))
108
+ )
109
+ dm.flatten("bad_dim")
110
+ client.shutdown()
111
+
112
+ def test_worker_mesh_init(self) -> None:
113
+ from monarch.worker.worker import DeviceMesh as WorkerDeviceMesh
114
+
115
+ processes = NDSlice(offset=0, sizes=[3, 4], strides=[4, 1])
116
+ wdm = WorkerDeviceMesh(0, ("a", "b"), processes, rank=1)
117
+ a, b = wdm.dims["a"], wdm.dims["b"]
118
+ assert b.members == [0, 1, 2, 3]
119
+ assert b.rank == 1
120
+
121
+ assert a.members == [1, 5, 9]
122
+ assert a.rank == 0
123
+
124
+ wdm = WorkerDeviceMesh(0, ("a", "b"), processes, rank=6)
125
+ a, b = wdm.dims["a"], wdm.dims["b"]
126
+ assert b.members == [4, 5, 6, 7]
127
+ assert b.rank == 2
128
+ assert a.members == [2, 6, 10]
129
+ assert a.rank == 1
130
+
131
+ processes = NDSlice(offset=0, sizes=[3, 4, 2], strides=[8, 2, 1])
132
+ wdm = WorkerDeviceMesh(0, ("a", "b", "c"), processes, rank=10)
@@ -0,0 +1,398 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import random
8
+ import time
9
+ from typing import List, Optional
10
+
11
+ import pytest
12
+ import torch
13
+
14
+ try:
15
+ from later.unittest import TestCase
16
+ except ModuleNotFoundError:
17
+ from unittest import TestCase
18
+
19
+ from monarch import fetch_shard, no_mesh, remote
20
+ from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus
21
+ from monarch.common.invocation import DeviceException, RemoteException
22
+ from monarch.rust_backend_mesh import MeshWorld, PoolDeviceMeshProvider
23
+ from monarch.rust_local_mesh import (
24
+ Bootstrap,
25
+ local_mesh_provider,
26
+ local_meshes_and_bootstraps,
27
+ LoggingLocation,
28
+ SocketType,
29
+ SupervisionParams,
30
+ )
31
+
32
+
33
+ def _do_bogus_tensor_work(
34
+ x: torch.Tensor, y: torch.Tensor, fail_rank: Optional[int] = None
35
+ ) -> torch.Tensor:
36
+ return x + y # real function actually does x @ y
37
+
38
+
39
+ do_bogus_tensor_work = remote(
40
+ "monarch.worker._testing_function.do_bogus_tensor_work",
41
+ propagate=_do_bogus_tensor_work,
42
+ )
43
+
44
+
45
+ def mesh_provider(
46
+ meshes: int = 2,
47
+ hosts_per_mesh: int = 1,
48
+ gpus_per_host: int = 1,
49
+ # pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type.
50
+ ) -> tuple[PoolDeviceMeshProvider, Bootstrap]:
51
+ return local_mesh_provider(
52
+ meshes=meshes,
53
+ hosts_per_mesh=hosts_per_mesh,
54
+ gpus_per_host=gpus_per_host,
55
+ socket_type=SocketType.UNIX,
56
+ logging_location=LoggingLocation.DEFAULT,
57
+ supervision_params=SupervisionParams(
58
+ update_timeout_in_sec=10, # Fail fast
59
+ query_interval_in_sec=1,
60
+ update_interval_in_sec=1,
61
+ ),
62
+ auto_epoch=True,
63
+ )
64
+
65
+
66
+ def local_meshes(
67
+ meshes: int = 2,
68
+ hosts_per_mesh: int = 1,
69
+ gpus_per_host: int = 1,
70
+ ) -> tuple[list[DeviceMesh], Bootstrap]:
71
+ return local_meshes_and_bootstraps(
72
+ meshes=meshes,
73
+ hosts_per_mesh=hosts_per_mesh,
74
+ gpus_per_host=gpus_per_host,
75
+ socket_type=SocketType.UNIX,
76
+ logging_location=LoggingLocation.DEFAULT,
77
+ supervision_params=SupervisionParams(
78
+ update_timeout_in_sec=10, # Fail fast
79
+ query_interval_in_sec=1,
80
+ update_interval_in_sec=1,
81
+ ),
82
+ auto_epoch=True,
83
+ )
84
+
85
+
86
+ # Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
87
+ # out is not counted as a failure, so we set a more restrictive timeout to
88
+ # ensure we see a hard failure in CI.
89
+ # The timeout is set to 250s as the failover takes longer than other tests.
90
+ @pytest.mark.timeout(250)
91
+ class TestFaultTolerance(TestCase):
92
+ def test_mesh_provider(self) -> None:
93
+ # Create multiple meshes using mesh provider
94
+ replicas = 4
95
+ provider, bootstrap = mesh_provider(meshes=replicas)
96
+ meshes: list[DeviceMesh] = []
97
+ while len(meshes) < replicas:
98
+ dm = provider.new_mesh()
99
+ meshes.append(dm)
100
+
101
+ statuses = provider._root_client.world_status()
102
+ for _, status in statuses.items():
103
+ assert (
104
+ DeviceMeshStatus(status) != DeviceMeshStatus.UNHEALTHY
105
+ ), f"unexpected unhealthy mesh; world status: {statuses}"
106
+
107
+ # Check that all meshes are initially live
108
+ for mesh in meshes:
109
+ with mesh.activate():
110
+ t = torch.ones(1)
111
+ local_t = fetch_shard(t).result()
112
+ assert torch.equal(local_t, torch.ones(1))
113
+
114
+ # Simulate a failure by killing one of the processes
115
+ bootstrap.processes[-1].kill()
116
+
117
+ # Find unhealthy mesh
118
+ # Mix user and device errors
119
+ unhealthy_meshes = []
120
+ for mesh in meshes:
121
+ with mesh.activate():
122
+ # Send a call to trigger a failure
123
+ x = torch.rand(3, 4)
124
+ y = torch.rand(3, 4)
125
+ z = do_bogus_tensor_work(x, y)
126
+ try:
127
+ _ = fetch_shard(z).result()
128
+ except RemoteException:
129
+ pass
130
+ except DeviceException as e:
131
+ # Device error
132
+ unhealthy_meshes.append(mesh)
133
+ mesh.exit(e)
134
+
135
+ self.assertEqual(len(unhealthy_meshes), 1)
136
+
137
+ # World status will transition to unhealthy
138
+ has_unhealth = False
139
+ unhealthy_statuses = []
140
+ while not has_unhealth:
141
+ statuses = provider._root_client.world_status()
142
+ for _, status in statuses.items():
143
+ if DeviceMeshStatus(status) == DeviceMeshStatus.UNHEALTHY:
144
+ has_unhealth = True
145
+ unhealthy_statuses = statuses
146
+ break
147
+ time.sleep(1)
148
+
149
+ # Unhealthy worlds will be evicted
150
+ has_unhealth = True
151
+ healthy_statuses = []
152
+ while has_unhealth:
153
+ has_unhealth = False
154
+ statuses = provider._root_client.world_status()
155
+ healthy_statuses = statuses
156
+ for _, status in statuses.items():
157
+ if DeviceMeshStatus(status) == DeviceMeshStatus.UNHEALTHY:
158
+ has_unhealth = True
159
+ break
160
+ time.sleep(1)
161
+
162
+ # A worker world will be evicted
163
+ self.assertEqual(len(healthy_statuses), len(unhealthy_statuses) - 1)
164
+
165
+ def test_worker_supervision_failure(self) -> None:
166
+ meshes, bootstrap = local_meshes(meshes=1)
167
+ assert len(meshes) == 1
168
+ mesh = meshes[0]
169
+
170
+ # Check the mesh initially functional
171
+ with mesh.activate():
172
+ t = torch.ones(1)
173
+ local_t = fetch_shard(t).result()
174
+ assert torch.equal(local_t, torch.ones(1))
175
+
176
+ # Simulate a failure by killing one of the processes
177
+ bootstrap.processes[-1].kill()
178
+
179
+ # A device error will be raised
180
+ with mesh.activate():
181
+ t = torch.ones(1)
182
+ with self.assertRaisesRegex(DeviceException, r"crashed"):
183
+ local_t = fetch_shard(t).result()
184
+
185
+ def test_multi_mesh_failure_isolation(self) -> None:
186
+ replicas = 4
187
+ provider, bootstrap = mesh_provider(meshes=replicas)
188
+ meshes: list[DeviceMesh] = []
189
+ while len(meshes) < replicas:
190
+ dm = provider.new_mesh()
191
+ meshes.append(dm)
192
+
193
+ # Check the meshes initially functional
194
+ for mesh in meshes:
195
+ with mesh.activate():
196
+ t = torch.ones(1)
197
+ local_t = fetch_shard(t).result()
198
+ assert torch.equal(local_t, torch.ones(1))
199
+
200
+ initial_size = len(provider._root_client.world_status())
201
+
202
+ # Simulate a failure by killing one of the processes
203
+ bootstrap.processes[-1].kill()
204
+
205
+ # Mix user and device errors
206
+ healthy_meshes = []
207
+ unhealthy_meshes = []
208
+ for mesh in meshes:
209
+ with mesh.activate():
210
+ # Send a call to trigger a failure
211
+ x = torch.rand(3, 4)
212
+ y = torch.rand(3, 4)
213
+ z = do_bogus_tensor_work(x, y)
214
+ try:
215
+ _ = fetch_shard(z).result()
216
+ except RemoteException:
217
+ # User error
218
+ fetch_shard(x).result()
219
+ healthy_meshes.append(mesh)
220
+ except DeviceException as e:
221
+ # Device error
222
+ unhealthy_meshes.append(mesh)
223
+ mesh.exit(e)
224
+
225
+ self.assertEqual(len(healthy_meshes), replicas - 1)
226
+ self.assertEqual(len(unhealthy_meshes), 1)
227
+
228
+ while True:
229
+ size = len(provider._root_client.world_status())
230
+ if size == initial_size - 2:
231
+ break
232
+
233
+ # The healthy meshes should still be functional
234
+ for mesh in healthy_meshes:
235
+ with mesh.activate():
236
+ t = torch.ones(1)
237
+ local_t = fetch_shard(t).result()
238
+ assert torch.equal(local_t, torch.ones(1))
239
+
240
+ def test_out_of_order_receive(self) -> None:
241
+ meshes, _ = local_meshes(meshes=8)
242
+
243
+ # Check the meshes initially functional
244
+ ts = []
245
+ for i, mesh in enumerate(meshes):
246
+ with mesh.activate():
247
+ t = torch.ones(i + 1)
248
+ ts.append(t)
249
+
250
+ # Shuffle the meshes to makes sure the client is able to dispatch results in order
251
+ indices = list(range(len(meshes)))
252
+ shuffled_meshes = list(zip(indices, meshes, ts))
253
+ random.shuffle(shuffled_meshes)
254
+ for i, mesh, t in shuffled_meshes:
255
+ with mesh.activate():
256
+ local_t = fetch_shard(t).result()
257
+ assert torch.equal(local_t, torch.ones(i + 1))
258
+
259
+ def test_mesh_shrink_and_grow(self) -> None:
260
+ # Create multiple meshes using mesh provider
261
+ replicas = 4
262
+ provider, bootstrap = mesh_provider(meshes=replicas)
263
+ meshes: list[DeviceMesh] = []
264
+ while len(meshes) < replicas:
265
+ dm = provider.new_mesh()
266
+ meshes.append(dm)
267
+
268
+ worlds = len(provider._root_client.world_status())
269
+ assigned_meshes = provider._mesh_map
270
+ assert len(assigned_meshes) == replicas
271
+
272
+ # Happy path
273
+ for i, mesh in enumerate(meshes):
274
+ with mesh.activate():
275
+ t = torch.ones(i + 1)
276
+ local_t = fetch_shard(t).result()
277
+ assert torch.equal(local_t, torch.ones(i + 1))
278
+
279
+ # Kill a worker
280
+ mesh_to_kill: MeshWorld = list(bootstrap.mesh_worlds.keys())[1]
281
+ procs = bootstrap.mesh_worlds[mesh_to_kill]
282
+ assert len(procs) == 2
283
+ procs[-1].kill()
284
+
285
+ # The killed mesh will become unhealthy
286
+ healthy_meshes = []
287
+ unhealthy_meshes = []
288
+ for i, mesh in enumerate(meshes):
289
+ with mesh.activate():
290
+ try:
291
+ t = torch.ones(i + 1)
292
+ local_t = fetch_shard(t).result()
293
+ with no_mesh.activate():
294
+ assert torch.equal(local_t, torch.ones(i + 1))
295
+ healthy_meshes.append(mesh)
296
+ except DeviceException as e:
297
+ unhealthy_meshes.append(mesh)
298
+ mesh.exit(e)
299
+ assert len(healthy_meshes) == replicas - 1
300
+ assert len(unhealthy_meshes) == 1
301
+
302
+ # Restart the mesh
303
+ for proc in procs:
304
+ proc.kill()
305
+
306
+ # We should be able to acquire a new mesh without waiting for the old mesh to be evicted
307
+ (worker_world, controller_id) = mesh_to_kill
308
+ bootstrap.launch_mesh(controller_id=controller_id, worker_world=worker_world)
309
+
310
+ dm = provider.new_mesh()
311
+ healthy_meshes.append(dm)
312
+
313
+ # We could have 4 or 5 meshes depending on if the unhealthy mesh is evicted
314
+ assigned_meshes = provider._mesh_map
315
+ assert len(assigned_meshes) >= replicas
316
+
317
+ # We are happy again
318
+ assert len(healthy_meshes) == replicas
319
+ for i, mesh in enumerate(healthy_meshes):
320
+ with mesh.activate():
321
+ t = torch.ones(i + 1)
322
+ local_t = fetch_shard(t).result()
323
+ assert torch.equal(local_t, torch.ones(i + 1))
324
+
325
+ # Old world should be evicted and new world should be spawned. So we ended up with the same number of worlds.
326
+ while len((provider._root_client.world_status())) != worlds:
327
+ # We expect to evict both controller and worker worlds from the same mesh.
328
+ time.sleep(1)
329
+
330
+ # Eventually, we only have 4 healthy meshes
331
+ assigned_meshes = provider._mesh_map
332
+ while len(assigned_meshes) != replicas:
333
+ with self.assertRaisesRegex(
334
+ TimeoutError, r"Could not find a healthy world"
335
+ ):
336
+ _ = provider.new_mesh(timeout_in_sec=1)
337
+ assigned_meshes = provider._mesh_map
338
+ time.sleep(1)
339
+
340
+ def test_kill_controller(self) -> None:
341
+ # Create multiple meshes using mesh provider
342
+ replicas = 2
343
+ provider, bootstrap = mesh_provider(meshes=replicas)
344
+ meshes: list[DeviceMesh] = []
345
+ while len(meshes) < replicas:
346
+ dm = provider.new_mesh()
347
+ meshes.append(dm)
348
+
349
+ # Happy path
350
+ for i, mesh in enumerate(meshes):
351
+ with mesh.activate():
352
+ t = torch.ones(i + 1)
353
+ local_t = fetch_shard(t).result()
354
+ assert torch.equal(local_t, torch.ones(i + 1))
355
+
356
+ # Kill a controller
357
+ mesh_to_kill: MeshWorld = list(bootstrap.mesh_worlds.keys())[1]
358
+ procs = bootstrap.mesh_worlds[mesh_to_kill]
359
+ assert len(procs) == 2
360
+ procs[0].kill()
361
+
362
+ # We should be able to detect the failure
363
+ healthy_meshes = []
364
+ detected_failure = False
365
+ for i, mesh in enumerate(meshes):
366
+ with mesh.activate():
367
+ try:
368
+ t = torch.ones(i + 1)
369
+ local_t = fetch_shard(t).result()
370
+ with no_mesh.activate():
371
+ assert torch.equal(local_t, torch.ones(i + 1))
372
+ healthy_meshes.append(mesh)
373
+ except DeviceException:
374
+ detected_failure = True
375
+ assert len(healthy_meshes) == replicas - 1
376
+ assert detected_failure
377
+
378
+ def test_late_client_attaching(self) -> None:
379
+ provider, _ = mesh_provider(meshes=1)
380
+
381
+ # Wait for the meshes to be healthy
382
+ healthy_meshes = 0
383
+ while healthy_meshes < 2:
384
+ healthy_meshes = 0
385
+ statuses = provider._root_client.world_status()
386
+ for _, status in statuses.items():
387
+ if DeviceMeshStatus(status) == DeviceMeshStatus.LIVE:
388
+ healthy_meshes += 1
389
+ time.sleep(1)
390
+
391
+ # Sleep long enough to allow those "hidden messages" to be sent
392
+ time.sleep(15)
393
+
394
+ # Those "hidden messages" should not cause a trouble before a client is ready
395
+ mesh = provider.new_mesh()
396
+ with mesh.activate():
397
+ t = torch.ones(1)
398
+ fetch_shard(t).result()
tests/test_future.py ADDED
@@ -0,0 +1,94 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # pyre-strict
8
+ import time
9
+ from typing import Callable
10
+
11
+ import pytest
12
+ from monarch import Future, RemoteException
13
+ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
14
+ ActorId,
15
+ )
16
+ from monarch.common import future
17
+ from monarch.common.client import Client
18
+
19
+
20
+ class TestFuture:
21
+ def test_future(self, monkeypatch: pytest.MonkeyPatch) -> None:
22
+ the_time: int = 0
23
+ the_messages: list[tuple[int | float, Callable[[], None]]] = []
24
+
25
+ class MockClient(Client):
26
+ def __init__(self):
27
+ pass
28
+
29
+ def handle_next_message(self, timeout) -> bool:
30
+ nonlocal the_time
31
+ if not the_messages:
32
+ return False
33
+ time, action = the_messages[0]
34
+ if timeout is None or time <= the_time + timeout:
35
+ the_time = time
36
+ action()
37
+ the_messages.pop(0)
38
+ return True
39
+ else:
40
+ the_time += timeout
41
+ return False
42
+
43
+ def _request_status(self):
44
+ pass
45
+
46
+ client: Client = MockClient()
47
+
48
+ def mock_time() -> int:
49
+ return the_time
50
+
51
+ monkeypatch.setattr(time, "time", mock_time)
52
+ f = Future(client)
53
+ the_messages = [(1, lambda: f._set_result(4))]
54
+ assert not f.done()
55
+ with pytest.raises(TimeoutError):
56
+ f.result(timeout=0.5)
57
+ assert 4 == f.result(timeout=1)
58
+ assert f.exception() is None
59
+ assert f.done()
60
+ f = Future(client)
61
+ the_messages = [(1, lambda: None), (2, lambda: f._set_result(3))]
62
+ the_time = 0
63
+ assert 3 == f.result()
64
+ f = Future(client)
65
+ re = RemoteException(
66
+ 0, Exception(), None, [], [], ActorId.from_string("unknown[0].unknown[0]")
67
+ )
68
+
69
+ the_messages = [(1, lambda: None), (2, lambda: f._set_result(re))]
70
+ the_time = 0
71
+ assert f.exception() is not None
72
+
73
+ f = Future(client)
74
+ the_messages = [(0, lambda: None), (0.2, lambda: f._set_result(7))]
75
+ the_time = 0
76
+ assert 7 == f.result(timeout=0.3)
77
+
78
+ fs = []
79
+
80
+ def setup() -> None:
81
+ nonlocal fs, the_messages
82
+ fs = [Future(client) for _ in range(4)]
83
+
84
+ # To avoid closure binding gotcha.
85
+ def set_at_time(f: Future, time: int) -> tuple[int, Callable[[], None]]:
86
+ return (time, lambda: f._set_result(time))
87
+
88
+ the_messages = [set_at_time(f, time) for time, f in enumerate(fs)]
89
+
90
+ setup()
91
+ assert {f.result() for f in future.stream(fs, timeout=2)} == {0, 1, 2}
92
+
93
+ setup()
94
+ assert {f.result() for f in future.stream(fs, timeout=3)} == {0, 1, 2, 3}