torchmonarch-nightly 2025.6.27__cp312-cp312-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +58 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +230 -0
- monarch/actor_mesh.py +761 -0
- monarch/allocator.py +220 -0
- monarch/bootstrap_main.py +59 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +68 -0
- monarch/cached_remote_function.py +257 -0
- monarch/code_sync.py +10 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +690 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +417 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +573 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +297 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +229 -0
- monarch/common/stream.py +114 -0
- monarch/common/tensor.py +814 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/debugger.py +379 -0
- monarch/fetch.py +55 -0
- monarch/future.py +76 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/mesh_controller.py +271 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/pdb_wrapper.py +135 -0
- monarch/proc_mesh.py +299 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +162 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +359 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/telemetry.py +19 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +251 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +58 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +165 -0
- monarch/tools/network.py +69 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +180 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +240 -0
- tests/test_alloc.py +25 -0
- tests/test_allocator.py +365 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +845 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +736 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +217 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- tests/test_tensor_engine.py +52 -0
- torchmonarch_nightly-2025.6.27.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.27.dist-info/RECORD +165 -0
- torchmonarch_nightly-2025.6.27.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.27.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.27.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.27.dist-info/top_level.txt +3 -0
tests/test_controller.py
ADDED
@@ -0,0 +1,845 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import itertools
|
9
|
+
import logging
|
10
|
+
import re
|
11
|
+
import sys
|
12
|
+
import traceback
|
13
|
+
from contextlib import contextmanager
|
14
|
+
|
15
|
+
import monarch
|
16
|
+
import monarch.random
|
17
|
+
import pytest
|
18
|
+
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from monarch import (
|
22
|
+
DeviceMesh,
|
23
|
+
fetch_shard,
|
24
|
+
grad_function,
|
25
|
+
grad_generator,
|
26
|
+
no_mesh,
|
27
|
+
Stream,
|
28
|
+
Tensor,
|
29
|
+
)
|
30
|
+
|
31
|
+
from monarch._testing import BackendType, TestingContext
|
32
|
+
from monarch.common.controller_api import LogMessage
|
33
|
+
from monarch.common.invocation import DeviceException
|
34
|
+
from monarch.common.remote import remote
|
35
|
+
from monarch.common.tree import flattener
|
36
|
+
from monarch.rust_local_mesh import (
|
37
|
+
ControllerParams,
|
38
|
+
local_mesh,
|
39
|
+
local_meshes_and_bootstraps,
|
40
|
+
LoggingLocation,
|
41
|
+
SocketType,
|
42
|
+
SupervisionParams,
|
43
|
+
)
|
44
|
+
from monarch_supervisor.logging import fix_exception_lines
|
45
|
+
|
46
|
+
|
47
|
+
def custom_excepthook(exc_type, exc_value, exc_traceback):
|
48
|
+
tb_lines = fix_exception_lines(
|
49
|
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
50
|
+
)
|
51
|
+
print("\n".join(tb_lines), file=sys.stderr)
|
52
|
+
|
53
|
+
|
54
|
+
sys.excepthook = custom_excepthook
|
55
|
+
|
56
|
+
|
57
|
+
@pytest.fixture(scope="module", autouse=True)
|
58
|
+
def testing_context():
|
59
|
+
global local
|
60
|
+
with TestingContext() as local:
|
61
|
+
yield
|
62
|
+
|
63
|
+
|
64
|
+
@contextmanager
|
65
|
+
def local_rust_device_mesh(
|
66
|
+
hosts,
|
67
|
+
gpu_per_host,
|
68
|
+
activate: bool = True,
|
69
|
+
controller_params: ControllerParams | None = None,
|
70
|
+
):
|
71
|
+
with local_mesh(
|
72
|
+
hosts=hosts,
|
73
|
+
gpus_per_host=gpu_per_host,
|
74
|
+
socket_type=SocketType.UNIX,
|
75
|
+
logging_location=LoggingLocation.FILE,
|
76
|
+
controller_params=controller_params,
|
77
|
+
) as dm:
|
78
|
+
try:
|
79
|
+
if activate:
|
80
|
+
with dm.activate():
|
81
|
+
yield dm
|
82
|
+
else:
|
83
|
+
yield dm
|
84
|
+
dm.exit()
|
85
|
+
except Exception:
|
86
|
+
dm.client._shutdown = True
|
87
|
+
raise
|
88
|
+
|
89
|
+
|
90
|
+
panic = remote("__test_panic", propagate="inspect")
|
91
|
+
|
92
|
+
remote_sleep = remote("time.sleep", propagate="inspect")
|
93
|
+
|
94
|
+
|
95
|
+
@pytest.mark.skipif(
|
96
|
+
torch.cuda.device_count() < 2,
|
97
|
+
reason="Not enough GPUs, this test requires at least 2 GPUs",
|
98
|
+
)
|
99
|
+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS, "mesh"])
|
100
|
+
# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
|
101
|
+
# out is not counted as a failure, so we set a more restrictive timeout to
|
102
|
+
# ensure we see a hard failure in CI.
|
103
|
+
@pytest.mark.timeout(120)
|
104
|
+
class TestController:
|
105
|
+
@classmethod
|
106
|
+
def local_device_mesh(
|
107
|
+
cls,
|
108
|
+
N,
|
109
|
+
gpu_per_host,
|
110
|
+
backend_type,
|
111
|
+
activate=True,
|
112
|
+
):
|
113
|
+
return local.local_device_mesh(
|
114
|
+
N,
|
115
|
+
gpu_per_host,
|
116
|
+
activate,
|
117
|
+
backend=str(backend_type),
|
118
|
+
)
|
119
|
+
|
120
|
+
def test_errors(self, backend_type):
|
121
|
+
t = torch.rand(3, 4)
|
122
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
123
|
+
y = torch.rand(3, 4)
|
124
|
+
with pytest.raises(TypeError, match="LOCAL_TENSOR"):
|
125
|
+
t.add(y)
|
126
|
+
with pytest.raises(TypeError, match="WRONG_MESH"):
|
127
|
+
sm = device_mesh.slice(host=0)
|
128
|
+
with sm.activate():
|
129
|
+
x = torch.rand(3, 4)
|
130
|
+
x.add(y)
|
131
|
+
|
132
|
+
other = Stream("other")
|
133
|
+
t = torch.rand(10).cuda()
|
134
|
+
with pytest.raises(TypeError, match="WRONG_STREAM"):
|
135
|
+
with other.activate():
|
136
|
+
t = t.reduce("host", "sum")
|
137
|
+
|
138
|
+
def test_sub_mesh(self, backend_type):
|
139
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
140
|
+
h0 = device_mesh.slice(host=0)
|
141
|
+
h1 = device_mesh.slice(host=1)
|
142
|
+
with h0.activate():
|
143
|
+
_ = torch.rand(3, 4)
|
144
|
+
with h1.activate():
|
145
|
+
_ = torch.rand(3, 4)
|
146
|
+
# Runs on a different mesh but should still work
|
147
|
+
|
148
|
+
def test_fetch_result_device(self, backend_type):
|
149
|
+
with self.local_device_mesh(2, 2, backend_type):
|
150
|
+
on_gpu = torch.ones(2, 3, device="cuda")
|
151
|
+
on_cpu = torch.ones(2, 3, device="cpu")
|
152
|
+
|
153
|
+
on_gpu_local = fetch_shard(on_gpu).result()
|
154
|
+
on_cpu_local = fetch_shard(on_cpu).result()
|
155
|
+
|
156
|
+
assert on_gpu_local.device == torch.device("cpu")
|
157
|
+
assert on_cpu_local.device == torch.device("cpu")
|
158
|
+
|
159
|
+
def test_dim1_mesh(self, backend_type):
|
160
|
+
with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
|
161
|
+
mesh3d = device_mesh.split(host=("oh", "ih"), ih=1)
|
162
|
+
with mesh3d.activate():
|
163
|
+
x = torch.ones(3, 4)
|
164
|
+
local_x = fetch_shard(x).result()
|
165
|
+
|
166
|
+
assert torch.equal(local_x, torch.ones(3, 4))
|
167
|
+
|
168
|
+
def test_sub_mesh_use_only_one(self, backend_type):
|
169
|
+
with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
|
170
|
+
h0 = device_mesh.slice(host=0)
|
171
|
+
|
172
|
+
with h0.activate():
|
173
|
+
x = torch.ones(3, 4)
|
174
|
+
local_x = fetch_shard(x)
|
175
|
+
|
176
|
+
local_x = local_x.result(timeout=20)
|
177
|
+
assert torch.equal(local_x, torch.ones(3, 4))
|
178
|
+
|
179
|
+
def test_sub_mesh_process_grop(self, backend_type):
|
180
|
+
with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
|
181
|
+
h0 = device_mesh.slice(host=0)
|
182
|
+
pg0 = h0.process_group(("gpu",))
|
183
|
+
pg1 = h0.process_group(("gpu",))
|
184
|
+
# Is there a way to functionally test that these two PG's aren't
|
185
|
+
# the same in the backend?
|
186
|
+
assert pg0 != pg1
|
187
|
+
|
188
|
+
def test_reduce(self, backend_type):
|
189
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
190
|
+
x = (
|
191
|
+
12 * 2 * device_mesh.rank("host")
|
192
|
+
+ 12 * device_mesh.rank("gpu")
|
193
|
+
+ torch.arange(12, device="cuda").reshape(3, 4)
|
194
|
+
)
|
195
|
+
y = x.reduce("gpu", "sum")
|
196
|
+
g = x.reduce("gpu", "stack")
|
197
|
+
with pytest.raises(TypeError, match="When scattering"):
|
198
|
+
x = x.reduce("gpu", "sum", scatter=True)
|
199
|
+
x = x.reshape(2, 6)
|
200
|
+
atoa = x.reduce("gpu", "stack", scatter=True)
|
201
|
+
rs = x.reduce("gpu", "sum", scatter=True)
|
202
|
+
rad = x.reduce((), "sum")
|
203
|
+
rade = x.reduce(("gpu", "host"), "sum")
|
204
|
+
with pytest.raises(
|
205
|
+
ValueError, match="is not valid for multiple dimensions"
|
206
|
+
):
|
207
|
+
x.reduce((), "sum", scatter=True)
|
208
|
+
with pytest.raises(
|
209
|
+
ValueError, match="is not valid for multiple dimensions"
|
210
|
+
):
|
211
|
+
x.reduce((), "stack")
|
212
|
+
with pytest.raises(
|
213
|
+
ValueError, match="is not valid for multiple dimensions"
|
214
|
+
):
|
215
|
+
x.reduce((), "stack", scatter=True)
|
216
|
+
y_local = fetch_shard(y).result()
|
217
|
+
g_local = fetch_shard(g).result()
|
218
|
+
# TODO compute the expected values to compare agains in the below section
|
219
|
+
_ = fetch_shard(atoa).result()
|
220
|
+
_ = fetch_shard(rs).result()
|
221
|
+
rad_local = fetch_shard(rad).result()
|
222
|
+
rade_local = fetch_shard(rade).result()
|
223
|
+
|
224
|
+
xs = {
|
225
|
+
(h, g): 12 * 2 * h + 12 * g + torch.arange(12, device="cpu").reshape(3, 4)
|
226
|
+
for h, g in itertools.product(range(2), range(2))
|
227
|
+
}
|
228
|
+
|
229
|
+
y_expected = xs[(0, 0)] + xs[(0, 1)]
|
230
|
+
g_expected = torch.stack([xs[(0, 0)], xs[(0, 1)]])
|
231
|
+
assert torch.equal(y_local, y_expected)
|
232
|
+
assert torch.equal(g_local, g_expected)
|
233
|
+
rad_expected = (xs[(0, 0)] + xs[(0, 1)] + xs[(1, 0)] + xs[(1, 1)]).reshape(
|
234
|
+
rad_local.shape
|
235
|
+
)
|
236
|
+
assert torch.equal(rad_local, rad_expected)
|
237
|
+
assert torch.equal(rade_local, rad_expected)
|
238
|
+
|
239
|
+
# test is run on 4 GPUs, can't have mesh with 3 non-trivial dimensions
|
240
|
+
with self.local_device_mesh(2, 2, backend_type, activate=False) as mesh2d:
|
241
|
+
device_mesh = mesh2d.split(host=("oh", "ih"), ih=1)
|
242
|
+
with device_mesh.activate():
|
243
|
+
x = (
|
244
|
+
12 * 2 * device_mesh.rank("oh")
|
245
|
+
+ 12 * device_mesh.rank("gpu")
|
246
|
+
+ torch.arange(12, device="cuda").reshape(3, 4)
|
247
|
+
)
|
248
|
+
y = x.reduce(("ih", "gpu"), "sum")
|
249
|
+
y_local = fetch_shard(y).result()
|
250
|
+
z = x.reduce(("oh", "gpu"), "sum")
|
251
|
+
z_local = fetch_shard(z).result()
|
252
|
+
|
253
|
+
assert torch.equal(y_local, y_expected)
|
254
|
+
assert torch.equal(z_local, rad_expected.reshape(z_local.shape))
|
255
|
+
|
256
|
+
def test_reduce_out(self, backend_type):
|
257
|
+
with self.local_device_mesh(2, 2, backend_type):
|
258
|
+
inp = torch.rand(2, 4, device="cuda")
|
259
|
+
out_incorrect = torch.rand(2, 4, device="cuda")
|
260
|
+
out = torch.rand(4, device="cuda")
|
261
|
+
|
262
|
+
with pytest.raises(
|
263
|
+
ValueError, match="Reduce expects the shape to be torch.Size."
|
264
|
+
):
|
265
|
+
_ = inp.reduce("host", reduction="sum", scatter=True, out=out_incorrect)
|
266
|
+
|
267
|
+
reduce_out = inp.reduce("host", reduction="sum", scatter=True)
|
268
|
+
local_out = fetch_shard(out).result()
|
269
|
+
local_reduce_out = fetch_shard(reduce_out).result()
|
270
|
+
assert out._fake is not reduce_out._fake
|
271
|
+
with no_mesh.activate():
|
272
|
+
assert not torch.equal(local_out, local_reduce_out)
|
273
|
+
|
274
|
+
reduce_out = inp.reduce("host", reduction="sum", scatter=True, out=out)
|
275
|
+
local_out = fetch_shard(out).result()
|
276
|
+
local_reduce_out = fetch_shard(reduce_out).result()
|
277
|
+
assert out._fake is reduce_out._fake
|
278
|
+
with no_mesh.activate():
|
279
|
+
assert torch.equal(local_out, local_reduce_out)
|
280
|
+
|
281
|
+
def test_fetch(self, backend_type):
|
282
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
283
|
+
h = device_mesh.rank("host")
|
284
|
+
g = device_mesh.rank("gpu")
|
285
|
+
for hi in range(2):
|
286
|
+
for gi in range(2):
|
287
|
+
x, y = fetch_shard((h, g), {"host": hi, "gpu": gi}).result()
|
288
|
+
with no_mesh.activate():
|
289
|
+
assert (hi, gi) == (x.item(), y.item())
|
290
|
+
|
291
|
+
def test_mutate(self, backend_type):
|
292
|
+
with self.local_device_mesh(2, 2, backend_type):
|
293
|
+
x = torch.rand(3, 4).cuda()
|
294
|
+
x.abs_()
|
295
|
+
s = Stream("other")
|
296
|
+
b, drop = s.borrow(x)
|
297
|
+
with pytest.raises(TypeError, match="would be mutated"):
|
298
|
+
x.abs_()
|
299
|
+
with s.activate():
|
300
|
+
_ = b.add(b)
|
301
|
+
drop.drop()
|
302
|
+
x.abs_()
|
303
|
+
b, drop = s.borrow(x, mutable=True)
|
304
|
+
with s.activate():
|
305
|
+
b.abs_()
|
306
|
+
drop.drop()
|
307
|
+
# del b
|
308
|
+
x.abs_()
|
309
|
+
|
310
|
+
def test_movement(self, backend_type):
|
311
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
312
|
+
sm0 = device_mesh.slice(host=0)
|
313
|
+
sm1 = device_mesh.slice(host=1)
|
314
|
+
|
315
|
+
with sm0.activate():
|
316
|
+
x = torch.rand(3, 4, device="cuda")
|
317
|
+
_ = x.to_mesh(sm1)
|
318
|
+
|
319
|
+
a = torch.rand(3, 4, device="cuda")
|
320
|
+
|
321
|
+
b = a.slice_mesh(host=0)
|
322
|
+
_ = b.to_mesh(sm0)
|
323
|
+
_ = b.to_mesh(sm1)
|
324
|
+
|
325
|
+
def test_broadcast_one(self, backend_type):
|
326
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
327
|
+
for dim in ("host", "gpu"):
|
328
|
+
subset = device_mesh.slice(**{dim: 1})
|
329
|
+
with subset.activate():
|
330
|
+
x = torch.rand(3, device="cuda")
|
331
|
+
y = x.to_mesh(device_mesh)
|
332
|
+
|
333
|
+
with subset.activate():
|
334
|
+
a = monarch.inspect(x)
|
335
|
+
with device_mesh.activate():
|
336
|
+
b = monarch.inspect(y.reduce(dim, reduction="stack"))
|
337
|
+
with no_mesh.activate():
|
338
|
+
assert torch.allclose(a.expand(2, -1), b, rtol=0, atol=0)
|
339
|
+
|
340
|
+
def test_broadcast_two(self, backend_type):
|
341
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
342
|
+
subset = device_mesh.slice(host=1, gpu=1)
|
343
|
+
with subset.activate():
|
344
|
+
x = torch.rand(3, device="cuda")
|
345
|
+
y = x.to_mesh(device_mesh)
|
346
|
+
|
347
|
+
with subset.activate():
|
348
|
+
a = monarch.inspect(x)
|
349
|
+
with device_mesh.activate():
|
350
|
+
b = monarch.inspect(
|
351
|
+
y.reduce("host", reduction="stack").reduce("gpu", reduction="stack")
|
352
|
+
)
|
353
|
+
with no_mesh.activate():
|
354
|
+
assert torch.allclose(a.expand(2, 2, -1), b, rtol=0, atol=0)
|
355
|
+
|
356
|
+
def test_autograd(self, backend_type):
|
357
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
358
|
+
x = torch.rand(3, 4, requires_grad=True)
|
359
|
+
y = torch.rand(4, 3, requires_grad=True)
|
360
|
+
z = torch.rand(3, requires_grad=True)
|
361
|
+
|
362
|
+
foo = (x @ y + z).sum()
|
363
|
+
with no_mesh.activate():
|
364
|
+
# check backward restores forward mesh
|
365
|
+
for t in grad_generator(foo, [z, y, x]):
|
366
|
+
with device_mesh.activate():
|
367
|
+
fetch_shard(t).result()
|
368
|
+
|
369
|
+
def test_mesh_semantics(self, backend_type):
|
370
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
371
|
+
host0 = device_mesh.slice(host=0)
|
372
|
+
host1 = device_mesh.slice(host=1)
|
373
|
+
with host0.activate():
|
374
|
+
x = torch.randn(5)
|
375
|
+
y = x * 5
|
376
|
+
with host1.activate():
|
377
|
+
a = torch.randn(5)
|
378
|
+
b = a * 5
|
379
|
+
x.cos()
|
380
|
+
y.cos()
|
381
|
+
b.cos()
|
382
|
+
|
383
|
+
def test_autograd_multi_mesh(self, backend_type):
|
384
|
+
@grad_function
|
385
|
+
def to_mesh(x: Tensor, mesh: DeviceMesh):
|
386
|
+
omesh = x.mesh
|
387
|
+
|
388
|
+
def backward(grad_x: Tensor):
|
389
|
+
print(grad_x.mesh, omesh)
|
390
|
+
return grad_x.to_mesh(omesh), None
|
391
|
+
|
392
|
+
return x.to_mesh(mesh), backward
|
393
|
+
|
394
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
395
|
+
host0 = device_mesh.slice(host=0)
|
396
|
+
host1 = device_mesh.slice(host=1)
|
397
|
+
with host0.activate():
|
398
|
+
x = torch.rand(3, 4, requires_grad=True, device="cuda")
|
399
|
+
y = torch.rand(4, 3, requires_grad=True, device="cuda")
|
400
|
+
t = x @ y
|
401
|
+
t = to_mesh(t, host1)
|
402
|
+
with host1.activate():
|
403
|
+
z = torch.rand(3, requires_grad=True, device="cuda")
|
404
|
+
foo = (t + z).sum()
|
405
|
+
|
406
|
+
for r in grad_generator(foo, [z, y, x]):
|
407
|
+
with r.mesh.activate():
|
408
|
+
print(fetch_shard(r).result())
|
409
|
+
|
410
|
+
def test_many(self, backend_type):
|
411
|
+
with self.local_device_mesh(2, 2, backend_type):
|
412
|
+
x = torch.rand(3, 4)
|
413
|
+
for _ in range(2048):
|
414
|
+
x = x + torch.rand(3, 4)
|
415
|
+
fetch_shard(x).result()
|
416
|
+
|
417
|
+
def test_flattener(self, backend_type):
|
418
|
+
e = (8, 9, {"a": 10, "b": 11})
|
419
|
+
flatten = flattener(e)
|
420
|
+
e2 = (0, 1, {"a": 2, "b": 3})
|
421
|
+
assert [0, 1, 2, 3] == flatten(e2)
|
422
|
+
|
423
|
+
def test_torch_tensor(self, backend_type):
|
424
|
+
with self.local_device_mesh(2, 2, backend_type):
|
425
|
+
t = torch.tensor([1, 2, 4])
|
426
|
+
tc = torch.tensor([1, 2, 4], device="cuda")
|
427
|
+
t2 = fetch_shard(t).result()
|
428
|
+
tc2 = fetch_shard(tc).result()
|
429
|
+
assert torch.allclose(t2, torch.tensor([1, 2, 4]))
|
430
|
+
assert torch.allclose(tc2, torch.tensor([1, 2, 4], device="cpu"))
|
431
|
+
|
432
|
+
def test_to_mesh_aliasing(self, backend_type):
|
433
|
+
with self.local_device_mesh(2, 2, backend_type) as mesh:
|
434
|
+
p2p_stream = Stream("p2p_stream")
|
435
|
+
|
436
|
+
ppmesh = mesh.flatten("all").split(
|
437
|
+
all=(
|
438
|
+
"dp",
|
439
|
+
"pp",
|
440
|
+
),
|
441
|
+
pp=2,
|
442
|
+
)
|
443
|
+
pp_meshes = [ppmesh.slice(pp=i) for i in range(2)]
|
444
|
+
|
445
|
+
with ppmesh.activate():
|
446
|
+
with pp_meshes[0].activate():
|
447
|
+
x = torch.randn((3, 3), device="cuda")
|
448
|
+
x_borrowed_tensor, x_borrow = p2p_stream.borrow(x)
|
449
|
+
with p2p_stream.activate():
|
450
|
+
y_on_mesh_1_p2p_stream = x_borrowed_tensor.to_mesh(pp_meshes[1])
|
451
|
+
|
452
|
+
with pp_meshes[1].activate():
|
453
|
+
x_borrow.drop()
|
454
|
+
y_on_mesh_1_default_stream, y_borrow = (
|
455
|
+
monarch.get_active_stream().borrow(y_on_mesh_1_p2p_stream)
|
456
|
+
)
|
457
|
+
|
458
|
+
monarch.inspect(y_on_mesh_1_default_stream)
|
459
|
+
y_borrow.drop()
|
460
|
+
|
461
|
+
def test_to_mesh_cow(self, backend_type):
|
462
|
+
with self.local_device_mesh(2, 2, backend_type) as mesh:
|
463
|
+
t = torch.zeros((), device="cuda")
|
464
|
+
t2 = t.to_mesh(mesh)
|
465
|
+
t.add_(1)
|
466
|
+
assert monarch.inspect(t2).item() == 0
|
467
|
+
assert monarch.inspect(t).item() == 1
|
468
|
+
|
469
|
+
def test_to_mesh_stream(self, backend_type):
|
470
|
+
other = monarch.Stream("other")
|
471
|
+
with self.local_device_mesh(2, 2, backend_type) as mesh:
|
472
|
+
m0 = mesh.slice(host=0)
|
473
|
+
m1 = mesh.slice(host=1)
|
474
|
+
with m0.activate():
|
475
|
+
t2 = torch.rand(3, 4, device="cuda").to_mesh(m1, stream=other)
|
476
|
+
with m1.activate(), other.activate():
|
477
|
+
# assert doesn't fail
|
478
|
+
monarch.inspect(t2 + t2)
|
479
|
+
|
480
|
+
def test_dropped_trace(self, backend_type):
|
481
|
+
with self.local_device_mesh(2, 2, backend_type) as _:
|
482
|
+
x = torch.rand(4, 4).cuda()
|
483
|
+
s = Stream("other")
|
484
|
+
b, drop = s.borrow(x)
|
485
|
+
drop.drop()
|
486
|
+
with s.activate():
|
487
|
+
pattern = re.compile(
|
488
|
+
".*tensor.*is dropped at.*.*drop.drop().*", flags=re.DOTALL
|
489
|
+
)
|
490
|
+
with pytest.raises(TypeError, match=pattern):
|
491
|
+
_ = b.abs()
|
492
|
+
|
493
|
+
def test_sub_mesh_reduce(self, backend_type):
|
494
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
495
|
+
host1 = device_mesh.slice(host=1)
|
496
|
+
with host1.activate():
|
497
|
+
myrank = (
|
498
|
+
(device_mesh.rank("host") + 1) * 2 + device_mesh.rank("gpu") + 1
|
499
|
+
)
|
500
|
+
x = torch.ones((3, 4), device="cuda") * myrank
|
501
|
+
reduce = x.reduce("gpu", "sum")
|
502
|
+
local_reduce = fetch_shard(reduce).result()
|
503
|
+
|
504
|
+
assert torch.equal(local_reduce, torch.ones(3, 4) * 11)
|
505
|
+
|
506
|
+
def test_size(self, backend_type):
|
507
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
508
|
+
assert device_mesh.size(["host", "gpu"]) == 4
|
509
|
+
|
510
|
+
def test_random_state(self, backend_type):
|
511
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
512
|
+
monarch.random.make_deterministic()
|
513
|
+
for device in ("cpu", "cuda"):
|
514
|
+
a = monarch.random.get_state()
|
515
|
+
monarch.inspect(a)
|
516
|
+
first = torch.rand(1, device=device)
|
517
|
+
monarch.random.set_state(a)
|
518
|
+
second = torch.rand(1, device=device)
|
519
|
+
f, s = monarch.inspect((first, second))
|
520
|
+
with no_mesh.activate():
|
521
|
+
assert torch.allclose(f, s, atol=0, rtol=1)
|
522
|
+
seed = device_mesh.rank(["host", "gpu"]) + 4
|
523
|
+
s2 = monarch.random.new_state(seed)
|
524
|
+
s3 = monarch.random.new_state(seed)
|
525
|
+
monarch.random.set_state(s2)
|
526
|
+
r0 = torch.rand(1, device=device)
|
527
|
+
if device == "cuda":
|
528
|
+
for d in ("host", "gpu"):
|
529
|
+
r0 = r0.reduce(d, reduction="stack")
|
530
|
+
monarch.random.set_state(s3)
|
531
|
+
r1 = torch.rand(1, device=device)
|
532
|
+
if device == "cuda":
|
533
|
+
for d in ("host", "gpu"):
|
534
|
+
r1 = r1.reduce(d, reduction="stack")
|
535
|
+
r2, r3 = monarch.inspect((r0, r1))
|
536
|
+
monarch.random.set_state(a)
|
537
|
+
with no_mesh.activate():
|
538
|
+
assert torch.allclose(r2, r3, atol=0, rtol=0)
|
539
|
+
assert not torch.allclose(r2, f, atol=0, rtol=0)
|
540
|
+
|
541
|
+
def test_torch_op_with_optional_tensors(self, backend_type):
|
542
|
+
"""
|
543
|
+
This test ensures that for torch ops like LayerNorm, which allow for
|
544
|
+
optional tensor arguments, the controller serializes monarch tensors
|
545
|
+
correctly as Refs instead of as IValues.
|
546
|
+
"""
|
547
|
+
with self.local_device_mesh(2, 2, backend_type):
|
548
|
+
x = torch.rand(3, 4, device="cuda")
|
549
|
+
# When bias and elementwise_affine are true, extra tensors are passed through optional
|
550
|
+
# fields inside LayerNorm. When they are false, None is passed to the same optional fields.
|
551
|
+
# If we are handling serialization correctly, there shouldn't be a crash in either case.
|
552
|
+
layer_norm_with_vals = torch.nn.LayerNorm(
|
553
|
+
4, device="cuda", bias=True, elementwise_affine=True
|
554
|
+
)
|
555
|
+
layer_norm_with_none = torch.nn.LayerNorm(
|
556
|
+
4, device="cuda", bias=False, elementwise_affine=False
|
557
|
+
)
|
558
|
+
monarch.inspect(layer_norm_with_vals(x))
|
559
|
+
monarch.inspect(layer_norm_with_none(x))
|
560
|
+
|
561
|
+
def test_reduce_pytree(self, backend_type):
|
562
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
563
|
+
a = device_mesh.rank(("gpu", "host")) + torch.zeros((1,), device="cuda")
|
564
|
+
b = device_mesh.rank(("gpu", "host")) + torch.ones((1,), device="cuda")
|
565
|
+
|
566
|
+
tensor_dict = {"a": a, "b": b}
|
567
|
+
_ = monarch.reduce_(tensor_dict, dims=("gpu", "host"), reduction="sum")
|
568
|
+
reduced_tensor_dict = monarch.reduce(
|
569
|
+
tensor_dict, dims=("gpu", "host"), reduction="sum"
|
570
|
+
)
|
571
|
+
reduced_a = fetch_shard(reduced_tensor_dict["a"]).result()
|
572
|
+
reduced_b = fetch_shard(reduced_tensor_dict["b"]).result()
|
573
|
+
reduced_a_inplace = fetch_shard(tensor_dict["a"]).result()
|
574
|
+
reduced_b_inplace = fetch_shard(tensor_dict["b"]).result()
|
575
|
+
|
576
|
+
assert torch.equal(reduced_a_inplace, torch.tensor([6.0]))
|
577
|
+
assert torch.equal(reduced_b_inplace, torch.tensor([10.0]))
|
578
|
+
assert torch.equal(reduced_a, torch.tensor([24.0]))
|
579
|
+
assert torch.equal(reduced_b, torch.tensor([40.0]))
|
580
|
+
|
581
|
+
def test_to_mesh_pytree(self, backend_type):
|
582
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
583
|
+
host0 = device_mesh.slice(host=0)
|
584
|
+
host1 = device_mesh.slice(host=1)
|
585
|
+
|
586
|
+
with host0.activate():
|
587
|
+
a = torch.zeros((1,), device="cuda")
|
588
|
+
b = torch.ones((1,), device="cuda")
|
589
|
+
tensor_dict = {"a": a, "b": b}
|
590
|
+
moved_tensor_dict = monarch.to_mesh(tensor_dict, host1)
|
591
|
+
|
592
|
+
with host1.activate():
|
593
|
+
moved_tensor_dict["a"].add_(1)
|
594
|
+
moved_tensor_dict["b"].add_(1)
|
595
|
+
|
596
|
+
moved_tensor_a = monarch.inspect(moved_tensor_dict["a"])
|
597
|
+
moved_tensor_b = monarch.inspect(moved_tensor_dict["b"])
|
598
|
+
|
599
|
+
host0.exit()
|
600
|
+
host1.exit()
|
601
|
+
|
602
|
+
assert torch.equal(moved_tensor_a, torch.tensor([1.0]))
|
603
|
+
assert torch.equal(moved_tensor_b, torch.tensor([2.0]))
|
604
|
+
|
605
|
+
def test_hanging_error(self, backend_type):
|
606
|
+
if backend_type != "mesh":
|
607
|
+
pytest.skip("only relevant for mesh backend")
|
608
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
609
|
+
remote(lambda: torch.rand(3) + torch.rand(4), propagate=lambda: None)()
|
610
|
+
|
611
|
+
with pytest.raises(Exception, match="The size of tensor"):
|
612
|
+
device_mesh.client.shutdown()
|
613
|
+
|
614
|
+
def test_slice_mesh_pytree(self, backend_type):
|
615
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
616
|
+
a = device_mesh.rank(("host")) + torch.zeros((1,), device="cuda")
|
617
|
+
b = device_mesh.rank(("host")) + torch.ones((1,), device="cuda")
|
618
|
+
|
619
|
+
tensor_dict = {"a": a, "b": b}
|
620
|
+
host0_slices = monarch.slice_mesh(tensor_dict, host=0)
|
621
|
+
host1_slices = monarch.slice_mesh(tensor_dict, host=1)
|
622
|
+
|
623
|
+
host0 = device_mesh.slice(host=0)
|
624
|
+
host1 = device_mesh.slice(host=1)
|
625
|
+
|
626
|
+
host0_tensors = monarch.to_mesh(host0_slices, host0)
|
627
|
+
host1_tensors = monarch.to_mesh(host1_slices, host1)
|
628
|
+
|
629
|
+
with host0.activate():
|
630
|
+
_ = monarch.reduce_(host0_tensors, dims=("gpu"), reduction="sum")
|
631
|
+
host0_a = fetch_shard(host0_tensors["a"]).result()
|
632
|
+
host0_b = fetch_shard(host0_tensors["b"]).result()
|
633
|
+
|
634
|
+
with host1.activate():
|
635
|
+
_ = monarch.reduce_(host1_tensors, dims=("gpu"), reduction="sum")
|
636
|
+
host1_a = fetch_shard(host1_tensors["a"]).result()
|
637
|
+
host1_b = fetch_shard(host1_tensors["b"]).result()
|
638
|
+
|
639
|
+
host0.exit()
|
640
|
+
host1.exit()
|
641
|
+
|
642
|
+
assert torch.equal(host0_a, torch.tensor([0.0]))
|
643
|
+
assert torch.equal(host0_b, torch.tensor([2.0]))
|
644
|
+
assert torch.equal(host1_a, torch.tensor([2.0]))
|
645
|
+
assert torch.equal(host1_b, torch.tensor([4.0]))
|
646
|
+
|
647
|
+
|
648
|
+
def test_panicking_worker():
|
649
|
+
with pytest.raises(DeviceException, match="__test_panic called"):
|
650
|
+
with local_rust_device_mesh(1, 1) as _:
|
651
|
+
panic()
|
652
|
+
# induce a sync to allow the panic to propagate back
|
653
|
+
_ = fetch_shard(torch.ones(2, 3)).result()
|
654
|
+
|
655
|
+
|
656
|
+
def test_timeout_warning(caplog):
|
657
|
+
timeout = 3
|
658
|
+
with local_rust_device_mesh(
|
659
|
+
1,
|
660
|
+
2,
|
661
|
+
True,
|
662
|
+
controller_params=ControllerParams(1, timeout, 100, False),
|
663
|
+
) as dm:
|
664
|
+
for _ in range(3):
|
665
|
+
dm.client.new_node([], [])
|
666
|
+
|
667
|
+
assert dm.client.inner.next_message(timeout * 3) is None
|
668
|
+
|
669
|
+
remote_sleep(timeout * 2)
|
670
|
+
for _ in range(3):
|
671
|
+
dm.client.new_node([], [])
|
672
|
+
|
673
|
+
with caplog.at_level(logging.WARNING, logger=dm.client.__module__):
|
674
|
+
has_message = dm.client.handle_next_message(120)
|
675
|
+
assert has_message
|
676
|
+
assert (
|
677
|
+
f"ranks 1, 0 have operations that have not completed after {timeout} seconds"
|
678
|
+
in caplog.text
|
679
|
+
) or (
|
680
|
+
f"ranks 0, 1 have operations that have not completed after {timeout} seconds"
|
681
|
+
in caplog.text
|
682
|
+
)
|
683
|
+
|
684
|
+
|
685
|
+
def test_timeout_failure():
|
686
|
+
timeout = 3
|
687
|
+
with local_rust_device_mesh(
|
688
|
+
1,
|
689
|
+
1,
|
690
|
+
True,
|
691
|
+
controller_params=ControllerParams(1, timeout, 100, True),
|
692
|
+
) as dm:
|
693
|
+
for _ in range(3):
|
694
|
+
dm.client.new_node([], [])
|
695
|
+
|
696
|
+
assert dm.client.inner.next_message(timeout * 3) is None
|
697
|
+
|
698
|
+
remote_sleep(timeout * 2)
|
699
|
+
for _ in range(3):
|
700
|
+
dm.client.new_node([], [])
|
701
|
+
|
702
|
+
for _ in range(5):
|
703
|
+
result = dm.client.inner.next_message(1)
|
704
|
+
if result is None:
|
705
|
+
continue
|
706
|
+
if isinstance(result, LogMessage):
|
707
|
+
continue
|
708
|
+
if result.error is None:
|
709
|
+
continue
|
710
|
+
assert isinstance(result.error, DeviceException)
|
711
|
+
assert "crashed" in result.error.message in result.error.message
|
712
|
+
assert "mesh_0_worker[0].worker[0]" in result.error.message
|
713
|
+
assert (
|
714
|
+
f"ranks 0 have operations that have not completed after {timeout} seconds"
|
715
|
+
in result.error.frames[0].name
|
716
|
+
)
|
717
|
+
|
718
|
+
|
719
|
+
def test_supervision_heartbeat_failure():
|
720
|
+
(dms, bootstrap) = local_meshes_and_bootstraps(
|
721
|
+
meshes=1,
|
722
|
+
hosts_per_mesh=1,
|
723
|
+
gpus_per_host=2,
|
724
|
+
socket_type=SocketType.UNIX,
|
725
|
+
logging_location=LoggingLocation.DEFAULT,
|
726
|
+
supervision_params=SupervisionParams(
|
727
|
+
# Set a low timeout so heatbeat failure can be detected faster.
|
728
|
+
update_timeout_in_sec=10,
|
729
|
+
query_interval_in_sec=1,
|
730
|
+
update_interval_in_sec=1,
|
731
|
+
),
|
732
|
+
)
|
733
|
+
assert len(dms) == 1
|
734
|
+
dm = dms[0]
|
735
|
+
|
736
|
+
# Kill a process of a worker actor. This should trigger supervision
|
737
|
+
# heartbeat failure event.
|
738
|
+
# Index 0 and 1 are system process and controller process respectively.
|
739
|
+
process = bootstrap.processes[2]
|
740
|
+
process.kill()
|
741
|
+
|
742
|
+
for _ in range(20):
|
743
|
+
# poll the next message in order to get the supervision failure
|
744
|
+
result = dm.client.inner.next_message(3)
|
745
|
+
if result is None:
|
746
|
+
continue
|
747
|
+
if result.error is None:
|
748
|
+
continue
|
749
|
+
assert isinstance(result.error, DeviceException)
|
750
|
+
assert "crashed" in result.error.message
|
751
|
+
return
|
752
|
+
|
753
|
+
dm.exit()
|
754
|
+
raise AssertionError("Should have failed supervision health check")
|
755
|
+
|
756
|
+
|
757
|
+
def test_supervision_system_actor_down():
|
758
|
+
(dms, bootstrap) = local_meshes_and_bootstraps(
|
759
|
+
meshes=1,
|
760
|
+
hosts_per_mesh=1,
|
761
|
+
gpus_per_host=2,
|
762
|
+
socket_type=SocketType.UNIX,
|
763
|
+
logging_location=LoggingLocation.DEFAULT,
|
764
|
+
supervision_params=SupervisionParams(
|
765
|
+
# Set a low timeout so heatbeat failure can be detected faster.
|
766
|
+
update_timeout_in_sec=10,
|
767
|
+
query_interval_in_sec=1,
|
768
|
+
update_interval_in_sec=1,
|
769
|
+
),
|
770
|
+
)
|
771
|
+
assert len(dms) == 1
|
772
|
+
dm = dms[0]
|
773
|
+
|
774
|
+
# Index 0 is system process
|
775
|
+
process = bootstrap.processes[0]
|
776
|
+
process.kill()
|
777
|
+
|
778
|
+
try:
|
779
|
+
for _ in range(20):
|
780
|
+
# poll the next message in order to get the supervision failure
|
781
|
+
dm.client.inner.next_message(3)
|
782
|
+
except RuntimeError as e:
|
783
|
+
assert "actor has been stopped" in str(e)
|
784
|
+
return
|
785
|
+
|
786
|
+
dm.exit()
|
787
|
+
raise AssertionError("Should have failed supervision health check")
|
788
|
+
|
789
|
+
|
790
|
+
def test_supervision_controller_actor_down():
|
791
|
+
(dms, bootstrap) = local_meshes_and_bootstraps(
|
792
|
+
meshes=1,
|
793
|
+
hosts_per_mesh=1,
|
794
|
+
gpus_per_host=2,
|
795
|
+
socket_type=SocketType.UNIX,
|
796
|
+
logging_location=LoggingLocation.DEFAULT,
|
797
|
+
supervision_params=SupervisionParams(
|
798
|
+
# Set a low timeout so heatbeat failure can be detected faster.
|
799
|
+
update_timeout_in_sec=10,
|
800
|
+
query_interval_in_sec=1,
|
801
|
+
update_interval_in_sec=1,
|
802
|
+
),
|
803
|
+
)
|
804
|
+
assert len(dms) == 1
|
805
|
+
dm = dms[0]
|
806
|
+
|
807
|
+
# Index 1 is controller process
|
808
|
+
process = bootstrap.processes[1]
|
809
|
+
process.kill()
|
810
|
+
|
811
|
+
for _ in range(20):
|
812
|
+
# poll the next message in order to get the supervision failure
|
813
|
+
result = dm.client.inner.next_message(3)
|
814
|
+
if result is None:
|
815
|
+
continue
|
816
|
+
if result.error is None:
|
817
|
+
continue
|
818
|
+
assert isinstance(result.error, DeviceException)
|
819
|
+
assert "mesh_0_controller[0].controller[0] crashed" in result.error.message
|
820
|
+
return
|
821
|
+
|
822
|
+
dm.exit()
|
823
|
+
raise AssertionError("Should have failed supervision health check")
|
824
|
+
|
825
|
+
|
826
|
+
def a_function_called_by_a_live_function(x):
|
827
|
+
return 2 * x
|
828
|
+
|
829
|
+
|
830
|
+
def a_live_function_call_by_a_live_function(x):
|
831
|
+
return 3 * x
|
832
|
+
|
833
|
+
|
834
|
+
def test_delete_refs():
|
835
|
+
with local_mesh(
|
836
|
+
hosts=2,
|
837
|
+
gpus_per_host=2,
|
838
|
+
socket_type=SocketType.UNIX,
|
839
|
+
logging_location=LoggingLocation.DEFAULT,
|
840
|
+
) as dm:
|
841
|
+
dm.client.delete_ref(dm, 1)
|
842
|
+
dm.client.delete_ref(dm, 2)
|
843
|
+
assert len(dm.client._pending_del[dm]) == 2
|
844
|
+
dm.client.flush_deletes()
|
845
|
+
assert len(dm.client._pending_del[dm]) == 0
|