torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,306 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import pickle
9
+ import time
10
+ from contextlib import contextmanager
11
+ from dataclasses import dataclass
12
+ from datetime import timedelta
13
+ from typing import Callable, cast, Generator, Optional, TypeVar, Union
14
+
15
+ import torch
16
+ from torch.distributed import Work
17
+ from torch.distributed.tensor import _DTensorSpec, DTensor
18
+ from torch.utils._pytree import (
19
+ KeyPath,
20
+ tree_flatten_with_path,
21
+ tree_unflatten,
22
+ TreeSpec,
23
+ )
24
+
25
+ from torchft.checkpointing.transport import CheckpointTransport
26
+ from torchft.process_group import ProcessGroup
27
+
28
+ logger: logging.Logger = logging.getLogger(__name__)
29
+
30
+ T = TypeVar("T")
31
+
32
+
33
+ @dataclass
34
+ class _TensorMeta:
35
+ """
36
+ This is the metadata for a tensor that is used to transfer checkpoints.
37
+ It contains the shape, the dtype, the storage offset and the stride of the
38
+ tensor.
39
+
40
+ This must be pickleable so that it can be sent over the wire.
41
+ """
42
+
43
+ shape: torch.Size
44
+ dtype: torch.dtype
45
+ storage_offset: int
46
+ stride: tuple[int, ...]
47
+ nbytes: int
48
+
49
+
50
+ @dataclass
51
+ class _DTensorMeta:
52
+ """
53
+ This is the metadata for a DTensor that is used to transfer checkpoints.
54
+ It contains the metadata for the local tensor and the spec of the DTensor.
55
+
56
+ This must be pickleable so that it can be sent over the wire.
57
+ """
58
+
59
+ local: _TensorMeta
60
+ spec: _DTensorSpec
61
+
62
+
63
+ @dataclass
64
+ class _StateDictMeta:
65
+ """
66
+ This is the metadata for a state dict that is used to transfer checkpoints.
67
+ It contains the step, the pytree spec of the state dict and the metadata for
68
+ each tensor in the state dict.
69
+
70
+ This must be pickleable so that it can be sent over the wire.
71
+
72
+ Args:
73
+ step: the step of the checkpoint to verify consistency
74
+ treespec: the pytree spec of the state dict
75
+ paths: the path of each leaf in the state dict
76
+ non_tensor_leaves: the metadata for each tensor in the state dict and any
77
+ non-tensor leaves in the state dict
78
+ """
79
+
80
+ step: int
81
+ treespec: TreeSpec
82
+ paths: list[KeyPath]
83
+ non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]]
84
+
85
+
86
+ @contextmanager
87
+ def _timeit(name: str) -> Generator[None, None, None]:
88
+ start = time.perf_counter()
89
+ yield
90
+ dur = time.perf_counter() - start
91
+ logger.info(f"{name} took {dur}s")
92
+
93
+
94
+ def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]:
95
+ return (
96
+ _cast_tensor(tensor, torch.uint8),
97
+ _TensorMeta(
98
+ shape=tensor.shape,
99
+ dtype=tensor.dtype,
100
+ storage_offset=cast(int, tensor.storage_offset()),
101
+ stride=tensor.stride(),
102
+ nbytes=tensor.untyped_storage().nbytes(),
103
+ ),
104
+ )
105
+
106
+
107
+ def _prepare_state_dict(
108
+ state_dict: object,
109
+ step: int,
110
+ device: torch.device,
111
+ ) -> tuple[_StateDictMeta, list[torch.Tensor]]:
112
+ leaves: list[tuple[KeyPath, object]]
113
+ leaves, treespec = tree_flatten_with_path(state_dict)
114
+
115
+ paths: list[KeyPath] = []
116
+ non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = []
117
+ tensors: list[torch.Tensor] = []
118
+ for key_path, v in leaves:
119
+ paths.append(key_path)
120
+
121
+ if isinstance(v, DTensor):
122
+ tensor, tensor_meta = _prepare_tensor(v._local_tensor)
123
+
124
+ tensors.append(tensor)
125
+
126
+ non_tensor_leaves.append(
127
+ _DTensorMeta(
128
+ local=tensor_meta,
129
+ spec=v._spec,
130
+ )
131
+ )
132
+ elif isinstance(v, torch.Tensor):
133
+ tensor, tensor_meta = _prepare_tensor(v)
134
+ tensors.append(tensor)
135
+ non_tensor_leaves.append(tensor_meta)
136
+ else:
137
+ non_tensor_leaves.append(v)
138
+
139
+ return (
140
+ _StateDictMeta(
141
+ step=step,
142
+ treespec=treespec,
143
+ paths=paths,
144
+ non_tensor_leaves=non_tensor_leaves,
145
+ ),
146
+ tensors,
147
+ )
148
+
149
+
150
+ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
151
+ """
152
+ Casts the underlying storage to a tensor of the given dtype.
153
+
154
+ The returned tensor will be of size ``storage.nbytes``.
155
+
156
+ This works for all datatypes and supports strided/offset tensors with the
157
+ caveat that the cast tensor may be larger than the original tensor due to
158
+ the differences in striding.
159
+ """
160
+ assert (
161
+ type(tensor) is torch.Tensor
162
+ ), f"can only cast standard tensors not {type(tensor)}"
163
+ storage = tensor.untyped_storage()
164
+ ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
165
+ assert ret.untyped_storage() is storage, "storage should be the same"
166
+ return ret
167
+
168
+
169
+ class PGTransport(CheckpointTransport[T]):
170
+ """
171
+ This is a checkpoint transport that uses the process group to transfer checkpoints.
172
+ This allows for fast recovery of workers by fetching the current weights
173
+ from an existing worker.
174
+
175
+ Args:
176
+ pg: the process group to use for communication
177
+ timeout: the timeout for communication
178
+ device: the device to use for tensors
179
+ state_dict: if specified this function will be called to do an inplace
180
+ receive into the returned state_dict. This is much faster than
181
+ having to allocate new tensors and transferring them to the CPU.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ pg: ProcessGroup,
187
+ timeout: timedelta,
188
+ device: torch.device,
189
+ state_dict: Optional[Callable[[], object]] = None,
190
+ ) -> None:
191
+ self._work: list[Work] = []
192
+ self._pg = pg
193
+ self._timeout = timeout
194
+ self._device = device
195
+ self._state_dict = state_dict
196
+
197
+ def metadata(self) -> str:
198
+ return "<n/a>"
199
+
200
+ def disallow_checkpoint(self) -> None:
201
+ pass
202
+
203
+ def send_checkpoint(
204
+ self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta
205
+ ) -> None:
206
+ with _timeit("preparing state_dict"):
207
+ meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)
208
+
209
+ work = []
210
+
211
+ with _timeit("send pickle"):
212
+ buf = pickle.dumps(meta)
213
+ len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device)
214
+ buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device)
215
+ for dst_rank in dst_ranks:
216
+ work.append(self._pg.send([len_t], dst_rank, tag=1))
217
+ work.append(self._pg.send([buf_t], dst_rank, tag=2))
218
+
219
+ with _timeit("send tensors"):
220
+ for i, t in enumerate(tensors):
221
+ original_device = t.device
222
+ t = t.to(self._device)
223
+ for dst_rank in dst_ranks:
224
+ work.append(self._pg.send([t], dst_rank, tag=3 + i))
225
+
226
+ # if we did a copy we should wait for the work to complete so we
227
+ # can free the memory to avoid OOMs
228
+ if original_device == torch.device("cpu"):
229
+ for w in work:
230
+ w.wait(timeout)
231
+ work = []
232
+
233
+ for w in work:
234
+ w.wait(timeout)
235
+
236
+ def recv_checkpoint(
237
+ self, src_rank: int, metadata: str, step: int, timeout: timedelta
238
+ ) -> T:
239
+ state_dict = self._state_dict() if self._state_dict else {}
240
+ state_dict_leaves, _ = tree_flatten_with_path(state_dict)
241
+
242
+ dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves)
243
+
244
+ len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
245
+ self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
246
+ length = cast(int, len_t.item())
247
+
248
+ assert length > 0, f"invalid metadata length {length=}"
249
+
250
+ buf = torch.empty(length, dtype=torch.uint8, device=self._device)
251
+ self._pg.recv([buf], src_rank, tag=2).wait(timeout)
252
+
253
+ meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes())
254
+ assert meta.step == step
255
+
256
+ i: int = 0
257
+ works: list[Work] = []
258
+
259
+ def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor:
260
+ nonlocal i
261
+
262
+ inplace = dst_tensors.get(path)
263
+ if (
264
+ isinstance(inplace, torch.Tensor)
265
+ and inplace.device.type == self._device.type
266
+ ):
267
+ if isinstance(inplace, DTensor):
268
+ inplace = inplace._local_tensor
269
+ t = _cast_tensor(inplace, torch.uint8)
270
+ assert (
271
+ t.nbytes == v.nbytes
272
+ ), "inplace tensor storage must be the same size"
273
+ else:
274
+ t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
275
+
276
+ work = self._pg.recv([t], src_rank, tag=3 + i)
277
+ i += 1
278
+
279
+ if inplace is None:
280
+ # if not inplace we need to copy it to CPU to avoid OOMing
281
+ work.wait(timeout)
282
+ t = t.cpu()
283
+ else:
284
+ works.append(work)
285
+
286
+ return torch.as_strided(
287
+ t.view(v.dtype),
288
+ size=v.shape,
289
+ stride=v.stride,
290
+ storage_offset=v.storage_offset,
291
+ )
292
+
293
+ values = []
294
+ for path, v in zip(meta.paths, meta.non_tensor_leaves):
295
+ if isinstance(v, _TensorMeta):
296
+ values.append(recv(path, v))
297
+ elif isinstance(v, _DTensorMeta):
298
+ tensor = recv(path, v.local)
299
+ values.append(DTensor(tensor, v.spec, requires_grad=False))
300
+ else:
301
+ values.append(v)
302
+
303
+ for work in works:
304
+ work.wait(timeout)
305
+
306
+ return tree_unflatten(values, meta.treespec)
@@ -0,0 +1,99 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import sys
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from datetime import timedelta
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+
15
+ from torchft.checkpointing.pg_transport import _timeit, PGTransport
16
+ from torchft.process_group import ProcessGroupBabyNCCL
17
+
18
+ logger: logging.Logger = logging.getLogger(__name__)
19
+
20
+
21
+ def main(argv: list[str]) -> None:
22
+ import argparse
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--inplace", action="store_true")
28
+ parser.add_argument("--device", type=str, default="cpu")
29
+ parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB
30
+ parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB
31
+ args = parser.parse_args(argv)
32
+
33
+ CHUNK_SIZE: int = args.chunk_size
34
+ TOTAL_SIZE: int = args.total_size
35
+ INPLACE: bool = args.inplace
36
+ DEVICE: str = args.device
37
+
38
+ timeout: timedelta = timedelta(seconds=10)
39
+
40
+ store = dist.TCPStore(
41
+ "localhost",
42
+ 0,
43
+ is_master=True,
44
+ timeout=timeout,
45
+ wait_for_workers=False,
46
+ )
47
+ store_addr: str = f"localhost:{store.port}"
48
+
49
+ def run(rank: int) -> None:
50
+ torch.cuda.set_device(rank)
51
+
52
+ device = torch.device(DEVICE)
53
+
54
+ with _timeit("init_pg"):
55
+ pg = ProcessGroupBabyNCCL(timeout=timeout)
56
+ pg.configure(store_addr=store_addr, replica_id="0", rank=rank, world_size=2)
57
+
58
+ t = torch.zeros(10, device=device, dtype=torch.float32)
59
+ pg.allreduce([t], dist.ReduceOp.SUM).wait(timeout=timeout)
60
+
61
+ with _timeit("create state_dict"):
62
+ state_dict: dict[str, torch.Tensor] = {}
63
+ for i in range(0, TOTAL_SIZE, CHUNK_SIZE):
64
+ state_dict[f"chunk/{i}"] = torch.zeros(
65
+ CHUNK_SIZE // 4, dtype=torch.float32, device=device
66
+ )
67
+
68
+ def get_state_dict() -> object:
69
+ return state_dict
70
+
71
+ transport = PGTransport(
72
+ pg=pg,
73
+ timeout=timeout,
74
+ device=device,
75
+ state_dict=get_state_dict if INPLACE else None,
76
+ )
77
+ metadata = transport.metadata()
78
+
79
+ if rank == 0:
80
+ with _timeit("send_checkpoint"):
81
+ transport.send_checkpoint(
82
+ dst_ranks=[1],
83
+ step=1,
84
+ state_dict=state_dict,
85
+ timeout=timedelta(seconds=60),
86
+ )
87
+ elif rank == 1:
88
+ with _timeit("recv_checkpoint"):
89
+ transport.recv_checkpoint(
90
+ src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=60)
91
+ )
92
+
93
+ with ThreadPoolExecutor(max_workers=2) as executor:
94
+ results = executor.map(run, range(2))
95
+ list(results)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main(sys.argv[1:])
@@ -0,0 +1,101 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import sys
8
+ from datetime import timedelta
9
+ from unittest import skipIf, skipUnless, TestCase
10
+
11
+ import torch
12
+ from torch.distributed import TCPStore
13
+
14
+ from torchft.checkpointing.pg_transport import PGTransport
15
+ from torchft.checkpointing.transport import CheckpointTransport
16
+ from torchft.checkpointing.transport_test import (
17
+ make_state_dict,
18
+ run_multi_recovery_test,
19
+ )
20
+ from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
21
+
22
+
23
+ class PGTransportTest(TestCase):
24
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
25
+ @skipIf(sys.platform == "darwin", "not passing on mac")
26
+ def test_pg_transport_gloo(self) -> None:
27
+ store: TCPStore = TCPStore(
28
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
29
+ )
30
+ device: torch.device = torch.device("cpu")
31
+
32
+ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
33
+ pg = ProcessGroupGloo()
34
+ pg.configure(
35
+ store_addr=f"localhost:{store.port}/prefix",
36
+ replica_id="0",
37
+ rank=rank,
38
+ world_size=world_size,
39
+ )
40
+
41
+ return PGTransport[dict[str, object]](
42
+ pg, timeout=timedelta(seconds=10), device=device
43
+ )
44
+
45
+ run_multi_recovery_test(self, init, device=device)
46
+
47
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
48
+ @skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
49
+ def test_pg_transport_baby_nccl(self) -> None:
50
+ store: TCPStore = TCPStore(
51
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
52
+ )
53
+ device: torch.device = torch.device("cuda")
54
+ timeout: timedelta = timedelta(seconds=10)
55
+
56
+ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
57
+ torch.cuda.set_device(rank)
58
+
59
+ pg = ProcessGroupBabyNCCL(timeout=timeout)
60
+ pg.configure(
61
+ store_addr=f"localhost:{store.port}/prefix",
62
+ replica_id="0",
63
+ rank=rank,
64
+ world_size=world_size,
65
+ )
66
+
67
+ return PGTransport[dict[str, object]](pg, timeout=timeout, device=device)
68
+
69
+ run_multi_recovery_test(self, init, device=device)
70
+
71
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
72
+ @skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
73
+ def test_pg_transport_baby_nccl_inplace(self) -> None:
74
+ store: TCPStore = TCPStore(
75
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
76
+ )
77
+ device: torch.device = torch.device("cuda")
78
+ timeout: timedelta = timedelta(seconds=10)
79
+
80
+ def state_dict() -> dict[str, object]:
81
+ return make_state_dict(device)
82
+
83
+ def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
84
+ torch.cuda.set_device(rank)
85
+
86
+ pg = ProcessGroupBabyNCCL(timeout=timeout)
87
+ pg.configure(
88
+ store_addr=f"localhost:{store.port}/prefix",
89
+ replica_id="0",
90
+ rank=rank,
91
+ world_size=world_size,
92
+ )
93
+
94
+ return PGTransport[dict[str, object]](
95
+ pg,
96
+ timeout=timeout,
97
+ device=device,
98
+ state_dict=state_dict,
99
+ )
100
+
101
+ run_multi_recovery_test(self, init, device=device)
@@ -0,0 +1,58 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import pytest
8
+
9
+ from torchft.checkpointing._rwlock import RWLock
10
+
11
+
12
+ def test_w_locked() -> None:
13
+ lock = RWLock()
14
+
15
+ with lock.w_lock():
16
+ assert lock.w_locked()
17
+ assert not lock.w_locked()
18
+
19
+
20
+ def test_w_lock_timeout() -> None:
21
+ lock = RWLock(timeout=0.01)
22
+
23
+ lock.r_acquire()
24
+ lock.r_acquire()
25
+
26
+ with pytest.raises(TimeoutError):
27
+ lock.w_acquire()
28
+
29
+ with pytest.raises(TimeoutError):
30
+ with lock.w_lock():
31
+ pass
32
+
33
+ lock.r_release()
34
+ with pytest.raises(TimeoutError):
35
+ lock.w_acquire()
36
+
37
+ lock.r_release()
38
+ with lock.w_lock():
39
+ pass
40
+ lock.w_acquire()
41
+
42
+
43
+ def test_r_lock_timeout() -> None:
44
+ lock = RWLock(timeout=0.01)
45
+
46
+ lock.w_acquire()
47
+
48
+ with pytest.raises(TimeoutError):
49
+ lock.r_acquire()
50
+
51
+ with pytest.raises(TimeoutError):
52
+ with lock.r_lock():
53
+ pass
54
+
55
+ lock.w_release()
56
+ with lock.r_lock():
57
+ pass
58
+ lock.r_acquire()
@@ -0,0 +1,68 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ from datetime import timedelta
9
+ from typing import Generic, List, TypeVar
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ class CheckpointTransport(Generic[T], ABC):
15
+ @abstractmethod
16
+ def metadata(self) -> str:
17
+ """
18
+ Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
19
+ """
20
+ ...
21
+
22
+ @abstractmethod
23
+ def send_checkpoint(
24
+ self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
25
+ ) -> None:
26
+ """
27
+ Sends the checkpoint, only called when there is a rank that is behind.
28
+
29
+ This may be async.
30
+
31
+ Args:
32
+ dst_ranks: the ranks to send to
33
+ step: the step number to send
34
+ state_dict: the state dict to send
35
+ timeout: the timeout to wait for the checkpoint to be sent
36
+ """
37
+ ...
38
+
39
+ def disallow_checkpoint(self) -> None:
40
+ """
41
+ Called after send_checkpoint to wait for the checkpoint to be sent.
42
+
43
+ Once this returns, the state_dict may be mutated so no further data should be sent.
44
+ """
45
+ ...
46
+
47
+ @abstractmethod
48
+ def recv_checkpoint(
49
+ self, src_rank: int, metadata: str, step: int, timeout: timedelta
50
+ ) -> T:
51
+ """
52
+ Receives the checkpoint from the given rank.
53
+
54
+ Args:
55
+ src_rank: the rank to receive the checkpoint from
56
+ metadata: the metadata returned by the remote CheckpointTransport
57
+ step: the step number to receive
58
+ timeout: the timeout to wait for the checkpoint
59
+ """
60
+ ...
61
+
62
+ def shutdown(self, wait: bool = True) -> None:
63
+ """
64
+ Called to shutdown the checkpoint transport.
65
+
66
+ Args:
67
+ wait: whether to wait for the transport to shutdown
68
+ """