torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,1028 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import gc
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
|
11
|
+
from datetime import timedelta
|
|
12
|
+
from typing import Any, Callable, cast, Dict, List
|
|
13
|
+
from unittest import skipIf, skipUnless, TestCase
|
|
14
|
+
from unittest.mock import Mock, patch
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.distributed as dist
|
|
18
|
+
from parameterized import parameterized
|
|
19
|
+
from torch import nn
|
|
20
|
+
from torch._C._distributed_c10d import (
|
|
21
|
+
_resolve_process_group,
|
|
22
|
+
AllgatherOptions,
|
|
23
|
+
AllreduceCoalescedOptions,
|
|
24
|
+
AllreduceOptions,
|
|
25
|
+
AllToAllOptions,
|
|
26
|
+
BarrierOptions,
|
|
27
|
+
BroadcastOptions,
|
|
28
|
+
ReduceOp,
|
|
29
|
+
ReduceScatterOptions,
|
|
30
|
+
)
|
|
31
|
+
from torch.distributed import (
|
|
32
|
+
_functional_collectives,
|
|
33
|
+
get_world_size,
|
|
34
|
+
ReduceOp,
|
|
35
|
+
TCPStore,
|
|
36
|
+
)
|
|
37
|
+
from torchft.manager import Manager
|
|
38
|
+
from torchft.process_group import (
|
|
39
|
+
_ErrorSwallowingWork,
|
|
40
|
+
ErrorSwallowingProcessGroupWrapper,
|
|
41
|
+
ManagedProcessGroup,
|
|
42
|
+
ProcessGroup,
|
|
43
|
+
ProcessGroupBabyGloo,
|
|
44
|
+
ProcessGroupBabyNCCL,
|
|
45
|
+
ProcessGroupDummy,
|
|
46
|
+
ProcessGroupGloo,
|
|
47
|
+
ProcessGroupNCCL,
|
|
48
|
+
ProcessGroupWrapper,
|
|
49
|
+
)
|
|
50
|
+
from torchft.work import _DummyWork
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def dummy_init_pg() -> None:
|
|
54
|
+
if not dist.is_initialized():
|
|
55
|
+
dist.init_process_group(
|
|
56
|
+
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _test_pg(
|
|
61
|
+
pg: ProcessGroup,
|
|
62
|
+
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
|
|
63
|
+
skip: list[str] = [],
|
|
64
|
+
) -> Dict[str, dist._Work]:
|
|
65
|
+
"""
|
|
66
|
+
Helper function to test a set of collective operations on a given process group.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
shape: torch.Size = example_tensor.shape
|
|
70
|
+
dtype: torch.dtype = example_tensor.dtype
|
|
71
|
+
|
|
72
|
+
# Create some dummy tensors for testing
|
|
73
|
+
input_tensor = example_tensor.clone()
|
|
74
|
+
output_tensors = [
|
|
75
|
+
[torch.empty_like(input_tensor) for _ in range(get_world_size(pg))]
|
|
76
|
+
]
|
|
77
|
+
tensor_list = [torch.empty_like(input_tensor)]
|
|
78
|
+
|
|
79
|
+
def check_tensors(arg: object) -> None:
|
|
80
|
+
"""Recursively check tensors for expected shape and dtype."""
|
|
81
|
+
if isinstance(arg, torch.Tensor):
|
|
82
|
+
assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}"
|
|
83
|
+
assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}"
|
|
84
|
+
elif isinstance(arg, (list, tuple)):
|
|
85
|
+
for item in arg:
|
|
86
|
+
check_tensors(item)
|
|
87
|
+
|
|
88
|
+
# Test collectives. send/recv require multiple processes to test, so we skip them here
|
|
89
|
+
collectives = [
|
|
90
|
+
("allreduce", ([input_tensor], AllreduceOptions())),
|
|
91
|
+
("allreduce", ([input_tensor], ReduceOp.SUM)),
|
|
92
|
+
("allreduce_coalesced", ([input_tensor], AllreduceCoalescedOptions())),
|
|
93
|
+
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
|
|
94
|
+
(
|
|
95
|
+
"allgather_into_tensor_coalesced",
|
|
96
|
+
(output_tensors[0], [input_tensor], AllgatherOptions()),
|
|
97
|
+
),
|
|
98
|
+
(
|
|
99
|
+
"alltoall_base",
|
|
100
|
+
(
|
|
101
|
+
output_tensors[0][0],
|
|
102
|
+
input_tensor,
|
|
103
|
+
[input_tensor.shape[0]],
|
|
104
|
+
[input_tensor.shape[0]],
|
|
105
|
+
AllToAllOptions(),
|
|
106
|
+
),
|
|
107
|
+
),
|
|
108
|
+
("barrier", (BarrierOptions(),)),
|
|
109
|
+
("broadcast", (tensor_list, BroadcastOptions())),
|
|
110
|
+
("broadcast_one", (input_tensor, 0)),
|
|
111
|
+
(
|
|
112
|
+
"reduce_scatter",
|
|
113
|
+
(output_tensors[0], [[input_tensor]], ReduceScatterOptions()),
|
|
114
|
+
),
|
|
115
|
+
(
|
|
116
|
+
"reduce_scatter_tensor_coalesced",
|
|
117
|
+
(output_tensors[0], [input_tensor], ReduceScatterOptions()),
|
|
118
|
+
),
|
|
119
|
+
]
|
|
120
|
+
works: Dict[str, dist._Work] = {}
|
|
121
|
+
|
|
122
|
+
for coll_str, args in collectives:
|
|
123
|
+
if coll_str in skip:
|
|
124
|
+
continue
|
|
125
|
+
try:
|
|
126
|
+
coll = getattr(pg, coll_str)
|
|
127
|
+
work = coll(*args)
|
|
128
|
+
works[coll_str] = work
|
|
129
|
+
work.wait()
|
|
130
|
+
fut = work.get_future()
|
|
131
|
+
fut.wait()
|
|
132
|
+
# Check that all tensor arguments have the expected shapes and dtypes
|
|
133
|
+
check_tensors(args)
|
|
134
|
+
except RuntimeError as e:
|
|
135
|
+
if f"does not support {coll_str}" in str(e):
|
|
136
|
+
# Skip collectives that are not supported by the backend.
|
|
137
|
+
continue
|
|
138
|
+
raise e
|
|
139
|
+
|
|
140
|
+
return works
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def run_allgather_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
144
|
+
"""Test allgather collective operation.
|
|
145
|
+
|
|
146
|
+
Suppose each rank's local tensor = [rank+1, rank+2],
|
|
147
|
+
we allgather => gather onto a list of length world_sz.
|
|
148
|
+
"""
|
|
149
|
+
world_sz = pg.size()
|
|
150
|
+
to_gather = torch.stack([tensor, tensor + 1], dim=0)
|
|
151
|
+
# shape: (2,)
|
|
152
|
+
to_gather = to_gather.reshape(-1)
|
|
153
|
+
|
|
154
|
+
# Gathers as follows: [ [ recv0 ], [ recv1 ], ... [ recv_{sz-1} ] ]
|
|
155
|
+
# Each recv is shape (2,)
|
|
156
|
+
output_list = [
|
|
157
|
+
torch.zeros(2, device=tensor.device, dtype=tensor.dtype)
|
|
158
|
+
for _ in range(world_sz)
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
work = pg.allgather([output_list], [to_gather], AllgatherOptions())
|
|
162
|
+
work.wait()
|
|
163
|
+
|
|
164
|
+
for r in range(world_sz):
|
|
165
|
+
expected = torch.tensor(
|
|
166
|
+
[r + 1, r + 2], device=tensor.device, dtype=tensor.dtype
|
|
167
|
+
)
|
|
168
|
+
torch.testing.assert_close(output_list[r], expected)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def run_allgather_into_tensor_coalesced_test(
|
|
172
|
+
pg: ProcessGroup, rank: int, tensor: torch.Tensor
|
|
173
|
+
) -> None:
|
|
174
|
+
"""Test allgather tensor coalesced collective operation.
|
|
175
|
+
|
|
176
|
+
This example gathers two local tensors, T0 and T1, from each rank into corresponding
|
|
177
|
+
output tensors.
|
|
178
|
+
|
|
179
|
+
For world_sz = n, each rank r has:
|
|
180
|
+
T0 = [r+1],
|
|
181
|
+
T1 = [r+10]
|
|
182
|
+
|
|
183
|
+
After allgather_into_tensor_coalesced, we result in two tensors: out0, out1,
|
|
184
|
+
both length n.
|
|
185
|
+
|
|
186
|
+
out0 gathers T0 from all ranks, out1 gathers T1 from all ranks.
|
|
187
|
+
|
|
188
|
+
We verify that out0[k] == [k+1] and out1[k] == [k+10] for all k.
|
|
189
|
+
|
|
190
|
+
"""
|
|
191
|
+
world_sz = pg.size()
|
|
192
|
+
|
|
193
|
+
if world_sz < 2:
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
t0 = torch.tensor([rank + 1], device=tensor.device, dtype=tensor.dtype)
|
|
197
|
+
t1 = torch.tensor([rank + 10], device=tensor.device, dtype=tensor.dtype)
|
|
198
|
+
|
|
199
|
+
out0 = torch.zeros(world_sz, device=tensor.device, dtype=tensor.dtype)
|
|
200
|
+
out1 = torch.zeros(world_sz, device=tensor.device, dtype=tensor.dtype)
|
|
201
|
+
|
|
202
|
+
work = pg.allgather_into_tensor_coalesced(
|
|
203
|
+
[out0, out1], [t0, t1], AllgatherOptions()
|
|
204
|
+
)
|
|
205
|
+
work.wait()
|
|
206
|
+
|
|
207
|
+
for r in range(world_sz):
|
|
208
|
+
expected0 = torch.tensor([r + 1], device=t0.device, dtype=t0.dtype)
|
|
209
|
+
torch.testing.assert_close(out0[r], expected0[0])
|
|
210
|
+
expected1 = torch.tensor([r + 10], device=t1.device, dtype=t1.dtype)
|
|
211
|
+
torch.testing.assert_close(out1[r], expected1[0])
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def run_allreduce_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
215
|
+
"""Test allreduce collective operation.
|
|
216
|
+
|
|
217
|
+
Assume each rank's tensor has value = rank + 1.
|
|
218
|
+
The final result after allreduce(SUM) should be sum(r=1,...,world_sz-1).
|
|
219
|
+
"""
|
|
220
|
+
tc = tensor.clone()
|
|
221
|
+
world_sz = pg.size()
|
|
222
|
+
work = pg.allreduce([tc], ReduceOp.SUM)
|
|
223
|
+
work.wait()
|
|
224
|
+
expected_val = sum(r + 1 for r in range(world_sz))
|
|
225
|
+
torch.testing.assert_close(tc, torch.tensor([expected_val], device=tensor.device))
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def run_allreduce_coalesced_test(
|
|
229
|
+
pg: ProcessGroup, rank: int, tensor: torch.Tensor
|
|
230
|
+
) -> None:
|
|
231
|
+
"""Test allreduce_coalesced collective operation.
|
|
232
|
+
|
|
233
|
+
Assume each rank's tensor has value = rank + 1.
|
|
234
|
+
We coalesce 1 tensors:
|
|
235
|
+
- t0 = [rank + 1]
|
|
236
|
+
- t1 = [rank + 2]
|
|
237
|
+
|
|
238
|
+
Our final sum should be sum(r=1,...,world_sz-1) + sum(r=2,...,world_sz-1).
|
|
239
|
+
"""
|
|
240
|
+
world_sz = pg.size()
|
|
241
|
+
t0 = tensor.clone()
|
|
242
|
+
t1 = tensor.clone() + 1
|
|
243
|
+
work = pg.allreduce_coalesced([t0, t1], AllreduceCoalescedOptions())
|
|
244
|
+
work.wait()
|
|
245
|
+
sum_t0 = sum(r + 1 for r in range(world_sz))
|
|
246
|
+
sum_t1 = sum(r + 2 for r in range(world_sz))
|
|
247
|
+
torch.testing.assert_close(t0, torch.tensor([sum_t0], device=t0.device))
|
|
248
|
+
torch.testing.assert_close(t1, torch.tensor([sum_t1], device=t1.device))
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def run_alltoall_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
252
|
+
"""Test all-to-all collective operation.
|
|
253
|
+
|
|
254
|
+
Suppose each rank's local tensor = [rank*ws+1, rank*ws+2, ..., rank*ws + n]
|
|
255
|
+
|
|
256
|
+
e.g.:
|
|
257
|
+
rank=0 => [1,2]
|
|
258
|
+
rank=1 => [3,4]
|
|
259
|
+
|
|
260
|
+
After all-to-all, rank r's output[k] = the element from rank k that is destined for rank r,
|
|
261
|
+
e.g.: (k*n) + (r+1):
|
|
262
|
+
|
|
263
|
+
rank=0 => [1,3]
|
|
264
|
+
rank=1 => [2,4]
|
|
265
|
+
|
|
266
|
+
"""
|
|
267
|
+
world_sz = pg.size()
|
|
268
|
+
if world_sz < 2:
|
|
269
|
+
return
|
|
270
|
+
|
|
271
|
+
input_tensor = torch.arange(
|
|
272
|
+
start=rank * world_sz + 1,
|
|
273
|
+
end=rank * world_sz + 1 + world_sz,
|
|
274
|
+
device=tensor.device,
|
|
275
|
+
dtype=tensor.dtype,
|
|
276
|
+
)
|
|
277
|
+
output_tensor = torch.empty(world_sz, device=tensor.device, dtype=tensor.dtype)
|
|
278
|
+
|
|
279
|
+
send_sz = [1] * world_sz
|
|
280
|
+
recv_sz = [1] * world_sz
|
|
281
|
+
|
|
282
|
+
alltoall_work = pg.alltoall_base(
|
|
283
|
+
output_tensor, input_tensor, send_sz, recv_sz, AllToAllOptions()
|
|
284
|
+
)
|
|
285
|
+
alltoall_work.wait()
|
|
286
|
+
|
|
287
|
+
expected = torch.empty(world_sz, device=tensor.device, dtype=tensor.dtype)
|
|
288
|
+
for k in range(world_sz):
|
|
289
|
+
val = k * world_sz + (rank + 1)
|
|
290
|
+
expected[k] = val
|
|
291
|
+
|
|
292
|
+
torch.testing.assert_close(output_tensor, expected)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def run_broadcast_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
296
|
+
"""Test broadcast collective operation.
|
|
297
|
+
|
|
298
|
+
rank0 will broadcast a known value and all other ranks should get it.
|
|
299
|
+
"""
|
|
300
|
+
broadcast_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor)
|
|
301
|
+
broadcast_work = pg.broadcast([broadcast_tensor], BroadcastOptions())
|
|
302
|
+
broadcast_work.wait()
|
|
303
|
+
expected_broadcast = torch.tensor([1], device=tensor.device)
|
|
304
|
+
torch.testing.assert_close(broadcast_tensor, expected_broadcast)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def run_broadcast_one_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
308
|
+
"""Test broadcast_one collective operation.
|
|
309
|
+
|
|
310
|
+
rank0 will broadcast a known value and all other ranks should get it.
|
|
311
|
+
"""
|
|
312
|
+
broadcast_one_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor)
|
|
313
|
+
broadcast_one_work = pg.broadcast_one(broadcast_one_tensor, 0)
|
|
314
|
+
broadcast_one_work.wait()
|
|
315
|
+
torch.testing.assert_close(
|
|
316
|
+
broadcast_one_tensor, torch.tensor([1], device=tensor.device)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def run_barrier_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
321
|
+
"""Test barrier collective operation."""
|
|
322
|
+
opts = BarrierOptions()
|
|
323
|
+
if tensor.is_cuda:
|
|
324
|
+
device_id = tensor.device.index
|
|
325
|
+
opts.device_ids = [device_id]
|
|
326
|
+
barrier_work = pg.barrier(opts)
|
|
327
|
+
barrier_work.wait()
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def run_send_recv_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
331
|
+
"""Test send/recv point-to-point operations.
|
|
332
|
+
|
|
333
|
+
Simple point-to-point between ranks 0 and 1, ignored for other ranks.
|
|
334
|
+
"""
|
|
335
|
+
if pg.size() < 2:
|
|
336
|
+
return
|
|
337
|
+
if rank == 0:
|
|
338
|
+
send_tensor = tensor.clone()
|
|
339
|
+
send_work = pg.send([send_tensor], 1, 0)
|
|
340
|
+
send_work.wait()
|
|
341
|
+
elif rank == 1:
|
|
342
|
+
recv_tensor = torch.zeros_like(tensor)
|
|
343
|
+
recv_work = pg.recv([recv_tensor], 0, 0)
|
|
344
|
+
recv_work.wait()
|
|
345
|
+
expected = torch.tensor([1], device=tensor.device)
|
|
346
|
+
torch.testing.assert_close(recv_tensor, expected)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def run_reduce_scatter_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
|
|
350
|
+
"""Test reduce_scatter collective operation.
|
|
351
|
+
|
|
352
|
+
Assume each rank creates a matrix where each row r contains values:
|
|
353
|
+
[r * world_sz + 1, ..., r * world_sz + world_sz]
|
|
354
|
+
|
|
355
|
+
For example, with world_size=2:
|
|
356
|
+
[[1, 2],
|
|
357
|
+
[3, 4]]
|
|
358
|
+
|
|
359
|
+
The reduce_scatter operation then:
|
|
360
|
+
- Reduces (sums) corresponding rows across all ranks
|
|
361
|
+
- Scatters the results so each rank gets one row of the final sum
|
|
362
|
+
- Since all ranks had the same initial data, the expected result for each rank r is:
|
|
363
|
+
rank r receives: [rworld_sz + 1, ..., rworld_sz + world_sz] * world_sz
|
|
364
|
+
|
|
365
|
+
For example, with 2 ranks:
|
|
366
|
+
rank 0 gets: [1, 2] * 2 = [2, 4] (first row)
|
|
367
|
+
rank 1 gets: [3, 4] * 2 = [6, 8] (second row)
|
|
368
|
+
"""
|
|
369
|
+
if tensor.device.type == "cpu":
|
|
370
|
+
return
|
|
371
|
+
# reduce scatter not supported on GLOO
|
|
372
|
+
world_sz = pg.size()
|
|
373
|
+
if world_sz < 2:
|
|
374
|
+
return
|
|
375
|
+
|
|
376
|
+
local_data = []
|
|
377
|
+
for r in range(world_sz):
|
|
378
|
+
row_vals = torch.arange(
|
|
379
|
+
start=r * world_sz + 1,
|
|
380
|
+
end=r * world_sz + world_sz + 1,
|
|
381
|
+
device=tensor.device,
|
|
382
|
+
dtype=torch.float32,
|
|
383
|
+
)
|
|
384
|
+
local_data.append(row_vals)
|
|
385
|
+
|
|
386
|
+
out = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32)
|
|
387
|
+
opts = ReduceScatterOptions()
|
|
388
|
+
opts.reduceOp = ReduceOp.SUM
|
|
389
|
+
work = pg.reduce_scatter([out], [local_data], opts)
|
|
390
|
+
work.wait()
|
|
391
|
+
|
|
392
|
+
expected_row = torch.arange(
|
|
393
|
+
start=rank * world_sz + 1,
|
|
394
|
+
end=rank * world_sz + world_sz + 1,
|
|
395
|
+
device=tensor.device,
|
|
396
|
+
dtype=torch.float32,
|
|
397
|
+
)
|
|
398
|
+
expected_sum = expected_row * world_sz
|
|
399
|
+
torch.testing.assert_close(out, expected_sum)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def run_reduce_scatter_tensor_coalesced_test(
|
|
403
|
+
pg: ProcessGroup, rank: int, tensor: torch.Tensor
|
|
404
|
+
) -> None:
|
|
405
|
+
"""Test reduce_scatter tensor coalesced collective operation.
|
|
406
|
+
|
|
407
|
+
We define two 2D tensors, each shaped [world_sz, world_sz] which is replicated on each rank.
|
|
408
|
+
|
|
409
|
+
reduce_scatter coalesced will reduce each row of each tensor, then scatter the results to each rank.
|
|
410
|
+
Because these are replicated on all ranks, the reduced sum for each row is:
|
|
411
|
+
[r*world_sz + 1, ..., r*world_sz + world_sz] * world_sz
|
|
412
|
+
|
|
413
|
+
For example, with 2 ranks:
|
|
414
|
+
rank 0 gets: [1, 2] * 2 = [2, 4] (first row)
|
|
415
|
+
rank 1 gets: [3, 4] * 2 = [6, 8] (second row)
|
|
416
|
+
For example, with 2 ranks:
|
|
417
|
+
rank 0 gets: [1, 2] * 2 = [2, 4] (first row)
|
|
418
|
+
rank 1 gets: [3, 4] * 2 = [6, 8] (second row)
|
|
419
|
+
|
|
420
|
+
"""
|
|
421
|
+
world_sz = pg.size()
|
|
422
|
+
if world_sz < 2:
|
|
423
|
+
return # skip trivial
|
|
424
|
+
|
|
425
|
+
# Build m0, m1 (each is a list of n rows) fully replicated on all ranks
|
|
426
|
+
m0 = []
|
|
427
|
+
m1 = []
|
|
428
|
+
for r in range(world_sz):
|
|
429
|
+
row0 = torch.arange(
|
|
430
|
+
start=r * world_sz + 1,
|
|
431
|
+
end=r * world_sz + world_sz + 1,
|
|
432
|
+
device=tensor.device,
|
|
433
|
+
dtype=torch.float32,
|
|
434
|
+
)
|
|
435
|
+
row1 = torch.arange(
|
|
436
|
+
start=r * world_sz + 100,
|
|
437
|
+
end=r * world_sz + 100 + world_sz,
|
|
438
|
+
device=tensor.device,
|
|
439
|
+
dtype=torch.float32,
|
|
440
|
+
)
|
|
441
|
+
m0.append(row0)
|
|
442
|
+
m1.append(row1)
|
|
443
|
+
|
|
444
|
+
# Each rank receives one "row" for m0, one row for m1, after reduce_scatter_coalesced
|
|
445
|
+
out0 = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32)
|
|
446
|
+
out1 = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32)
|
|
447
|
+
|
|
448
|
+
opts = ReduceScatterOptions()
|
|
449
|
+
opts.reduceOp = ReduceOp.SUM
|
|
450
|
+
|
|
451
|
+
m0 = torch.stack(m0)
|
|
452
|
+
m1 = torch.stack(m1)
|
|
453
|
+
|
|
454
|
+
work = pg.reduce_scatter_tensor_coalesced([out0, out1], [m0, m1], opts)
|
|
455
|
+
work.wait()
|
|
456
|
+
|
|
457
|
+
base0 = (
|
|
458
|
+
torch.arange(
|
|
459
|
+
start=rank * world_sz + 1,
|
|
460
|
+
end=rank * world_sz + world_sz + 1,
|
|
461
|
+
device=tensor.device,
|
|
462
|
+
dtype=torch.float32,
|
|
463
|
+
)
|
|
464
|
+
* world_sz
|
|
465
|
+
)
|
|
466
|
+
base1 = (
|
|
467
|
+
torch.arange(
|
|
468
|
+
start=rank * world_sz + 100,
|
|
469
|
+
end=rank * world_sz + 100 + world_sz,
|
|
470
|
+
device=tensor.device,
|
|
471
|
+
dtype=torch.float32,
|
|
472
|
+
)
|
|
473
|
+
* world_sz
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
torch.testing.assert_close(out0, base0)
|
|
477
|
+
torch.testing.assert_close(out1, base1)
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
_COLLECTIVE_TO_FUNC: Dict[str, Callable[[ProcessGroup, int, torch.Tensor], None]] = {
|
|
481
|
+
"allgather": run_allgather_test,
|
|
482
|
+
"allgather_into_tensor_coalesced": run_allgather_into_tensor_coalesced_test,
|
|
483
|
+
"allreduce": run_allreduce_test,
|
|
484
|
+
"allreduce_coalesced": run_allreduce_coalesced_test,
|
|
485
|
+
"alltoall_base": run_alltoall_test,
|
|
486
|
+
"barrier": run_barrier_test,
|
|
487
|
+
"broadcast": run_broadcast_test,
|
|
488
|
+
"broadcast_one": run_broadcast_one_test,
|
|
489
|
+
"reduce_scatter": run_reduce_scatter_test,
|
|
490
|
+
"reduce_scatter_tensor_coalesced": run_reduce_scatter_tensor_coalesced_test,
|
|
491
|
+
"send/recv": run_send_recv_test,
|
|
492
|
+
}
|
|
493
|
+
_ALL_COLLECTIVES: List[str] = list(_COLLECTIVE_TO_FUNC.keys())
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class ProcessGroupTest(TestCase):
|
|
497
|
+
@parameterized.expand(["cpu", "cuda"])
|
|
498
|
+
def test_gloo_apis(self, device: str) -> None:
|
|
499
|
+
if device == "cuda" and not torch.cuda.is_available():
|
|
500
|
+
self.skipTest("CUDA is not available")
|
|
501
|
+
return
|
|
502
|
+
|
|
503
|
+
store = TCPStore(
|
|
504
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
508
|
+
pg = ProcessGroupGloo()
|
|
509
|
+
pg.configure(store_addr, "0", 0, 1)
|
|
510
|
+
|
|
511
|
+
self.assertEqual(pg.size(), 1)
|
|
512
|
+
|
|
513
|
+
_test_pg(
|
|
514
|
+
pg,
|
|
515
|
+
torch.tensor([2], device=device),
|
|
516
|
+
skip=(
|
|
517
|
+
# https://github.com/pytorch/pytorch/issues/152645
|
|
518
|
+
[
|
|
519
|
+
"allreduce_coalesced",
|
|
520
|
+
"allgather_into_tensor_coalesced",
|
|
521
|
+
]
|
|
522
|
+
if device == "cuda"
|
|
523
|
+
else []
|
|
524
|
+
),
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
m = nn.Linear(3, 4).to(device)
|
|
528
|
+
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
|
|
529
|
+
m(torch.rand(2, 3, device=device))
|
|
530
|
+
|
|
531
|
+
def test_gloo_timeout(self) -> None:
|
|
532
|
+
store = TCPStore(
|
|
533
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
537
|
+
pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01))
|
|
538
|
+
with self.assertRaisesRegex(
|
|
539
|
+
RuntimeError, "(timeout after 10ms|Socket Timeout)"
|
|
540
|
+
):
|
|
541
|
+
pg.configure(store_addr, "0", 0, 2)
|
|
542
|
+
|
|
543
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
544
|
+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
|
|
545
|
+
def test_nccl_apis(self) -> None:
|
|
546
|
+
store = TCPStore(
|
|
547
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
548
|
+
)
|
|
549
|
+
device = "cuda"
|
|
550
|
+
|
|
551
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
552
|
+
pg = ProcessGroupNCCL()
|
|
553
|
+
pg.configure(store_addr, "0", 0, 1)
|
|
554
|
+
|
|
555
|
+
self.assertEqual(pg.size(), 1)
|
|
556
|
+
|
|
557
|
+
_test_pg(pg, torch.tensor([2], device=device))
|
|
558
|
+
|
|
559
|
+
m = nn.Linear(3, 4).to(device)
|
|
560
|
+
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
|
|
561
|
+
m(torch.rand(2, 3, device=device))
|
|
562
|
+
|
|
563
|
+
# reconfigure
|
|
564
|
+
store_addr = f"localhost:{store.port}/prefix2"
|
|
565
|
+
pg.configure(store_addr, "0", 0, 1)
|
|
566
|
+
|
|
567
|
+
_test_pg(pg, torch.tensor([2], device=device))
|
|
568
|
+
|
|
569
|
+
torch.cuda.synchronize()
|
|
570
|
+
|
|
571
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
572
|
+
@skipUnless(
|
|
573
|
+
torch.cuda.is_available() and torch.cuda.nccl.version() >= (2, 25),
|
|
574
|
+
"needs NCCL >=2.25",
|
|
575
|
+
)
|
|
576
|
+
@patch("torchft.process_group.stream_timeout", autospec=True)
|
|
577
|
+
@patch("torchft.process_group.context_timeout", autospec=True)
|
|
578
|
+
def test_nccl_timeouts(
|
|
579
|
+
self, mock_context_timeout: Mock, mock_stream_timeout: Mock
|
|
580
|
+
) -> None:
|
|
581
|
+
store = TCPStore(
|
|
582
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
583
|
+
)
|
|
584
|
+
device = "cuda"
|
|
585
|
+
|
|
586
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
587
|
+
pg = ProcessGroupNCCL()
|
|
588
|
+
pg.configure(store_addr, "0", 0, 1)
|
|
589
|
+
|
|
590
|
+
t = torch.tensor([2], device=device)
|
|
591
|
+
pg.allreduce([t], ReduceOp.SUM).wait()
|
|
592
|
+
self.assertEqual(mock_stream_timeout.call_count, 1)
|
|
593
|
+
self.assertEqual(mock_context_timeout.return_value.__enter__.call_count, 2)
|
|
594
|
+
|
|
595
|
+
pg.allreduce([t], ReduceOp.SUM).get_future().wait()
|
|
596
|
+
self.assertEqual(mock_stream_timeout.call_count, 2)
|
|
597
|
+
self.assertEqual(mock_context_timeout.return_value.__enter__.call_count, 4)
|
|
598
|
+
|
|
599
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
|
600
|
+
@skipUnless(
|
|
601
|
+
torch.cuda.is_available(),
|
|
602
|
+
"needs CUDA",
|
|
603
|
+
)
|
|
604
|
+
def test_nccl_init_timeout(self) -> None:
|
|
605
|
+
store = TCPStore(
|
|
606
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
607
|
+
)
|
|
608
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
609
|
+
del store
|
|
610
|
+
|
|
611
|
+
pg = ProcessGroupNCCL(timeout=timedelta(seconds=0.01))
|
|
612
|
+
|
|
613
|
+
with self.assertRaisesRegex(RuntimeError, "timed out after 10ms"):
|
|
614
|
+
pg.configure(store_addr, "0", 0, 2)
|
|
615
|
+
|
|
616
|
+
def test_baby_gloo_timeout(self) -> None:
|
|
617
|
+
store = TCPStore(
|
|
618
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
622
|
+
|
|
623
|
+
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=0.01))
|
|
624
|
+
with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
|
|
625
|
+
a.configure(store_addr, "0", 0, 2)
|
|
626
|
+
|
|
627
|
+
def test_reconfigure_baby_process_group(self) -> None:
|
|
628
|
+
store = TCPStore(
|
|
629
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
630
|
+
)
|
|
631
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
632
|
+
|
|
633
|
+
a = ProcessGroupBabyGloo()
|
|
634
|
+
a.configure(store_addr, "0", 0, 1)
|
|
635
|
+
future_thread_1 = a._future_thread
|
|
636
|
+
future_pipe_1 = a._future_pipe
|
|
637
|
+
p_1 = a._p
|
|
638
|
+
|
|
639
|
+
store_addr = f"localhost:{store.port}/prefix2"
|
|
640
|
+
a.configure(store_addr, "0", 0, 1)
|
|
641
|
+
future_thread_2 = a._future_thread
|
|
642
|
+
future_pipe_2 = a._future_pipe
|
|
643
|
+
p_2 = a._p
|
|
644
|
+
|
|
645
|
+
self.assertNotEqual(future_thread_1, future_thread_2)
|
|
646
|
+
self.assertNotEqual(future_pipe_1, future_pipe_2)
|
|
647
|
+
self.assertNotEqual(p_1, p_2)
|
|
648
|
+
|
|
649
|
+
assert future_thread_1 is not None
|
|
650
|
+
self.assertFalse(future_thread_1.is_alive())
|
|
651
|
+
assert future_pipe_1 is not None
|
|
652
|
+
self.assertTrue(future_pipe_1.closed())
|
|
653
|
+
assert p_1 is not None
|
|
654
|
+
self.assertFalse(p_1.is_alive())
|
|
655
|
+
|
|
656
|
+
assert future_thread_2 is not None
|
|
657
|
+
self.assertTrue(future_thread_2.is_alive())
|
|
658
|
+
assert future_pipe_2 is not None
|
|
659
|
+
self.assertFalse(future_pipe_2.closed())
|
|
660
|
+
assert p_2 is not None
|
|
661
|
+
self.assertTrue(p_2.is_alive())
|
|
662
|
+
|
|
663
|
+
def test_baby_gloo_apis(self) -> None:
|
|
664
|
+
store = TCPStore(
|
|
665
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
669
|
+
|
|
670
|
+
a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
|
|
671
|
+
try:
|
|
672
|
+
a.configure(store_addr, "0", 0, 1)
|
|
673
|
+
|
|
674
|
+
_test_pg(a)
|
|
675
|
+
|
|
676
|
+
# force collection to ensure no BabyWork objects remain
|
|
677
|
+
gc.collect()
|
|
678
|
+
|
|
679
|
+
self.assertEqual(a.num_active_work(), 0)
|
|
680
|
+
|
|
681
|
+
finally:
|
|
682
|
+
a.shutdown()
|
|
683
|
+
|
|
684
|
+
t = torch.zeros(10)
|
|
685
|
+
with self.assertRaisesRegex(OSError, "handle is closed"):
|
|
686
|
+
a.allreduce([t], AllreduceOptions()).wait()
|
|
687
|
+
|
|
688
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
689
|
+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
|
|
690
|
+
def test_baby_nccl_apis(self) -> None:
|
|
691
|
+
# set to 1 if more than >=2 gpus
|
|
692
|
+
device_id = 1 % torch.cuda.device_count()
|
|
693
|
+
torch.cuda.set_device(device_id)
|
|
694
|
+
|
|
695
|
+
store = TCPStore(
|
|
696
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
700
|
+
|
|
701
|
+
a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
|
|
702
|
+
try:
|
|
703
|
+
a.configure(store_addr, "0", 0, 1)
|
|
704
|
+
|
|
705
|
+
_test_pg(a, torch.randn((2, 3), device="cuda"))
|
|
706
|
+
|
|
707
|
+
torch.cuda.synchronize()
|
|
708
|
+
|
|
709
|
+
# force collection to ensure no BabyWork objects remain
|
|
710
|
+
gc.collect()
|
|
711
|
+
|
|
712
|
+
self.assertEqual(a.num_active_work(), 0)
|
|
713
|
+
finally:
|
|
714
|
+
a.shutdown()
|
|
715
|
+
torch.cuda.synchronize()
|
|
716
|
+
torch.cuda.empty_cache()
|
|
717
|
+
|
|
718
|
+
t = torch.zeros(10)
|
|
719
|
+
with self.assertRaisesRegex(OSError, "handle is closed"):
|
|
720
|
+
a.allreduce([t], AllreduceOptions()).wait()
|
|
721
|
+
|
|
722
|
+
def test_dummy(self) -> None:
|
|
723
|
+
pg = ProcessGroupDummy(0, 1)
|
|
724
|
+
m = nn.Linear(3, 4)
|
|
725
|
+
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
|
|
726
|
+
m(torch.rand(2, 3))
|
|
727
|
+
|
|
728
|
+
def test_functional_collectives(self) -> None:
|
|
729
|
+
dummy_init_pg()
|
|
730
|
+
|
|
731
|
+
store = TCPStore(
|
|
732
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
733
|
+
)
|
|
734
|
+
store_addr = f"localhost:{store.port}/prefix"
|
|
735
|
+
|
|
736
|
+
pg = ProcessGroupGloo().register("test_func_col")
|
|
737
|
+
pg.configure(store_addr, "0", 0, 1)
|
|
738
|
+
|
|
739
|
+
self.assertEqual(pg.group_name, str(dist.get_pg_count() - 1))
|
|
740
|
+
|
|
741
|
+
self.assertIs(
|
|
742
|
+
_resolve_process_group(pg.group_name), # pyre-ignore[6]: GroupName vs str
|
|
743
|
+
pg,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
try:
|
|
747
|
+
t = torch.zeros(10)
|
|
748
|
+
_functional_collectives.all_reduce(t, "sum", pg).wait()
|
|
749
|
+
finally:
|
|
750
|
+
pg.unregister()
|
|
751
|
+
|
|
752
|
+
def test_process_group_wrapper(self) -> None:
|
|
753
|
+
pg = ProcessGroupDummy(0, 1)
|
|
754
|
+
wrapper = ProcessGroupWrapper(pg=pg)
|
|
755
|
+
self.assertIs(wrapper.parent, pg)
|
|
756
|
+
|
|
757
|
+
wrapper.configure("addr", "0", 0, 1)
|
|
758
|
+
self.assertEqual(pg.configure_count, 1)
|
|
759
|
+
|
|
760
|
+
self.assertEqual(repr(wrapper), "ProcessGroupWrapper(pg=ProcessGroupDummy())")
|
|
761
|
+
|
|
762
|
+
def test_error_swallowing_process_group_wrapper(self) -> None:
|
|
763
|
+
pg = ProcessGroupDummy(0, 1)
|
|
764
|
+
wrapper = ErrorSwallowingProcessGroupWrapper(pg)
|
|
765
|
+
self.assertIs(wrapper.parent, pg)
|
|
766
|
+
|
|
767
|
+
works = _test_pg(wrapper)
|
|
768
|
+
self.assertIsInstance(list(works.values())[0], _ErrorSwallowingWork)
|
|
769
|
+
|
|
770
|
+
err = RuntimeError("test")
|
|
771
|
+
wrapper.report_error(err)
|
|
772
|
+
self.assertEqual(wrapper.error(), err)
|
|
773
|
+
|
|
774
|
+
works = _test_pg(wrapper)
|
|
775
|
+
for work in works.values():
|
|
776
|
+
self.assertIsInstance(work, _DummyWork)
|
|
777
|
+
|
|
778
|
+
def test_managed_process_group(self) -> None:
|
|
779
|
+
manager = Mock(spec=Manager)
|
|
780
|
+
manager.errored.return_value = None
|
|
781
|
+
manager._pg = ProcessGroupDummy(0, 1)
|
|
782
|
+
pg = ManagedProcessGroup(manager)
|
|
783
|
+
manager.num_participants.return_value = 123
|
|
784
|
+
|
|
785
|
+
self.assertEqual(pg.size(), 123)
|
|
786
|
+
|
|
787
|
+
works = _test_pg(pg)
|
|
788
|
+
|
|
789
|
+
self.assertEqual(manager.allreduce.call_count, 2)
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
class MultiPgBaseTest(TestCase):
|
|
793
|
+
"""
|
|
794
|
+
A base test that creates N processes (via ThreadPoolExecutor) sharing
|
|
795
|
+
a single ProcessGroup. Each test_* method will reuse the same PG.
|
|
796
|
+
|
|
797
|
+
Subclasses can specify:
|
|
798
|
+
- BACKEND: the backend to use for the ProcessGroup ("gloo" or "nccl")
|
|
799
|
+
- WORLD_SIZE: how many ranks to simulate
|
|
800
|
+
- Additional config for the PG, i.e. timeouts.
|
|
801
|
+
"""
|
|
802
|
+
|
|
803
|
+
BACKEND = "gloo"
|
|
804
|
+
WORLD_SIZE = 2
|
|
805
|
+
|
|
806
|
+
@classmethod
|
|
807
|
+
def setUpClass(cls) -> None:
|
|
808
|
+
super().setUpClass()
|
|
809
|
+
|
|
810
|
+
cls.store = TCPStore(
|
|
811
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
812
|
+
)
|
|
813
|
+
cls.store_addr = f"localhost:{cls.store.port}/prefix"
|
|
814
|
+
|
|
815
|
+
cls.pg_pool: List[ProcessGroup] = []
|
|
816
|
+
|
|
817
|
+
cls.executor = ThreadPoolExecutor(max_workers=cls.WORLD_SIZE)
|
|
818
|
+
|
|
819
|
+
def init_pg(rank: int) -> ProcessGroup:
|
|
820
|
+
if torch.accelerator.is_available():
|
|
821
|
+
torch.accelerator.set_device_idx(rank)
|
|
822
|
+
pg = cls._create_pg(cls.BACKEND)
|
|
823
|
+
pg.configure(cls.store_addr, "0", rank, cls.WORLD_SIZE)
|
|
824
|
+
return pg
|
|
825
|
+
|
|
826
|
+
futures = [cls.executor.submit(init_pg, rank) for rank in range(cls.WORLD_SIZE)]
|
|
827
|
+
cls.pg_pool = [future.result() for future in futures]
|
|
828
|
+
|
|
829
|
+
@classmethod
|
|
830
|
+
def tearDownClass(cls) -> None:
|
|
831
|
+
# Cleanup
|
|
832
|
+
for pg in cls.pg_pool:
|
|
833
|
+
shutdown = getattr(pg, "shutdown", None)
|
|
834
|
+
if shutdown is not None:
|
|
835
|
+
shutdown()
|
|
836
|
+
cls.executor.shutdown(wait=True)
|
|
837
|
+
super().tearDownClass()
|
|
838
|
+
|
|
839
|
+
@classmethod
|
|
840
|
+
def _create_pg(cls, backend: str) -> ProcessGroup:
|
|
841
|
+
"""
|
|
842
|
+
Helper that creates a new ProcessGroup of the specified type.
|
|
843
|
+
|
|
844
|
+
NCCL groups aren't currently supported - we prefer to test
|
|
845
|
+
BabyNCCLGroups as they spin up their own subprocesses.
|
|
846
|
+
"""
|
|
847
|
+
if backend == "gloo":
|
|
848
|
+
return ProcessGroupGloo(timeout=timedelta(seconds=1))
|
|
849
|
+
elif backend == "baby_gloo":
|
|
850
|
+
return ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
|
|
851
|
+
elif backend == "nccl":
|
|
852
|
+
return ProcessGroupNCCL(timeout=timedelta(seconds=10))
|
|
853
|
+
elif backend == "baby_nccl":
|
|
854
|
+
return ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
|
|
855
|
+
elif backend == "dummy":
|
|
856
|
+
return ProcessGroupDummy(0, 1)
|
|
857
|
+
else:
|
|
858
|
+
raise NotImplementedError(f"Unsupported backend: {backend}")
|
|
859
|
+
|
|
860
|
+
def _run_parallel(self, collective: str, device: str = "cpu") -> None:
|
|
861
|
+
"""
|
|
862
|
+
Helper to run on all ranks in parallel, returning a list
|
|
863
|
+
of results or raising an exception if any fail.
|
|
864
|
+
"""
|
|
865
|
+
func = _COLLECTIVE_TO_FUNC[collective]
|
|
866
|
+
|
|
867
|
+
futures = []
|
|
868
|
+
for rank in range(self.WORLD_SIZE):
|
|
869
|
+
pg = self.pg_pool[rank]
|
|
870
|
+
# Each worker calls `func(pg=pg, rank=rank, tensor=tensor, *args, **kwargs)`
|
|
871
|
+
if "cuda" in device:
|
|
872
|
+
device = f"cuda:{rank}"
|
|
873
|
+
tensor = torch.tensor([rank + 1], device=device)
|
|
874
|
+
|
|
875
|
+
fut = self.executor.submit(func, pg, rank, tensor)
|
|
876
|
+
futures.append(fut)
|
|
877
|
+
|
|
878
|
+
self._collect(futures)
|
|
879
|
+
|
|
880
|
+
def _collect(self, futs: list[Future]) -> None:
|
|
881
|
+
for i, f in enumerate(futs):
|
|
882
|
+
try:
|
|
883
|
+
res = f.result() # timeout=10)
|
|
884
|
+
if res:
|
|
885
|
+
print(f"Rank {i}: {res}")
|
|
886
|
+
except Exception as e:
|
|
887
|
+
print(f"Rank {i}: {e}")
|
|
888
|
+
raise
|
|
889
|
+
|
|
890
|
+
def _run_with_resiliency(self, collective: str, device: str = "cpu") -> None:
|
|
891
|
+
"""
|
|
892
|
+
Run a collective with resiliency:
|
|
893
|
+
- fault_rank (last rank) simulates a crash.
|
|
894
|
+
- surviving ranks detect the error, then reconfigure PG to exclude fault_rank.
|
|
895
|
+
- surviving ranks run the same collective again successfully.
|
|
896
|
+
"""
|
|
897
|
+
|
|
898
|
+
def worker(pg: ProcessGroup, rank: int, dev: str) -> str:
|
|
899
|
+
pg.set_timeout(timedelta(seconds=30))
|
|
900
|
+
|
|
901
|
+
if dev == "cuda":
|
|
902
|
+
torch.cuda.set_device(rank)
|
|
903
|
+
# Use a separate stream to avoid deadlocks between threads.
|
|
904
|
+
torch.cuda.set_stream(torch.cuda.Stream())
|
|
905
|
+
|
|
906
|
+
fault_rank = self.WORLD_SIZE - 1
|
|
907
|
+
test = _COLLECTIVE_TO_FUNC[collective]
|
|
908
|
+
|
|
909
|
+
# Re-configure the PG to exclude the fault rank
|
|
910
|
+
new_store_addr = f"localhost:{self.store.port}/reconfig_{collective}"
|
|
911
|
+
|
|
912
|
+
pg.configure(new_store_addr, "0", rank, self.WORLD_SIZE)
|
|
913
|
+
|
|
914
|
+
# run the same collective again successfully
|
|
915
|
+
t2 = torch.tensor([rank + 1], device=dev)
|
|
916
|
+
test(pg, rank, t2)
|
|
917
|
+
|
|
918
|
+
# Simulate a failure
|
|
919
|
+
|
|
920
|
+
t1 = torch.tensor([rank + 1], device=dev)
|
|
921
|
+
# Simulate failure on the fault rank, but other ranks should still succeed.
|
|
922
|
+
if rank == fault_rank:
|
|
923
|
+
pg.shutdown()
|
|
924
|
+
return f"Rank{rank} crashed"
|
|
925
|
+
|
|
926
|
+
pg.set_timeout(timedelta(seconds=1))
|
|
927
|
+
|
|
928
|
+
# We hardcode the list of expected errors.
|
|
929
|
+
# gloo: Connection closed by peer, timed out waiting, no error, read error
|
|
930
|
+
# nccl: Tensor-likes are not equal/not close (due to abort)
|
|
931
|
+
with self.assertRaisesRegex(
|
|
932
|
+
Exception,
|
|
933
|
+
r"(Connection closed by peer|timed out after|Timed out waiting|no error|Read error|not equal|not close|process group not initialized)",
|
|
934
|
+
):
|
|
935
|
+
test(pg, rank, t1.clone())
|
|
936
|
+
raise RuntimeError("no error")
|
|
937
|
+
|
|
938
|
+
if err := pg.errored():
|
|
939
|
+
with self.assertRaisesRegex(RuntimeError, "aborted"):
|
|
940
|
+
raise err
|
|
941
|
+
|
|
942
|
+
return f"Rank{rank} final success."
|
|
943
|
+
|
|
944
|
+
# run in parallel
|
|
945
|
+
futs = [
|
|
946
|
+
self.executor.submit(worker, self.pg_pool[r], r, device)
|
|
947
|
+
for r in range(self.WORLD_SIZE)
|
|
948
|
+
]
|
|
949
|
+
self._collect(futs)
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
class NormalGlooMultiPgTest(MultiPgBaseTest):
|
|
953
|
+
BACKEND = "gloo"
|
|
954
|
+
WORLD_SIZE = 3
|
|
955
|
+
SKIP = [
|
|
956
|
+
"alltoall_base",
|
|
957
|
+
"reduce_scatter",
|
|
958
|
+
"reduce_scatter_tensor_coalesced",
|
|
959
|
+
]
|
|
960
|
+
COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP))
|
|
961
|
+
|
|
962
|
+
@parameterized.expand(COLLECTIVES)
|
|
963
|
+
def test_collective(self, collective: str) -> None:
|
|
964
|
+
self._run_parallel(collective, device="cpu")
|
|
965
|
+
|
|
966
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
|
967
|
+
@skipUnless(
|
|
968
|
+
torch.__version__ >= "2.7",
|
|
969
|
+
"torch 2.6 has a bug with destructing PyWork objects",
|
|
970
|
+
)
|
|
971
|
+
@parameterized.expand(COLLECTIVES)
|
|
972
|
+
def test_collective_with_resiliency(self, collective: str) -> None:
|
|
973
|
+
self._run_with_resiliency(collective, device="cpu")
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
@skipIf(sys.platform == "darwin", "not reliable on mac")
|
|
977
|
+
class BabyGlooMultiPgTest(MultiPgBaseTest):
|
|
978
|
+
BACKEND = "baby_gloo"
|
|
979
|
+
WORLD_SIZE = 3
|
|
980
|
+
SKIP = [
|
|
981
|
+
"alltoall_base",
|
|
982
|
+
"reduce_scatter",
|
|
983
|
+
"reduce_scatter_tensor_coalesced",
|
|
984
|
+
]
|
|
985
|
+
COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP))
|
|
986
|
+
|
|
987
|
+
@parameterized.expand(COLLECTIVES)
|
|
988
|
+
def test_collective(self, collective: str) -> None:
|
|
989
|
+
self._run_parallel(collective, device="cpu")
|
|
990
|
+
|
|
991
|
+
@parameterized.expand(COLLECTIVES)
|
|
992
|
+
def test_collective_with_resiliency(self, collective: str) -> None:
|
|
993
|
+
self._run_with_resiliency(collective, device="cpu")
|
|
994
|
+
|
|
995
|
+
|
|
996
|
+
@skipUnless(
|
|
997
|
+
torch.cuda.is_available() and torch.cuda.device_count() >= 2, "needs 2 CUDA devices"
|
|
998
|
+
)
|
|
999
|
+
class BabyNcclMultiPgTest(MultiPgBaseTest):
|
|
1000
|
+
BACKEND = "baby_nccl"
|
|
1001
|
+
WORLD_SIZE = 2
|
|
1002
|
+
|
|
1003
|
+
@parameterized.expand(_ALL_COLLECTIVES)
|
|
1004
|
+
def test_collective(self, collective: str) -> None:
|
|
1005
|
+
self._run_parallel(collective, device="cuda")
|
|
1006
|
+
|
|
1007
|
+
# @parameterized.expand(_ALL_COLLECTIVES)
|
|
1008
|
+
# def test_collective_with_resiliency(self, collective: str) -> None:
|
|
1009
|
+
# self._run_with_resiliency(collective, device="cuda")
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
@skipUnless(
|
|
1013
|
+
torch.cuda.is_available()
|
|
1014
|
+
and torch.cuda.device_count() >= 2
|
|
1015
|
+
and torch.cuda.nccl.version() >= (2, 25),
|
|
1016
|
+
"needs 2 CUDA devices and NCCL >=2.25",
|
|
1017
|
+
)
|
|
1018
|
+
class NormalNcclMultiPgTest(MultiPgBaseTest):
|
|
1019
|
+
BACKEND = "nccl"
|
|
1020
|
+
WORLD_SIZE = 2
|
|
1021
|
+
|
|
1022
|
+
@parameterized.expand(_ALL_COLLECTIVES)
|
|
1023
|
+
def test_collective(self, collective: str) -> None:
|
|
1024
|
+
self._run_parallel(collective, device="cuda")
|
|
1025
|
+
|
|
1026
|
+
@parameterized.expand(_ALL_COLLECTIVES)
|
|
1027
|
+
def test_collective_with_resiliency(self, collective: str) -> None:
|
|
1028
|
+
self._run_with_resiliency(collective, device="cuda")
|