torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,1028 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import gc
8
+ import os
9
+ import sys
10
+ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
11
+ from datetime import timedelta
12
+ from typing import Any, Callable, cast, Dict, List
13
+ from unittest import skipIf, skipUnless, TestCase
14
+ from unittest.mock import Mock, patch
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ from parameterized import parameterized
19
+ from torch import nn
20
+ from torch._C._distributed_c10d import (
21
+ _resolve_process_group,
22
+ AllgatherOptions,
23
+ AllreduceCoalescedOptions,
24
+ AllreduceOptions,
25
+ AllToAllOptions,
26
+ BarrierOptions,
27
+ BroadcastOptions,
28
+ ReduceOp,
29
+ ReduceScatterOptions,
30
+ )
31
+ from torch.distributed import (
32
+ _functional_collectives,
33
+ get_world_size,
34
+ ReduceOp,
35
+ TCPStore,
36
+ )
37
+ from torchft.manager import Manager
38
+ from torchft.process_group import (
39
+ _ErrorSwallowingWork,
40
+ ErrorSwallowingProcessGroupWrapper,
41
+ ManagedProcessGroup,
42
+ ProcessGroup,
43
+ ProcessGroupBabyGloo,
44
+ ProcessGroupBabyNCCL,
45
+ ProcessGroupDummy,
46
+ ProcessGroupGloo,
47
+ ProcessGroupNCCL,
48
+ ProcessGroupWrapper,
49
+ )
50
+ from torchft.work import _DummyWork
51
+
52
+
53
+ def dummy_init_pg() -> None:
54
+ if not dist.is_initialized():
55
+ dist.init_process_group(
56
+ backend="gloo", rank=0, world_size=1, store=dist.HashStore()
57
+ )
58
+
59
+
60
+ def _test_pg(
61
+ pg: ProcessGroup,
62
+ example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
63
+ skip: list[str] = [],
64
+ ) -> Dict[str, dist._Work]:
65
+ """
66
+ Helper function to test a set of collective operations on a given process group.
67
+ """
68
+
69
+ shape: torch.Size = example_tensor.shape
70
+ dtype: torch.dtype = example_tensor.dtype
71
+
72
+ # Create some dummy tensors for testing
73
+ input_tensor = example_tensor.clone()
74
+ output_tensors = [
75
+ [torch.empty_like(input_tensor) for _ in range(get_world_size(pg))]
76
+ ]
77
+ tensor_list = [torch.empty_like(input_tensor)]
78
+
79
+ def check_tensors(arg: object) -> None:
80
+ """Recursively check tensors for expected shape and dtype."""
81
+ if isinstance(arg, torch.Tensor):
82
+ assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}"
83
+ assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}"
84
+ elif isinstance(arg, (list, tuple)):
85
+ for item in arg:
86
+ check_tensors(item)
87
+
88
+ # Test collectives. send/recv require multiple processes to test, so we skip them here
89
+ collectives = [
90
+ ("allreduce", ([input_tensor], AllreduceOptions())),
91
+ ("allreduce", ([input_tensor], ReduceOp.SUM)),
92
+ ("allreduce_coalesced", ([input_tensor], AllreduceCoalescedOptions())),
93
+ ("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
94
+ (
95
+ "allgather_into_tensor_coalesced",
96
+ (output_tensors[0], [input_tensor], AllgatherOptions()),
97
+ ),
98
+ (
99
+ "alltoall_base",
100
+ (
101
+ output_tensors[0][0],
102
+ input_tensor,
103
+ [input_tensor.shape[0]],
104
+ [input_tensor.shape[0]],
105
+ AllToAllOptions(),
106
+ ),
107
+ ),
108
+ ("barrier", (BarrierOptions(),)),
109
+ ("broadcast", (tensor_list, BroadcastOptions())),
110
+ ("broadcast_one", (input_tensor, 0)),
111
+ (
112
+ "reduce_scatter",
113
+ (output_tensors[0], [[input_tensor]], ReduceScatterOptions()),
114
+ ),
115
+ (
116
+ "reduce_scatter_tensor_coalesced",
117
+ (output_tensors[0], [input_tensor], ReduceScatterOptions()),
118
+ ),
119
+ ]
120
+ works: Dict[str, dist._Work] = {}
121
+
122
+ for coll_str, args in collectives:
123
+ if coll_str in skip:
124
+ continue
125
+ try:
126
+ coll = getattr(pg, coll_str)
127
+ work = coll(*args)
128
+ works[coll_str] = work
129
+ work.wait()
130
+ fut = work.get_future()
131
+ fut.wait()
132
+ # Check that all tensor arguments have the expected shapes and dtypes
133
+ check_tensors(args)
134
+ except RuntimeError as e:
135
+ if f"does not support {coll_str}" in str(e):
136
+ # Skip collectives that are not supported by the backend.
137
+ continue
138
+ raise e
139
+
140
+ return works
141
+
142
+
143
+ def run_allgather_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
144
+ """Test allgather collective operation.
145
+
146
+ Suppose each rank's local tensor = [rank+1, rank+2],
147
+ we allgather => gather onto a list of length world_sz.
148
+ """
149
+ world_sz = pg.size()
150
+ to_gather = torch.stack([tensor, tensor + 1], dim=0)
151
+ # shape: (2,)
152
+ to_gather = to_gather.reshape(-1)
153
+
154
+ # Gathers as follows: [ [ recv0 ], [ recv1 ], ... [ recv_{sz-1} ] ]
155
+ # Each recv is shape (2,)
156
+ output_list = [
157
+ torch.zeros(2, device=tensor.device, dtype=tensor.dtype)
158
+ for _ in range(world_sz)
159
+ ]
160
+
161
+ work = pg.allgather([output_list], [to_gather], AllgatherOptions())
162
+ work.wait()
163
+
164
+ for r in range(world_sz):
165
+ expected = torch.tensor(
166
+ [r + 1, r + 2], device=tensor.device, dtype=tensor.dtype
167
+ )
168
+ torch.testing.assert_close(output_list[r], expected)
169
+
170
+
171
+ def run_allgather_into_tensor_coalesced_test(
172
+ pg: ProcessGroup, rank: int, tensor: torch.Tensor
173
+ ) -> None:
174
+ """Test allgather tensor coalesced collective operation.
175
+
176
+ This example gathers two local tensors, T0 and T1, from each rank into corresponding
177
+ output tensors.
178
+
179
+ For world_sz = n, each rank r has:
180
+ T0 = [r+1],
181
+ T1 = [r+10]
182
+
183
+ After allgather_into_tensor_coalesced, we result in two tensors: out0, out1,
184
+ both length n.
185
+
186
+ out0 gathers T0 from all ranks, out1 gathers T1 from all ranks.
187
+
188
+ We verify that out0[k] == [k+1] and out1[k] == [k+10] for all k.
189
+
190
+ """
191
+ world_sz = pg.size()
192
+
193
+ if world_sz < 2:
194
+ return
195
+
196
+ t0 = torch.tensor([rank + 1], device=tensor.device, dtype=tensor.dtype)
197
+ t1 = torch.tensor([rank + 10], device=tensor.device, dtype=tensor.dtype)
198
+
199
+ out0 = torch.zeros(world_sz, device=tensor.device, dtype=tensor.dtype)
200
+ out1 = torch.zeros(world_sz, device=tensor.device, dtype=tensor.dtype)
201
+
202
+ work = pg.allgather_into_tensor_coalesced(
203
+ [out0, out1], [t0, t1], AllgatherOptions()
204
+ )
205
+ work.wait()
206
+
207
+ for r in range(world_sz):
208
+ expected0 = torch.tensor([r + 1], device=t0.device, dtype=t0.dtype)
209
+ torch.testing.assert_close(out0[r], expected0[0])
210
+ expected1 = torch.tensor([r + 10], device=t1.device, dtype=t1.dtype)
211
+ torch.testing.assert_close(out1[r], expected1[0])
212
+
213
+
214
+ def run_allreduce_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
215
+ """Test allreduce collective operation.
216
+
217
+ Assume each rank's tensor has value = rank + 1.
218
+ The final result after allreduce(SUM) should be sum(r=1,...,world_sz-1).
219
+ """
220
+ tc = tensor.clone()
221
+ world_sz = pg.size()
222
+ work = pg.allreduce([tc], ReduceOp.SUM)
223
+ work.wait()
224
+ expected_val = sum(r + 1 for r in range(world_sz))
225
+ torch.testing.assert_close(tc, torch.tensor([expected_val], device=tensor.device))
226
+
227
+
228
+ def run_allreduce_coalesced_test(
229
+ pg: ProcessGroup, rank: int, tensor: torch.Tensor
230
+ ) -> None:
231
+ """Test allreduce_coalesced collective operation.
232
+
233
+ Assume each rank's tensor has value = rank + 1.
234
+ We coalesce 1 tensors:
235
+ - t0 = [rank + 1]
236
+ - t1 = [rank + 2]
237
+
238
+ Our final sum should be sum(r=1,...,world_sz-1) + sum(r=2,...,world_sz-1).
239
+ """
240
+ world_sz = pg.size()
241
+ t0 = tensor.clone()
242
+ t1 = tensor.clone() + 1
243
+ work = pg.allreduce_coalesced([t0, t1], AllreduceCoalescedOptions())
244
+ work.wait()
245
+ sum_t0 = sum(r + 1 for r in range(world_sz))
246
+ sum_t1 = sum(r + 2 for r in range(world_sz))
247
+ torch.testing.assert_close(t0, torch.tensor([sum_t0], device=t0.device))
248
+ torch.testing.assert_close(t1, torch.tensor([sum_t1], device=t1.device))
249
+
250
+
251
+ def run_alltoall_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
252
+ """Test all-to-all collective operation.
253
+
254
+ Suppose each rank's local tensor = [rank*ws+1, rank*ws+2, ..., rank*ws + n]
255
+
256
+ e.g.:
257
+ rank=0 => [1,2]
258
+ rank=1 => [3,4]
259
+
260
+ After all-to-all, rank r's output[k] = the element from rank k that is destined for rank r,
261
+ e.g.: (k*n) + (r+1):
262
+
263
+ rank=0 => [1,3]
264
+ rank=1 => [2,4]
265
+
266
+ """
267
+ world_sz = pg.size()
268
+ if world_sz < 2:
269
+ return
270
+
271
+ input_tensor = torch.arange(
272
+ start=rank * world_sz + 1,
273
+ end=rank * world_sz + 1 + world_sz,
274
+ device=tensor.device,
275
+ dtype=tensor.dtype,
276
+ )
277
+ output_tensor = torch.empty(world_sz, device=tensor.device, dtype=tensor.dtype)
278
+
279
+ send_sz = [1] * world_sz
280
+ recv_sz = [1] * world_sz
281
+
282
+ alltoall_work = pg.alltoall_base(
283
+ output_tensor, input_tensor, send_sz, recv_sz, AllToAllOptions()
284
+ )
285
+ alltoall_work.wait()
286
+
287
+ expected = torch.empty(world_sz, device=tensor.device, dtype=tensor.dtype)
288
+ for k in range(world_sz):
289
+ val = k * world_sz + (rank + 1)
290
+ expected[k] = val
291
+
292
+ torch.testing.assert_close(output_tensor, expected)
293
+
294
+
295
+ def run_broadcast_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
296
+ """Test broadcast collective operation.
297
+
298
+ rank0 will broadcast a known value and all other ranks should get it.
299
+ """
300
+ broadcast_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor)
301
+ broadcast_work = pg.broadcast([broadcast_tensor], BroadcastOptions())
302
+ broadcast_work.wait()
303
+ expected_broadcast = torch.tensor([1], device=tensor.device)
304
+ torch.testing.assert_close(broadcast_tensor, expected_broadcast)
305
+
306
+
307
+ def run_broadcast_one_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
308
+ """Test broadcast_one collective operation.
309
+
310
+ rank0 will broadcast a known value and all other ranks should get it.
311
+ """
312
+ broadcast_one_tensor = tensor.clone() if rank == 0 else torch.zeros_like(tensor)
313
+ broadcast_one_work = pg.broadcast_one(broadcast_one_tensor, 0)
314
+ broadcast_one_work.wait()
315
+ torch.testing.assert_close(
316
+ broadcast_one_tensor, torch.tensor([1], device=tensor.device)
317
+ )
318
+
319
+
320
+ def run_barrier_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
321
+ """Test barrier collective operation."""
322
+ opts = BarrierOptions()
323
+ if tensor.is_cuda:
324
+ device_id = tensor.device.index
325
+ opts.device_ids = [device_id]
326
+ barrier_work = pg.barrier(opts)
327
+ barrier_work.wait()
328
+
329
+
330
+ def run_send_recv_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
331
+ """Test send/recv point-to-point operations.
332
+
333
+ Simple point-to-point between ranks 0 and 1, ignored for other ranks.
334
+ """
335
+ if pg.size() < 2:
336
+ return
337
+ if rank == 0:
338
+ send_tensor = tensor.clone()
339
+ send_work = pg.send([send_tensor], 1, 0)
340
+ send_work.wait()
341
+ elif rank == 1:
342
+ recv_tensor = torch.zeros_like(tensor)
343
+ recv_work = pg.recv([recv_tensor], 0, 0)
344
+ recv_work.wait()
345
+ expected = torch.tensor([1], device=tensor.device)
346
+ torch.testing.assert_close(recv_tensor, expected)
347
+
348
+
349
+ def run_reduce_scatter_test(pg: ProcessGroup, rank: int, tensor: torch.Tensor) -> None:
350
+ """Test reduce_scatter collective operation.
351
+
352
+ Assume each rank creates a matrix where each row r contains values:
353
+ [r * world_sz + 1, ..., r * world_sz + world_sz]
354
+
355
+ For example, with world_size=2:
356
+ [[1, 2],
357
+ [3, 4]]
358
+
359
+ The reduce_scatter operation then:
360
+ - Reduces (sums) corresponding rows across all ranks
361
+ - Scatters the results so each rank gets one row of the final sum
362
+ - Since all ranks had the same initial data, the expected result for each rank r is:
363
+ rank r receives: [rworld_sz + 1, ..., rworld_sz + world_sz] * world_sz
364
+
365
+ For example, with 2 ranks:
366
+ rank 0 gets: [1, 2] * 2 = [2, 4] (first row)
367
+ rank 1 gets: [3, 4] * 2 = [6, 8] (second row)
368
+ """
369
+ if tensor.device.type == "cpu":
370
+ return
371
+ # reduce scatter not supported on GLOO
372
+ world_sz = pg.size()
373
+ if world_sz < 2:
374
+ return
375
+
376
+ local_data = []
377
+ for r in range(world_sz):
378
+ row_vals = torch.arange(
379
+ start=r * world_sz + 1,
380
+ end=r * world_sz + world_sz + 1,
381
+ device=tensor.device,
382
+ dtype=torch.float32,
383
+ )
384
+ local_data.append(row_vals)
385
+
386
+ out = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32)
387
+ opts = ReduceScatterOptions()
388
+ opts.reduceOp = ReduceOp.SUM
389
+ work = pg.reduce_scatter([out], [local_data], opts)
390
+ work.wait()
391
+
392
+ expected_row = torch.arange(
393
+ start=rank * world_sz + 1,
394
+ end=rank * world_sz + world_sz + 1,
395
+ device=tensor.device,
396
+ dtype=torch.float32,
397
+ )
398
+ expected_sum = expected_row * world_sz
399
+ torch.testing.assert_close(out, expected_sum)
400
+
401
+
402
+ def run_reduce_scatter_tensor_coalesced_test(
403
+ pg: ProcessGroup, rank: int, tensor: torch.Tensor
404
+ ) -> None:
405
+ """Test reduce_scatter tensor coalesced collective operation.
406
+
407
+ We define two 2D tensors, each shaped [world_sz, world_sz] which is replicated on each rank.
408
+
409
+ reduce_scatter coalesced will reduce each row of each tensor, then scatter the results to each rank.
410
+ Because these are replicated on all ranks, the reduced sum for each row is:
411
+ [r*world_sz + 1, ..., r*world_sz + world_sz] * world_sz
412
+
413
+ For example, with 2 ranks:
414
+ rank 0 gets: [1, 2] * 2 = [2, 4] (first row)
415
+ rank 1 gets: [3, 4] * 2 = [6, 8] (second row)
416
+ For example, with 2 ranks:
417
+ rank 0 gets: [1, 2] * 2 = [2, 4] (first row)
418
+ rank 1 gets: [3, 4] * 2 = [6, 8] (second row)
419
+
420
+ """
421
+ world_sz = pg.size()
422
+ if world_sz < 2:
423
+ return # skip trivial
424
+
425
+ # Build m0, m1 (each is a list of n rows) fully replicated on all ranks
426
+ m0 = []
427
+ m1 = []
428
+ for r in range(world_sz):
429
+ row0 = torch.arange(
430
+ start=r * world_sz + 1,
431
+ end=r * world_sz + world_sz + 1,
432
+ device=tensor.device,
433
+ dtype=torch.float32,
434
+ )
435
+ row1 = torch.arange(
436
+ start=r * world_sz + 100,
437
+ end=r * world_sz + 100 + world_sz,
438
+ device=tensor.device,
439
+ dtype=torch.float32,
440
+ )
441
+ m0.append(row0)
442
+ m1.append(row1)
443
+
444
+ # Each rank receives one "row" for m0, one row for m1, after reduce_scatter_coalesced
445
+ out0 = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32)
446
+ out1 = torch.zeros(world_sz, device=tensor.device, dtype=torch.float32)
447
+
448
+ opts = ReduceScatterOptions()
449
+ opts.reduceOp = ReduceOp.SUM
450
+
451
+ m0 = torch.stack(m0)
452
+ m1 = torch.stack(m1)
453
+
454
+ work = pg.reduce_scatter_tensor_coalesced([out0, out1], [m0, m1], opts)
455
+ work.wait()
456
+
457
+ base0 = (
458
+ torch.arange(
459
+ start=rank * world_sz + 1,
460
+ end=rank * world_sz + world_sz + 1,
461
+ device=tensor.device,
462
+ dtype=torch.float32,
463
+ )
464
+ * world_sz
465
+ )
466
+ base1 = (
467
+ torch.arange(
468
+ start=rank * world_sz + 100,
469
+ end=rank * world_sz + 100 + world_sz,
470
+ device=tensor.device,
471
+ dtype=torch.float32,
472
+ )
473
+ * world_sz
474
+ )
475
+
476
+ torch.testing.assert_close(out0, base0)
477
+ torch.testing.assert_close(out1, base1)
478
+
479
+
480
+ _COLLECTIVE_TO_FUNC: Dict[str, Callable[[ProcessGroup, int, torch.Tensor], None]] = {
481
+ "allgather": run_allgather_test,
482
+ "allgather_into_tensor_coalesced": run_allgather_into_tensor_coalesced_test,
483
+ "allreduce": run_allreduce_test,
484
+ "allreduce_coalesced": run_allreduce_coalesced_test,
485
+ "alltoall_base": run_alltoall_test,
486
+ "barrier": run_barrier_test,
487
+ "broadcast": run_broadcast_test,
488
+ "broadcast_one": run_broadcast_one_test,
489
+ "reduce_scatter": run_reduce_scatter_test,
490
+ "reduce_scatter_tensor_coalesced": run_reduce_scatter_tensor_coalesced_test,
491
+ "send/recv": run_send_recv_test,
492
+ }
493
+ _ALL_COLLECTIVES: List[str] = list(_COLLECTIVE_TO_FUNC.keys())
494
+
495
+
496
+ class ProcessGroupTest(TestCase):
497
+ @parameterized.expand(["cpu", "cuda"])
498
+ def test_gloo_apis(self, device: str) -> None:
499
+ if device == "cuda" and not torch.cuda.is_available():
500
+ self.skipTest("CUDA is not available")
501
+ return
502
+
503
+ store = TCPStore(
504
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
505
+ )
506
+
507
+ store_addr = f"localhost:{store.port}/prefix"
508
+ pg = ProcessGroupGloo()
509
+ pg.configure(store_addr, "0", 0, 1)
510
+
511
+ self.assertEqual(pg.size(), 1)
512
+
513
+ _test_pg(
514
+ pg,
515
+ torch.tensor([2], device=device),
516
+ skip=(
517
+ # https://github.com/pytorch/pytorch/issues/152645
518
+ [
519
+ "allreduce_coalesced",
520
+ "allgather_into_tensor_coalesced",
521
+ ]
522
+ if device == "cuda"
523
+ else []
524
+ ),
525
+ )
526
+
527
+ m = nn.Linear(3, 4).to(device)
528
+ m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
529
+ m(torch.rand(2, 3, device=device))
530
+
531
+ def test_gloo_timeout(self) -> None:
532
+ store = TCPStore(
533
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
534
+ )
535
+
536
+ store_addr = f"localhost:{store.port}/prefix"
537
+ pg = ProcessGroupGloo(timeout=timedelta(seconds=0.01))
538
+ with self.assertRaisesRegex(
539
+ RuntimeError, "(timeout after 10ms|Socket Timeout)"
540
+ ):
541
+ pg.configure(store_addr, "0", 0, 2)
542
+
543
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
544
+ @skipUnless(torch.cuda.is_available(), "needs CUDA")
545
+ def test_nccl_apis(self) -> None:
546
+ store = TCPStore(
547
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
548
+ )
549
+ device = "cuda"
550
+
551
+ store_addr = f"localhost:{store.port}/prefix"
552
+ pg = ProcessGroupNCCL()
553
+ pg.configure(store_addr, "0", 0, 1)
554
+
555
+ self.assertEqual(pg.size(), 1)
556
+
557
+ _test_pg(pg, torch.tensor([2], device=device))
558
+
559
+ m = nn.Linear(3, 4).to(device)
560
+ m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
561
+ m(torch.rand(2, 3, device=device))
562
+
563
+ # reconfigure
564
+ store_addr = f"localhost:{store.port}/prefix2"
565
+ pg.configure(store_addr, "0", 0, 1)
566
+
567
+ _test_pg(pg, torch.tensor([2], device=device))
568
+
569
+ torch.cuda.synchronize()
570
+
571
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
572
+ @skipUnless(
573
+ torch.cuda.is_available() and torch.cuda.nccl.version() >= (2, 25),
574
+ "needs NCCL >=2.25",
575
+ )
576
+ @patch("torchft.process_group.stream_timeout", autospec=True)
577
+ @patch("torchft.process_group.context_timeout", autospec=True)
578
+ def test_nccl_timeouts(
579
+ self, mock_context_timeout: Mock, mock_stream_timeout: Mock
580
+ ) -> None:
581
+ store = TCPStore(
582
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
583
+ )
584
+ device = "cuda"
585
+
586
+ store_addr = f"localhost:{store.port}/prefix"
587
+ pg = ProcessGroupNCCL()
588
+ pg.configure(store_addr, "0", 0, 1)
589
+
590
+ t = torch.tensor([2], device=device)
591
+ pg.allreduce([t], ReduceOp.SUM).wait()
592
+ self.assertEqual(mock_stream_timeout.call_count, 1)
593
+ self.assertEqual(mock_context_timeout.return_value.__enter__.call_count, 2)
594
+
595
+ pg.allreduce([t], ReduceOp.SUM).get_future().wait()
596
+ self.assertEqual(mock_stream_timeout.call_count, 2)
597
+ self.assertEqual(mock_context_timeout.return_value.__enter__.call_count, 4)
598
+
599
+ # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
600
+ @skipUnless(
601
+ torch.cuda.is_available(),
602
+ "needs CUDA",
603
+ )
604
+ def test_nccl_init_timeout(self) -> None:
605
+ store = TCPStore(
606
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
607
+ )
608
+ store_addr = f"localhost:{store.port}/prefix"
609
+ del store
610
+
611
+ pg = ProcessGroupNCCL(timeout=timedelta(seconds=0.01))
612
+
613
+ with self.assertRaisesRegex(RuntimeError, "timed out after 10ms"):
614
+ pg.configure(store_addr, "0", 0, 2)
615
+
616
+ def test_baby_gloo_timeout(self) -> None:
617
+ store = TCPStore(
618
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
619
+ )
620
+
621
+ store_addr = f"localhost:{store.port}/prefix"
622
+
623
+ a = ProcessGroupBabyGloo(timeout=timedelta(seconds=0.01))
624
+ with self.assertRaisesRegex(TimeoutError, "timed out after 0.01 seconds"):
625
+ a.configure(store_addr, "0", 0, 2)
626
+
627
+ def test_reconfigure_baby_process_group(self) -> None:
628
+ store = TCPStore(
629
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
630
+ )
631
+ store_addr = f"localhost:{store.port}/prefix"
632
+
633
+ a = ProcessGroupBabyGloo()
634
+ a.configure(store_addr, "0", 0, 1)
635
+ future_thread_1 = a._future_thread
636
+ future_pipe_1 = a._future_pipe
637
+ p_1 = a._p
638
+
639
+ store_addr = f"localhost:{store.port}/prefix2"
640
+ a.configure(store_addr, "0", 0, 1)
641
+ future_thread_2 = a._future_thread
642
+ future_pipe_2 = a._future_pipe
643
+ p_2 = a._p
644
+
645
+ self.assertNotEqual(future_thread_1, future_thread_2)
646
+ self.assertNotEqual(future_pipe_1, future_pipe_2)
647
+ self.assertNotEqual(p_1, p_2)
648
+
649
+ assert future_thread_1 is not None
650
+ self.assertFalse(future_thread_1.is_alive())
651
+ assert future_pipe_1 is not None
652
+ self.assertTrue(future_pipe_1.closed())
653
+ assert p_1 is not None
654
+ self.assertFalse(p_1.is_alive())
655
+
656
+ assert future_thread_2 is not None
657
+ self.assertTrue(future_thread_2.is_alive())
658
+ assert future_pipe_2 is not None
659
+ self.assertFalse(future_pipe_2.closed())
660
+ assert p_2 is not None
661
+ self.assertTrue(p_2.is_alive())
662
+
663
+ def test_baby_gloo_apis(self) -> None:
664
+ store = TCPStore(
665
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
666
+ )
667
+
668
+ store_addr = f"localhost:{store.port}/prefix"
669
+
670
+ a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
671
+ try:
672
+ a.configure(store_addr, "0", 0, 1)
673
+
674
+ _test_pg(a)
675
+
676
+ # force collection to ensure no BabyWork objects remain
677
+ gc.collect()
678
+
679
+ self.assertEqual(a.num_active_work(), 0)
680
+
681
+ finally:
682
+ a.shutdown()
683
+
684
+ t = torch.zeros(10)
685
+ with self.assertRaisesRegex(OSError, "handle is closed"):
686
+ a.allreduce([t], AllreduceOptions()).wait()
687
+
688
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
689
+ @skipUnless(torch.cuda.is_available(), "needs CUDA")
690
+ def test_baby_nccl_apis(self) -> None:
691
+ # set to 1 if more than >=2 gpus
692
+ device_id = 1 % torch.cuda.device_count()
693
+ torch.cuda.set_device(device_id)
694
+
695
+ store = TCPStore(
696
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
697
+ )
698
+
699
+ store_addr = f"localhost:{store.port}/prefix"
700
+
701
+ a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
702
+ try:
703
+ a.configure(store_addr, "0", 0, 1)
704
+
705
+ _test_pg(a, torch.randn((2, 3), device="cuda"))
706
+
707
+ torch.cuda.synchronize()
708
+
709
+ # force collection to ensure no BabyWork objects remain
710
+ gc.collect()
711
+
712
+ self.assertEqual(a.num_active_work(), 0)
713
+ finally:
714
+ a.shutdown()
715
+ torch.cuda.synchronize()
716
+ torch.cuda.empty_cache()
717
+
718
+ t = torch.zeros(10)
719
+ with self.assertRaisesRegex(OSError, "handle is closed"):
720
+ a.allreduce([t], AllreduceOptions()).wait()
721
+
722
+ def test_dummy(self) -> None:
723
+ pg = ProcessGroupDummy(0, 1)
724
+ m = nn.Linear(3, 4)
725
+ m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
726
+ m(torch.rand(2, 3))
727
+
728
+ def test_functional_collectives(self) -> None:
729
+ dummy_init_pg()
730
+
731
+ store = TCPStore(
732
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
733
+ )
734
+ store_addr = f"localhost:{store.port}/prefix"
735
+
736
+ pg = ProcessGroupGloo().register("test_func_col")
737
+ pg.configure(store_addr, "0", 0, 1)
738
+
739
+ self.assertEqual(pg.group_name, str(dist.get_pg_count() - 1))
740
+
741
+ self.assertIs(
742
+ _resolve_process_group(pg.group_name), # pyre-ignore[6]: GroupName vs str
743
+ pg,
744
+ )
745
+
746
+ try:
747
+ t = torch.zeros(10)
748
+ _functional_collectives.all_reduce(t, "sum", pg).wait()
749
+ finally:
750
+ pg.unregister()
751
+
752
+ def test_process_group_wrapper(self) -> None:
753
+ pg = ProcessGroupDummy(0, 1)
754
+ wrapper = ProcessGroupWrapper(pg=pg)
755
+ self.assertIs(wrapper.parent, pg)
756
+
757
+ wrapper.configure("addr", "0", 0, 1)
758
+ self.assertEqual(pg.configure_count, 1)
759
+
760
+ self.assertEqual(repr(wrapper), "ProcessGroupWrapper(pg=ProcessGroupDummy())")
761
+
762
+ def test_error_swallowing_process_group_wrapper(self) -> None:
763
+ pg = ProcessGroupDummy(0, 1)
764
+ wrapper = ErrorSwallowingProcessGroupWrapper(pg)
765
+ self.assertIs(wrapper.parent, pg)
766
+
767
+ works = _test_pg(wrapper)
768
+ self.assertIsInstance(list(works.values())[0], _ErrorSwallowingWork)
769
+
770
+ err = RuntimeError("test")
771
+ wrapper.report_error(err)
772
+ self.assertEqual(wrapper.error(), err)
773
+
774
+ works = _test_pg(wrapper)
775
+ for work in works.values():
776
+ self.assertIsInstance(work, _DummyWork)
777
+
778
+ def test_managed_process_group(self) -> None:
779
+ manager = Mock(spec=Manager)
780
+ manager.errored.return_value = None
781
+ manager._pg = ProcessGroupDummy(0, 1)
782
+ pg = ManagedProcessGroup(manager)
783
+ manager.num_participants.return_value = 123
784
+
785
+ self.assertEqual(pg.size(), 123)
786
+
787
+ works = _test_pg(pg)
788
+
789
+ self.assertEqual(manager.allreduce.call_count, 2)
790
+
791
+
792
+ class MultiPgBaseTest(TestCase):
793
+ """
794
+ A base test that creates N processes (via ThreadPoolExecutor) sharing
795
+ a single ProcessGroup. Each test_* method will reuse the same PG.
796
+
797
+ Subclasses can specify:
798
+ - BACKEND: the backend to use for the ProcessGroup ("gloo" or "nccl")
799
+ - WORLD_SIZE: how many ranks to simulate
800
+ - Additional config for the PG, i.e. timeouts.
801
+ """
802
+
803
+ BACKEND = "gloo"
804
+ WORLD_SIZE = 2
805
+
806
+ @classmethod
807
+ def setUpClass(cls) -> None:
808
+ super().setUpClass()
809
+
810
+ cls.store = TCPStore(
811
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
812
+ )
813
+ cls.store_addr = f"localhost:{cls.store.port}/prefix"
814
+
815
+ cls.pg_pool: List[ProcessGroup] = []
816
+
817
+ cls.executor = ThreadPoolExecutor(max_workers=cls.WORLD_SIZE)
818
+
819
+ def init_pg(rank: int) -> ProcessGroup:
820
+ if torch.accelerator.is_available():
821
+ torch.accelerator.set_device_idx(rank)
822
+ pg = cls._create_pg(cls.BACKEND)
823
+ pg.configure(cls.store_addr, "0", rank, cls.WORLD_SIZE)
824
+ return pg
825
+
826
+ futures = [cls.executor.submit(init_pg, rank) for rank in range(cls.WORLD_SIZE)]
827
+ cls.pg_pool = [future.result() for future in futures]
828
+
829
+ @classmethod
830
+ def tearDownClass(cls) -> None:
831
+ # Cleanup
832
+ for pg in cls.pg_pool:
833
+ shutdown = getattr(pg, "shutdown", None)
834
+ if shutdown is not None:
835
+ shutdown()
836
+ cls.executor.shutdown(wait=True)
837
+ super().tearDownClass()
838
+
839
+ @classmethod
840
+ def _create_pg(cls, backend: str) -> ProcessGroup:
841
+ """
842
+ Helper that creates a new ProcessGroup of the specified type.
843
+
844
+ NCCL groups aren't currently supported - we prefer to test
845
+ BabyNCCLGroups as they spin up their own subprocesses.
846
+ """
847
+ if backend == "gloo":
848
+ return ProcessGroupGloo(timeout=timedelta(seconds=1))
849
+ elif backend == "baby_gloo":
850
+ return ProcessGroupBabyGloo(timeout=timedelta(seconds=10))
851
+ elif backend == "nccl":
852
+ return ProcessGroupNCCL(timeout=timedelta(seconds=10))
853
+ elif backend == "baby_nccl":
854
+ return ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
855
+ elif backend == "dummy":
856
+ return ProcessGroupDummy(0, 1)
857
+ else:
858
+ raise NotImplementedError(f"Unsupported backend: {backend}")
859
+
860
+ def _run_parallel(self, collective: str, device: str = "cpu") -> None:
861
+ """
862
+ Helper to run on all ranks in parallel, returning a list
863
+ of results or raising an exception if any fail.
864
+ """
865
+ func = _COLLECTIVE_TO_FUNC[collective]
866
+
867
+ futures = []
868
+ for rank in range(self.WORLD_SIZE):
869
+ pg = self.pg_pool[rank]
870
+ # Each worker calls `func(pg=pg, rank=rank, tensor=tensor, *args, **kwargs)`
871
+ if "cuda" in device:
872
+ device = f"cuda:{rank}"
873
+ tensor = torch.tensor([rank + 1], device=device)
874
+
875
+ fut = self.executor.submit(func, pg, rank, tensor)
876
+ futures.append(fut)
877
+
878
+ self._collect(futures)
879
+
880
+ def _collect(self, futs: list[Future]) -> None:
881
+ for i, f in enumerate(futs):
882
+ try:
883
+ res = f.result() # timeout=10)
884
+ if res:
885
+ print(f"Rank {i}: {res}")
886
+ except Exception as e:
887
+ print(f"Rank {i}: {e}")
888
+ raise
889
+
890
+ def _run_with_resiliency(self, collective: str, device: str = "cpu") -> None:
891
+ """
892
+ Run a collective with resiliency:
893
+ - fault_rank (last rank) simulates a crash.
894
+ - surviving ranks detect the error, then reconfigure PG to exclude fault_rank.
895
+ - surviving ranks run the same collective again successfully.
896
+ """
897
+
898
+ def worker(pg: ProcessGroup, rank: int, dev: str) -> str:
899
+ pg.set_timeout(timedelta(seconds=30))
900
+
901
+ if dev == "cuda":
902
+ torch.cuda.set_device(rank)
903
+ # Use a separate stream to avoid deadlocks between threads.
904
+ torch.cuda.set_stream(torch.cuda.Stream())
905
+
906
+ fault_rank = self.WORLD_SIZE - 1
907
+ test = _COLLECTIVE_TO_FUNC[collective]
908
+
909
+ # Re-configure the PG to exclude the fault rank
910
+ new_store_addr = f"localhost:{self.store.port}/reconfig_{collective}"
911
+
912
+ pg.configure(new_store_addr, "0", rank, self.WORLD_SIZE)
913
+
914
+ # run the same collective again successfully
915
+ t2 = torch.tensor([rank + 1], device=dev)
916
+ test(pg, rank, t2)
917
+
918
+ # Simulate a failure
919
+
920
+ t1 = torch.tensor([rank + 1], device=dev)
921
+ # Simulate failure on the fault rank, but other ranks should still succeed.
922
+ if rank == fault_rank:
923
+ pg.shutdown()
924
+ return f"Rank{rank} crashed"
925
+
926
+ pg.set_timeout(timedelta(seconds=1))
927
+
928
+ # We hardcode the list of expected errors.
929
+ # gloo: Connection closed by peer, timed out waiting, no error, read error
930
+ # nccl: Tensor-likes are not equal/not close (due to abort)
931
+ with self.assertRaisesRegex(
932
+ Exception,
933
+ r"(Connection closed by peer|timed out after|Timed out waiting|no error|Read error|not equal|not close|process group not initialized)",
934
+ ):
935
+ test(pg, rank, t1.clone())
936
+ raise RuntimeError("no error")
937
+
938
+ if err := pg.errored():
939
+ with self.assertRaisesRegex(RuntimeError, "aborted"):
940
+ raise err
941
+
942
+ return f"Rank{rank} final success."
943
+
944
+ # run in parallel
945
+ futs = [
946
+ self.executor.submit(worker, self.pg_pool[r], r, device)
947
+ for r in range(self.WORLD_SIZE)
948
+ ]
949
+ self._collect(futs)
950
+
951
+
952
+ class NormalGlooMultiPgTest(MultiPgBaseTest):
953
+ BACKEND = "gloo"
954
+ WORLD_SIZE = 3
955
+ SKIP = [
956
+ "alltoall_base",
957
+ "reduce_scatter",
958
+ "reduce_scatter_tensor_coalesced",
959
+ ]
960
+ COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP))
961
+
962
+ @parameterized.expand(COLLECTIVES)
963
+ def test_collective(self, collective: str) -> None:
964
+ self._run_parallel(collective, device="cpu")
965
+
966
+ # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
967
+ @skipUnless(
968
+ torch.__version__ >= "2.7",
969
+ "torch 2.6 has a bug with destructing PyWork objects",
970
+ )
971
+ @parameterized.expand(COLLECTIVES)
972
+ def test_collective_with_resiliency(self, collective: str) -> None:
973
+ self._run_with_resiliency(collective, device="cpu")
974
+
975
+
976
+ @skipIf(sys.platform == "darwin", "not reliable on mac")
977
+ class BabyGlooMultiPgTest(MultiPgBaseTest):
978
+ BACKEND = "baby_gloo"
979
+ WORLD_SIZE = 3
980
+ SKIP = [
981
+ "alltoall_base",
982
+ "reduce_scatter",
983
+ "reduce_scatter_tensor_coalesced",
984
+ ]
985
+ COLLECTIVES: List[str] = list(set(_ALL_COLLECTIVES) - set(SKIP))
986
+
987
+ @parameterized.expand(COLLECTIVES)
988
+ def test_collective(self, collective: str) -> None:
989
+ self._run_parallel(collective, device="cpu")
990
+
991
+ @parameterized.expand(COLLECTIVES)
992
+ def test_collective_with_resiliency(self, collective: str) -> None:
993
+ self._run_with_resiliency(collective, device="cpu")
994
+
995
+
996
+ @skipUnless(
997
+ torch.cuda.is_available() and torch.cuda.device_count() >= 2, "needs 2 CUDA devices"
998
+ )
999
+ class BabyNcclMultiPgTest(MultiPgBaseTest):
1000
+ BACKEND = "baby_nccl"
1001
+ WORLD_SIZE = 2
1002
+
1003
+ @parameterized.expand(_ALL_COLLECTIVES)
1004
+ def test_collective(self, collective: str) -> None:
1005
+ self._run_parallel(collective, device="cuda")
1006
+
1007
+ # @parameterized.expand(_ALL_COLLECTIVES)
1008
+ # def test_collective_with_resiliency(self, collective: str) -> None:
1009
+ # self._run_with_resiliency(collective, device="cuda")
1010
+
1011
+
1012
+ @skipUnless(
1013
+ torch.cuda.is_available()
1014
+ and torch.cuda.device_count() >= 2
1015
+ and torch.cuda.nccl.version() >= (2, 25),
1016
+ "needs 2 CUDA devices and NCCL >=2.25",
1017
+ )
1018
+ class NormalNcclMultiPgTest(MultiPgBaseTest):
1019
+ BACKEND = "nccl"
1020
+ WORLD_SIZE = 2
1021
+
1022
+ @parameterized.expand(_ALL_COLLECTIVES)
1023
+ def test_collective(self, collective: str) -> None:
1024
+ self._run_parallel(collective, device="cuda")
1025
+
1026
+ @parameterized.expand(_ALL_COLLECTIVES)
1027
+ def test_collective_with_resiliency(self, collective: str) -> None:
1028
+ self._run_with_resiliency(collective, device="cuda")