torchmonarch-nightly 2025.6.4__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monarch/__init__.py +189 -0
- monarch/_monarch/__init__.py +5 -0
- monarch/_monarch/hyperactor/__init__.py +74 -0
- monarch/_monarch/selection/__init__.py +13 -0
- monarch/_monarch/worker/__init__.py +0 -0
- monarch/_monarch/worker/debugger.py +117 -0
- monarch/_monarch/worker/logging.py +107 -0
- monarch/_rust_bindings.so +0 -0
- monarch/_testing.py +198 -0
- monarch/actor_mesh.py +692 -0
- monarch/allocator.py +62 -0
- monarch/bootstrap_main.py +75 -0
- monarch/builtins/__init__.py +14 -0
- monarch/builtins/log.py +22 -0
- monarch/builtins/random.py +69 -0
- monarch/cached_remote_function.py +257 -0
- monarch/common/_C.pyi +11 -0
- monarch/common/_C.so +0 -0
- monarch/common/__init__.py +0 -0
- monarch/common/_coalescing.py +308 -0
- monarch/common/_device_utils.py +18 -0
- monarch/common/_tensor_to_table.py +172 -0
- monarch/common/base_tensor.py +28 -0
- monarch/common/borrows.py +143 -0
- monarch/common/client.py +646 -0
- monarch/common/constants.py +10 -0
- monarch/common/context_manager.py +40 -0
- monarch/common/controller_api.py +104 -0
- monarch/common/device_mesh.py +443 -0
- monarch/common/fake.py +55 -0
- monarch/common/function.py +160 -0
- monarch/common/function_caching.py +164 -0
- monarch/common/future.py +168 -0
- monarch/common/invocation.py +125 -0
- monarch/common/mast.py +221 -0
- monarch/common/messages.py +572 -0
- monarch/common/mock_cuda.py +41 -0
- monarch/common/opaque_ref.py +98 -0
- monarch/common/pickle_flatten.py +48 -0
- monarch/common/pipe.py +152 -0
- monarch/common/process_group.py +55 -0
- monarch/common/recording.py +127 -0
- monarch/common/reference.py +33 -0
- monarch/common/remote.py +304 -0
- monarch/common/selection.py +9 -0
- monarch/common/shape.py +204 -0
- monarch/common/stream.py +111 -0
- monarch/common/tensor.py +793 -0
- monarch/common/tensor_factory.py +31 -0
- monarch/common/tree.py +73 -0
- monarch/controller/__init__.py +7 -0
- monarch/controller/backend.py +223 -0
- monarch/controller/controller.py +223 -0
- monarch/controller/debugger.py +47 -0
- monarch/controller/history.py +90 -0
- monarch/controller/rust_backend/__init__.py +7 -0
- monarch/controller/rust_backend/controller.py +245 -0
- monarch/fetch.py +55 -0
- monarch/future.py +25 -0
- monarch/gradient/__init__.py +11 -0
- monarch/gradient/_gradient_generator.pyi +22 -0
- monarch/gradient/_gradient_generator.so +0 -0
- monarch/gradient_generator.py +185 -0
- monarch/memory.py +43 -0
- monarch/monarch_controller +0 -0
- monarch/notebook.py +761 -0
- monarch/opaque_module.py +235 -0
- monarch/opaque_object.py +88 -0
- monarch/parallel/__init__.py +9 -0
- monarch/parallel/pipelining/__init__.py +7 -0
- monarch/parallel/pipelining/runtime.py +847 -0
- monarch/parallel/pipelining/schedule_ir.py +692 -0
- monarch/parallel/pipelining/scheduler.py +249 -0
- monarch/proc_mesh.py +188 -0
- monarch/profiler.py +160 -0
- monarch/python_local_mesh.py +107 -0
- monarch/random.py +61 -0
- monarch/rdma.py +190 -0
- monarch/remote_class.py +114 -0
- monarch/rust_backend_mesh.py +280 -0
- monarch/rust_local_mesh.py +1402 -0
- monarch/sim_mesh.py +357 -0
- monarch/simulator/__init__.py +7 -0
- monarch/simulator/command_history.py +424 -0
- monarch/simulator/config.py +21 -0
- monarch/simulator/interface.py +59 -0
- monarch/simulator/ir.py +770 -0
- monarch/simulator/mock_controller.py +214 -0
- monarch/simulator/profiling.py +424 -0
- monarch/simulator/simulator.py +1052 -0
- monarch/simulator/task.py +255 -0
- monarch/simulator/tensor.py +373 -0
- monarch/simulator/trace.py +395 -0
- monarch/simulator/utils.py +41 -0
- monarch/simulator/worker.py +389 -0
- monarch/tensor_worker_main.py +260 -0
- monarch/tensorboard.py +84 -0
- monarch/timer/__init__.py +21 -0
- monarch/timer/example_monarch.py +78 -0
- monarch/timer/example_spmd.py +55 -0
- monarch/timer/execution_timer.py +199 -0
- monarch/timer/execution_timer_test.py +131 -0
- monarch/tools/__init__.py +7 -0
- monarch/tools/cli.py +167 -0
- monarch/tools/commands.py +189 -0
- monarch/tools/components/__init__.py +7 -0
- monarch/tools/components/hyperactor.py +57 -0
- monarch/tools/config/__init__.py +20 -0
- monarch/tools/config/defaults.py +54 -0
- monarch/tools/mesh_spec.py +121 -0
- monarch/worker/__init__.py +7 -0
- monarch/worker/_testing_function.py +481 -0
- monarch/worker/compiled_block.py +270 -0
- monarch/worker/debugger.py +125 -0
- monarch/worker/lines.py +47 -0
- monarch/worker/monitor.py +53 -0
- monarch/worker/worker.py +1191 -0
- monarch/world_mesh.py +34 -0
- monarch_supervisor/__init__.py +1044 -0
- monarch_supervisor/_testing.py +44 -0
- monarch_supervisor/function_call.py +30 -0
- monarch_supervisor/host.py +386 -0
- monarch_supervisor/launchers.py +145 -0
- monarch_supervisor/log_pstree.py +48 -0
- monarch_supervisor/logging.py +103 -0
- monarch_supervisor/python_executable.py +42 -0
- tests/__init__.py +0 -0
- tests/dispatch_bench.py +124 -0
- tests/dispatch_bench_helper.py +25 -0
- tests/error_test_binary.py +139 -0
- tests/simulator/__init__.py +0 -0
- tests/simulator/test_profiling.py +136 -0
- tests/simulator/test_simulator.py +411 -0
- tests/simulator/test_task.py +64 -0
- tests/simulator/test_worker.py +102 -0
- tests/sleep_binary.py +35 -0
- tests/test_actor_error.py +112 -0
- tests/test_alloc.py +25 -0
- tests/test_coalescing.py +492 -0
- tests/test_controller.py +835 -0
- tests/test_device_mesh.py +132 -0
- tests/test_fault_tolerance.py +398 -0
- tests/test_future.py +94 -0
- tests/test_grad_generator.py +121 -0
- tests/test_mock_cuda.py +74 -0
- tests/test_pdb_actor.py +110 -0
- tests/test_python_actors.py +372 -0
- tests/test_remote_functions.py +1271 -0
- tests/test_rust_backend.py +182 -0
- tests/test_signal_safe_block_on.py +103 -0
- tests/test_sim_backend.py +54 -0
- torchmonarch_nightly-2025.6.4.dist-info/METADATA +94 -0
- torchmonarch_nightly-2025.6.4.dist-info/RECORD +157 -0
- torchmonarch_nightly-2025.6.4.dist-info/WHEEL +5 -0
- torchmonarch_nightly-2025.6.4.dist-info/entry_points.txt +3 -0
- torchmonarch_nightly-2025.6.4.dist-info/licenses/LICENSE +29 -0
- torchmonarch_nightly-2025.6.4.dist-info/top_level.txt +3 -0
@@ -0,0 +1,1271 @@
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2
|
+
# All rights reserved.
|
3
|
+
#
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
6
|
+
|
7
|
+
# pyre-unsafe
|
8
|
+
import itertools
|
9
|
+
import math
|
10
|
+
import sys
|
11
|
+
import traceback
|
12
|
+
from enum import Enum
|
13
|
+
from typing import Callable, ContextManager, Tuple
|
14
|
+
from unittest.mock import patch
|
15
|
+
|
16
|
+
import monarch
|
17
|
+
import pytest
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from monarch import (
|
21
|
+
fetch_shard,
|
22
|
+
inspect,
|
23
|
+
no_mesh,
|
24
|
+
OpaqueRef,
|
25
|
+
Pipe,
|
26
|
+
remote,
|
27
|
+
remote_generator,
|
28
|
+
RemoteException,
|
29
|
+
Stream,
|
30
|
+
)
|
31
|
+
from monarch._testing import BackendType, TestingContext
|
32
|
+
from monarch.builtins.log import log_remote
|
33
|
+
from monarch.builtins.random import set_manual_seed_remote
|
34
|
+
from monarch.cached_remote_function import remote_autograd_function
|
35
|
+
from monarch.common import remote as remote_module
|
36
|
+
from monarch.common.device_mesh import DeviceMesh
|
37
|
+
from monarch.common.remote import Remote
|
38
|
+
|
39
|
+
from monarch.opaque_module import OpaqueModule
|
40
|
+
from monarch.opaque_object import opaque_method, OpaqueObject
|
41
|
+
from monarch.worker._testing_function import (
|
42
|
+
all_gather,
|
43
|
+
all_gather_into_tensor,
|
44
|
+
all_reduce,
|
45
|
+
all_to_all,
|
46
|
+
all_to_all_single,
|
47
|
+
barrier,
|
48
|
+
broadcast,
|
49
|
+
gather,
|
50
|
+
irecv,
|
51
|
+
isend,
|
52
|
+
reduce,
|
53
|
+
reduce_scatter,
|
54
|
+
reduce_scatter_tensor,
|
55
|
+
scatter,
|
56
|
+
)
|
57
|
+
from monarch_supervisor.logging import fix_exception_lines
|
58
|
+
from torch.distributed import ReduceOp
|
59
|
+
|
60
|
+
|
61
|
+
def custom_excepthook(exc_type, exc_value, exc_traceback):
|
62
|
+
tb_lines = fix_exception_lines(
|
63
|
+
traceback.format_exception(exc_type, exc_value, exc_traceback)
|
64
|
+
)
|
65
|
+
print("\n".join(tb_lines), file=sys.stderr)
|
66
|
+
|
67
|
+
|
68
|
+
sys.excepthook = custom_excepthook
|
69
|
+
|
70
|
+
|
71
|
+
def _set_device_udf(*args):
|
72
|
+
return torch.zeros(1)
|
73
|
+
|
74
|
+
|
75
|
+
set_device_udf = remote(
|
76
|
+
"monarch.worker._testing_function.set_device_udf_worker", propagate=_set_device_udf
|
77
|
+
)
|
78
|
+
|
79
|
+
rlist = remote("builtins.list", propagate=lambda elem: elem)
|
80
|
+
|
81
|
+
|
82
|
+
def _do_bogus_tensor_work(x, y, fail_rank=None):
|
83
|
+
return x + y # real function actually does x @ y
|
84
|
+
|
85
|
+
|
86
|
+
do_bogus_tensor_work = remote(
|
87
|
+
"monarch.worker._testing_function.do_bogus_tensor_work",
|
88
|
+
propagate=_do_bogus_tensor_work,
|
89
|
+
)
|
90
|
+
|
91
|
+
|
92
|
+
@remote_generator("monarch.worker._testing_function.example_echo_add")
|
93
|
+
def example_echo_add(p: "Pipe"):
|
94
|
+
while True:
|
95
|
+
yield p.recv() + 1
|
96
|
+
|
97
|
+
|
98
|
+
@remote_generator("monarch.worker._testing_function.example_data_loader")
|
99
|
+
def example_data_loader(p: "Pipe", x, y):
|
100
|
+
for _i in range(x, y):
|
101
|
+
yield torch.zeros(())
|
102
|
+
|
103
|
+
|
104
|
+
@remote_generator(
|
105
|
+
"monarch.worker._testing_function.example_data_loader_small_pipe",
|
106
|
+
max_messages=1,
|
107
|
+
)
|
108
|
+
def example_data_loader_small_pipe(p: "Pipe", iters: int, shape: Tuple[int, int]):
|
109
|
+
for _i in range(iters):
|
110
|
+
yield torch.zeros(shape)
|
111
|
+
|
112
|
+
|
113
|
+
sleep = remote("monarch.worker._testing_function.remote_sleep", propagate="inspect")
|
114
|
+
|
115
|
+
new_barrier_hackery = remote(
|
116
|
+
"monarch.worker._testing_function.new_barrier_hackery",
|
117
|
+
propagate=lambda threads: torch.zeros(1),
|
118
|
+
)
|
119
|
+
|
120
|
+
wait_barrier_hackery = remote(
|
121
|
+
"monarch.worker._testing_function.wait_barrier_hackery",
|
122
|
+
propagate=lambda t: None,
|
123
|
+
)
|
124
|
+
|
125
|
+
setup_state = remote(
|
126
|
+
"monarch.worker._testing_function.setup_state_worker",
|
127
|
+
propagate=lambda: [OpaqueRef(None) for _ in range(4)],
|
128
|
+
)
|
129
|
+
|
130
|
+
iteration = remote(
|
131
|
+
"monarch.worker._testing_function.iteration_worker",
|
132
|
+
propagate=lambda model, dataloader, criterion, optimizer, pg: torch.zeros(1),
|
133
|
+
)
|
134
|
+
|
135
|
+
opaque_ref_key_table_length = remote(
|
136
|
+
"monarch.worker._testing_function.opaque_ref_key_table_length_worker",
|
137
|
+
propagate=lambda: torch.zeros(1),
|
138
|
+
)
|
139
|
+
|
140
|
+
create_opaque_ref = remote(
|
141
|
+
"monarch.worker._testing_function.create_opaque_ref_worker",
|
142
|
+
propagate=lambda: OpaqueRef(None),
|
143
|
+
)
|
144
|
+
|
145
|
+
outer_remote_function_that_calls_inner = remote(
|
146
|
+
"monarch.worker._testing_function.outer_remote_function_that_calls_inner",
|
147
|
+
propagate=lambda: torch.zeros(1),
|
148
|
+
)
|
149
|
+
|
150
|
+
|
151
|
+
@pytest.fixture(scope="module", autouse=True)
|
152
|
+
def testing_context():
|
153
|
+
global local
|
154
|
+
with TestingContext() as local:
|
155
|
+
yield
|
156
|
+
|
157
|
+
|
158
|
+
class RemoteFunctionsTestBase:
|
159
|
+
@classmethod
|
160
|
+
def local_device_mesh(
|
161
|
+
cls,
|
162
|
+
num_hosts: int,
|
163
|
+
gpu_per_host: int,
|
164
|
+
backend_type: BackendType,
|
165
|
+
activate: bool = True,
|
166
|
+
) -> ContextManager[DeviceMesh]:
|
167
|
+
# pyre-fixme[10]: pytest defines this fixture.
|
168
|
+
return local.local_device_mesh(
|
169
|
+
num_hosts,
|
170
|
+
gpu_per_host,
|
171
|
+
activate,
|
172
|
+
rust=backend_type == BackendType.RS,
|
173
|
+
)
|
174
|
+
|
175
|
+
|
176
|
+
@pytest.mark.skipif(
|
177
|
+
torch.cuda.device_count() < 2,
|
178
|
+
reason="Not enough GPUs, this test requires at least 2 GPUs",
|
179
|
+
)
|
180
|
+
# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
|
181
|
+
# out is not counted as a failure, so we set a more restrictive timeout to
|
182
|
+
# ensure we see a hard failure in CI.
|
183
|
+
@pytest.mark.timeout(120)
|
184
|
+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
|
185
|
+
class TestRemoteFunctions(RemoteFunctionsTestBase):
|
186
|
+
@classmethod
|
187
|
+
def do_test_reduce_scatter_tensor(cls, backend_type, reduce_op, expected_tensor):
|
188
|
+
n_gpus = 2
|
189
|
+
with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
|
190
|
+
rank = device_mesh.rank("host") * n_gpus + device_mesh.rank("gpu")
|
191
|
+
tensor_in = rank * torch.arange(0, 8, device="cuda", dtype=float).reshape(
|
192
|
+
4, 2
|
193
|
+
)
|
194
|
+
tensor_out = torch.arange(2, device="cuda", dtype=float)
|
195
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
196
|
+
|
197
|
+
reduce_scatter_tensor(tensor_out, tensor_in, op=reduce_op, group=pg)
|
198
|
+
|
199
|
+
for host in range(2):
|
200
|
+
for gpu in range(n_gpus):
|
201
|
+
rank = 2 * host + gpu
|
202
|
+
local_tensor_out = inspect(tensor_out, {"host": host, "gpu": gpu})
|
203
|
+
with no_mesh.activate():
|
204
|
+
assert torch.equal(
|
205
|
+
local_tensor_out,
|
206
|
+
expected_tensor[rank],
|
207
|
+
)
|
208
|
+
|
209
|
+
@classmethod
|
210
|
+
def do_test_reduce_scatter_tensor_subgroup(
|
211
|
+
cls,
|
212
|
+
backend_type: BackendType,
|
213
|
+
reduce_op,
|
214
|
+
expected_tensor_host_group: torch.Tensor,
|
215
|
+
expected_tensor_gpu_group: torch.Tensor,
|
216
|
+
) -> None:
|
217
|
+
n_gpus = 2
|
218
|
+
with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
|
219
|
+
# Use a group smaller than the world size.
|
220
|
+
host_pg = device_mesh.process_group("host")
|
221
|
+
gpu_pg = device_mesh.process_group("gpu")
|
222
|
+
# host_rank = device_mesh.rank("host")
|
223
|
+
# gpu_rank = device_mesh.rank("gpu")
|
224
|
+
rank = device_mesh.rank(("host", "gpu"))
|
225
|
+
|
226
|
+
tensor_in = rank * torch.arange(
|
227
|
+
0, 8, device="cuda", dtype=torch.float32
|
228
|
+
).reshape(4, 2)
|
229
|
+
|
230
|
+
gpu_tensor_out = torch.zeros(4, device="cuda", dtype=torch.float32)
|
231
|
+
reduce_scatter_tensor(gpu_tensor_out, tensor_in, op=reduce_op, group=gpu_pg)
|
232
|
+
|
233
|
+
tensor_in = rank * torch.arange(
|
234
|
+
0, 8, device="cuda", dtype=torch.float32
|
235
|
+
).reshape(4, 2)
|
236
|
+
host_tensor_out = torch.zeros(4, device="cuda", dtype=torch.float32)
|
237
|
+
reduce_scatter_tensor(
|
238
|
+
host_tensor_out, tensor_in, op=reduce_op, group=host_pg
|
239
|
+
)
|
240
|
+
|
241
|
+
for host in range(2):
|
242
|
+
for gpu in range(n_gpus):
|
243
|
+
rank = host * 2 + gpu
|
244
|
+
local_gpu_tensor_out = inspect(
|
245
|
+
gpu_tensor_out, {"host": host, "gpu": gpu}
|
246
|
+
)
|
247
|
+
local_host_tensor_out = inspect(
|
248
|
+
host_tensor_out, {"host": host, "gpu": gpu}
|
249
|
+
)
|
250
|
+
with no_mesh.activate():
|
251
|
+
assert torch.equal(
|
252
|
+
local_host_tensor_out,
|
253
|
+
expected_tensor_host_group[rank],
|
254
|
+
), f"{rank=}, {host=}, {gpu=}"
|
255
|
+
assert torch.equal(
|
256
|
+
local_gpu_tensor_out,
|
257
|
+
expected_tensor_gpu_group[rank],
|
258
|
+
), f"{rank=}, {host=}, {gpu=}"
|
259
|
+
|
260
|
+
@classmethod
|
261
|
+
def do_test_reduce_scatter(
|
262
|
+
cls,
|
263
|
+
backend_type: BackendType,
|
264
|
+
reduce_op: ReduceOp,
|
265
|
+
expected_tensor: torch.Tensor,
|
266
|
+
) -> None:
|
267
|
+
n_gpus = 2
|
268
|
+
with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
|
269
|
+
rank = device_mesh.rank("host") * n_gpus + device_mesh.rank("gpu")
|
270
|
+
tensor_in = rank * torch.arange(0, 8, device="cuda", dtype=torch.float32)
|
271
|
+
tensor_out = torch.arange(2, device="cuda", dtype=torch.float32)
|
272
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
273
|
+
|
274
|
+
tensor_out = reduce_scatter(
|
275
|
+
tensor_out,
|
276
|
+
list(torch.chunk(tensor_in, 2 * n_gpus)),
|
277
|
+
op=reduce_op,
|
278
|
+
group=pg,
|
279
|
+
)
|
280
|
+
|
281
|
+
for host in range(2):
|
282
|
+
for gpu in range(n_gpus):
|
283
|
+
rank = 2 * host + gpu
|
284
|
+
local_tensor_out = inspect(tensor_out, {"host": host, "gpu": gpu})
|
285
|
+
with no_mesh.activate():
|
286
|
+
assert torch.equal(
|
287
|
+
local_tensor_out,
|
288
|
+
expected_tensor[rank],
|
289
|
+
)
|
290
|
+
|
291
|
+
@classmethod
|
292
|
+
def do_test_all_reduce(cls, backend_type, reduce_op, expected_tensor):
|
293
|
+
n_gpus = 2
|
294
|
+
with cls.local_device_mesh(2, n_gpus, backend_type) as device_mesh:
|
295
|
+
rank = device_mesh.rank(("host", "gpu"))
|
296
|
+
tensor_in = rank * torch.arange(0, 8, device="cuda", dtype=float).reshape(
|
297
|
+
4, 2
|
298
|
+
)
|
299
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
300
|
+
|
301
|
+
tensor_out = all_reduce(tensor_in, op=reduce_op, group=pg)
|
302
|
+
|
303
|
+
for host in range(2):
|
304
|
+
for gpu in range(n_gpus):
|
305
|
+
local_tensor_out = inspect(tensor_out, {"host": host, "gpu": gpu})
|
306
|
+
with no_mesh.activate():
|
307
|
+
assert torch.equal(
|
308
|
+
local_tensor_out,
|
309
|
+
expected_tensor,
|
310
|
+
)
|
311
|
+
|
312
|
+
def test_hello(self, backend_type):
|
313
|
+
with self.local_device_mesh(2, 2, backend_type):
|
314
|
+
log_remote("hello, world")
|
315
|
+
|
316
|
+
def test_eager_remote_function_failed(self, backend_type):
|
317
|
+
if backend_type == BackendType.PY:
|
318
|
+
pytest.skip("Python support not planned for this test")
|
319
|
+
with self.local_device_mesh(1, 2, backend_type) as _:
|
320
|
+
x = torch.rand(3, 4)
|
321
|
+
y = torch.rand(3, 4)
|
322
|
+
z = do_bogus_tensor_work(x, y, fail_rank=1)
|
323
|
+
a = z + x
|
324
|
+
with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
|
325
|
+
# NCCL init is slow, and fails on internal RE!
|
326
|
+
_ = fetch_shard(a).result(timeout=40)
|
327
|
+
|
328
|
+
def test_set_device_inside_udf_fails_with_explanation(self, backend_type):
|
329
|
+
if backend_type == BackendType.PY:
|
330
|
+
pytest.skip("Python support not planned for this test")
|
331
|
+
with self.local_device_mesh(2, 2, backend_type):
|
332
|
+
t = set_device_udf(2)
|
333
|
+
try:
|
334
|
+
inspect(t)
|
335
|
+
except RemoteException as e:
|
336
|
+
backtrace = "\n".join([frame.name for frame in e.worker_frames])
|
337
|
+
assert "are available to monarch worker" in backtrace
|
338
|
+
|
339
|
+
def test_simple_tensors(self, backend_type):
|
340
|
+
with self.local_device_mesh(2, 2, backend_type):
|
341
|
+
x = torch.rand(3, 4)
|
342
|
+
y = x + x
|
343
|
+
log_remote("%s %s", x, y)
|
344
|
+
z = torch.std_mean(x)
|
345
|
+
log_remote("%s", z)
|
346
|
+
|
347
|
+
def test_user_call(self, backend_type):
|
348
|
+
with self.local_device_mesh(2, 2, backend_type) as _:
|
349
|
+
x = torch.rand(3, 4)
|
350
|
+
y = rlist((x + 1, x))
|
351
|
+
log_remote("%s", y)
|
352
|
+
|
353
|
+
# resume monday:
|
354
|
+
# 1. tensor ctor resource guard (done)
|
355
|
+
# 2. __torch_dispatch__ forward of normal ops (done)
|
356
|
+
# 3. collectives created for device mesh
|
357
|
+
# 4. implement comms APIs
|
358
|
+
# 5. transfer tensor back, and simple future to wait for result.
|
359
|
+
|
360
|
+
def test_remote_function_with_comms_full_mesh(self, backend_type):
|
361
|
+
nGPUs = 2
|
362
|
+
with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
|
363
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
364
|
+
myrank = (
|
365
|
+
(device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
|
366
|
+
)
|
367
|
+
x = torch.ones((3, 4), device="cuda") * myrank
|
368
|
+
|
369
|
+
reduce = all_reduce(x, group=pg)
|
370
|
+
local_reduce = fetch_shard(reduce).result()
|
371
|
+
assert torch.equal(local_reduce, torch.ones(3, 4) * 18)
|
372
|
+
|
373
|
+
def test_remote_function_with_comms_by_dimension(self, backend_type):
|
374
|
+
nGPUs = 2
|
375
|
+
with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
|
376
|
+
pg = device_mesh.process_group(("gpu",))
|
377
|
+
myrank = (
|
378
|
+
(device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
|
379
|
+
)
|
380
|
+
x = torch.ones((3, 4), device="cuda") * myrank
|
381
|
+
reduce = all_reduce(x, group=pg)
|
382
|
+
local_reduce_host_0 = fetch_shard(reduce).result()
|
383
|
+
local_reduce_host_1 = fetch_shard(reduce, {"gpu": 1, "host": 1}).result()
|
384
|
+
assert torch.equal(local_reduce_host_0, torch.ones(3, 4) * 7)
|
385
|
+
assert torch.equal(local_reduce_host_1, torch.ones(3, 4) * 11)
|
386
|
+
|
387
|
+
with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
|
388
|
+
pg = device_mesh.process_group(("host",))
|
389
|
+
myrank = (
|
390
|
+
(device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
|
391
|
+
)
|
392
|
+
x = torch.ones((3, 4), device="cuda") * myrank
|
393
|
+
reduce = all_reduce(x, group=pg)
|
394
|
+
local_reduce_gpu_0 = fetch_shard(reduce).result()
|
395
|
+
local_reduce_gpu_2 = fetch_shard(reduce, {"gpu": 1, "host": 0}).result()
|
396
|
+
assert torch.equal(local_reduce_gpu_0, torch.ones(3, 4) * 8)
|
397
|
+
|
398
|
+
assert torch.equal(local_reduce_gpu_2, torch.ones(3, 4) * 10)
|
399
|
+
|
400
|
+
def test_remote_function_with_comms_sub_mesh(self, backend_type):
|
401
|
+
nGPUs = 2
|
402
|
+
with self.local_device_mesh(
|
403
|
+
2, nGPUs, backend_type, activate=False
|
404
|
+
) as device_mesh:
|
405
|
+
host1 = device_mesh(host=1)
|
406
|
+
with host1.activate():
|
407
|
+
pg = device_mesh.process_group(("gpu",))
|
408
|
+
myrank = (
|
409
|
+
(device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
|
410
|
+
)
|
411
|
+
x = torch.ones((3, 4), device="cuda") * myrank
|
412
|
+
reduce = all_reduce(x, group=pg)
|
413
|
+
local_reduce = fetch_shard(reduce).result()
|
414
|
+
|
415
|
+
assert torch.equal(local_reduce, torch.ones(3, 4) * 11)
|
416
|
+
|
417
|
+
host0 = device_mesh(host=0)
|
418
|
+
with host0.activate():
|
419
|
+
pg = device_mesh.process_group(("gpu",))
|
420
|
+
myrank = (
|
421
|
+
(device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
|
422
|
+
)
|
423
|
+
x = torch.ones((3, 4), device="cuda") * myrank
|
424
|
+
reduce = all_reduce(x, group=pg)
|
425
|
+
local_reduce = fetch_shard(reduce).result()
|
426
|
+
|
427
|
+
assert torch.equal(local_reduce, torch.ones(3, 4) * 7)
|
428
|
+
|
429
|
+
def test_remote_exception(self, backend_type):
|
430
|
+
with self.local_device_mesh(2, 2, backend_type) as _:
|
431
|
+
x = torch.rand(3, 4)
|
432
|
+
y = torch.rand(3, 4)
|
433
|
+
z = do_bogus_tensor_work(x, y)
|
434
|
+
a = z + x
|
435
|
+
b = x + y
|
436
|
+
with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
|
437
|
+
# NCCL init is slow, and fails on internal RE!
|
438
|
+
_ = fetch_shard(a).result(timeout=20)
|
439
|
+
# but values not dependent on z are fine
|
440
|
+
fetch_shard(b).result(timeout=10)
|
441
|
+
|
442
|
+
def test_remote_function_barrier(self, backend_type):
|
443
|
+
if backend_type == BackendType.PY:
|
444
|
+
pytest.skip("FIXME: Python support for this function")
|
445
|
+
nGPUs = 2
|
446
|
+
with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
|
447
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
448
|
+
finished = barrier(group=pg)
|
449
|
+
local = fetch_shard(finished).result()
|
450
|
+
assert local.item() == 1.0
|
451
|
+
|
452
|
+
def test_remote_function_all_gather(self, backend_type: BackendType) -> None:
|
453
|
+
nGPUs = 2
|
454
|
+
with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
|
455
|
+
myrank = (
|
456
|
+
(device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
|
457
|
+
)
|
458
|
+
# Don't start at zero to ensure there are no leftover zeros.
|
459
|
+
tensor_in = torch.arange(1, 3, device="cuda") * myrank
|
460
|
+
world_size = 2 * nGPUs
|
461
|
+
tensor_out = list(
|
462
|
+
torch.zeros(2 * world_size, dtype=torch.int64, device="cuda").chunk(
|
463
|
+
world_size
|
464
|
+
)
|
465
|
+
)
|
466
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
467
|
+
|
468
|
+
tensor_out = all_gather(tensor_out, tensor_in, group=pg)
|
469
|
+
local_tensor_out = inspect(tensor_out)
|
470
|
+
|
471
|
+
t0, t1, t2, t3 = local_tensor_out
|
472
|
+
assert torch.equal(t0, torch.tensor([3, 6]))
|
473
|
+
assert torch.equal(t1, torch.tensor([4, 8]))
|
474
|
+
assert torch.equal(t2, torch.tensor([5, 10]))
|
475
|
+
assert torch.equal(t3, torch.tensor([6, 12]))
|
476
|
+
|
477
|
+
def test_remote_function_all_gather_into_tensor(self, backend_type):
|
478
|
+
nGPUs = 2
|
479
|
+
with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
|
480
|
+
myrank = (
|
481
|
+
(device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank("gpu") + 1
|
482
|
+
)
|
483
|
+
# Don't start at zero to ensure there are no leftover zeros.
|
484
|
+
tensor_in = torch.arange(1, 3, device="cuda") * myrank
|
485
|
+
tensor_out = torch.zeros(2 * nGPUs * 2, dtype=torch.int64, device="cuda")
|
486
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
487
|
+
|
488
|
+
finished = all_gather_into_tensor(tensor_out, tensor_in, group=pg)
|
489
|
+
local_finished = inspect(finished)
|
490
|
+
local_tensor_out = inspect(tensor_out)
|
491
|
+
|
492
|
+
assert local_finished.item() == 1.0
|
493
|
+
assert torch.equal(local_tensor_out, torch.tensor([3, 6, 4, 8, 5, 10, 6, 12]))
|
494
|
+
|
495
|
+
def test_remote_function_isend(self, backend_type):
|
496
|
+
nGPUs = 2
|
497
|
+
with self.local_device_mesh(2, nGPUs, backend_type) as device_mesh:
|
498
|
+
pg = device_mesh.process_group(("host",))
|
499
|
+
host_0_mesh = device_mesh(host=0)
|
500
|
+
host_1_mesh = device_mesh(host=1)
|
501
|
+
with host_0_mesh.activate():
|
502
|
+
to_rank = (device_mesh.rank("host") + 1) * nGPUs + device_mesh.rank(
|
503
|
+
"gpu"
|
504
|
+
)
|
505
|
+
t0 = torch.ones(1, device="cuda")
|
506
|
+
finished0 = isend(t0, to_rank, group=pg)
|
507
|
+
with host_1_mesh.activate():
|
508
|
+
from_rank = (device_mesh.rank("host") - 1) * nGPUs + device_mesh.rank(
|
509
|
+
"gpu"
|
510
|
+
)
|
511
|
+
t1 = torch.zeros(1, device="cuda")
|
512
|
+
finished1 = irecv(t1, from_rank, group=pg)
|
513
|
+
|
514
|
+
with host_0_mesh.activate():
|
515
|
+
local_finished_0 = inspect(finished0)
|
516
|
+
with host_1_mesh.activate():
|
517
|
+
local_finished_1 = inspect(finished1)
|
518
|
+
assert local_finished_0.item() == 1.0
|
519
|
+
assert local_finished_1.item() == 1.0
|
520
|
+
|
521
|
+
def test_distributed_error(self, backend_type):
|
522
|
+
with self.local_device_mesh(2, 2, backend_type) as _:
|
523
|
+
x = torch.rand(3, 4).cuda()
|
524
|
+
y = torch.rand(3, 4).cuda()
|
525
|
+
# z is broken on rank 1 but not others
|
526
|
+
z = do_bogus_tensor_work(x, y, fail_rank=1)
|
527
|
+
# test that rank 1 is still doing work despite z failing
|
528
|
+
a = (x + y).reduce("gpu")
|
529
|
+
fetch_shard(a).result()
|
530
|
+
# but z itself should fail, even if we do not fetch it from rank 1
|
531
|
+
# (since fetch shard says we first want to assert the whole tensor is correct)
|
532
|
+
with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
|
533
|
+
fetch_shard(z).result()
|
534
|
+
# try to reduce z, which should fail, but ranks that are not 1 do not
|
535
|
+
# know about the failure. Rank 1 should still participate in the reduce
|
536
|
+
# to unblock work.
|
537
|
+
rz = z.reduce("gpu")
|
538
|
+
# but we should see the error message still retrieving it because it is
|
539
|
+
# dependent on an error.
|
540
|
+
with pytest.raises(RemoteException, match="do_bogus_tensor_work"):
|
541
|
+
fetch_shard(rz).result()
|
542
|
+
# however, we should still be able to compute and get a result back
|
543
|
+
# from host 1, signaling that the reduction didn't get cuda compute stuck.
|
544
|
+
fetch_shard(2 * x, gpu=1, host=0).result()
|
545
|
+
|
546
|
+
def test_pipe(self, backend_type):
|
547
|
+
with self.local_device_mesh(2, 2, backend_type):
|
548
|
+
p = example_echo_add()
|
549
|
+
for _i in range(10):
|
550
|
+
x = torch.rand(3, 4)
|
551
|
+
p.send(x)
|
552
|
+
y = p.recv()
|
553
|
+
x, y = fetch_shard((x, y)).result()
|
554
|
+
with no_mesh.activate():
|
555
|
+
assert torch.allclose(x + 1, y)
|
556
|
+
|
557
|
+
def test_loader(self, backend_type):
|
558
|
+
with self.local_device_mesh(2, 2, backend_type):
|
559
|
+
p = example_data_loader(3, 7)
|
560
|
+
for i in range(3, 7):
|
561
|
+
x = fetch_shard(p.recv()).result()
|
562
|
+
with no_mesh.activate():
|
563
|
+
assert x.item() == i
|
564
|
+
|
565
|
+
def test_loader_blocks_with_small_pipe(self, backend_type):
|
566
|
+
with self.local_device_mesh(2, 2, backend_type):
|
567
|
+
iters = 10
|
568
|
+
p = example_data_loader_small_pipe(iters, (1000, 1000))
|
569
|
+
# timeout should proc on pipe process
|
570
|
+
sleep(0.6)
|
571
|
+
# it takes a few iters of reasonably sized tensors to fill up OS buffer
|
572
|
+
# max_messages (SNDHWM) only affects the zmq buffer
|
573
|
+
for _ in range(iters - 1):
|
574
|
+
p.recv()
|
575
|
+
t = fetch_shard(p.recv()).result()
|
576
|
+
assert t[0][0].item() == -1.0
|
577
|
+
|
578
|
+
def test_streams_run_parallel(self, backend_type):
|
579
|
+
with self.local_device_mesh(2, 2, backend_type):
|
580
|
+
# test that these two streams do in fact run in parallel
|
581
|
+
# on the worker by having each stream wait on a barrier.
|
582
|
+
# The Tensor t is just used as a data-dependency so that
|
583
|
+
# we can make sure new_barrier_hackery is called before
|
584
|
+
# the wait on 'other'.
|
585
|
+
other = Stream("other")
|
586
|
+
t = new_barrier_hackery(2)
|
587
|
+
t_other, borrow = other.borrow(t)
|
588
|
+
with borrow:
|
589
|
+
with other.activate():
|
590
|
+
wait_barrier_hackery(t_other)
|
591
|
+
wait_barrier_hackery(t)
|
592
|
+
fetch_shard(t).result()
|
593
|
+
|
594
|
+
def test_debug(self, backend_type):
|
595
|
+
gonna_pdb = remote(
|
596
|
+
"monarch.worker._testing_function.gonna_pdb", propagate="inspect"
|
597
|
+
)
|
598
|
+
|
599
|
+
with self.local_device_mesh(2, 2, backend_type):
|
600
|
+
writes = []
|
601
|
+
|
602
|
+
def dw(s):
|
603
|
+
writes.append(s)
|
604
|
+
|
605
|
+
def dr(n):
|
606
|
+
buffer = "".join(["print(x)\n", "c\n"]).encode()
|
607
|
+
assert len(buffer) <= n
|
608
|
+
return buffer
|
609
|
+
|
610
|
+
if backend_type == BackendType.RS:
|
611
|
+
patch_read = patch(
|
612
|
+
"monarch.controller.rust_backend.controller.debugger_read", new=dr
|
613
|
+
)
|
614
|
+
patch_write = patch(
|
615
|
+
"monarch.controller.rust_backend.controller.debugger_write", new=dw
|
616
|
+
)
|
617
|
+
else:
|
618
|
+
patch_read = patch("monarch.controller.debugger.read", new=dr)
|
619
|
+
patch_write = patch("monarch.controller.debugger.write", new=dw)
|
620
|
+
with patch_read, patch_write:
|
621
|
+
gonna_pdb()
|
622
|
+
# xxx: we do not process messages from workers
|
623
|
+
# unless fetching a result
|
624
|
+
fetch_shard(None).result()
|
625
|
+
assert "".join(writes).count("7\n") == 4
|
626
|
+
|
627
|
+
def test_fetch_preprocess(self, backend_type):
|
628
|
+
with self.local_device_mesh(2, 2, backend_type):
|
629
|
+
assert (
|
630
|
+
"an argument processed"
|
631
|
+
== remote("monarch.worker._testing_function.do_some_processing")
|
632
|
+
.call_on_shard_and_fetch(
|
633
|
+
"an argument",
|
634
|
+
)
|
635
|
+
.result()
|
636
|
+
)
|
637
|
+
|
638
|
+
def test_cached_remote_function(self, backend_type):
|
639
|
+
fn = remote("monarch.worker._testing_function.how_many_of_these_do_you_want")
|
640
|
+
start_hits = remote_module._hit
|
641
|
+
with self.local_device_mesh(2, 2, backend_type):
|
642
|
+
x = torch.ones(3, 4)
|
643
|
+
y = torch.rand(3, 4)
|
644
|
+
|
645
|
+
a, _, _ = fn(3, x)
|
646
|
+
b, _, _ = fn(3, x)
|
647
|
+
assert len(a._aliases.aliases) == 1
|
648
|
+
assert len(b._aliases.aliases) == 1
|
649
|
+
_, _, _ = fn(3, y)
|
650
|
+
t0, t1 = fn(2, x)
|
651
|
+
t0.add(t1)
|
652
|
+
local_a = fetch_shard(a).result()
|
653
|
+
with no_mesh.activate():
|
654
|
+
assert torch.all(local_a == 1.0)
|
655
|
+
|
656
|
+
end_hits = remote_module._hit
|
657
|
+
assert end_hits - start_hits == 2
|
658
|
+
|
659
|
+
def test_remote_autograd_function(self, backend_type):
|
660
|
+
from monarch.worker import _testing_function
|
661
|
+
|
662
|
+
remote_fn = remote_autograd_function(
|
663
|
+
_testing_function.TestRemoteAutogradFunction
|
664
|
+
)
|
665
|
+
|
666
|
+
with self.local_device_mesh(1, 1, backend_type):
|
667
|
+
x = torch.ones(1, requires_grad=True)
|
668
|
+
y = torch.ones_like(x).requires_grad_(True)
|
669
|
+
outs = remote_fn.apply(x, y)
|
670
|
+
assert outs[3] == 4
|
671
|
+
local_0 = fetch_shard(outs[0]).result()
|
672
|
+
local_1 = fetch_shard(outs[1]).result()
|
673
|
+
(outs[0] + outs[1]).sum().backward()
|
674
|
+
# unfortunately, grad_fn of local tensor is always None
|
675
|
+
# regardless of whether we set `no_grad` on the worker
|
676
|
+
# so we can test only requires_grad
|
677
|
+
for ll in (local_0, local_1):
|
678
|
+
assert not ll.requires_grad
|
679
|
+
grad_local_0 = fetch_shard(x.grad).result()
|
680
|
+
grad_local_1 = fetch_shard(x.grad).result()
|
681
|
+
x = x.detach()
|
682
|
+
x.grad = None
|
683
|
+
y.grad = None
|
684
|
+
outs = remote_fn.apply(x, y)
|
685
|
+
local_0_f = fetch_shard(outs[0]).result()
|
686
|
+
(outs[0] + outs[1]).sum().backward()
|
687
|
+
assert x.grad is None
|
688
|
+
grad_local_1_f = fetch_shard(y.grad).result()
|
689
|
+
|
690
|
+
assert torch.equal(local_0_f, torch.full_like(local_0_f, 2))
|
691
|
+
assert torch.equal(local_0, torch.ones_like(local_0))
|
692
|
+
assert torch.equal(grad_local_0, torch.ones_like(local_0))
|
693
|
+
assert torch.equal(grad_local_1, torch.ones_like(local_0))
|
694
|
+
assert torch.equal(grad_local_1_f, torch.ones_like(local_0))
|
695
|
+
|
696
|
+
def test_cached_remote_aliases(self, backend_type):
|
697
|
+
fn = remote("monarch.worker._testing_function.remote_chunk")
|
698
|
+
with self.local_device_mesh(1, 1, backend_type):
|
699
|
+
x = torch.randn(16, 5, device="cuda")
|
700
|
+
outs = fn(x)
|
701
|
+
aliases = outs[0]._aliases.aliases
|
702
|
+
# x and 4 results of x.chunk(4)
|
703
|
+
assert len(aliases) == 5
|
704
|
+
assert outs[2]._fake.storage_offset() == 40
|
705
|
+
|
706
|
+
def test_live_function(self, backend_type):
|
707
|
+
def bar(x, y):
|
708
|
+
return (
|
709
|
+
a_function_called_by_a_live_function(x)
|
710
|
+
+ a_live_function_call_by_a_live_function(y)
|
711
|
+
+ math.pi
|
712
|
+
)
|
713
|
+
|
714
|
+
@remote
|
715
|
+
def check(x):
|
716
|
+
return torch.allclose(x, torch.zeros(()) + math.pi + 5)
|
717
|
+
|
718
|
+
y = 7
|
719
|
+
|
720
|
+
@monarch.remote
|
721
|
+
def close():
|
722
|
+
return y
|
723
|
+
|
724
|
+
@monarch.remote
|
725
|
+
def cuda_works(x):
|
726
|
+
return x.cuda()
|
727
|
+
|
728
|
+
with self.local_device_mesh(2, 2, backend_type):
|
729
|
+
a = torch.ones(())
|
730
|
+
assert check.call_on_shard_and_fetch(bar(a, a)).result()
|
731
|
+
# ensure we do not attempt to pickle closures
|
732
|
+
close()
|
733
|
+
|
734
|
+
b = cuda_works(a)
|
735
|
+
fetch_shard(b).result()
|
736
|
+
|
737
|
+
@monarch.remote
|
738
|
+
def something_else():
|
739
|
+
raise Exception("No") # this line appears
|
740
|
+
|
741
|
+
# check that the stack trace has correct line numbers
|
742
|
+
with pytest.raises(Exception, match=r"this line appears"):
|
743
|
+
something_else()
|
744
|
+
|
745
|
+
def test_setting_random_seed(self, backend_type):
|
746
|
+
with self.local_device_mesh(2, 2, backend_type):
|
747
|
+
set_manual_seed_remote(12345)
|
748
|
+
t = torch.randn(3, 4)
|
749
|
+
t_d = torch.randn(3, 4, device="cuda")
|
750
|
+
ref = fetch_shard(t).result()
|
751
|
+
ref_d = fetch_shard(t_d).result()
|
752
|
+
vals = {
|
753
|
+
(h, d): fetch_shard(t, {"host": h, "gpu": d}).result()
|
754
|
+
for h, d in itertools.product(range(2), repeat=2)
|
755
|
+
}
|
756
|
+
|
757
|
+
vals_d = {
|
758
|
+
(h, d): fetch_shard(t_d, {"host": h, "gpu": d}).result()
|
759
|
+
for h, d in itertools.product(range(2), repeat=2)
|
760
|
+
}
|
761
|
+
|
762
|
+
for v, v_d in zip(vals.values(), vals_d.values()):
|
763
|
+
assert torch.equal(v, ref)
|
764
|
+
assert torch.equal(v_d, ref_d)
|
765
|
+
|
766
|
+
def test_return_exception(self, backend_type):
|
767
|
+
@monarch.remote
|
768
|
+
def simple():
|
769
|
+
return Exception("is a valid value to return")
|
770
|
+
|
771
|
+
with self.local_device_mesh(1, 1, backend_type):
|
772
|
+
# This should be a valid return than an exception to raise
|
773
|
+
simple.call_on_shard_and_fetch().result()
|
774
|
+
|
775
|
+
def test_opaque_object(self, backend_type):
|
776
|
+
with self.local_device_mesh(2, 2, backend_type):
|
777
|
+
|
778
|
+
class Foo(OpaqueObject):
|
779
|
+
@opaque_method
|
780
|
+
def add(self, x: torch.Tensor):
|
781
|
+
return x + x
|
782
|
+
|
783
|
+
f = Foo("monarch.worker._testing_function.WorkerFoo", 4.0)
|
784
|
+
|
785
|
+
result = monarch.inspect(f.add(torch.ones(3, 4)))
|
786
|
+
with monarch.no_mesh.activate():
|
787
|
+
assert torch.allclose(torch.full((3, 4), 5.0), result)
|
788
|
+
|
789
|
+
f.hi = 4
|
790
|
+
assert f.hi == 4
|
791
|
+
|
792
|
+
def test_opaqueRef_setup_state_and_iteration(self, backend_type):
|
793
|
+
with self.local_device_mesh(1, 2, backend_type) as mesh:
|
794
|
+
pg = mesh.process_group(("gpu",))
|
795
|
+
model, dataloader, criterion, optimizer = setup_state()
|
796
|
+
num_epochs = 5
|
797
|
+
for _ in range(num_epochs):
|
798
|
+
loss = iteration(model, dataloader, criterion, optimizer, pg)
|
799
|
+
assert inspect(loss).item() > 0
|
800
|
+
|
801
|
+
def test_opaqueRef_key_deleted(self, backend_type):
|
802
|
+
with self.local_device_mesh(1, 1, backend_type):
|
803
|
+
ref = create_opaque_ref()
|
804
|
+
assert inspect(opaque_ref_key_table_length()).item() == 1
|
805
|
+
del ref
|
806
|
+
assert inspect(opaque_ref_key_table_length()).item() == 0
|
807
|
+
|
808
|
+
def test_opaque_module(self, backend_type):
|
809
|
+
with self.local_device_mesh(2, 2, backend_type):
|
810
|
+
linear = OpaqueModule("torch.nn.Linear", 3, 3, device="cuda")
|
811
|
+
with torch.no_grad():
|
812
|
+
for p in linear.parameters():
|
813
|
+
p.zero_()
|
814
|
+
input_ = torch.rand(4, 3, device="cuda")
|
815
|
+
# we should have been able to clear the parameters and have that result
|
816
|
+
# affect how the linear works.
|
817
|
+
output = linear.call_method("forward", lambda self, x: x.clone(), input_)
|
818
|
+
assert monarch.inspect(output.sum()).item() == 0
|
819
|
+
|
820
|
+
def test_opaque_module_autograd(self, backend_type):
|
821
|
+
with self.local_device_mesh(2, 2, backend_type):
|
822
|
+
input_ = torch.rand(3, 3, device="cuda", requires_grad=True)
|
823
|
+
|
824
|
+
linear = OpaqueModule("torch.nn.Linear", 3, 3, device="cuda")
|
825
|
+
output = linear(input_, propagator=lambda self, x: x.clone())
|
826
|
+
r = output.sum()
|
827
|
+
with torch.no_grad():
|
828
|
+
r.backward()
|
829
|
+
|
830
|
+
weight, bias = linear.parameters()
|
831
|
+
ig0, wg0, bg0 = monarch.inspect((input_.grad, weight.grad, bias.grad))
|
832
|
+
|
833
|
+
input_.grad = None
|
834
|
+
weight.grad = None
|
835
|
+
bias.grad = None
|
836
|
+
|
837
|
+
(input_ @ weight.T + bias).sum().backward()
|
838
|
+
|
839
|
+
ig1, wg1, bg1 = monarch.inspect((input_.grad, weight.grad, bias.grad))
|
840
|
+
|
841
|
+
with monarch.no_mesh.activate():
|
842
|
+
assert torch.allclose(ig0, ig1)
|
843
|
+
assert torch.allclose(wg0, wg1)
|
844
|
+
assert torch.allclose(bg0, bg1)
|
845
|
+
|
846
|
+
def test_remote_function_reduce_scatter_tensor_sum(self, backend_type):
|
847
|
+
self.do_test_reduce_scatter_tensor(
|
848
|
+
backend_type,
|
849
|
+
torch.distributed.ReduceOp.SUM,
|
850
|
+
(
|
851
|
+
torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
|
852
|
+
* torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
|
853
|
+
).sum(0),
|
854
|
+
)
|
855
|
+
|
856
|
+
def test_remote_function_reduce_scatter_tensor_subgroup_sum(
|
857
|
+
self, backend_type: BackendType
|
858
|
+
) -> None:
|
859
|
+
self.do_test_reduce_scatter_tensor_subgroup(
|
860
|
+
backend_type,
|
861
|
+
torch.distributed.ReduceOp.SUM,
|
862
|
+
expected_tensor_host_group=torch.tensor(
|
863
|
+
[[0, 2, 4, 6], [0, 4, 8, 12], [8, 10, 12, 14], [16, 20, 24, 28]],
|
864
|
+
dtype=torch.float32,
|
865
|
+
),
|
866
|
+
expected_tensor_gpu_group=torch.tensor(
|
867
|
+
[[0, 1, 2, 3], [4, 5, 6, 7], [0, 5, 10, 15], [20, 25, 30, 35]],
|
868
|
+
dtype=torch.float32,
|
869
|
+
),
|
870
|
+
)
|
871
|
+
|
872
|
+
def test_remote_function_reduce_scatter_tensor_avg(self, backend_type):
|
873
|
+
self.do_test_reduce_scatter_tensor(
|
874
|
+
backend_type,
|
875
|
+
torch.distributed.ReduceOp.AVG,
|
876
|
+
(
|
877
|
+
torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
|
878
|
+
* torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
|
879
|
+
).mean(0),
|
880
|
+
)
|
881
|
+
|
882
|
+
def test_remote_function_reduce_scatter_sum(
|
883
|
+
self, backend_type: BackendType
|
884
|
+
) -> None:
|
885
|
+
self.do_test_reduce_scatter(
|
886
|
+
backend_type,
|
887
|
+
torch.distributed.ReduceOp.SUM,
|
888
|
+
(
|
889
|
+
torch.arange(0, 8, dtype=torch.float32).reshape(1, 4, 2).repeat(4, 1, 1)
|
890
|
+
* torch.arange(4, dtype=torch.float32).unsqueeze(-1).unsqueeze(-1)
|
891
|
+
).sum(0),
|
892
|
+
)
|
893
|
+
|
894
|
+
def test_remote_function_reduce_scatter_avg(
|
895
|
+
self, backend_type: BackendType
|
896
|
+
) -> None:
|
897
|
+
self.do_test_reduce_scatter(
|
898
|
+
backend_type,
|
899
|
+
torch.distributed.ReduceOp.AVG,
|
900
|
+
(
|
901
|
+
torch.arange(0, 8, dtype=torch.float32).reshape(1, 4, 2).repeat(4, 1, 1)
|
902
|
+
* torch.arange(4, dtype=torch.float32).unsqueeze(-1).unsqueeze(-1)
|
903
|
+
).mean(0),
|
904
|
+
)
|
905
|
+
|
906
|
+
def test_remote_function_all_reduce_sum(self, backend_type):
|
907
|
+
self.do_test_all_reduce(
|
908
|
+
backend_type,
|
909
|
+
torch.distributed.ReduceOp.SUM,
|
910
|
+
(
|
911
|
+
torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
|
912
|
+
* torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
|
913
|
+
).sum(0),
|
914
|
+
)
|
915
|
+
|
916
|
+
def test_remote_function_all_reduce_avg(self, backend_type):
|
917
|
+
self.do_test_all_reduce(
|
918
|
+
backend_type,
|
919
|
+
torch.distributed.ReduceOp.AVG,
|
920
|
+
(
|
921
|
+
torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
|
922
|
+
* torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
|
923
|
+
).mean(0),
|
924
|
+
)
|
925
|
+
|
926
|
+
def test_remote_function_all_reduce_max(self, backend_type):
|
927
|
+
self.do_test_all_reduce(
|
928
|
+
backend_type,
|
929
|
+
torch.distributed.ReduceOp.MAX,
|
930
|
+
(
|
931
|
+
torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
|
932
|
+
* torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
|
933
|
+
).max(0)[0],
|
934
|
+
)
|
935
|
+
|
936
|
+
def test_remote_function_all_reduce_min(self, backend_type):
|
937
|
+
self.do_test_all_reduce(
|
938
|
+
backend_type,
|
939
|
+
torch.distributed.ReduceOp.MIN,
|
940
|
+
(
|
941
|
+
torch.arange(0, 8, dtype=float).reshape(1, 4, 2).repeat(4, 1, 1)
|
942
|
+
* torch.arange(4, dtype=float).unsqueeze(-1).unsqueeze(-1)
|
943
|
+
).min(0)[0],
|
944
|
+
)
|
945
|
+
|
946
|
+
def test_remote_function_failure_message_contains_traceback(self, backend_type):
|
947
|
+
with self.local_device_mesh(2, 2, backend_type):
|
948
|
+
x = outer_remote_function_that_calls_inner()
|
949
|
+
try:
|
950
|
+
inspect(x)
|
951
|
+
except RemoteException as e:
|
952
|
+
backtrace = "\n".join([frame.name for frame in e.worker_frames])
|
953
|
+
assert "outer_remote_function" in backtrace
|
954
|
+
assert "inner_remote_function" in backtrace
|
955
|
+
|
956
|
+
def test_remote_function_broadcast(self, backend_type):
|
957
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
958
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
959
|
+
for i in range(4):
|
960
|
+
rank = 2 * device_mesh.rank("host") + device_mesh.rank("gpu")
|
961
|
+
rank = rank.cuda()
|
962
|
+
broadcast(rank, src=i, group=pg)
|
963
|
+
for host in range(2):
|
964
|
+
for gpu in range(2):
|
965
|
+
with no_mesh.activate():
|
966
|
+
assert inspect(rank, {"host": host, "gpu": gpu}).item() == i
|
967
|
+
|
968
|
+
def test_remote_function_all_to_all_single(self, backend_type):
|
969
|
+
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
|
970
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
971
|
+
tensor_in = torch.arange(4, device="cuda", dtype=float)
|
972
|
+
tensor_out = torch.empty(4, device="cuda", dtype=float)
|
973
|
+
all_to_all_single(tensor_out, tensor_in, group=pg)
|
974
|
+
for host in range(2):
|
975
|
+
for gpu in range(2):
|
976
|
+
rank = 2 * host + gpu
|
977
|
+
with no_mesh.activate():
|
978
|
+
assert torch.equal(
|
979
|
+
inspect(tensor_out, {"host": host, "gpu": gpu}),
|
980
|
+
rank * torch.ones(4),
|
981
|
+
)
|
982
|
+
|
983
|
+
def test_remote_function_all_to_all(self, backend_type: BackendType) -> None:
|
984
|
+
world_size = 2
|
985
|
+
n_gpus = 2
|
986
|
+
size = world_size * n_gpus
|
987
|
+
expected_tensors = [
|
988
|
+
torch.tensor([0, 4, 8, 12], dtype=torch.float32),
|
989
|
+
torch.tensor([1, 5, 9, 13], dtype=torch.float32),
|
990
|
+
torch.tensor([2, 6, 10, 14], dtype=torch.float32),
|
991
|
+
torch.tensor([3, 7, 11, 15], dtype=torch.float32),
|
992
|
+
]
|
993
|
+
|
994
|
+
with self.local_device_mesh(world_size, n_gpus, backend_type) as device_mesh:
|
995
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
996
|
+
rank = n_gpus * device_mesh.rank("host") + device_mesh.rank("gpu")
|
997
|
+
in_tensors = list(
|
998
|
+
torch.chunk(
|
999
|
+
torch.arange(size, device="cuda", dtype=torch.float32)
|
1000
|
+
+ (rank * size),
|
1001
|
+
size,
|
1002
|
+
)
|
1003
|
+
)
|
1004
|
+
# These values will be replaced, just used for shape.
|
1005
|
+
out_tensors = list(torch.zeros(size, device="cuda").chunk(size))
|
1006
|
+
out_tensors = all_to_all(out_tensors, in_tensors, group=pg)
|
1007
|
+
for host in range(world_size):
|
1008
|
+
for gpu in range(n_gpus):
|
1009
|
+
local_tensor_out = inspect(out_tensors, {"host": host, "gpu": gpu})
|
1010
|
+
rank = host * n_gpus + gpu
|
1011
|
+
with no_mesh.activate():
|
1012
|
+
# Combine the tensor list together for a better comparison
|
1013
|
+
# message.
|
1014
|
+
local_tensor_out = torch.cat(local_tensor_out)
|
1015
|
+
assert torch.equal(
|
1016
|
+
local_tensor_out, expected_tensors[rank]
|
1017
|
+
), f"For {rank=}, {host=}, {gpu=}"
|
1018
|
+
|
1019
|
+
|
1020
|
+
@pytest.mark.skipif(
|
1021
|
+
torch.cuda.device_count() < 2,
|
1022
|
+
reason="Not enough GPUs, this test requires at least 2 GPUs",
|
1023
|
+
)
|
1024
|
+
# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
|
1025
|
+
# out is not counted as a failure, so we set a more restrictive timeout to
|
1026
|
+
# ensure we see a hard failure in CI.
|
1027
|
+
@pytest.mark.timeout(120)
|
1028
|
+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
|
1029
|
+
class TestComm(RemoteFunctionsTestBase):
|
1030
|
+
N_GPUS: int = 2
|
1031
|
+
N_HOSTS: int = 2
|
1032
|
+
|
1033
|
+
@property
|
1034
|
+
def world_size(self) -> int:
|
1035
|
+
return self.N_GPUS * self.N_HOSTS
|
1036
|
+
|
1037
|
+
@property
|
1038
|
+
def device(self):
|
1039
|
+
self.fail("test subclass didn't override device")
|
1040
|
+
|
1041
|
+
def _test_tensor_dtype_complex(self, backend_type: BackendType) -> None:
|
1042
|
+
with self.local_device_mesh(
|
1043
|
+
self.N_HOSTS, self.N_GPUS, backend_type
|
1044
|
+
) as device_mesh:
|
1045
|
+
group = device_mesh.process_group(("host", "gpu"))
|
1046
|
+
tensor = torch.rand(2, device="cuda")
|
1047
|
+
tensor_c = torch.view_as_complex(tensor)
|
1048
|
+
tensor_list = [
|
1049
|
+
torch.rand(2, device="cuda") for _ in range(self.N_HOSTS * self.N_GPUS)
|
1050
|
+
]
|
1051
|
+
tensor_list_c = list(tensor_list)
|
1052
|
+
tensor_list_c[1] = torch.view_as_complex(tensor_list_c[1])
|
1053
|
+
|
1054
|
+
inspect(all_gather(tensor_list, tensor, group=group))
|
1055
|
+
inspect(all_gather(tensor_list, tensor_c, group=group))
|
1056
|
+
inspect(all_gather(tensor_list_c, tensor, group=group))
|
1057
|
+
inspect(all_gather(tensor_list_c, tensor_c, group=group))
|
1058
|
+
|
1059
|
+
def test_nccl_barrier(self, backend_type: BackendType) -> None:
|
1060
|
+
with self.local_device_mesh(
|
1061
|
+
self.N_HOSTS, self.N_GPUS, backend_type
|
1062
|
+
) as device_mesh:
|
1063
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
1064
|
+
rank = device_mesh.rank(("host", "gpu"))
|
1065
|
+
t = torch.tensor([1] * 10, device="cuda") + rank
|
1066
|
+
all_reduce(t, group=pg)
|
1067
|
+
|
1068
|
+
for host in range(self.N_HOSTS):
|
1069
|
+
for gpu in range(self.N_GPUS):
|
1070
|
+
rank = 2 * host + gpu
|
1071
|
+
with no_mesh.activate():
|
1072
|
+
# all reduce will sum rank + 1 across all ranks.
|
1073
|
+
expected_tensor = torch.tensor(
|
1074
|
+
[sum(range(1, self.world_size + 1))] * 10
|
1075
|
+
)
|
1076
|
+
assert torch.equal(
|
1077
|
+
expected_tensor,
|
1078
|
+
inspect(t, {"host": host, "gpu": gpu}),
|
1079
|
+
)
|
1080
|
+
|
1081
|
+
def test_tensor_dtype_complex(self, backend_type: BackendType) -> None:
|
1082
|
+
self._test_tensor_dtype_complex(backend_type)
|
1083
|
+
|
1084
|
+
def test_reduce_scatter_base_k(self, backend_type: BackendType) -> None:
|
1085
|
+
expected_tensor = (
|
1086
|
+
torch.arange(self.N_HOSTS * self.N_GPUS * 2, dtype=torch.float32)
|
1087
|
+
.reshape(1, self.N_HOSTS * self.N_GPUS, 2)
|
1088
|
+
.repeat(self.N_HOSTS * self.N_GPUS, 1, 1)
|
1089
|
+
).sum(0)
|
1090
|
+
with self.local_device_mesh(
|
1091
|
+
self.N_HOSTS, self.N_GPUS, backend_type
|
1092
|
+
) as device_mesh:
|
1093
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
1094
|
+
output_tensor = torch.zeros(2, dtype=torch.int64, device="cuda")
|
1095
|
+
input_tensors = torch.arange(
|
1096
|
+
self.N_HOSTS * self.N_GPUS * 2, dtype=torch.int64, device="cuda"
|
1097
|
+
)
|
1098
|
+
input_tensors = torch.reshape(
|
1099
|
+
input_tensors, (self.N_HOSTS * self.N_GPUS, 2)
|
1100
|
+
)
|
1101
|
+
# Input is [[0, 1], [2, 3], [4, 5], [6, 7]] across 4 ranks.
|
1102
|
+
# After reduce + scatter, output_tensor should be [0 * 4, 1 * 4] on the 0th rank
|
1103
|
+
# and [2 * 4, 3 * 4] on the 1st rank, and so on
|
1104
|
+
reduce_scatter_tensor(output_tensor, input_tensors, group=pg)
|
1105
|
+
|
1106
|
+
for host in range(self.N_HOSTS):
|
1107
|
+
for gpu in range(self.N_GPUS):
|
1108
|
+
rank = 2 * host + gpu
|
1109
|
+
output_tensor_local = inspect(
|
1110
|
+
output_tensor, {"host": host, "gpu": gpu}
|
1111
|
+
)
|
1112
|
+
with no_mesh.activate():
|
1113
|
+
assert torch.equal(output_tensor_local, expected_tensor[rank])
|
1114
|
+
|
1115
|
+
|
1116
|
+
@pytest.mark.skipif(
|
1117
|
+
torch.cuda.device_count() < 2,
|
1118
|
+
reason="Not enough GPUs, this test requires at least 2 GPUs",
|
1119
|
+
)
|
1120
|
+
# Set global timeout--sandcastle's timeout is 600s. A test that sandcastle times
|
1121
|
+
# out is not counted as a failure, so we set a more restrictive timeout to
|
1122
|
+
# ensure we see a hard failure in CI.
|
1123
|
+
@pytest.mark.timeout(120)
|
1124
|
+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
|
1125
|
+
class TestNcclProcessGroupWithDispatchedCollectives(RemoteFunctionsTestBase):
|
1126
|
+
"""This test is copied from test_c10d_nccl.py::NcclProcessGroupWithDispatchedCollectivesTests
|
1127
|
+
in torch, but modified to setup a Monarch device mesh and use remote functions"""
|
1128
|
+
|
1129
|
+
N_GPUS: int = 2
|
1130
|
+
N_HOSTS: int = 2
|
1131
|
+
|
1132
|
+
def _call_collective_with_varying_tensors(
|
1133
|
+
self,
|
1134
|
+
world_size: int,
|
1135
|
+
# pyre-fixme[24]: Incorrect ParamsSpec annotation.
|
1136
|
+
collective: Remote[..., torch.Tensor],
|
1137
|
+
*args,
|
1138
|
+
**kwargs,
|
1139
|
+
) -> None:
|
1140
|
+
# call collective with varying tensors to ensure that the tensors are
|
1141
|
+
# correctly dispatched
|
1142
|
+
|
1143
|
+
# ensure supported devices (cpu, cuda) succeeds during dispatch call
|
1144
|
+
tensor = torch.zeros(2, 2, device=torch.device("cuda"))
|
1145
|
+
# multi tensor collectives
|
1146
|
+
if collective == barrier:
|
1147
|
+
fetch_shard(collective(*args, **kwargs)).result()
|
1148
|
+
elif collective == all_gather:
|
1149
|
+
output_list = list(
|
1150
|
+
torch.zeros(world_size * 2, 2, device=torch.device("cuda")).chunk(
|
1151
|
+
world_size
|
1152
|
+
)
|
1153
|
+
)
|
1154
|
+
fetch_shard(collective(output_list, tensor, *args, **kwargs)).result()
|
1155
|
+
elif collective == reduce_scatter:
|
1156
|
+
fetch_shard(
|
1157
|
+
collective(tensor, [tensor] * world_size, *args, **kwargs)
|
1158
|
+
).result()
|
1159
|
+
elif collective == gather:
|
1160
|
+
gather_list = list(
|
1161
|
+
torch.zeros(world_size * 2, 2, device=torch.device("cuda")).chunk(
|
1162
|
+
world_size
|
1163
|
+
)
|
1164
|
+
)
|
1165
|
+
fetch_shard(collective(tensor, gather_list, *args, **kwargs)).result()
|
1166
|
+
elif collective == scatter:
|
1167
|
+
fetch_shard(
|
1168
|
+
collective(tensor, [tensor] * world_size, *args, **kwargs)
|
1169
|
+
).result()
|
1170
|
+
elif collective == all_to_all:
|
1171
|
+
fetch_shard(
|
1172
|
+
collective(
|
1173
|
+
[tensor] * world_size, [tensor] * world_size, *args, **kwargs
|
1174
|
+
)
|
1175
|
+
).result()
|
1176
|
+
else:
|
1177
|
+
fetch_shard(collective(tensor, *args, **kwargs)).result()
|
1178
|
+
|
1179
|
+
@pytest.mark.parametrize(
|
1180
|
+
"collective",
|
1181
|
+
[
|
1182
|
+
reduce,
|
1183
|
+
broadcast,
|
1184
|
+
all_reduce,
|
1185
|
+
all_gather,
|
1186
|
+
reduce_scatter,
|
1187
|
+
barrier,
|
1188
|
+
all_to_all,
|
1189
|
+
gather,
|
1190
|
+
scatter,
|
1191
|
+
],
|
1192
|
+
ids=[
|
1193
|
+
"reduce",
|
1194
|
+
"broadcast",
|
1195
|
+
"all_reduce",
|
1196
|
+
"all_gather",
|
1197
|
+
"reduce_scatter",
|
1198
|
+
"barrier",
|
1199
|
+
"all_to_all",
|
1200
|
+
"gather",
|
1201
|
+
"scatter",
|
1202
|
+
],
|
1203
|
+
)
|
1204
|
+
def test_collectives(
|
1205
|
+
self, backend_type: BackendType, collective: Callable[..., torch.Tensor]
|
1206
|
+
) -> None:
|
1207
|
+
world_size = self.N_HOSTS * self.N_GPUS
|
1208
|
+
with self.local_device_mesh(
|
1209
|
+
self.N_HOSTS, self.N_GPUS, backend_type
|
1210
|
+
) as device_mesh:
|
1211
|
+
rank = device_mesh.rank(("host", "gpu"))
|
1212
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
1213
|
+
|
1214
|
+
kwargs: dict[str, object] = {"group": pg}
|
1215
|
+
if collective == reduce:
|
1216
|
+
kwargs["group_dst"] = 0
|
1217
|
+
elif collective == broadcast:
|
1218
|
+
kwargs["group_src"] = rank
|
1219
|
+
elif collective == gather:
|
1220
|
+
kwargs["group_dst"] = 0
|
1221
|
+
elif collective == scatter:
|
1222
|
+
kwargs["group_src"] = 0
|
1223
|
+
self._call_collective_with_varying_tensors(world_size, collective, **kwargs)
|
1224
|
+
|
1225
|
+
def test_all_to_all_single(self, backend_type: BackendType) -> None:
|
1226
|
+
with self.local_device_mesh(
|
1227
|
+
self.N_HOSTS, self.N_GPUS, backend_type
|
1228
|
+
) as device_mesh:
|
1229
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
1230
|
+
# test alltoall_base
|
1231
|
+
tensor_in = torch.arange(4, device="cuda", dtype=torch.float32)
|
1232
|
+
tensor_out = torch.empty(4, device="cuda", dtype=torch.float32)
|
1233
|
+
all_to_all_single(tensor_out, tensor_in, group=pg)
|
1234
|
+
|
1235
|
+
for host in range(self.N_HOSTS):
|
1236
|
+
for gpu in range(self.N_GPUS):
|
1237
|
+
rank = 2 * host + gpu
|
1238
|
+
with no_mesh.activate():
|
1239
|
+
assert torch.equal(
|
1240
|
+
inspect(tensor_out, {"host": host, "gpu": gpu}),
|
1241
|
+
rank * torch.ones(4),
|
1242
|
+
)
|
1243
|
+
|
1244
|
+
def test_allgather_base(self, backend_type: BackendType) -> None:
|
1245
|
+
with self.local_device_mesh(
|
1246
|
+
self.N_HOSTS, self.N_GPUS, backend_type
|
1247
|
+
) as device_mesh:
|
1248
|
+
pg = device_mesh.process_group(("host", "gpu"))
|
1249
|
+
rank = (
|
1250
|
+
(device_mesh.rank("host") + 1) * self.N_GPUS
|
1251
|
+
+ device_mesh.rank("gpu")
|
1252
|
+
+ 1
|
1253
|
+
)
|
1254
|
+
tensor_in = torch.arange(2, device="cuda") * rank
|
1255
|
+
tensor_out = torch.zeros(
|
1256
|
+
self.N_HOSTS * self.N_GPUS * 2, dtype=torch.int64, device="cuda"
|
1257
|
+
)
|
1258
|
+
all_gather_into_tensor(tensor_out, tensor_in, group=pg)
|
1259
|
+
local_tensor_out = inspect(tensor_out)
|
1260
|
+
with no_mesh.activate():
|
1261
|
+
assert torch.equal(
|
1262
|
+
local_tensor_out, torch.tensor([0, 3, 0, 4, 0, 5, 0, 6])
|
1263
|
+
)
|
1264
|
+
|
1265
|
+
|
1266
|
+
def a_function_called_by_a_live_function(x):
|
1267
|
+
return 2 * x
|
1268
|
+
|
1269
|
+
|
1270
|
+
def a_live_function_call_by_a_live_function(x):
|
1271
|
+
return 3 * x
|