torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,212 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import unittest
8
+ from typing import Callable
9
+ from unittest import skipUnless, TestCase
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ from parameterized import parameterized
14
+ from torch import cuda
15
+ from torch.distributed import ReduceOp, ReduceScatterOptions
16
+
17
+ from torchft import _test_utils
18
+ from torchft.process_group import ProcessGroup
19
+ from torchft.process_group_test import MultiPgBaseTest
20
+
21
+ try:
22
+ # pyre-ignore[21]: Could not find a module corresponding to import `triton`
23
+ import triton
24
+ except ImportError:
25
+ pass
26
+ else:
27
+ from torchft.collectives import (
28
+ allocate_reduce_scatter_output,
29
+ allreduce_quantized,
30
+ get_padded_sizes,
31
+ reduce_scatter_quantized,
32
+ )
33
+
34
+ def _check_result_tolerance(
35
+ actual: torch.Tensor, expected: torch.Tensor, tolerance: float
36
+ ) -> None:
37
+ diff = torch.abs(
38
+ (expected - actual).div(expected.to(torch.float32) + 0.0000001)
39
+ )
40
+ mean_diff = diff.mean().item()
41
+
42
+ if mean_diff > tolerance:
43
+ print(f"Diff: {diff=}\n{expected=}\n{actual=}")
44
+ raise AssertionError(f"Results not within tolerance {tolerance}")
45
+
46
+ @skipUnless(
47
+ torch.cuda.is_available() and torch.cuda.device_count() >= 2,
48
+ "2 CUDA devices are required for this test",
49
+ )
50
+ class QuantizedAllReduceTest(MultiPgBaseTest):
51
+ BACKEND = "nccl"
52
+ WORLD_SIZE = 2
53
+
54
+ def _run_parallel_collectives(
55
+ self, collective: Callable[[ProcessGroup, int, str], None]
56
+ ) -> None:
57
+ futures = []
58
+ for rank in range(self.WORLD_SIZE):
59
+ pg = self.pg_pool[rank]
60
+ device = f"cuda:{rank}"
61
+ fut = self.executor.submit(collective, pg, rank, device)
62
+ futures.append(fut)
63
+
64
+ self._collect(futures)
65
+
66
+ def _run_all_reduce_collective(
67
+ self,
68
+ pg: ProcessGroup,
69
+ device: str,
70
+ tensors_num: int,
71
+ tensor_size: int,
72
+ multiplier: float,
73
+ tolerance: float,
74
+ reduce_op: ReduceOp,
75
+ dtype: torch.dtype,
76
+ ) -> None:
77
+ cuda.set_device(device)
78
+ inp = (
79
+ torch.rand(
80
+ tensors_num * tensor_size,
81
+ dtype=dtype,
82
+ device=device,
83
+ )
84
+ * multiplier
85
+ )
86
+ for split in _test_utils.gen_splits(inp, tensor_size):
87
+ actual = inp.clone()
88
+ expected = inp.clone()
89
+ tensors = [
90
+ i.view(*s)
91
+ for s, i in zip(
92
+ split,
93
+ torch.split(actual, tensor_size),
94
+ )
95
+ ]
96
+
97
+ work = allreduce_quantized(tensors, reduce_op, pg)
98
+ work.wait()
99
+
100
+ work = pg.allreduce([expected], reduce_op)
101
+ work.get_future().wait()
102
+
103
+ _check_result_tolerance(actual, expected, tolerance)
104
+
105
+ def _run_reduce_scatter_collective(
106
+ self,
107
+ pg: ProcessGroup,
108
+ device: str,
109
+ tensors_num: int,
110
+ tensor_size: int,
111
+ multiplier: float,
112
+ tolerance: float,
113
+ reduce_op: ReduceOp,
114
+ dtype: torch.dtype,
115
+ ) -> None:
116
+ cuda.set_device(device)
117
+ inp = (
118
+ torch.rand(
119
+ tensors_num * tensor_size,
120
+ dtype=dtype,
121
+ device=device,
122
+ )
123
+ * multiplier
124
+ )
125
+ world_size = pg.size()
126
+ for split in _test_utils.gen_splits(inp, tensor_size):
127
+ actual = inp.clone()
128
+ tensors = [
129
+ i.view(*s)
130
+ for s, i in zip(
131
+ split,
132
+ torch.split(actual, tensor_size),
133
+ )
134
+ ]
135
+
136
+ actual_output, _ = allocate_reduce_scatter_output(
137
+ tensors,
138
+ world_size,
139
+ )
140
+
141
+ opts = ReduceScatterOptions()
142
+ opts.reduceOp = reduce_op
143
+
144
+ work = reduce_scatter_quantized(actual_output, tensors, opts, pg)
145
+ work.get_future().wait()
146
+
147
+ padded_sizes = get_padded_sizes(tensors, world_size)
148
+ padded_numel = sum(s.numel() for s in padded_sizes)
149
+
150
+ padded_input = torch.empty(padded_numel, dtype=dtype, device=device)
151
+ torch._chunk_cat(
152
+ tensors, dim=0, num_chunks=world_size, out=padded_input
153
+ )
154
+
155
+ expected_output = torch.empty(
156
+ padded_numel // world_size, dtype=dtype, device=device
157
+ )
158
+
159
+ work = pg.reduce_scatter([expected_output], [[padded_input]], opts)
160
+ work.get_future().wait()
161
+
162
+ _check_result_tolerance(actual_output, expected_output, tolerance)
163
+
164
+ END_TO_END_CONFIGS: list[tuple[int, float, ReduceOp, torch.dtype]] = [
165
+ (ts, m, o, t)
166
+ for ts in [256, 1024, 2048]
167
+ for m in [1.0, 100.0, 1000.0]
168
+ for o in [ReduceOp.AVG, ReduceOp.SUM]
169
+ for t in [torch.float32, torch.float16, torch.bfloat16]
170
+ ]
171
+
172
+ @parameterized.expand(END_TO_END_CONFIGS)
173
+ def test_all_reduce_collective(
174
+ self,
175
+ tensor_size: int,
176
+ multiplier: float,
177
+ reduce_op: ReduceOp,
178
+ dtype: torch.dtype,
179
+ ) -> None:
180
+ self._run_parallel_collectives(
181
+ lambda pg, _, device: self._run_all_reduce_collective(
182
+ pg,
183
+ device,
184
+ 2,
185
+ tensor_size,
186
+ multiplier,
187
+ 0.04,
188
+ reduce_op,
189
+ dtype,
190
+ )
191
+ )
192
+
193
+ @parameterized.expand(END_TO_END_CONFIGS)
194
+ def test_reduce_scatter_collective(
195
+ self,
196
+ tensor_size: int,
197
+ multiplier: float,
198
+ reduce_op: ReduceOp,
199
+ dtype: torch.dtype,
200
+ ) -> None:
201
+ self._run_parallel_collectives(
202
+ lambda pg, _, device: self._run_reduce_scatter_collective(
203
+ pg,
204
+ device,
205
+ 2,
206
+ tensor_size,
207
+ multiplier,
208
+ 0.05,
209
+ reduce_op,
210
+ dtype,
211
+ )
212
+ )
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Coordination (Low Level API)
9
+ ============================
10
+
11
+ .. warning::
12
+ As torchft is still in development, the APIs in this module are subject to change.
13
+
14
+ This module exposes low level coordination APIs to allow you to build your own
15
+ custom fault tolerance algorithms on top of torchft.
16
+
17
+ If you're looking for a more complete solution, please use the other modules in
18
+ torchft.
19
+
20
+ This provides direct access to the Lighthouse and Manager servers and clients.
21
+ """
22
+
23
+ from torchft._torchft import (
24
+ LighthouseClient,
25
+ LighthouseServer,
26
+ ManagerClient,
27
+ ManagerServer,
28
+ Quorum,
29
+ QuorumMember,
30
+ )
31
+
32
+ __all__ = [
33
+ "LighthouseClient",
34
+ "LighthouseServer",
35
+ "ManagerServer",
36
+ "ManagerClient",
37
+ "Quorum",
38
+ "QuorumMember",
39
+ ]
@@ -0,0 +1,29 @@
1
+ import inspect
2
+ from unittest import TestCase
3
+
4
+ from torchft.coordination import (
5
+ LighthouseClient,
6
+ LighthouseServer,
7
+ ManagerClient,
8
+ ManagerServer,
9
+ Quorum,
10
+ QuorumMember,
11
+ )
12
+
13
+
14
+ class TestCoordination(TestCase):
15
+ def test_coordination_docs(self) -> None:
16
+ classes = [
17
+ ManagerClient,
18
+ ManagerServer,
19
+ LighthouseServer,
20
+ LighthouseClient,
21
+ Quorum,
22
+ QuorumMember,
23
+ ]
24
+ for cls in classes:
25
+ self.assertIn("Args:", str(cls.__doc__), cls)
26
+ for name, method in inspect.getmembers(cls, predicate=inspect.ismethod):
27
+ if name.startswith("_"):
28
+ continue
29
+ self.assertIn("Args:", str(cls.__doc__), cls)
torchft/data.py ADDED
@@ -0,0 +1,77 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data
9
+ ====
10
+
11
+ This module provides helper classes to implement fault tolerant data loaders.
12
+
13
+ We recommend using torchdata's StatefulDataLoader to checkpoint each replica's
14
+ dataloader frequently to avoid duplicate batches.
15
+ """
16
+
17
+ from typing import Optional
18
+
19
+ import torch.distributed as dist
20
+ from torch.utils import data
21
+
22
+
23
+ # pyre-fixme[24]: expected generic parameter
24
+ class DistributedSampler(data.distributed.DistributedSampler):
25
+ """
26
+ DistributedSampler extends the standard PyTorch DistributedSampler with a
27
+ `num_replica_groups` that is used to shard the data across the fault
28
+ tolerance replica groups.
29
+
30
+ torchft doesn't know how many replica groups ahead of time so we need to set
31
+ this to be the max number.
32
+
33
+ This sampler is inherently lossy when used with torchft. torchft
34
+ occasionally drops batches on rejoining and if a replica group is down that
35
+ group examples will never be used. This can lead to imbalances if using a
36
+ small dataset.
37
+
38
+ This will shard the input dataset into ``num_replicas*num_replica_group``
39
+ number of shards.
40
+
41
+ Each shard rank is calculated via: ``rank + num_replicas*replica_rank``
42
+
43
+ num_replicas and replica_rank must be the same on all workers.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ dataset: data.Dataset,
49
+ replica_rank: int,
50
+ num_replica_groups: int,
51
+ group_rank: Optional[int] = None,
52
+ num_replicas: Optional[int] = None,
53
+ **kwargs: object,
54
+ ) -> None:
55
+ """
56
+ Args:
57
+ data: the dataset to use
58
+ replica_rank: the group ID (0-num_replica_groups) to use for this shard of data.
59
+ num_replica_groups: the max number of global replica groups
60
+ rank: the local group rank
61
+ num_replicas: the local group world size
62
+ """
63
+ if group_rank is None:
64
+ group_rank = dist.get_rank()
65
+ if num_replicas is None:
66
+ num_replicas = dist.get_world_size()
67
+
68
+ self.global_rank: int = group_rank + num_replicas * replica_rank
69
+ self.global_world_size: int = num_replicas * num_replica_groups
70
+
71
+ super().__init__(
72
+ dataset,
73
+ rank=self.global_rank,
74
+ num_replicas=self.global_world_size,
75
+ # pyre-fixme[6]: got object
76
+ **kwargs,
77
+ )
torchft/data_test.py ADDED
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from unittest import TestCase
8
+
9
+ from torch.utils.data import Dataset
10
+
11
+ from torchft.data import DistributedSampler
12
+
13
+
14
+ class DummyDataset(Dataset):
15
+ def __init__(self, length: int) -> None:
16
+ self.length = length
17
+
18
+ def __len__(self) -> int:
19
+ return self.length
20
+
21
+ def __getitem__(self, idx: int) -> int:
22
+ return idx
23
+
24
+
25
+ class TestData(TestCase):
26
+ def test_distributed_sampler(self) -> None:
27
+ dataset = DummyDataset(1000)
28
+ sampler = DistributedSampler(
29
+ dataset,
30
+ replica_rank=1,
31
+ num_replica_groups=2,
32
+ group_rank=3,
33
+ num_replicas=4,
34
+ )
35
+ self.assertEqual(sampler.global_rank, 3 + 1 * 4)
36
+ self.assertEqual(sampler.global_world_size, 2 * 4)
37
+
38
+ sampler_iter = iter(sampler)
39
+ self.assertEqual(next(sampler_iter), 500)
torchft/ddp.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Distributed Data Parallel
9
+ ==========================
10
+
11
+ This module implements a DistributedDataParallel wrapper that works with the
12
+ Manager to provide fault tolerance.
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ from typing import cast, Optional, TYPE_CHECKING
18
+ from unittest.mock import patch
19
+
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch import nn
23
+ from torch.distributed.algorithms.join import Joinable
24
+ from torch.nn import parallel
25
+
26
+ from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo
27
+
28
+ if TYPE_CHECKING:
29
+ from torchft.manager import _ManagedFuture, Manager
30
+
31
+
32
+ class DistributedDataParallel(parallel.DistributedDataParallel):
33
+ """
34
+ This is a patched DistributedDataParallel implementation that makes it
35
+ compatible with torchft.
36
+
37
+ Important notes:
38
+
39
+ * This requires states to be synced on step 0 using an external mechanism
40
+ rather than an internal broadcast (torchft.Manager will do this).
41
+ * Using non-basic features of the DDP may cause your model to catch fire as
42
+ they haven't been tested with torchft.
43
+ * This doesn't any sanity checks such as verifying parameter sizes are the
44
+ same across workers.
45
+ """
46
+
47
+ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> None:
48
+ # use a dummy PG to soak up the init all reduce, actual comms will go
49
+ # through the comm_hook.
50
+ pg = ProcessGroupDummy(0, 1)
51
+
52
+ super().__init__(
53
+ module,
54
+ process_group=pg,
55
+ # HACK: This forces the reducer to never rebuild buckets.
56
+ # The reducer normally rebuilds the buckets after the first training
57
+ # step which can improve performance but is incompatible with
58
+ # torchft as it will cause the buckets to diverge for recovering
59
+ # replicas.
60
+ find_unused_parameters=True,
61
+ # pyre-fixme[6]: got object
62
+ **kwargs,
63
+ )
64
+
65
+ self.register_comm_hook(manager, self._comm_hook)
66
+
67
+ @staticmethod
68
+ def _comm_hook(
69
+ state: "Manager", bucket: dist.GradBucket
70
+ ) -> torch.futures.Future[torch.Tensor]:
71
+ work = state.allreduce(bucket.buffer())
72
+ work.wait()
73
+ fut = work.get_future()
74
+
75
+ # We need to return the underlying future here otherwise
76
+ # this can hang
77
+ fut = cast("_ManagedFuture[torch.Tensor]", fut)
78
+ assert fut._fut
79
+ return fut._fut
80
+
81
+
82
+ class PureDistributedDataParallel(nn.Module):
83
+ """
84
+ A pure Python reimplementation of the DDP wrapper.
85
+
86
+ We recommend using DistributedDataParallel instead of this class.
87
+
88
+ This calls one allreduce per gradient tensor and doesn't use a reducer. This
89
+ may be very slow for real models.
90
+ """
91
+
92
+ def __init__(self, manager: "Manager", module: nn.Module) -> None:
93
+ super().__init__()
94
+
95
+ self.module = module
96
+
97
+ def post_grad_hook(p: torch.Tensor) -> None:
98
+ if p.grad is not None:
99
+ manager.allreduce(p.grad)
100
+
101
+ for p in module.parameters():
102
+ p.register_post_accumulate_grad_hook(post_grad_hook)
103
+
104
+ def forward(self, *args: object) -> object:
105
+ return self.module(*args)
torchft/ddp_test.py ADDED
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from unittest import TestCase
8
+ from unittest.mock import create_autospec
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import nn
13
+ from torch.distributed.distributed_c10d import Work
14
+ from torch.futures import Future
15
+
16
+ from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
17
+ from torchft.manager import _ManagedWork, Manager
18
+ from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
19
+ from torchft.work import _DummyWork
20
+
21
+
22
+ class TestDDP(TestCase):
23
+ def test_pure_ddp(self) -> None:
24
+ manager = create_autospec(Manager)
25
+
26
+ m = nn.Linear(3, 4)
27
+ m = PureDistributedDataParallel(manager, m)
28
+
29
+ inp = torch.rand(2, 3)
30
+ out = m(inp)
31
+ loss = out.mean()
32
+ loss.backward()
33
+
34
+ for p in m.parameters():
35
+ self.assertIsNotNone(p.grad)
36
+
37
+ self.assertEqual(manager.allreduce.call_count, len(list(m.parameters())))
38
+
39
+ def test_ddp(self) -> None:
40
+ manager = create_autospec(Manager)
41
+
42
+ call_count = 0
43
+
44
+ # pyre-ignore[53]: Captured variable `manager` is not annotated.
45
+ def allreduce(
46
+ tensor: torch.Tensor,
47
+ ) -> Work:
48
+ nonlocal call_count
49
+
50
+ call_count += 1
51
+
52
+ work = _DummyWork(tensor)
53
+ return _ManagedWork(manager, work, tensor)
54
+
55
+ manager.allreduce = allreduce
56
+
57
+ m = nn.Linear(3, 4)
58
+ m = DistributedDataParallel(manager, m)
59
+
60
+ inp = torch.rand(2, 3)
61
+ out = m(inp)
62
+ loss = out.mean()
63
+ loss.backward()
64
+
65
+ for p in m.parameters():
66
+ self.assertIsNotNone(p.grad)
67
+
68
+ self.assertGreaterEqual(call_count, 1)