torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,95 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import random
10
+ import time
11
+
12
+ from torchx import specs
13
+ from torchx.runner import get_runner, Runner
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger: logging.Logger = logging.getLogger(__name__)
17
+
18
+ _SCHEDULER = "slurm"
19
+
20
+
21
+ def kill_all(runner: Runner) -> None:
22
+ jobs = runner.list(_SCHEDULER)
23
+ jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
24
+ for job in jobs:
25
+ if "ft_" not in job.name:
26
+ continue
27
+ print(f"killing {job.app_handle}")
28
+ runner.cancel(job.app_handle)
29
+
30
+
31
+ def kill_one(runner: Runner) -> None:
32
+ jobs = runner.list(_SCHEDULER)
33
+ jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
34
+ candidates = []
35
+ for job in jobs:
36
+ if "ft_" not in job.name:
37
+ continue
38
+ if "ft_0" in job.name:
39
+ continue
40
+ candidates.append(job.app_handle)
41
+ choice = random.choice(candidates)
42
+ print(f"killing {choice=} {candidates=}")
43
+ runner.cancel(choice)
44
+
45
+
46
+ def kill_loop(runner: Runner, args: argparse.Namespace) -> None:
47
+ for _ in range(args.num_failures):
48
+ kill_one(runner)
49
+ dur = random.random() * (2 * args.mtbf_secs)
50
+ print(f"sleeping for {dur=} {args.mtbf_secs=}")
51
+ time.sleep(args.mtbf_secs)
52
+
53
+
54
+ def main() -> None:
55
+ parser = argparse.ArgumentParser(description="CLI tool to inject failures on slurm")
56
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
57
+
58
+ # kill_loop subcommand
59
+ kill_loop_parser = subparsers.add_parser("kill_loop", help="Kill jobs in a loop")
60
+ kill_loop_parser.add_argument(
61
+ "--mtbf-secs",
62
+ type=float,
63
+ default=5,
64
+ help="Mean time between failures",
65
+ )
66
+ kill_loop_parser.add_argument(
67
+ "--num-failures",
68
+ type=int,
69
+ default=1,
70
+ help="Number of failures to inject",
71
+ )
72
+
73
+ # kill_one subcommand
74
+ subparsers.add_parser("kill_one", help="Kill a single job")
75
+
76
+ # kill_all subcommand
77
+ subparsers.add_parser("kill_all", help="Kill all jobs")
78
+
79
+ args = parser.parse_args()
80
+
81
+ if args.command is None:
82
+ parser.print_help()
83
+ return
84
+
85
+ with get_runner() as runner:
86
+ if args.command == "kill_loop":
87
+ kill_loop(runner, args)
88
+ elif args.command == "kill_one":
89
+ kill_one(runner)
90
+ elif args.command == "kill_all":
91
+ kill_all(runner)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
@@ -0,0 +1,221 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import time
11
+
12
+ from torchx import specs
13
+ from torchx.components.dist import ddp
14
+ from torchx.runner import get_runner, Runner
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger: logging.Logger = logging.getLogger(__name__)
18
+
19
+ _SCHEDULER = "slurm"
20
+
21
+
22
+ def _make_app(replica_id: int, cli_args: argparse.Namespace) -> specs.AppDef:
23
+ args = [
24
+ "--comm.trace_buf_size=0",
25
+ "--comm.train_timeout_seconds=60",
26
+ "--metrics.log_freq=1",
27
+ "--profiling.enable_profiling",
28
+ "--experimental.custom_args_module=torchtitan.components.ft.config",
29
+ "--job.config_file=./torchtitan/models/llama3/train_configs/llama3_8b.toml",
30
+ "--model.name=llama3_ft",
31
+ "--training.dataset=c4",
32
+ "--training.steps=10000",
33
+ "--training.local_batch_size=2",
34
+ f"--parallelism.data_parallel_shard_degree={cli_args.nodes * cli_args.nproc_per_node}",
35
+ "--fault_tolerance.enable",
36
+ f"--fault_tolerance.replica_id={replica_id}",
37
+ f"--fault_tolerance.group_size={cli_args.replica_count}",
38
+ f"--fault_tolerance.process_group={cli_args.process_group}",
39
+ f"--fault_tolerance.process_group_timeout_ms={600 * 1000}",
40
+ ]
41
+
42
+ if cli_args.enable_semi_sync:
43
+ args += [
44
+ f"--fault_tolerance.semi_sync_method={cli_args.semi_sync_method}",
45
+ ]
46
+
47
+ if cli_args.semi_sync_method == "diloco":
48
+ args += [
49
+ "--fault_tolerance.sync_steps=20",
50
+ "--fault_tolerance.fragment_sync_delay=1",
51
+ f"--fault_tolerance.num_fragments={cli_args.num_fragments}",
52
+ ]
53
+
54
+ if replica_id == 0:
55
+ args += [
56
+ "--metrics.enable-wandb",
57
+ "--checkpoint.interval=100",
58
+ ]
59
+
60
+ env = {}
61
+
62
+ # use agent store in torchelastic to avoid TCPStore init race condition
63
+ env["TORCH_SHARE_RDZV_TCP_STORE"] = "1"
64
+ env["TORCH_CPP_LOG_LEVEL"] = "INFO"
65
+
66
+ env["TORCH_CUDA_SANITIZER=1"] = "1"
67
+
68
+ # NCCL envs for debugging
69
+ env["NCCL_DEBUG"] = "INFO"
70
+ env["NCCL_DEBUG_SUBSYS"] = "ALL"
71
+ env["NCCL_PROTO"] = "Simple"
72
+
73
+ # gloo
74
+ if os.environ.get("GLOO_SOCKET_IFNAME") is not None:
75
+ env["GLOO_SOCKET_IFNAME"] = os.environ.get("GLOO_SOCKET_IFNAME")
76
+
77
+ # application log levels
78
+ env["LOGLEVEL"] = "INFO"
79
+ env["RUST_LOGS"] = "INFO"
80
+ env["TORCH_CPP_LOG_LEVEL"] = "INFO"
81
+
82
+ # application timeouts
83
+ env["TORCHFT_QUORUM_TIMEOUT_SEC"] = "900"
84
+ env["TORCHFT_TIMEOUT_SEC"] = "600"
85
+ env["TORCHFT_QUORUM_RETRIES"] = "0"
86
+
87
+ env["TORCHFT_LIGHTHOUSE"] = os.environ.get(
88
+ "TORCHFT_LIGHTHOUSE", "http://slurm-head-node-0:29510"
89
+ )
90
+
91
+ env["WANDB_PROJECT"] = "torchft"
92
+
93
+ app = ddp(
94
+ *args,
95
+ name=f"ft_{replica_id}",
96
+ env=env,
97
+ script="./torchtitan/train.py",
98
+ gpu=cli_args.nproc_per_node,
99
+ j=f"{cli_args.nodes}x{cli_args.nproc_per_node}",
100
+ )
101
+ app.roles[0].name = app.name
102
+ return app
103
+
104
+
105
+ def start_replica(
106
+ runner: Runner, replica_id: int, args: argparse.Namespace
107
+ ) -> specs.AppHandle:
108
+ app = _make_app(replica_id, args)
109
+
110
+ app_handle = runner.run(
111
+ app,
112
+ scheduler=_SCHEDULER,
113
+ )
114
+
115
+ return app_handle
116
+
117
+
118
+ def monitor(runner: Runner, args: argparse.Namespace) -> None:
119
+ jobs = runner.list(_SCHEDULER)
120
+ jobs = [job for job in jobs if job.state == specs.AppState.RUNNING]
121
+
122
+ active_replicas = {}
123
+
124
+ for job in jobs:
125
+ if "ft_" not in job.name:
126
+ continue
127
+ name, _, _ = job.name.partition("-")
128
+ _, _, replica_id_str = name.partition("_")
129
+ replica_id = int(replica_id_str)
130
+ active_replicas[replica_id] = job
131
+
132
+ to_launch = set()
133
+ for replica_id in range(args.replica_count):
134
+ alive = replica_id in active_replicas
135
+
136
+ if alive:
137
+ job = active_replicas[replica_id]
138
+ print(f" - {replica_id=:2d}: ALIVE {job.app_handle}")
139
+ else:
140
+ print(f" - {replica_id=:2d}: DEAD")
141
+ to_launch.add(replica_id)
142
+
143
+ for replica_id in to_launch:
144
+ app_handle = start_replica(
145
+ runner,
146
+ replica_id,
147
+ args,
148
+ )
149
+ print(f"launched {replica_id=}: {app_handle=}")
150
+
151
+
152
+ def main() -> None:
153
+ parser = argparse.ArgumentParser(
154
+ description="CLI tool lauch data parallel replicas on slurm"
155
+ )
156
+
157
+ parser.add_argument(
158
+ "--workspace-dir", type=str, help="Location of torchtitan folder"
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--nodes",
163
+ type=int,
164
+ default=10,
165
+ help="Number of nodes per replica",
166
+ )
167
+
168
+ parser.add_argument(
169
+ "--nproc-per-node",
170
+ type=int,
171
+ default=10,
172
+ help="Number of ranks per node",
173
+ )
174
+
175
+ parser.add_argument(
176
+ "--replica-count",
177
+ type=int,
178
+ default=10,
179
+ help="Number of data parallel replicas",
180
+ )
181
+
182
+ parser.add_argument(
183
+ "--process-group",
184
+ type=str,
185
+ default="gloo",
186
+ help="The process group to use for data parallel",
187
+ )
188
+
189
+ parser.add_argument(
190
+ "--enable-semi-sync",
191
+ type=bool,
192
+ default=True,
193
+ help="Whether to enable semi-sync method for data parallel",
194
+ )
195
+
196
+ parser.add_argument(
197
+ "--semi-sync-method",
198
+ type=str,
199
+ default="diloco",
200
+ help="The semi-sync method to use for data parallel. Options: diloco, local_sgd",
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--num-fragments",
205
+ type=int,
206
+ default=2,
207
+ help="The number of fragments to use for data parallel. Only used for diloco semi-sync method",
208
+ )
209
+
210
+ args = parser.parse_args()
211
+
212
+ os.chdir(args.workspace_dir)
213
+
214
+ with get_runner() as runner:
215
+ while True:
216
+ monitor(runner, args)
217
+ time.sleep(10)
218
+
219
+
220
+ if __name__ == "__main__":
221
+ main()
torchft/fsdp_test.py ADDED
@@ -0,0 +1,102 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import multiprocessing
8
+ import os
9
+ import unittest
10
+ from concurrent.futures import ProcessPoolExecutor
11
+ from unittest.mock import Mock
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from torch import nn
16
+ from torch._C._distributed_c10d import ReduceOp
17
+ from torch.distributed._composable.fsdp import FSDPModule, fully_shard
18
+ from torch.distributed.tensor import init_device_mesh
19
+ from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
20
+
21
+ from torchft.manager import Manager
22
+ from torchft.process_group import ProcessGroupGloo
23
+
24
+
25
+ class FSDPTest(unittest.TestCase):
26
+ @staticmethod
27
+ def _test_fsdp(
28
+ world_size: int,
29
+ rank: int,
30
+ dp_replicate: int = 2,
31
+ dp_shard: int = 2,
32
+ tp: int = 1,
33
+ ) -> None:
34
+ torch.cuda.set_device(rank)
35
+
36
+ group_size = world_size // dp_replicate
37
+ group = rank // group_size
38
+ group_rank = rank % group_size
39
+
40
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
41
+ os.environ["MASTER_PORT"] = str(12346 + group)
42
+ os.environ["RANK"] = str(group_rank)
43
+ os.environ["WORLD_SIZE"] = str(group_size)
44
+
45
+ manager = Mock(spec=Manager)
46
+ pg: ProcessGroupGloo = Mock(spec=ProcessGroupGloo)
47
+ device_mesh = init_device_mesh(
48
+ device_type="cuda",
49
+ mesh_shape=(dp_shard, tp),
50
+ mesh_dim_names=("dp_shard", "tp"),
51
+ )
52
+ manager.num_participants.return_value = 1
53
+ model = nn.Linear(128, 128).cuda()
54
+ batch = torch.randn(4, 128).cuda()
55
+
56
+ fsdp_mesh = device_mesh["dp_shard"]
57
+
58
+ def all_reduce_hook(output: torch.Tensor) -> None:
59
+ dist.all_reduce(output, group=pg, op=ReduceOp.AVG)
60
+
61
+ def apply_set_all_reduce_hook(m: nn.Module) -> None:
62
+ assert isinstance(m, FSDPModule)
63
+ m.set_all_reduce_hook(all_reduce_hook)
64
+
65
+ if tp > 1:
66
+ tp_mesh = device_mesh["tp"]
67
+ model = parallelize_module(
68
+ model,
69
+ tp_mesh,
70
+ ColwiseParallel(),
71
+ )
72
+ shard_model = fully_shard(model, mesh=fsdp_mesh)
73
+ shard_model.apply(apply_set_all_reduce_hook)
74
+ shard_model(batch).mean().backward()
75
+
76
+ # pyre-ignore[56]: Pyre was not able to infer the type of argument
77
+ @unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
78
+ def test_fsdp(self) -> None:
79
+ context = multiprocessing.get_context("spawn")
80
+ with ProcessPoolExecutor(max_workers=4, mp_context=context) as executor:
81
+ futures = []
82
+ for i in range(4):
83
+ future = executor.submit(self._test_fsdp, 4, i)
84
+ futures.append(future)
85
+
86
+ for fut in futures:
87
+ fut.result()
88
+
89
+ # pyre-ignore[56]: Pyre was not able to infer the type of argument
90
+ @unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
91
+ def test_fsdp_tp(self) -> None:
92
+ context = multiprocessing.get_context("spawn")
93
+ with ProcessPoolExecutor(max_workers=4, mp_context=context) as executor:
94
+ futures = []
95
+ for i in range(4):
96
+ future = executor.submit(
97
+ self._test_fsdp, 4, i, dp_replicate=1, dp_shard=2, tp=2
98
+ )
99
+ futures.append(future)
100
+
101
+ for fut in futures:
102
+ fut.result()