torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,299 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import socket
9
+ import threading
10
+ import time
11
+ import urllib.request
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from contextlib import contextmanager, nullcontext
14
+ from datetime import timedelta
15
+ from http.server import BaseHTTPRequestHandler
16
+ from typing import cast, Generator, List, Optional, TypeVar
17
+
18
+ import torch
19
+ from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
20
+
21
+ from torchft.checkpointing._rwlock import RWLock
22
+ from torchft.checkpointing._serialization import _streaming_load, _streaming_save
23
+ from torchft.checkpointing.transport import CheckpointTransport
24
+ from torchft.http import _IPv6HTTPServer
25
+
26
+ logger: logging.Logger = logging.getLogger(__name__)
27
+
28
+ T = TypeVar("T")
29
+
30
+
31
+ @contextmanager
32
+ def _time(desc: str) -> Generator[None, None, None]:
33
+ start = time.perf_counter()
34
+ yield
35
+ end = time.perf_counter()
36
+ logger.info(f"{desc} took {end - start}s")
37
+
38
+
39
+ class HTTPTransport(CheckpointTransport[T]):
40
+ """
41
+ This is an HTTP server that can be used to transfer checkpoints
42
+ between workers.
43
+
44
+ This allows for fast recovery of workers by fetching the current weights
45
+ from an existing worker.
46
+
47
+ Args:
48
+ timeout: the timeout for HTTP requests
49
+ num_chunks: the number of chunks to split the checkpoint into (0 for no chunking)
50
+ """
51
+
52
+ def __init__(self, timeout: timedelta, num_chunks: int) -> None:
53
+ self._checkpoint_lock = RWLock(timeout=timeout.total_seconds())
54
+ self._disallowed = False
55
+ self._step = -1
56
+ self._timeout = timeout
57
+ self._state_dict: Optional[T] = None
58
+ self._num_chunks = num_chunks
59
+ self._stream: Optional[torch.cuda.Stream] = (
60
+ torch.cuda.Stream() if torch.cuda.is_available() else None
61
+ )
62
+
63
+ # staged checkpoint information
64
+ self._spec: Optional[TreeSpec] = None
65
+ self._chunks: Optional[List[List[object]]] = None
66
+
67
+ # We don't allow checkpoints until the first send_checkpoint to avoid
68
+ # serving the default step=-1 invalid checkpoint.
69
+ self.disallow_checkpoint()
70
+
71
+ ckpt_server = self
72
+
73
+ class RequestHandler(BaseHTTPRequestHandler):
74
+ # set request socket timeout to avoid hanging forever
75
+ timeout = self._timeout.total_seconds()
76
+
77
+ def do_GET(self):
78
+ try:
79
+ # validate socket timeout is actually set
80
+ assert self.connection.gettimeout() == self.timeout
81
+
82
+ with ckpt_server._checkpoint_lock.r_lock():
83
+ step = ckpt_server._step
84
+
85
+ parts = self.path.split("/")
86
+ assert len(parts) == 4
87
+ if parts[1] != "checkpoint":
88
+ self.send_error(
89
+ 400,
90
+ f"invalid url format, expected /checkpoint/step/key but got {self.path}",
91
+ )
92
+ return
93
+
94
+ step = int(parts[2])
95
+ if step != ckpt_server._step:
96
+ self.send_error(
97
+ 400,
98
+ f"invalid checkpoint requested, serving {ckpt_server._step} but got {step=}",
99
+ )
100
+ return
101
+
102
+ key = parts[3]
103
+ if key == "full":
104
+ self.send_response(200)
105
+ self.send_header("Content-type", "application/octet-stream")
106
+ self.end_headers()
107
+
108
+ state_dict = ckpt_server._state_dict
109
+
110
+ _streaming_save(state_dict, self.wfile)
111
+ return
112
+
113
+ if key == "metadata":
114
+ self.send_response(200)
115
+ self.send_header("Content-type", "application/octet-stream")
116
+ self.end_headers()
117
+
118
+ _streaming_save(ckpt_server._spec, self.wfile)
119
+ else:
120
+ chunk = ckpt_server._chunks[int(key)]
121
+
122
+ self.send_response(200)
123
+ self.send_header("Content-type", "application/octet-stream")
124
+ self.end_headers()
125
+
126
+ _streaming_save(chunk, self.wfile)
127
+ except Exception as e:
128
+ logger.exception(
129
+ f"Exception in checkpoint server when handling {self.path=}: {e}",
130
+ )
131
+ self.send_error(500, str(e))
132
+
133
+ server_address = ("", 0)
134
+ self._server = _IPv6HTTPServer(server_address, RequestHandler)
135
+ logger.info(f"Started CheckpointServer on {self.address()}...")
136
+
137
+ self._thread = threading.Thread(
138
+ target=self._serve,
139
+ args=(),
140
+ daemon=True,
141
+ )
142
+ self._thread.start()
143
+
144
+ @classmethod
145
+ def _load_from_address(cls, address: str, timeout: timedelta) -> object:
146
+ """
147
+ Loads a checkpoint from the given address.
148
+
149
+ Args:
150
+ address: the HTTP address to load the checkpoint from
151
+ """
152
+ msg = f"fetching checkpoint from {address}"
153
+ logger.info(msg)
154
+
155
+ with (
156
+ _time(msg),
157
+ urllib.request.urlopen(address, timeout=timeout.total_seconds()) as f,
158
+ ):
159
+ # We have to set weights_only to False as there are some non-tensor
160
+ # states like lr_scheduler.
161
+ # pyre-fixme[16]: needs torch>=2.7
162
+ return cast(T, _streaming_load(f, weights_only=False))
163
+
164
+ def address(self) -> str:
165
+ """
166
+ Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address.
167
+
168
+ Format: http://host:port/checkpoint/1234
169
+
170
+ Returns:
171
+ an HTTP address
172
+ """
173
+ port = self._server.socket.getsockname()[1]
174
+ return f"http://{socket.gethostname()}:{port}/checkpoint/"
175
+
176
+ def _serve(self) -> None:
177
+ try:
178
+ self._server.serve_forever()
179
+ except Exception as e:
180
+ logger.exception("got exception in checkpoint server")
181
+
182
+ def disallow_checkpoint(self) -> None:
183
+ """
184
+ Disallows serving the checkpoint.
185
+
186
+ All requests will block until allow_checkpoint is called.
187
+ """
188
+ if not self._disallowed:
189
+ self._disallowed = True
190
+ self._checkpoint_lock.w_acquire()
191
+
192
+ def allow_checkpoint(self, step: int) -> None:
193
+ """
194
+ Allows serving the checkpoint with the specified step number.
195
+
196
+ Args:
197
+ step: the step number to serve
198
+ """
199
+ self._step = step
200
+
201
+ if self._disallowed:
202
+ self._disallowed = False
203
+ self._checkpoint_lock.w_release()
204
+
205
+ def shutdown(self, wait: bool = True) -> None:
206
+ """
207
+ Shutdown the server.
208
+ """
209
+ if not wait:
210
+ # hack for nonblocking shutdown of socketserver threads
211
+ # pyre-fixme[16]: no attribute `__shutdown_request`.
212
+ self._server.__shutdown_request = True
213
+ if wait:
214
+ self._server.shutdown()
215
+ self._thread.join()
216
+
217
+ def metadata(self) -> str:
218
+ return self.address()
219
+
220
+ def send_checkpoint(
221
+ self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
222
+ ) -> None:
223
+ values, spec = tree_flatten(state_dict)
224
+
225
+ with (
226
+ torch.cuda.stream(self._stream)
227
+ if self._stream is not None
228
+ else nullcontext()
229
+ ):
230
+ with _time("transferring state_dict to CPU"):
231
+ values = _to_cpu(values, pin_memory=False)
232
+ if self._stream is not None:
233
+ self._stream.synchronize()
234
+
235
+ # Unflatten so non-chunked transfer uses CPU tensors
236
+ self._state_dict = tree_unflatten(values, spec)
237
+
238
+ # Save spec for chunked
239
+ self._spec = spec
240
+ self._chunks = _split_chunks(values, self._num_chunks)
241
+
242
+ self.allow_checkpoint(step)
243
+
244
+ def recv_checkpoint(
245
+ self, src_rank: int, metadata: str, step: int, timeout: timedelta
246
+ ) -> T:
247
+ base_url = f"{metadata}{step}"
248
+ if self._num_chunks == 0:
249
+ return cast(T, self._load_from_address(f"{base_url}/full", timeout))
250
+ else:
251
+ urls = [f"{base_url}/metadata"] + [
252
+ f"{base_url}/{i}" for i in range(self._num_chunks)
253
+ ]
254
+
255
+ with ThreadPoolExecutor(max_workers=len(urls)) as executor:
256
+ futures = [
257
+ executor.submit(self._load_from_address, url, timeout)
258
+ for url in urls
259
+ ]
260
+
261
+ spec, *chunks = [future.result() for future in futures]
262
+ spec = cast(TreeSpec, spec)
263
+ chunks = cast(List[List[object]], chunks)
264
+
265
+ values = _merge_chunks(chunks, self._num_chunks)
266
+
267
+ return tree_unflatten(values, spec)
268
+
269
+
270
+ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
271
+ out = []
272
+ for v in values:
273
+ if isinstance(v, torch.Tensor):
274
+ if v.device.type == "cuda":
275
+ if pin_memory:
276
+ cpu = torch.empty(*tuple(v.size()), dtype=v.dtype, pin_memory=True)
277
+ cpu.copy_(v, non_blocking=True)
278
+ out.append(cpu)
279
+ else:
280
+ out.append(v.cpu())
281
+ else:
282
+ out.append(v)
283
+ else:
284
+ out.append(v)
285
+ return out
286
+
287
+
288
+ def _split_chunks(values: List[T], num_chunks: int) -> List[List[T]]:
289
+ return [values[i::num_chunks] for i in range(num_chunks)]
290
+
291
+
292
+ def _merge_chunks(chunks: List[List[T]], num_chunks: int) -> List[T]:
293
+ max_len = max(len(lst) for lst in chunks)
294
+ output_list = []
295
+ for i in range(max_len):
296
+ for lst in chunks:
297
+ if i < len(lst):
298
+ output_list.append(lst[i])
299
+ return output_list
@@ -0,0 +1,61 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import sys
9
+ from datetime import timedelta
10
+ from typing import List
11
+
12
+ import torch
13
+
14
+ from torchft.checkpointing.http_transport import _time, HTTPTransport
15
+
16
+ logger: logging.Logger = logging.getLogger(__name__)
17
+
18
+
19
+ def main(argv: List[str]) -> None:
20
+ import argparse
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument("--num-chunks", type=int, default=0)
26
+ parser.add_argument("--device", type=str, default="cpu")
27
+ parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB
28
+ parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB
29
+ args = parser.parse_args(argv)
30
+
31
+ device = torch.device(args.device)
32
+ num_chunks: int = args.num_chunks
33
+ CHUNK_SIZE = args.chunk_size
34
+ TOTAL_SIZE = args.total_size
35
+
36
+ transport = HTTPTransport(timedelta(seconds=60), num_chunks=num_chunks)
37
+ metadata = transport.metadata()
38
+
39
+ logger.info(f"creating state_dict... {CHUNK_SIZE=} {TOTAL_SIZE=}")
40
+
41
+ with _time("create state_dict"):
42
+ state_dict = {}
43
+ for i in range(0, TOTAL_SIZE, CHUNK_SIZE):
44
+ state_dict[f"chunk/{i}"] = torch.zeros(
45
+ CHUNK_SIZE // 4, dtype=torch.float32, device=device
46
+ )
47
+
48
+ logger.info(f"fetching from {metadata=} {device=} {num_chunks=} {len(state_dict)=}")
49
+
50
+ transport.send_checkpoint(
51
+ dst_ranks=[0], step=1, state_dict=state_dict, timeout=timedelta(seconds=60)
52
+ )
53
+
54
+ with _time("fetching checkpoint"):
55
+ transport.recv_checkpoint(
56
+ src_rank=1, metadata=metadata, step=1, timeout=timedelta(seconds=60)
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ main(sys.argv[1:])
@@ -0,0 +1,146 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import urllib.error
8
+ from datetime import timedelta
9
+ from typing import Dict
10
+ from unittest import skipUnless, TestCase
11
+ from unittest.mock import MagicMock
12
+
13
+ import torch
14
+ from parameterized import parameterized
15
+
16
+ from torchft.checkpointing.http_transport import HTTPTransport
17
+ from torchft.checkpointing.http_transport_bench import main as bench_main
18
+ from torchft.checkpointing.transport import CheckpointTransport
19
+ from torchft.checkpointing.transport_test import (
20
+ assertStateDictEqual,
21
+ run_multi_recovery_test,
22
+ )
23
+
24
+
25
+ class TestHTTPTransport(TestCase):
26
+ @parameterized.expand(
27
+ [
28
+ ("no chunks", 0),
29
+ ("chunked", 3),
30
+ ]
31
+ )
32
+ def test_checkpoint_server(self, name: str, num_chunks: int) -> None:
33
+ expected: Dict[str, object] = {
34
+ "state": "dict",
35
+ "tensor": torch.rand(5, 2),
36
+ "cuda": torch.rand(
37
+ 2, 3, device="cuda" if torch.cuda.is_available() else "cpu"
38
+ ),
39
+ }
40
+ state_dict_fn = MagicMock()
41
+ state_dict_fn.return_value = expected
42
+ server = HTTPTransport(
43
+ timeout=timedelta(seconds=10),
44
+ num_chunks=num_chunks,
45
+ )
46
+
47
+ server.send_checkpoint(
48
+ dst_ranks=[],
49
+ step=1234,
50
+ state_dict=expected,
51
+ timeout=timedelta(seconds=10),
52
+ )
53
+
54
+ metadata = server.metadata()
55
+
56
+ out = server.recv_checkpoint(
57
+ src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
58
+ )
59
+ assertStateDictEqual(self, out, expected)
60
+
61
+ # test timeout
62
+ with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"):
63
+ server.recv_checkpoint(
64
+ src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=0.0)
65
+ )
66
+
67
+ # test mismatch case
68
+ server.send_checkpoint(
69
+ dst_ranks=[],
70
+ step=2345,
71
+ state_dict=expected,
72
+ timeout=timedelta(seconds=10),
73
+ )
74
+
75
+ with self.assertRaisesRegex(
76
+ urllib.error.HTTPError, r"Error 400.*serving 2345 but got step=1234"
77
+ ):
78
+ server.recv_checkpoint(
79
+ src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
80
+ )
81
+
82
+ server.shutdown()
83
+
84
+ def test_checkpoint_server_locking(self) -> None:
85
+ server = HTTPTransport(
86
+ timeout=timedelta(seconds=10),
87
+ num_chunks=0,
88
+ )
89
+
90
+ # server should start up in a disallowed state this will block incoming
91
+ # requests until allow_checkpoint is called
92
+ self.assertTrue(server._checkpoint_lock.w_locked())
93
+ self.assertTrue(server._disallowed)
94
+ self.assertEqual(server._step, -1)
95
+
96
+ # allow requests
97
+ server.allow_checkpoint(1)
98
+
99
+ self.assertFalse(server._checkpoint_lock.w_locked())
100
+ self.assertFalse(server._disallowed)
101
+ self.assertEqual(server._step, 1)
102
+
103
+ # duplicate allow/disallow is fine
104
+ server.allow_checkpoint(2)
105
+ self.assertEqual(server._step, 2)
106
+
107
+ server.disallow_checkpoint()
108
+ server.disallow_checkpoint()
109
+ self.assertTrue(server._checkpoint_lock.w_locked())
110
+ self.assertTrue(server._disallowed)
111
+
112
+ server.shutdown()
113
+
114
+ def test_multi_http_transport_cpu(self) -> None:
115
+ device = torch.device("cpu")
116
+
117
+ def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
118
+ return HTTPTransport(
119
+ timeout=timedelta(seconds=10),
120
+ num_chunks=0,
121
+ )
122
+
123
+ run_multi_recovery_test(self, init, device=device)
124
+
125
+ # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
126
+ @skipUnless(torch.cuda.is_available(), "CUDA is not available")
127
+ def test_multi_http_transport_cuda(self) -> None:
128
+ device = torch.device("cuda")
129
+
130
+ def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
131
+ return HTTPTransport(
132
+ timeout=timedelta(seconds=10),
133
+ num_chunks=0,
134
+ )
135
+
136
+ run_multi_recovery_test(self, init, device=device)
137
+
138
+ def test_benchmark(self) -> None:
139
+ bench_main(
140
+ [
141
+ "--chunk-size=10",
142
+ "--num-chunks=0",
143
+ "--total-size=100",
144
+ "--device=cpu",
145
+ ]
146
+ )