torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
@@ -0,0 +1,132 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
import pytest
|
9
|
+
|
10
|
+
from monarch import DeviceMesh, NDSlice
|
11
|
+
from monarch.common.client import Client
|
12
|
+
from monarch.simulator.mock_controller import MockController
|
13
|
+
|
14
|
+
|
15
|
+
class TestDeviceMesh:
|
16
|
+
def test_mesh_index(self) -> None:
|
17
|
+
fake_processes = NDSlice(offset=0, sizes=[2, 3, 4], strides=[12, 4, 1])
|
18
|
+
ctrl = MockController(1, False)
|
19
|
+
client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
|
20
|
+
dm = DeviceMesh(client, fake_processes, ("a", "b", "c"))
|
21
|
+
assert 0 == dm(a=0, b=0, c=0).processes[0]
|
22
|
+
x = dm(a=0, b=0)
|
23
|
+
assert x.processes[:] == fake_processes[0:4]
|
24
|
+
assert x.names == ("c",)
|
25
|
+
assert x.processes.sizes[0] == 4
|
26
|
+
x = dm(c=slice(None, None, 2))
|
27
|
+
assert x.processes[:] == fake_processes[::2]
|
28
|
+
x = dm(b=2, c=3)
|
29
|
+
assert x.processes[:] == (11, 23)
|
30
|
+
client.shutdown()
|
31
|
+
|
32
|
+
def test_mesh_reshape(self) -> None:
|
33
|
+
fake_processes = NDSlice(offset=0, sizes=[60, 24], strides=[24, 1])
|
34
|
+
ctrl = MockController(1, False)
|
35
|
+
client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
|
36
|
+
dm = DeviceMesh(client, fake_processes, ("host", "gpu"))
|
37
|
+
dm2 = dm.split(host=("dp", "pp"), gpu=("tp",), pp=4)
|
38
|
+
assert dm2.names == ("dp", "pp", "tp")
|
39
|
+
assert dm2.processes.sizes == [15, 4, 24]
|
40
|
+
assert dm2.processes.strides == [4 * 24, 24, 1]
|
41
|
+
|
42
|
+
dm3 = dm.rename(host="dp", gpu="tp")
|
43
|
+
assert dm3.names == ("dp", "tp")
|
44
|
+
assert dm.processes.strides == dm3.processes.strides
|
45
|
+
dm4 = dm.split(host=("dp", "pp"), gpu=("tp",), dp=4)
|
46
|
+
assert dm4.processes.sizes == [4, 15, 24]
|
47
|
+
dm5 = dm.split(host=("dp", "pp"), dp=60)
|
48
|
+
assert dm5.processes.sizes == [60, 1, 24]
|
49
|
+
dm6 = dm.split(host=("dp", "pp"), pp=60)
|
50
|
+
assert dm6.processes.sizes == [1, 60, 24]
|
51
|
+
|
52
|
+
with pytest.raises(ValueError, match="Cannot infer size"):
|
53
|
+
dm2 = dm.split(host=("dp", "pp"))
|
54
|
+
|
55
|
+
with pytest.raises(ValueError, match="unused size constraints"):
|
56
|
+
dm2 = dm.split(host=("dp", "pp"), pp=4, ddp=3)
|
57
|
+
|
58
|
+
dm2 = dm.rename(host="dp")
|
59
|
+
assert dm2.names == ("dp", "gpu")
|
60
|
+
|
61
|
+
with pytest.raises(ValueError, match="Duplicate dimension name"):
|
62
|
+
dm2 = dm.split(host=("dp", "pp"), gpu=("pp",), dp=3)
|
63
|
+
|
64
|
+
with pytest.raises(ValueError, match="evenly divided"):
|
65
|
+
dm2 = dm.split(host=("dp", "pp"), dp=7)
|
66
|
+
|
67
|
+
client.shutdown()
|
68
|
+
|
69
|
+
def test_flatten(self) -> None:
|
70
|
+
fake_processes = NDSlice(offset=0, sizes=[60, 24], strides=[24, 1])
|
71
|
+
ctrl = MockController(1, False)
|
72
|
+
client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
|
73
|
+
dm = DeviceMesh(client, fake_processes, ("host", "gpu"))
|
74
|
+
dm2 = dm.flatten("gpu")
|
75
|
+
assert dm2.names == ("gpu",)
|
76
|
+
assert dm2.processes.sizes == [60 * 24]
|
77
|
+
assert dm2.processes.strides == [1]
|
78
|
+
client.shutdown()
|
79
|
+
|
80
|
+
good_cases = [
|
81
|
+
NDSlice(offset=0, sizes=[100], strides=[1]),
|
82
|
+
NDSlice(offset=100, sizes=[8, 4, 2], strides=[8, 2, 1]),
|
83
|
+
NDSlice(offset=1, sizes=[4, 2], strides=[2, 1]),
|
84
|
+
]
|
85
|
+
for slice in good_cases:
|
86
|
+
ctrl = MockController(1, False)
|
87
|
+
client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
|
88
|
+
dm = DeviceMesh(
|
89
|
+
client, slice, tuple(f"dim{d}" for d in range(len(slice.sizes)))
|
90
|
+
)
|
91
|
+
dm2 = dm.flatten("outer")
|
92
|
+
assert dm2.names == ("outer",)
|
93
|
+
assert dm2.processes.strides == [1]
|
94
|
+
assert list(slice) == list(dm2.processes)
|
95
|
+
client.shutdown()
|
96
|
+
|
97
|
+
# Test some bad ones (sparse slices).
|
98
|
+
bad_cases = [
|
99
|
+
NDSlice(offset=0, sizes=[100], strides=[2]),
|
100
|
+
NDSlice(offset=0, sizes=[64, 32], strides=[64, 1]),
|
101
|
+
]
|
102
|
+
for slice in bad_cases:
|
103
|
+
with pytest.raises(ValueError, match="cannot flatten sparse mesh"):
|
104
|
+
ctrl = MockController(1, False)
|
105
|
+
client = Client(ctrl, ctrl.world_size, ctrl.gpu_per_host)
|
106
|
+
dm = DeviceMesh(
|
107
|
+
client, slice, tuple(f"dim{d}" for d in range(len(slice.sizes)))
|
108
|
+
)
|
109
|
+
dm.flatten("bad_dim")
|
110
|
+
client.shutdown()
|
111
|
+
|
112
|
+
def test_worker_mesh_init(self) -> None:
|
113
|
+
from monarch.worker.worker import DeviceMesh as WorkerDeviceMesh
|
114
|
+
|
115
|
+
processes = NDSlice(offset=0, sizes=[3, 4], strides=[4, 1])
|
116
|
+
wdm = WorkerDeviceMesh(0, ("a", "b"), processes, rank=1)
|
117
|
+
a, b = wdm.dims["a"], wdm.dims["b"]
|
118
|
+
assert b.members == [0, 1, 2, 3]
|
119
|
+
assert b.rank == 1
|
120
|
+
|
121
|
+
assert a.members == [1, 5, 9]
|
122
|
+
assert a.rank == 0
|
123
|
+
|
124
|
+
wdm = WorkerDeviceMesh(0, ("a", "b"), processes, rank=6)
|
125
|
+
a, b = wdm.dims["a"], wdm.dims["b"]
|
126
|
+
assert b.members == [4, 5, 6, 7]
|
127
|
+
assert b.rank == 2
|
128
|
+
assert a.members == [2, 6, 10]
|
129
|
+
assert a.rank == 1
|
130
|
+
|
131
|
+
processes = NDSlice(offset=0, sizes=[3, 4, 2], strides=[8, 2, 1])
|
132
|
+
wdm = WorkerDeviceMesh(0, ("a", "b", "c"), processes, rank=10)
|
@@ -0,0 +1,398 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
import random
|
8
|
+
import time
|
9
|
+
from typing import List, Optional
|
10
|
+
|
11
|
+
import pytest
|
12
|
+
import torch
|
13
|
+
|
14
|
+
try:
|
15
|
+
from later.unittest import TestCase
|
16
|
+
except ModuleNotFoundError:
|
17
|
+
from unittest import TestCase
|
18
|
+
|
19
|
+
from monarch import fetch_shard, no_mesh, remote
|
20
|
+
from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus
|
21
|
+
from monarch.common.invocation import DeviceException, RemoteException
|
22
|
+
from monarch.rust_backend_mesh import MeshWorld, PoolDeviceMeshProvider
|
23
|
+
from monarch.rust_local_mesh import (
|
24
|
+
Bootstrap,
|
25
|
+
local_mesh_provider,
|
26
|
+
local_meshes_and_bootstraps,
|
27
|
+
LoggingLocation,
|
28
|
+
SocketType,
|
29
|
+
SupervisionParams,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def _do_bogus_tensor_work(
|
34
|
+
x: torch.Tensor, y: torch.Tensor, fail_rank: Optional[int] = None
|
35
|
+
) -> torch.Tensor:
|
36
|
+
return x + y # real function actually does x @ y
|
37
|
+
|
38
|
+
|
39
|
+
do_bogus_tensor_work = remote(
|
40
|
+
"monarch.worker._testing_function.do_bogus_tensor_work",
|
41
|
+
propagate=_do_bogus_tensor_work,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def mesh_provider(
|
46
|
+
meshes: int = 2,
|
47
|
+
hosts_per_mesh: int = 1,
|
48
|
+
gpus_per_host: int = 1,
|
49
|
+
# pyre-fixme[11]: Annotation `DeviceMeshProvider` is not defined as a type.
|
50
|
+
) -> tuple[PoolDeviceMeshProvider, Bootstrap]:
|
51
|
+
return local_mesh_provider(
|
52
|
+
meshes=meshes,
|
53
|
+
hosts_per_mesh=hosts_per_mesh,
|
54
|
+
gpus_per_host=gpus_per_host,
|
55
|
+
socket_type=SocketType.UNIX,
|
56
|
+
logging_location=LoggingLocation.DEFAULT,
|
57
|
+
supervision_params=SupervisionParams(
|
58
|
+
update_timeout_in_sec=10, # Fail fast
|
59
|
+
query_interval_in_sec=1,
|
60
|
+
update_interval_in_sec=1,
|
61
|
+
),
|
62
|
+
auto_epoch=True,
|
63
|
+
)
|
64
|
+
|
65
|
+
|
66
|
+
def local_meshes(
|
67
|
+
meshes: int = 2,
|
68
|
+
hosts_per_mesh: int = 1,
|
69
|
+
gpus_per_host: int = 1,
|
70
|
+
) -> tuple[list[DeviceMesh], Bootstrap]:
|
71
|
+
return local_meshes_and_bootstraps(
|
72
|
+
meshes=meshes,
|
73
|
+
hosts_per_mesh=hosts_per_mesh,
|
74
|
+
gpus_per_host=gpus_per_host,
|
75
|
+
socket_type=SocketType.UNIX,
|
76
|
+
logging_location=LoggingLocation.DEFAULT,
|
77
|
+
supervision_params=SupervisionParams(
|
78
|
+
update_timeout_in_sec=10, # Fail fast
|
79
|
+
query_interval_in_sec=1,
|
80
|
+
update_interval_in_sec=1,
|
81
|
+
),
|
82
|
+
auto_epoch=True,
|
83
|
+
)
|
84
|
+
|
85
|
+
|
86
|
+
# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
|
87
|
+
# out is not counted as a failure, so we set a more restrictive timeout to
|
88
|
+
# ensure we see a hard failure in CI.
|
89
|
+
# The timeout is set to 250s as the failover takes longer than other tests.
|
90
|
+
@pytest.mark.timeout(250)
|
91
|
+
class TestFaultTolerance(TestCase):
|
92
|
+
def test_mesh_provider(self) -> None:
|
93
|
+
# Create multiple meshes using mesh provider
|
94
|
+
replicas = 4
|
95
|
+
provider, bootstrap = mesh_provider(meshes=replicas)
|
96
|
+
meshes: list[DeviceMesh] = []
|
97
|
+
while len(meshes) < replicas:
|
98
|
+
dm = provider.new_mesh()
|
99
|
+
meshes.append(dm)
|
100
|
+
|
101
|
+
statuses = provider._root_client.world_status()
|
102
|
+
for _, status in statuses.items():
|
103
|
+
assert (
|
104
|
+
DeviceMeshStatus(status) != DeviceMeshStatus.UNHEALTHY
|
105
|
+
), f"unexpected unhealthy mesh; world status: {statuses}"
|
106
|
+
|
107
|
+
# Check that all meshes are initially live
|
108
|
+
for mesh in meshes:
|
109
|
+
with mesh.activate():
|
110
|
+
t = torch.ones(1)
|
111
|
+
local_t = fetch_shard(t).result()
|
112
|
+
assert torch.equal(local_t, torch.ones(1))
|
113
|
+
|
114
|
+
# Simulate a failure by killing one of the processes
|
115
|
+
bootstrap.processes[-1].kill()
|
116
|
+
|
117
|
+
# Find unhealthy mesh
|
118
|
+
# Mix user and device errors
|
119
|
+
unhealthy_meshes = []
|
120
|
+
for mesh in meshes:
|
121
|
+
with mesh.activate():
|
122
|
+
# Send a call to trigger a failure
|
123
|
+
x = torch.rand(3, 4)
|
124
|
+
y = torch.rand(3, 4)
|
125
|
+
z = do_bogus_tensor_work(x, y)
|
126
|
+
try:
|
127
|
+
_ = fetch_shard(z).result()
|
128
|
+
except RemoteException:
|
129
|
+
pass
|
130
|
+
except DeviceException as e:
|
131
|
+
# Device error
|
132
|
+
unhealthy_meshes.append(mesh)
|
133
|
+
mesh.exit(e)
|
134
|
+
|
135
|
+
self.assertEqual(len(unhealthy_meshes), 1)
|
136
|
+
|
137
|
+
# World status will transition to unhealthy
|
138
|
+
has_unhealth = False
|
139
|
+
unhealthy_statuses = []
|
140
|
+
while not has_unhealth:
|
141
|
+
statuses = provider._root_client.world_status()
|
142
|
+
for _, status in statuses.items():
|
143
|
+
if DeviceMeshStatus(status) == DeviceMeshStatus.UNHEALTHY:
|
144
|
+
has_unhealth = True
|
145
|
+
unhealthy_statuses = statuses
|
146
|
+
break
|
147
|
+
time.sleep(1)
|
148
|
+
|
149
|
+
# Unhealthy worlds will be evicted
|
150
|
+
has_unhealth = True
|
151
|
+
healthy_statuses = []
|
152
|
+
while has_unhealth:
|
153
|
+
has_unhealth = False
|
154
|
+
statuses = provider._root_client.world_status()
|
155
|
+
healthy_statuses = statuses
|
156
|
+
for _, status in statuses.items():
|
157
|
+
if DeviceMeshStatus(status) == DeviceMeshStatus.UNHEALTHY:
|
158
|
+
has_unhealth = True
|
159
|
+
break
|
160
|
+
time.sleep(1)
|
161
|
+
|
162
|
+
# A worker world will be evicted
|
163
|
+
self.assertEqual(len(healthy_statuses), len(unhealthy_statuses) - 1)
|
164
|
+
|
165
|
+
def test_worker_supervision_failure(self) -> None:
|
166
|
+
meshes, bootstrap = local_meshes(meshes=1)
|
167
|
+
assert len(meshes) == 1
|
168
|
+
mesh = meshes[0]
|
169
|
+
|
170
|
+
# Check the mesh initially functional
|
171
|
+
with mesh.activate():
|
172
|
+
t = torch.ones(1)
|
173
|
+
local_t = fetch_shard(t).result()
|
174
|
+
assert torch.equal(local_t, torch.ones(1))
|
175
|
+
|
176
|
+
# Simulate a failure by killing one of the processes
|
177
|
+
bootstrap.processes[-1].kill()
|
178
|
+
|
179
|
+
# A device error will be raised
|
180
|
+
with mesh.activate():
|
181
|
+
t = torch.ones(1)
|
182
|
+
with self.assertRaisesRegex(DeviceException, r"crashed"):
|
183
|
+
local_t = fetch_shard(t).result()
|
184
|
+
|
185
|
+
def test_multi_mesh_failure_isolation(self) -> None:
|
186
|
+
replicas = 4
|
187
|
+
provider, bootstrap = mesh_provider(meshes=replicas)
|
188
|
+
meshes: list[DeviceMesh] = []
|
189
|
+
while len(meshes) < replicas:
|
190
|
+
dm = provider.new_mesh()
|
191
|
+
meshes.append(dm)
|
192
|
+
|
193
|
+
# Check the meshes initially functional
|
194
|
+
for mesh in meshes:
|
195
|
+
with mesh.activate():
|
196
|
+
t = torch.ones(1)
|
197
|
+
local_t = fetch_shard(t).result()
|
198
|
+
assert torch.equal(local_t, torch.ones(1))
|
199
|
+
|
200
|
+
initial_size = len(provider._root_client.world_status())
|
201
|
+
|
202
|
+
# Simulate a failure by killing one of the processes
|
203
|
+
bootstrap.processes[-1].kill()
|
204
|
+
|
205
|
+
# Mix user and device errors
|
206
|
+
healthy_meshes = []
|
207
|
+
unhealthy_meshes = []
|
208
|
+
for mesh in meshes:
|
209
|
+
with mesh.activate():
|
210
|
+
# Send a call to trigger a failure
|
211
|
+
x = torch.rand(3, 4)
|
212
|
+
y = torch.rand(3, 4)
|
213
|
+
z = do_bogus_tensor_work(x, y)
|
214
|
+
try:
|
215
|
+
_ = fetch_shard(z).result()
|
216
|
+
except RemoteException:
|
217
|
+
# User error
|
218
|
+
fetch_shard(x).result()
|
219
|
+
healthy_meshes.append(mesh)
|
220
|
+
except DeviceException as e:
|
221
|
+
# Device error
|
222
|
+
unhealthy_meshes.append(mesh)
|
223
|
+
mesh.exit(e)
|
224
|
+
|
225
|
+
self.assertEqual(len(healthy_meshes), replicas - 1)
|
226
|
+
self.assertEqual(len(unhealthy_meshes), 1)
|
227
|
+
|
228
|
+
while True:
|
229
|
+
size = len(provider._root_client.world_status())
|
230
|
+
if size == initial_size - 2:
|
231
|
+
break
|
232
|
+
|
233
|
+
# The healthy meshes should still be functional
|
234
|
+
for mesh in healthy_meshes:
|
235
|
+
with mesh.activate():
|
236
|
+
t = torch.ones(1)
|
237
|
+
local_t = fetch_shard(t).result()
|
238
|
+
assert torch.equal(local_t, torch.ones(1))
|
239
|
+
|
240
|
+
def test_out_of_order_receive(self) -> None:
|
241
|
+
meshes, _ = local_meshes(meshes=8)
|
242
|
+
|
243
|
+
# Check the meshes initially functional
|
244
|
+
ts = []
|
245
|
+
for i, mesh in enumerate(meshes):
|
246
|
+
with mesh.activate():
|
247
|
+
t = torch.ones(i + 1)
|
248
|
+
ts.append(t)
|
249
|
+
|
250
|
+
# Shuffle the meshes to makes sure the client is able to dispatch results in order
|
251
|
+
indices = list(range(len(meshes)))
|
252
|
+
shuffled_meshes = list(zip(indices, meshes, ts))
|
253
|
+
random.shuffle(shuffled_meshes)
|
254
|
+
for i, mesh, t in shuffled_meshes:
|
255
|
+
with mesh.activate():
|
256
|
+
local_t = fetch_shard(t).result()
|
257
|
+
assert torch.equal(local_t, torch.ones(i + 1))
|
258
|
+
|
259
|
+
def test_mesh_shrink_and_grow(self) -> None:
|
260
|
+
# Create multiple meshes using mesh provider
|
261
|
+
replicas = 4
|
262
|
+
provider, bootstrap = mesh_provider(meshes=replicas)
|
263
|
+
meshes: list[DeviceMesh] = []
|
264
|
+
while len(meshes) < replicas:
|
265
|
+
dm = provider.new_mesh()
|
266
|
+
meshes.append(dm)
|
267
|
+
|
268
|
+
worlds = len(provider._root_client.world_status())
|
269
|
+
assigned_meshes = provider._mesh_map
|
270
|
+
assert len(assigned_meshes) == replicas
|
271
|
+
|
272
|
+
# Happy path
|
273
|
+
for i, mesh in enumerate(meshes):
|
274
|
+
with mesh.activate():
|
275
|
+
t = torch.ones(i + 1)
|
276
|
+
local_t = fetch_shard(t).result()
|
277
|
+
assert torch.equal(local_t, torch.ones(i + 1))
|
278
|
+
|
279
|
+
# Kill a worker
|
280
|
+
mesh_to_kill: MeshWorld = list(bootstrap.mesh_worlds.keys())[1]
|
281
|
+
procs = bootstrap.mesh_worlds[mesh_to_kill]
|
282
|
+
assert len(procs) == 2
|
283
|
+
procs[-1].kill()
|
284
|
+
|
285
|
+
# The killed mesh will become unhealthy
|
286
|
+
healthy_meshes = []
|
287
|
+
unhealthy_meshes = []
|
288
|
+
for i, mesh in enumerate(meshes):
|
289
|
+
with mesh.activate():
|
290
|
+
try:
|
291
|
+
t = torch.ones(i + 1)
|
292
|
+
local_t = fetch_shard(t).result()
|
293
|
+
with no_mesh.activate():
|
294
|
+
assert torch.equal(local_t, torch.ones(i + 1))
|
295
|
+
healthy_meshes.append(mesh)
|
296
|
+
except DeviceException as e:
|
297
|
+
unhealthy_meshes.append(mesh)
|
298
|
+
mesh.exit(e)
|
299
|
+
assert len(healthy_meshes) == replicas - 1
|
300
|
+
assert len(unhealthy_meshes) == 1
|
301
|
+
|
302
|
+
# Restart the mesh
|
303
|
+
for proc in procs:
|
304
|
+
proc.kill()
|
305
|
+
|
306
|
+
# We should be able to acquire a new mesh without waiting for the old mesh to be evicted
|
307
|
+
(worker_world, controller_id) = mesh_to_kill
|
308
|
+
bootstrap.launch_mesh(controller_id=controller_id, worker_world=worker_world)
|
309
|
+
|
310
|
+
dm = provider.new_mesh()
|
311
|
+
healthy_meshes.append(dm)
|
312
|
+
|
313
|
+
# We could have 4 or 5 meshes depending on if the unhealthy mesh is evicted
|
314
|
+
assigned_meshes = provider._mesh_map
|
315
|
+
assert len(assigned_meshes) >= replicas
|
316
|
+
|
317
|
+
# We are happy again
|
318
|
+
assert len(healthy_meshes) == replicas
|
319
|
+
for i, mesh in enumerate(healthy_meshes):
|
320
|
+
with mesh.activate():
|
321
|
+
t = torch.ones(i + 1)
|
322
|
+
local_t = fetch_shard(t).result()
|
323
|
+
assert torch.equal(local_t, torch.ones(i + 1))
|
324
|
+
|
325
|
+
# Old world should be evicted and new world should be spawned. So we ended up with the same number of worlds.
|
326
|
+
while len((provider._root_client.world_status())) != worlds:
|
327
|
+
# We expect to evict both controller and worker worlds from the same mesh.
|
328
|
+
time.sleep(1)
|
329
|
+
|
330
|
+
# Eventually, we only have 4 healthy meshes
|
331
|
+
assigned_meshes = provider._mesh_map
|
332
|
+
while len(assigned_meshes) != replicas:
|
333
|
+
with self.assertRaisesRegex(
|
334
|
+
TimeoutError, r"Could not find a healthy world"
|
335
|
+
):
|
336
|
+
_ = provider.new_mesh(timeout_in_sec=1)
|
337
|
+
assigned_meshes = provider._mesh_map
|
338
|
+
time.sleep(1)
|
339
|
+
|
340
|
+
def test_kill_controller(self) -> None:
|
341
|
+
# Create multiple meshes using mesh provider
|
342
|
+
replicas = 2
|
343
|
+
provider, bootstrap = mesh_provider(meshes=replicas)
|
344
|
+
meshes: list[DeviceMesh] = []
|
345
|
+
while len(meshes) < replicas:
|
346
|
+
dm = provider.new_mesh()
|
347
|
+
meshes.append(dm)
|
348
|
+
|
349
|
+
# Happy path
|
350
|
+
for i, mesh in enumerate(meshes):
|
351
|
+
with mesh.activate():
|
352
|
+
t = torch.ones(i + 1)
|
353
|
+
local_t = fetch_shard(t).result()
|
354
|
+
assert torch.equal(local_t, torch.ones(i + 1))
|
355
|
+
|
356
|
+
# Kill a controller
|
357
|
+
mesh_to_kill: MeshWorld = list(bootstrap.mesh_worlds.keys())[1]
|
358
|
+
procs = bootstrap.mesh_worlds[mesh_to_kill]
|
359
|
+
assert len(procs) == 2
|
360
|
+
procs[0].kill()
|
361
|
+
|
362
|
+
# We should be able to detect the failure
|
363
|
+
healthy_meshes = []
|
364
|
+
detected_failure = False
|
365
|
+
for i, mesh in enumerate(meshes):
|
366
|
+
with mesh.activate():
|
367
|
+
try:
|
368
|
+
t = torch.ones(i + 1)
|
369
|
+
local_t = fetch_shard(t).result()
|
370
|
+
with no_mesh.activate():
|
371
|
+
assert torch.equal(local_t, torch.ones(i + 1))
|
372
|
+
healthy_meshes.append(mesh)
|
373
|
+
except DeviceException:
|
374
|
+
detected_failure = True
|
375
|
+
assert len(healthy_meshes) == replicas - 1
|
376
|
+
assert detected_failure
|
377
|
+
|
378
|
+
def test_late_client_attaching(self) -> None:
|
379
|
+
provider, _ = mesh_provider(meshes=1)
|
380
|
+
|
381
|
+
# Wait for the meshes to be healthy
|
382
|
+
healthy_meshes = 0
|
383
|
+
while healthy_meshes < 2:
|
384
|
+
healthy_meshes = 0
|
385
|
+
statuses = provider._root_client.world_status()
|
386
|
+
for _, status in statuses.items():
|
387
|
+
if DeviceMeshStatus(status) == DeviceMeshStatus.LIVE:
|
388
|
+
healthy_meshes += 1
|
389
|
+
time.sleep(1)
|
390
|
+
|
391
|
+
# Sleep long enough to allow those "hidden messages" to be sent
|
392
|
+
time.sleep(15)
|
393
|
+
|
394
|
+
# Those "hidden messages" should not cause a trouble before a client is ready
|
395
|
+
mesh = provider.new_mesh()
|
396
|
+
with mesh.activate():
|
397
|
+
t = torch.ones(1)
|
398
|
+
fetch_shard(t).result()
|
tests/test_future.py
ADDED
@@ -0,0 +1,94 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-strict
|
8
|
+
import time
|
9
|
+
from typing import Callable
|
10
|
+
|
11
|
+
import pytest
|
12
|
+
from monarch import Future, RemoteException
|
13
|
+
from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension
|
14
|
+
ActorId,
|
15
|
+
)
|
16
|
+
from monarch.common import future
|
17
|
+
from monarch.common.client import Client
|
18
|
+
|
19
|
+
|
20
|
+
class TestFuture:
|
21
|
+
def test_future(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
22
|
+
the_time: int = 0
|
23
|
+
the_messages: list[tuple[int | float, Callable[[], None]]] = []
|
24
|
+
|
25
|
+
class MockClient(Client):
|
26
|
+
def __init__(self):
|
27
|
+
pass
|
28
|
+
|
29
|
+
def handle_next_message(self, timeout) -> bool:
|
30
|
+
nonlocal the_time
|
31
|
+
if not the_messages:
|
32
|
+
return False
|
33
|
+
time, action = the_messages[0]
|
34
|
+
if timeout is None or time <= the_time + timeout:
|
35
|
+
the_time = time
|
36
|
+
action()
|
37
|
+
the_messages.pop(0)
|
38
|
+
return True
|
39
|
+
else:
|
40
|
+
the_time += timeout
|
41
|
+
return False
|
42
|
+
|
43
|
+
def _request_status(self):
|
44
|
+
pass
|
45
|
+
|
46
|
+
client: Client = MockClient()
|
47
|
+
|
48
|
+
def mock_time() -> int:
|
49
|
+
return the_time
|
50
|
+
|
51
|
+
monkeypatch.setattr(time, "time", mock_time)
|
52
|
+
f = Future(client)
|
53
|
+
the_messages = [(1, lambda: f._set_result(4))]
|
54
|
+
assert not f.done()
|
55
|
+
with pytest.raises(TimeoutError):
|
56
|
+
f.result(timeout=0.5)
|
57
|
+
assert 4 == f.result(timeout=1)
|
58
|
+
assert f.exception() is None
|
59
|
+
assert f.done()
|
60
|
+
f = Future(client)
|
61
|
+
the_messages = [(1, lambda: None), (2, lambda: f._set_result(3))]
|
62
|
+
the_time = 0
|
63
|
+
assert 3 == f.result()
|
64
|
+
f = Future(client)
|
65
|
+
re = RemoteException(
|
66
|
+
0, Exception(), None, [], [], ActorId.from_string("unknown[0].unknown[0]")
|
67
|
+
)
|
68
|
+
|
69
|
+
the_messages = [(1, lambda: None), (2, lambda: f._set_result(re))]
|
70
|
+
the_time = 0
|
71
|
+
assert f.exception() is not None
|
72
|
+
|
73
|
+
f = Future(client)
|
74
|
+
the_messages = [(0, lambda: None), (0.2, lambda: f._set_result(7))]
|
75
|
+
the_time = 0
|
76
|
+
assert 7 == f.result(timeout=0.3)
|
77
|
+
|
78
|
+
fs = []
|
79
|
+
|
80
|
+
def setup() -> None:
|
81
|
+
nonlocal fs, the_messages
|
82
|
+
fs = [Future(client) for _ in range(4)]
|
83
|
+
|
84
|
+
# To avoid closure binding gotcha.
|
85
|
+
def set_at_time(f: Future, time: int) -> tuple[int, Callable[[], None]]:
|
86
|
+
return (time, lambda: f._set_result(time))
|
87
|
+
|
88
|
+
the_messages = [set_at_time(f, time) for time, f in enumerate(fs)]
|
89
|
+
|
90
|
+
setup()
|
91
|
+
assert {f.result() for f in future.stream(fs, timeout=2)} == {0, 1, 2}
|
92
|
+
|
93
|
+
setup()
|
94
|
+
assert {f.result() for f in future.stream(fs, timeout=3)} == {0, 1, 2, 3}
|