torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,131 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from unittest import skipUnless, TestCase
8
+
9
+ import torch
10
+ from parameterized import parameterized
11
+ from torch.distributed import ReduceOp
12
+
13
+ from torchft import _test_utils
14
+
15
+ torch.set_printoptions(precision=4, sci_mode=False)
16
+
17
+ DEVICE = "cuda"
18
+
19
+ try:
20
+ # pyre-fixme[21]: Could not find a module corresponding to import `triton`
21
+ import triton
22
+ except ImportError:
23
+ pass
24
+ else:
25
+ from torchft.quantization import (
26
+ fused_dequantize_from_fp8,
27
+ fused_quantize_into_fp8,
28
+ fused_reduce_fp8,
29
+ )
30
+
31
+ @skipUnless(
32
+ torch.cuda.is_available(),
33
+ "CUDA is required for this test",
34
+ )
35
+ class QuantizationTest(TestCase):
36
+ def run_test(
37
+ self,
38
+ world_size: int,
39
+ tensors_num: int,
40
+ tensor_size: int,
41
+ multiplier: float,
42
+ tolerance: float,
43
+ reduce_op: ReduceOp,
44
+ type: torch.dtype,
45
+ ) -> None:
46
+ inp = (
47
+ torch.rand(
48
+ tensors_num * tensor_size,
49
+ dtype=type,
50
+ device="cuda",
51
+ )
52
+ * multiplier
53
+ )
54
+
55
+ for split in _test_utils.gen_splits(inp, tensor_size):
56
+ inputs = inp.clone()
57
+ outputs = torch.empty_like(inputs)
58
+
59
+ reshaped_inputs = []
60
+ reshaped_outputs = []
61
+ for s, i, o in zip(
62
+ split,
63
+ torch.split(inputs, tensor_size),
64
+ torch.split(outputs, tensor_size),
65
+ ):
66
+ reshaped_inputs.append(i.view(*s))
67
+ reshaped_outputs.append(o.view(*s))
68
+
69
+ quant = fused_quantize_into_fp8(reshaped_inputs, world_size)
70
+ quant_slices = torch.split(quant, quant.numel() // world_size)
71
+
72
+ quant_final = torch.empty_like(quant)
73
+ quant_final_slices = torch.split(
74
+ quant_final, quant_final.numel() // world_size
75
+ )
76
+
77
+ for rank in range(world_size):
78
+ r = (rank) % world_size
79
+ quant_copy = torch.empty_like(quant)
80
+ quant_copy_slices = torch.split(
81
+ quant_copy, quant_copy.numel() // world_size
82
+ )
83
+ for other in range(world_size):
84
+ quant_copy_slices[other].copy_(quant_slices[r])
85
+
86
+ fused_reduce_fp8(
87
+ reshaped_inputs, quant_copy, world_size, r, reduce_op
88
+ )
89
+
90
+ quant_final_slices[r].copy_(quant_copy_slices[r])
91
+
92
+ fused_dequantize_from_fp8(reshaped_outputs, quant_final, world_size)
93
+
94
+ self.assertFalse(_test_utils.any_nan(reshaped_outputs))
95
+
96
+ if reduce_op == ReduceOp.SUM:
97
+ inputs.mul_(world_size)
98
+
99
+ diff = torch.abs(
100
+ (inputs - outputs).div(inputs.to(torch.float32) + 0.0000001)
101
+ )
102
+ mean_diff = diff.mean().item()
103
+ self.assertLessEqual(
104
+ mean_diff, tolerance, f"Results not within tolerance {tolerance}"
105
+ )
106
+
107
+ END_TO_END_CONFIGS: list[tuple[int, float, ReduceOp, torch.dtype]] = [
108
+ (ts, m, o, t)
109
+ for ts in [128, 512, 4096]
110
+ for m in [1.0, 100.0, 1000.0]
111
+ for o in [ReduceOp.AVG, ReduceOp.SUM]
112
+ for t in [torch.float32, torch.float16, torch.bfloat16]
113
+ ]
114
+
115
+ @parameterized.expand(END_TO_END_CONFIGS)
116
+ def test_end_to_end(
117
+ self,
118
+ tensor_size: int,
119
+ multiplier: float,
120
+ reduce_op: ReduceOp,
121
+ type: torch.dtype,
122
+ ) -> None:
123
+ self.run_test(
124
+ world_size=2,
125
+ tensors_num=3,
126
+ tensor_size=tensor_size,
127
+ multiplier=multiplier,
128
+ tolerance=0.05,
129
+ reduce_op=reduce_op,
130
+ type=type,
131
+ )
torchft/torchx.py ADDED
@@ -0,0 +1,89 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ This is a file for TorchX components used for testing torchft.
9
+ """
10
+
11
+ import os
12
+ from typing import Dict, Optional
13
+
14
+ import torchx.specs as specs
15
+
16
+
17
+ def hsdp(
18
+ *script_args: str,
19
+ replicas: int = 2,
20
+ workers_per_replica: int = 1,
21
+ max_restarts: int = 10,
22
+ script: str = "train_ddp.py",
23
+ env: Optional[Dict[str, str]] = None,
24
+ image: str = "",
25
+ h: Optional[str] = None,
26
+ cpu: int = 2,
27
+ gpu: int = 0,
28
+ memMB: int = 1024,
29
+ ) -> specs.AppDef:
30
+ assert replicas > 0, "replicas must be > 0"
31
+ assert workers_per_replica > 0, "workers_per_replica must be > 0"
32
+
33
+ env = env or {}
34
+
35
+ # Enable logging for PyTorch, torchelastic and Rust.
36
+ env.setdefault("TORCH_CPP_LOG_LEVEL", "INFO")
37
+ env.setdefault("LOGLEVEL", "INFO")
38
+ env.setdefault("RUST_BACKTRACE", "1")
39
+
40
+ # Enable colored logging for torchft Rust logger.
41
+ env.setdefault("CLICOLOR_FORCE", "1")
42
+
43
+ # Set lighthouse address for replicas
44
+ # This must be run externally
45
+ env.setdefault(
46
+ "TORCHFT_LIGHTHOUSE",
47
+ os.environ.get("TORCHFT_LIGHTHOUSE", f"http://localhost:29510"),
48
+ )
49
+
50
+ # Disable CUDA for CPU-only jobs
51
+ env.setdefault("CUDA_VISIBLE_DEVICES", "")
52
+
53
+ # Disable XPU for CPU-only jobs
54
+ env.setdefault("XPU_VISIBLE_DEVICES", "")
55
+
56
+ roles = []
57
+ for replica_id in range(replicas):
58
+ cmd = [
59
+ f"--master_port={29600+replica_id}",
60
+ "--nnodes=1",
61
+ f"--nproc_per_node={workers_per_replica}",
62
+ f"--max_restarts={max_restarts}",
63
+ ]
64
+ if script:
65
+ cmd += [script]
66
+ cmd += list(script_args)
67
+
68
+ roles.append(
69
+ specs.Role(
70
+ name=f"replica_{replica_id}",
71
+ image=image,
72
+ min_replicas=workers_per_replica,
73
+ num_replicas=workers_per_replica,
74
+ resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
75
+ max_retries=0,
76
+ env={
77
+ "REPLICA_GROUP_ID": str(replica_id),
78
+ "NUM_REPLICA_GROUPS": str(replicas),
79
+ **env,
80
+ },
81
+ entrypoint="torchrun",
82
+ args=cmd,
83
+ )
84
+ )
85
+
86
+ return specs.AppDef(
87
+ name="torchft",
88
+ roles=roles,
89
+ )
torchft/utils.py ADDED
@@ -0,0 +1,67 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility functions for TorchFT.
9
+ """
10
+
11
+ from contextlib import nullcontext
12
+ from typing import Any, Optional, Union
13
+
14
+ import torch
15
+
16
+
17
+ def get_stream_context(
18
+ stream: Optional[torch.Stream],
19
+ ) -> Union[torch.cuda.StreamContext, torch.xpu.StreamContext, nullcontext[None]]:
20
+ """
21
+ Get the appropriate stream context for the given stream.
22
+
23
+ This function provides a unified way to handle stream contexts across different
24
+ accelerator types (CUDA, XPU).
25
+
26
+ Args:
27
+ stream: The stream to create a context for. If None, returns nullcontext.
28
+
29
+ Returns:
30
+ The appropriate stream context for the accelerator type, or nullcontext
31
+ if stream is None or no accelerator is available.
32
+ """
33
+ if stream is not None:
34
+ if torch.cuda.is_available():
35
+ # pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
36
+ return torch.cuda.stream(stream)
37
+ elif torch.xpu.is_available():
38
+ # pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
39
+ return torch.xpu.stream(stream)
40
+ else:
41
+ return nullcontext()
42
+ else:
43
+ return nullcontext()
44
+
45
+
46
+ def record_event() -> None:
47
+ """
48
+ Record an event in the current stream.
49
+
50
+ This function provides a unified way to record events across different
51
+ accelerator types (CUDA, XPU).
52
+ """
53
+ if torch.xpu.is_available():
54
+ torch.xpu.current_stream().record_event(torch.xpu.Event())
55
+ else:
56
+ torch.cuda.current_stream().record_event(torch.cuda.Event(interprocess=True))
57
+
58
+
59
+ def synchronize() -> None:
60
+ """
61
+ This function provides a unified way to synchronize current stream across different
62
+ accelerator types (CUDA, XPU).
63
+ """
64
+ if torch.cuda.is_available():
65
+ torch.cuda.current_stream().synchronize()
66
+ elif torch.xpu.is_available():
67
+ torch.xpu.current_stream().synchronize()
torchft/work.py ADDED
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from contextlib import nullcontext
8
+ from datetime import timedelta
9
+ from typing import Optional
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+
15
+ class _DummyWork(dist._Work):
16
+ def __init__(self, result: object) -> None:
17
+ super().__init__()
18
+ self.result_ = result
19
+ self.future_: torch.futures.Future[object] = torch.futures.Future()
20
+ self.future_.set_result(result)
21
+
22
+ def wait(self, timeout: Optional[timedelta] = None) -> bool:
23
+ return True
24
+
25
+ def get_future(self) -> torch.futures.Future[object]:
26
+ return self.future_
@@ -0,0 +1,308 @@
1
+ Metadata-Version: 2.4
2
+ Name: torchft-nightly
3
+ Version: 2026.1.3
4
+ Classifier: Programming Language :: Rust
5
+ Classifier: Programming Language :: Python :: Implementation :: CPython
6
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
7
+ Requires-Dist: torch>=2.7
8
+ Requires-Dist: opentelemetry-exporter-otlp-proto-http>=1.39.0
9
+ Requires-Dist: opentelemetry-sdk>=1.39.0
10
+ Requires-Dist: opentelemetry-api>=1.39.0
11
+ Requires-Dist: pytest==8.3.4 ; extra == 'dev'
12
+ Requires-Dist: pytest-timeout ; extra == 'dev'
13
+ Requires-Dist: parameterized ; extra == 'dev'
14
+ Requires-Dist: expecttest ; extra == 'dev'
15
+ Requires-Dist: numpy ; extra == 'dev'
16
+ Requires-Dist: torchx-nightly ; extra == 'dev'
17
+ Requires-Dist: lintrunner ; extra == 'dev'
18
+ Requires-Dist: lintrunner-adapters ; extra == 'dev'
19
+ Provides-Extra: dev
20
+ License-File: LICENSE
21
+ Requires-Python: >=3.8
22
+ Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
23
+ Project-URL: Documentation, https://docs.pytorch.org/torchft
24
+ Project-URL: Issues, https://github.com/pytorch/torchft/issues
25
+ Project-URL: Repository, https://github.com/pytorch/torchft
26
+
27
+ <p align="center">
28
+ <picture>
29
+ <source media="(prefers-color-scheme: dark)" srcset="./media/torchft_logo_dark.svg">
30
+ <img width="55%" src="./media/torchft_logo.svg" alt="torchft">
31
+ </picture>
32
+ </p>
33
+
34
+ <h3 align="center">
35
+ Easy Per Step Fault Tolerance for PyTorch
36
+ </h3>
37
+
38
+ <p align="center">
39
+ | <a href="https://pytorch.org/torchft/"><b>Documentation</b></a>
40
+ | <a href="https://github.com/pytorch/torchft/blob/main/media/fault_tolerance_poster.pdf"><b>Poster</b></a>
41
+ | <a href="https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit"><b>Design Doc</b></a>
42
+ |
43
+ </p>
44
+ <p align="center">
45
+ <a href="https://pypi.org/project/torchft-nightly/"><img alt="PyPI - Version" src="https://img.shields.io/pypi/v/torchft-nightly"></a>
46
+ </p>
47
+
48
+ ---
49
+
50
+ This repository implements techniques for doing a per-step fault tolerance so
51
+ you can keep training if errors occur without interrupting the entire training
52
+ job.
53
+
54
+ [This is based on the large scale training techniques presented at PyTorch
55
+ Conference 2024.](./media/fault_tolerance_poster.pdf)
56
+
57
+ ## Overview
58
+
59
+ torchft is designed to provide the primitives required to implement fault
60
+ tolerance in any application/train script as well as the primitives needed to
61
+ implement custom fault tolerance strategies.
62
+
63
+ Out of the box, torchft provides the following algorithms:
64
+
65
+ * Fault Tolerant DDP
66
+ * Fault Tolerant HSDP: fault tolerance across the replicated dimension with any mix of FSDP/TP/etc across the other dimensions.
67
+ * LocalSGD
68
+ * DiLoCo
69
+
70
+ To implement these, torchft provides some key reusable components:
71
+
72
+ 1. Coordination primitives that can determine which workers are healthy via
73
+ heartbeating on a per-step basis
74
+ 2. Fault tolerant ProcessGroup implementations that report errors sanely and be
75
+ reinitialized gracefully.
76
+ 3. Checkpoint transports that can be used to do live recovery from a healthy
77
+ peer when doing scale up operations.
78
+
79
+ The following component diagram shows the high level components and how they
80
+ relate to each other:
81
+
82
+ ![Component Diagram](./media/overview.mmd.svg)
83
+
84
+ See [torchft's documentation](https://pytorch.org/torchft) for more details.
85
+
86
+ ## Examples
87
+
88
+ ### torchtitan (Fault Tolerant HSDP)
89
+
90
+ torchtitan provides an out of the box fault tolerant HSDP training loop built on
91
+ top of torchft that can be used to train models such as Llama 3 70B.
92
+
93
+ It also serves as a good example of how you can integrate torchft into your own training script for use with HSDP.
94
+
95
+ See [torchtitan's documentation for end to end usage](https://github.com/pytorch/torchtitan/blob/main/docs/torchft.md).
96
+
97
+ ### Fault Tolerant DDP
98
+
99
+ We have a minimal DDP train loop that highlights all of the key components in torchft.
100
+
101
+ See [train_ddp.py](./train_ddp.py) for more info.
102
+
103
+
104
+ ### DiLoCo
105
+
106
+ LocalSGD and DiLoCo are currently experimental.
107
+
108
+ See
109
+ [the diloco_train_loop/local_sgd_train_loop tests](./torchft/local_sgd_integ_test.py)
110
+ for an example on how to integrate these algorithms into your training loop.
111
+
112
+
113
+ ## Design
114
+
115
+ torchft is designed to allow for fault tolerance when using training with replicated weights such as in DDP or HSDP (FSDP with DDP).
116
+
117
+ See the [design doc](https://docs.google.com/document/d/1OZsOsz34gRDSxYXiKkj4WqcD9x0lP9TcsfBeu_SsOY4/edit) for the most detailed explanation.
118
+
119
+ ### Lighthouse
120
+
121
+ torchft implements a lighthouse server that coordinates across the different
122
+ replica groups and then a per replica group manager and fault tolerance library
123
+ that can be used in a standard PyTorch training loop.
124
+
125
+ This allows for membership changes at the training step granularity which can
126
+ greatly improve efficiency by avoiding stopping the world training on errors.
127
+
128
+ ![Lighthouse Diagram](./media/torchft-overview.png)
129
+
130
+ ### Fault Tolerant HSDP Algorithm
131
+
132
+ torchft provides an implementation of a fault tolerant HSDP/DDP algorithm. The
133
+ following diagram shows the high level operations that need to happen in the
134
+ train loop to ensure everything stays consistent during a healing operation.
135
+
136
+ ![HSDP Diagram](./media/hsdp_train_loop.png)
137
+
138
+ See the design doc linked above for more details.
139
+
140
+ ## Installing from PyPI
141
+
142
+ We have nighty builds available at https://pypi.org/project/torchft-nightly/
143
+
144
+ To install torchft with minimal dependencies you can run:
145
+
146
+ ```sh
147
+ pip install torchft-nightly
148
+ ```
149
+
150
+ If you want all development dependencies you can install:
151
+
152
+ ```sh
153
+ pip install torchft-nightly[dev]
154
+ ```
155
+
156
+ ## Installing from Source
157
+
158
+ ### Prerequisites
159
+
160
+ Before proceeding, ensure you have the following installed:
161
+
162
+ - Rust (with necessary dependencies)
163
+ - `protobuf-compiler` and the corresponding development package for Protobuf.
164
+ - PyTorch 2.7 RC+ or Nightly
165
+
166
+ Note that the Rust versions available in many conda environments may be outdated. To install the latest version of Rust, we recommend downloading it directly from the official website as shown in the below command:
167
+ ```sh
168
+ curl --proto '=https' --tlsv1.2 https://sh.rustup.rs -sSf | sh
169
+ ```
170
+
171
+ To install the required packages on a Debian-based system (such as Ubuntu) using apt, run:
172
+
173
+ ```sh
174
+ sudo apt install protobuf-compiler libprotobuf-dev
175
+ ```
176
+
177
+ or for a Red Hat-based system, run:
178
+
179
+ ```sh
180
+ sudo dnf install protobuf-compiler protobuf-devel
181
+ ```
182
+
183
+ ### Installation
184
+
185
+ ```sh
186
+ pip install .
187
+ ```
188
+
189
+ This uses pyo3+maturin to build the package, you'll need maturin installed.
190
+
191
+ If the installation command fails to invoke `cargo update` due to an inability to fetch the manifest, it may be caused by the `proxy`, `proxySSLCert`, and `proxySSLKey` settings in your .`gitconfig` file affecting the `cargo` command. To resolve this issue, try temporarily removing these fields from your `.gitconfig` before running the installation command.
192
+
193
+ To install in editable mode w/ the Rust extensions and development dependencies, you can use the normal pip install command:
194
+
195
+ ```sh
196
+ pip install -e '.[dev]'
197
+ ```
198
+
199
+ ## Usage
200
+
201
+ ### Lighthouse
202
+
203
+ The lighthouse is used for fault tolerance across replicated workers (DDP/FSDP)
204
+ when using synchronous training.
205
+
206
+ You can start a lighthouse server by running:
207
+
208
+ ```sh
209
+ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
210
+ ```
211
+
212
+ ### Example Training Loop (DDP)
213
+
214
+ See [train_ddp.py](./train_ddp.py) for the full example.
215
+
216
+ Invoke with:
217
+
218
+ ```sh
219
+ TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 train_ddp.py
220
+ ```
221
+
222
+ train.py:
223
+
224
+ ```py
225
+ from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo
226
+
227
+ manager = Manager(
228
+ pg=ProcessGroupGloo(),
229
+ load_state_dict=...,
230
+ state_dict=...,
231
+ )
232
+
233
+ m = nn.Linear(2, 3)
234
+ m = DistributedDataParallel(manager, m)
235
+ optimizer = Optimizer(manager, optim.AdamW(m.parameters()))
236
+
237
+ for i in range(1000):
238
+ batch = torch.rand(2, 2, device=device)
239
+
240
+ optimizer.zero_grad()
241
+
242
+ out = m(batch)
243
+ loss = out.sum()
244
+
245
+ loss.backward()
246
+
247
+ optimizer.step()
248
+ ```
249
+
250
+ ### Running DDP
251
+
252
+ After starting the lighthouse server by running:
253
+
254
+ ```sh
255
+ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
256
+ ```
257
+
258
+ A test DDP script can be launched with torchX with:
259
+
260
+ ```sh
261
+ torchx run
262
+ ```
263
+
264
+ Or Diloco with:
265
+
266
+ ```sh
267
+ USE_STREAMING=True torchx run ./torchft/torchx.py:hsdp --script='train_diloco.py'
268
+ ```
269
+
270
+ See [.torchxconfig](.torchxconfig), [torchx.py](./torchft/torchx.py) and the [torchX documentation](https://pytorch.org/torchx/latest/) to understand how DDP is being ran.
271
+
272
+ `torchx.py` could also launch HSDP jobs when `workers_per_replica` is set > 1, if the training script supports it. For an example HSDP training implementation with torchFT enabled, see [torchtitan](https://github.com/pytorch/torchtitan).
273
+
274
+ Alternatively, to test on a node with two GPUs, you can launch two replica groups running [train_ddp.py](./train_ddp.py) by:
275
+
276
+ On shell 1 (one replica groups starts initial training):
277
+ ```sh
278
+ export REPLICA_GROUP_ID=0
279
+ export NUM_REPLICA_GROUPS=2
280
+
281
+ CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp.py
282
+ ```
283
+
284
+ On shell 2 (a second replica group joins):
285
+ ```sh
286
+ export REPLICA_GROUP_ID=1
287
+ export NUM_REPLICA_GROUPS=2
288
+
289
+ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp.py
290
+ ```
291
+
292
+ By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery.
293
+
294
+ ### Example Parameter Server
295
+
296
+ torchft has a fault tolerant parameter server implementation built on it's
297
+ reconfigurable ProcessGroups. This does not require/use a Lighthouse server.
298
+
299
+ See [parameter_server_test.py](./torchft/parameter_server_test.py) for an example.
300
+
301
+ ## Contributing
302
+
303
+ We welcome PRs! See the [CONTRIBUTING](./CONTRIBUTING.md) file.
304
+
305
+ ## License
306
+
307
+ torchft is BSD 3-Clause licensed. See [LICENSE](./LICENSE) for more details.
308
+
@@ -0,0 +1,61 @@
1
+ torchft/__init__.py,sha256=OGvxyNeEBOM---9FsnapJcyqH3KB0GnJ9mGhbr-jsjw,915
2
+ torchft/_test/diloco_trainer.py,sha256=r4_LJRzkluRipies39NyyjMb3sEG4JO8hmFJlrdR7ds,9648
3
+ torchft/_test/managed_work_test.py,sha256=28t-b5EaQ4DSFFvuzipq_ox1tmkTrVOWKNADWFf3AbU,11711
4
+ torchft/_test_utils.py,sha256=Hlf_g2rUbFxCI_VskDDZgn9MdvqzPj6VFsNZXWQ07Z4,3260
5
+ torchft/_torchft.cpython-310-x86_64-linux-gnu.so,sha256=X0DHEKq_bCBvmROrOudqRZ6tiqix98zcNEU8_2KKuy8,6832056
6
+ torchft/_torchft.pyi,sha256=WqJlm5dsMT1G8h7iofyZNhvRttmDZpKBuw5GrbeRO-0,2953
7
+ torchft/checkpointing/__init__.py,sha256=-Bottyor1P4u9Fd-yYGpvWKrPMc7V0wuF25-EUjdS04,528
8
+ torchft/checkpointing/_rwlock.py,sha256=yVJGo3pXRTwFb3O5WoaIiz6_gJyrW-I5J5TtSeK53n8,4533
9
+ torchft/checkpointing/_serialization.py,sha256=MU9aYF0xAN1hDXqjZNdUDrnwgDkq3XGFhPtkvzy-CR0,1068
10
+ torchft/checkpointing/http_transport.py,sha256=mtmxi9-1rVRHom_9RoZyrHY5Nv2AAW9F2KGwG_GkfJE,10282
11
+ torchft/checkpointing/http_transport_bench.py,sha256=Uo-GTndb1hDz6vzIoxWPNO0aH82uZFaBdpdLsQKCDPY,1892
12
+ torchft/checkpointing/http_transport_test.py,sha256=c35Hy7gCggg6p9Ml2gf1cEXCG4lsVi8hl7gI7KON2xs,4606
13
+ torchft/checkpointing/pg_transport.py,sha256=hIESpx6lEhZSQLYYPg5CHuMvSDezBxLGY-M3HQplP-s,9742
14
+ torchft/checkpointing/pg_transport_bench.py,sha256=VTlnDUmcl4FWNPFAk8O0XqbRK7BnUwIBbPGus_6eDpw,3074
15
+ torchft/checkpointing/pg_transport_test.py,sha256=imFArrK285gICgLJbm9136zmYSO8-XVrhWI6Nqf13Rs,3678
16
+ torchft/checkpointing/rwlock_test.py,sha256=E79ah4dQMMMWBoTOzBhlzru8tRcSTYKQrI4HOJ6TRGQ,1155
17
+ torchft/checkpointing/transport.py,sha256=2mUf3_xtL-IeE90ZrRWAvvFmJSctPUXPnUYOi47PqeQ,2013
18
+ torchft/checkpointing/transport_test.py,sha256=BTKYMNKqkP3NO1PkapxyNVQP0HmgRMze-SVOSslE-Hs,5029
19
+ torchft/collectives.py,sha256=-UXCWzHE2k6TKsj1D34on-HcxgAP-Tl9tGc2rFM05vk,14528
20
+ torchft/collectives_test.py,sha256=QpqosipNB-Z0LlgUfsoKY6ZBCZcZZ11eRsyLdtUJ46Y,6829
21
+ torchft/coordination.py,sha256=E2nkxPMILaQneDM9VhwvU4GYWlYHEEt81HBVLyWDBUU,964
22
+ torchft/coordination_test.py,sha256=yFkXxL8PVrf73viGEHQpv8c-N1RmfyHen_ArokpclfA,770
23
+ torchft/data.py,sha256=cCtlKPcM-djU73sNq6wld9jBjkuvyLUjgdYrn8UT_jo,2528
24
+ torchft/data_test.py,sha256=ATiE8XPNjNbFDCCsmKCireUU-EnUnqn9vRm6td8w86s,1051
25
+ torchft/ddp.py,sha256=bE46JK2-yrjM7fnjBCXXi6v0smmuG_uRX2sOsNh91Zw,3389
26
+ torchft/ddp_test.py,sha256=SQZFWkPqB1ejw-DJpwaEiHKxUj0Afx_QNLX7kjBjAgc,1867
27
+ torchft/diloco_regression_test.py,sha256=t1gLcc4m55ujM3M-PTWGuYLSwY9BCgojafUcWY80mC8,25664
28
+ torchft/examples/slurm/README.md,sha256=l8cRSpy_kCUWInKYSx78-rYuMjUCxXKX5oFbI81ksKc,891
29
+ torchft/examples/slurm/punisher.py,sha256=Cp4oWrh8Zp4ZAg_MuT1e2BGJLdczlLm4XkQQtQQbMI4,2655
30
+ torchft/examples/slurm/runner.py,sha256=rsYv1bBsC-uvxTT-1Xu6dOV6rKevXbPXFYe7lYtOrCo,6005
31
+ torchft/fsdp_test.py,sha256=mxIa6-y0Xs1JKph8zkireITd1VMj_vbn3vYT-x3Ns-k,3543
32
+ torchft/futures.py,sha256=5VpK99LthFJw0Kiwd99M9gebpp5xGTtmUg25Lmos9MM,11341
33
+ torchft/futures_test.py,sha256=gN5Ud68v0s62Lbs2h3q-VOUsNAvZ2Lu6yDYMS4nwvR4,4582
34
+ torchft/http.py,sha256=SmoBdLMAL8JatoTQut7JsNK9x7QTIS8mR5PQQCV9d-w,407
35
+ torchft/lighthouse_test.py,sha256=d686G7J5Ax8BU37ArONNHOi7b4udGRUhAj4pIhoznzA,5474
36
+ torchft/local_sgd.py,sha256=giYvBZT_MkU-Q0u_OhxF5p6yavWGDAVrQYsM_rteCpk,29440
37
+ torchft/local_sgd_integ_test.py,sha256=mCbNUpBXEeYhhOYhzaBCje1W-AYs7Ifz3vGA_4b8UA8,20230
38
+ torchft/local_sgd_test.py,sha256=SQpl2fanta0MDzAsKbWVw7SlsgTNJdUui7b7wao5YPk,11352
39
+ torchft/manager.py,sha256=KxSZvob05UsiyzkDnIiP9uHlLwZptDmlGr3yfScZatU,50900
40
+ torchft/manager_integ_test.py,sha256=JMzOk80p8X1-O3t2yEb_DH18Hrv2yK4QFbHJCJylj0I,20586
41
+ torchft/manager_test.py,sha256=d42XKVS5bZFHLszNI1tGfHIxL3tN3ytOYqUwZJ7RfKs,33818
42
+ torchft/multiprocessing.py,sha256=AcaakPRd7LSIBP_MvEdtcU3BxIa_LOplEw-vApGNRxI,1116
43
+ torchft/multiprocessing_dummy_context.py,sha256=DXi5iRZWD_N2PETEj0JNte10znyEljCi6sLci9MQyyw,3893
44
+ torchft/multiprocessing_test.py,sha256=JnHcO48NBYOuPwCa37KzIZg6CgBcmRRAhZsA2WqapK0,1518
45
+ torchft/optim.py,sha256=iBJhywVjG9-0490may2PZbEdhvYrbVjsyNOkrkdMPYg,1948
46
+ torchft/optim_test.py,sha256=oA9t83t1beMAWml2Zs0MkmdKJArpSYK7OQ14gkVCpYU,1495
47
+ torchft/otel.py,sha256=gyaSxmNppiPeZSfjIfe42JILYWj0Uw_1Ir-zEI2AF4c,4241
48
+ torchft/parameter_server.py,sha256=fQ4pg8dDGSgkFGsOu36k1-zps0s71EIF5B-wSgXgGRE,6000
49
+ torchft/parameter_server_test.py,sha256=0Yk_6vG2sf3dC8g8OFSSpcuwmJa6yNDwh7xRweUHFzc,1305
50
+ torchft/process_group.py,sha256=sLbmHShafhz4711IunfLTUWP00E1QbbXVfEKumrJACU,69988
51
+ torchft/process_group_test.py,sha256=qHijH_pzca35VClth5aXT_tCgY3ZIxwoFH_UHY2qk5Q,34506
52
+ torchft/quantization.py,sha256=EAIxzdP6AX-c8mEFfyAju11K3tUde1Y4urTrnNYu2n8,24074
53
+ torchft/quantization_test.py,sha256=sRHV4oJJAR8ce9v2nDN2s6LkTr2IQE5lZisHOd2UOPo,4275
54
+ torchft/torchx.py,sha256=dSe0yT5DheWpjl-no6hSFhla_GdB12PaF5uoYBtmsZ8,2493
55
+ torchft/utils.py,sha256=xRPUsm5bIPrWEJeC7_rT0IXr7OO4z864ckLYXDgUpZY,2105
56
+ torchft/work.py,sha256=IlNdLONzoHf_Wu1ujb3tmKn9KytxmGJZH3WE9BBiQMw,776
57
+ torchft_nightly-2026.1.3.dist-info/METADATA,sha256=VtFfDcAo04ZFIh8AgLqvu3XBHEVMh739Dxne5ffX6dg,9888
58
+ torchft_nightly-2026.1.3.dist-info/WHEEL,sha256=pb4kzB25058gQK4V2ufO4qOywKBWVlask3GAY9yVdn4,109
59
+ torchft_nightly-2026.1.3.dist-info/entry_points.txt,sha256=xTOOuXiCVovuQMX7rEX77rnYtRwIwDb1rYiMSWygtqM,70
60
+ torchft_nightly-2026.1.3.dist-info/licenses/LICENSE,sha256=yXcfhd3XpfByVfZX2WlDmfjdj8lytYNDpjowg2qzFuc,1641
61
+ torchft_nightly-2026.1.3.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: maturin (1.11.0)
3
+ Root-Is-Purelib: false
4
+ Tag: cp310-cp310-manylinux_2_24_x86_64
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ torchft_lighthouse=torchft._torchft:lighthouse_main