torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/_test_utils.py ADDED
@@ -0,0 +1,111 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+
9
+
10
+ def any_nan(ts: list[torch.Tensor]) -> bool:
11
+ """
12
+ Check if any tensor in the list contains NaN values.
13
+
14
+ Args:
15
+ ts: List of tensors to check for NaN values
16
+
17
+ Returns:
18
+ True if any tensor contains NaN values, False otherwise
19
+ """
20
+ for t in ts:
21
+ if torch.isnan(t).any():
22
+ return True
23
+ return False
24
+
25
+
26
+ def combine_views(
27
+ views: list[list[tuple[int, ...]]],
28
+ combinations: list[list[tuple[int, ...]]],
29
+ tmp: list[tuple[int, ...]],
30
+ i: int,
31
+ ) -> None:
32
+ """
33
+ Recursively generate all possible combinations of views from a list of
34
+ lists of views.
35
+
36
+ This function uses backtracking to generate all possible combinations by
37
+ selecting each list in the input. The results are stored in the
38
+ combinations list.
39
+
40
+ Args:
41
+ views: A list of lists, where each inner list contains possible view
42
+ shapes (tuples)
43
+ combinations: Output list where all combinations will be stored
44
+ tmp: Temporary list to build the current combination
45
+ i: Current index in the views list being processed
46
+
47
+ Returns:
48
+ None. Results are stored in the combinations list passed as
49
+ an argument.
50
+ """
51
+ if i == len(views):
52
+ combinations.append(tmp.copy())
53
+ return
54
+
55
+ for j in range(len(views[i])):
56
+ tmp.append(views[i][j])
57
+ combine_views(views, combinations, tmp, i + 1)
58
+ tmp.pop()
59
+
60
+
61
+ def gen_views(inp: torch.Tensor) -> list[tuple[int, ...]]:
62
+ """
63
+ Generate all possible 2D views (shapes) for a tensor with a given number
64
+ of elements.
65
+
66
+ This function finds all pairs of integers (m, n) such that m * n equals the
67
+ total number of elements in the input tensor. These pairs represent possible
68
+ 2D shapes that the tensor can be reshaped into.
69
+
70
+ Args:
71
+ inp: Input tensor
72
+
73
+ Returns:
74
+ A list of tuples, where each tuple (m, n) represents a possible 2D shape
75
+ such that m * n equals the total number of elements in the input tensor
76
+ """
77
+ size = inp.numel()
78
+
79
+ views = []
80
+ for m in range(1 if size % 2 == 0 else 2, size):
81
+ if size % m == 0:
82
+ views.append((m, size // m))
83
+
84
+ return views
85
+
86
+
87
+ def gen_splits(inp: torch.Tensor, split_size: int) -> list[list[tuple[int, ...]]]:
88
+ """
89
+ Split a tensor into chunks and generate all possible combinations of views.
90
+
91
+ This function first splits the input tensor into chunks of the specified size,
92
+ then generates all possible 2D views for each chunk, and finally computes all
93
+ possible combinations of these views across all chunks.
94
+
95
+ Args:
96
+ inp: Input tensor to be split
97
+ split_size: Size of each chunk
98
+
99
+ Returns:
100
+ A list of lists, where each inner list contains a combination of view
101
+ shapes, one for each chunk of the input tensor
102
+ """
103
+ views = []
104
+
105
+ for split in torch.split(inp, split_size):
106
+ views.append(gen_views(split))
107
+
108
+ combinations = []
109
+ combine_views(views, combinations, [], 0)
110
+
111
+ return combinations
torchft/_torchft.pyi ADDED
@@ -0,0 +1,116 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from datetime import timedelta
9
+ from typing import Hashable, List, Optional
10
+
11
+ class ManagerClient:
12
+ def __init__(self, addr: str, connect_timeout: timedelta) -> None: ...
13
+ def _quorum(
14
+ self,
15
+ group_rank: int,
16
+ step: int,
17
+ checkpoint_metadata: str,
18
+ shrink_only: bool,
19
+ timeout: timedelta,
20
+ commit_failures: int,
21
+ init_sync: bool = True,
22
+ ) -> QuorumResult: ...
23
+ def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
24
+ def should_commit(
25
+ self,
26
+ group_rank: int,
27
+ step: int,
28
+ should_commit: bool,
29
+ timeout: timedelta,
30
+ ) -> bool: ...
31
+
32
+ class QuorumResult:
33
+ quorum_id: int
34
+ replica_rank: int
35
+ replica_world_size: int
36
+ recover_src_manager_address: str
37
+ recover_src_replica_rank: Optional[int]
38
+ recover_dst_replica_ranks: List[int]
39
+ store_address: str
40
+ max_step: int
41
+ max_replica_rank: Optional[int]
42
+ max_world_size: int
43
+ heal: bool
44
+ commit_failures: int
45
+ replica_ids: list[str]
46
+
47
+ class ManagerServer:
48
+ def __init__(
49
+ self,
50
+ replica_id: str,
51
+ lighthouse_addr: str,
52
+ hostname: str,
53
+ bind: str,
54
+ store_addr: str,
55
+ world_size: int,
56
+ heartbeat_interval: timedelta,
57
+ connect_timeout: timedelta,
58
+ quorum_retries: int,
59
+ ) -> None: ...
60
+ def address(self) -> str: ...
61
+ def shutdown(self) -> None: ...
62
+
63
+ class LighthouseServer:
64
+ def __init__(
65
+ self,
66
+ bind: str,
67
+ min_replicas: int,
68
+ join_timeout_ms: Optional[int] = None,
69
+ quorum_tick_ms: Optional[int] = None,
70
+ heartbeat_timeout_ms: Optional[int] = None,
71
+ ) -> None: ...
72
+ def address(self) -> str: ...
73
+ def shutdown(self) -> None: ...
74
+
75
+ @dataclass
76
+ class QuorumMember:
77
+ replica_id: str
78
+ address: str
79
+ store_address: str
80
+ step: int
81
+ world_size: int
82
+ shrink_only: bool
83
+ data: Optional[dict[Hashable, object]] = None
84
+
85
+ @dataclass
86
+ class Timestamp:
87
+ seconds: int
88
+ nanos: int
89
+
90
+ @dataclass
91
+ class Quorum:
92
+ quorum_id: str
93
+ participants: List[QuorumMember]
94
+ created: Timestamp
95
+
96
+ @dataclass
97
+ class LighthouseClient:
98
+ addr: str
99
+ connect_timeout: timedelta
100
+
101
+ def quorum(
102
+ self,
103
+ replica_id: str,
104
+ timeout: timedelta,
105
+ address: Optional[str] = None,
106
+ store_address: Optional[str] = None,
107
+ step: Optional[int] = None,
108
+ world_size: Optional[int] = None,
109
+ shrink_only: Optional[bool] = None,
110
+ data: Optional[dict[Hashable, object]] = None,
111
+ ) -> Quorum: ...
112
+ def heartbeat(
113
+ self,
114
+ replica_id: str,
115
+ timeout: timedelta = timedelta(seconds=5),
116
+ ) -> None: ...
@@ -0,0 +1,20 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Checkpointing
9
+ ==============
10
+
11
+ This module implements methods for checkpointing and resuming training from a checkpoint.
12
+ """
13
+
14
+ from torchft.checkpointing.http_transport import HTTPTransport
15
+ from torchft.checkpointing.transport import CheckpointTransport
16
+
17
+ __all__ = [
18
+ "HTTPTransport",
19
+ "CheckpointTransport",
20
+ ]
@@ -0,0 +1,136 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """rwlock.py
8
+
9
+ Adapted from: https://github.com/tylerneylon/rwlock/blob/main/rwlock.py
10
+
11
+ A class to implement read-write locks on top of the standard threading
12
+ library.
13
+
14
+ This is implemented with two mutexes (threading.Lock instances) as per this
15
+ wikipedia pseudocode:
16
+
17
+ https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_two_mutexes
18
+
19
+ __________________________
20
+ License info (MIT):
21
+
22
+ *******
23
+
24
+ Copyright 2023 Tyler Neylon and contributors
25
+
26
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
27
+ documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
28
+ rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
29
+ persons to whom the Software is furnished to do so, subject to the following conditions:
30
+
31
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
32
+
33
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
34
+ WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
35
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
36
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
37
+
38
+ *******
39
+ """
40
+
41
+ from contextlib import contextmanager
42
+ from threading import Lock
43
+ from typing import Generator
44
+
45
+
46
+ class RWLock(object):
47
+ """RWLock class; this is meant to allow an object to be read from by
48
+ multiple threads, but only written to by a single thread at a time. See:
49
+ https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock
50
+
51
+ All operations are timed and will throw TimeoutError if the timeout is
52
+ exceeded.
53
+
54
+ Usage:
55
+
56
+ from rwlock import RWLock
57
+
58
+ my_obj_rwlock = RWLock(timeout=60.0)
59
+
60
+ # When reading from my_obj:
61
+ with my_obj_rwlock.r_lock():
62
+ do_read_only_things_with(my_obj)
63
+
64
+ # When writing to my_obj:
65
+ with my_obj_rwlock.w_lock():
66
+ mutate(my_obj)
67
+ """
68
+
69
+ def __init__(self, timeout: float = -1) -> None:
70
+ self.timeout = timeout
71
+
72
+ self._w_lock = Lock()
73
+ self._num_r_lock = Lock()
74
+ self._num_r = 0
75
+
76
+ # ___________________________________________________________________
77
+ # Reading methods.
78
+
79
+ def r_acquire(self) -> None:
80
+ if not self._num_r_lock.acquire(timeout=self.timeout):
81
+ raise TimeoutError(
82
+ f"Timed out waiting for rlock after {self.timeout} seconds"
83
+ )
84
+
85
+ self._num_r += 1
86
+ if self._num_r == 1:
87
+ if not self._w_lock.acquire(timeout=self.timeout):
88
+ self._num_r -= 1
89
+ self._num_r_lock.release()
90
+ raise TimeoutError(
91
+ f"Timed out waiting for wlock after {self.timeout} seconds"
92
+ )
93
+
94
+ self._num_r_lock.release()
95
+
96
+ def r_release(self) -> None:
97
+ assert self._num_r > 0
98
+ self._num_r_lock.acquire()
99
+ self._num_r -= 1
100
+ if self._num_r == 0:
101
+ self._w_lock.release()
102
+ self._num_r_lock.release()
103
+
104
+ @contextmanager
105
+ def r_lock(self) -> Generator[None, None, None]:
106
+ """This method is designed to be used via the `with` statement."""
107
+ self.r_acquire()
108
+ try:
109
+ yield
110
+ finally:
111
+ self.r_release()
112
+
113
+ # ___________________________________________________________________
114
+ # Writing methods.
115
+
116
+ def w_acquire(self) -> None:
117
+ if not self._w_lock.acquire(timeout=self.timeout):
118
+ raise TimeoutError(
119
+ f"Timed out waiting for wlock after {self.timeout} seconds"
120
+ )
121
+
122
+ def w_release(self) -> None:
123
+ self._w_lock.release()
124
+
125
+ @contextmanager
126
+ def w_lock(self) -> Generator[None, None, None]:
127
+ """This method is designed to be used via the `with` statement."""
128
+ self.w_acquire()
129
+ try:
130
+ yield
131
+ finally:
132
+ self.w_release()
133
+
134
+ def w_locked(self) -> bool:
135
+ """Returns True if the lock is currently locked for reading."""
136
+ return self._w_lock.locked()
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import io
8
+ import warnings
9
+ from typing import IO
10
+
11
+ import torch
12
+
13
+
14
+ def _fallback_save(obj: object, f: IO[bytes]) -> None:
15
+ warnings.warn(
16
+ "using slow fallback torch.save implementation, please upgrade to PT 2.7+ for fast streaming saves"
17
+ )
18
+
19
+ torch.save(obj, f)
20
+
21
+
22
+ def _fallback_load(f: IO[bytes], weights_only: bool = True) -> object:
23
+ warnings.warn(
24
+ "using slow fallback torch.load implementation, please upgrade to PT 2.7+ for fast streaming loads"
25
+ )
26
+
27
+ # torch.load requires a seekable file object
28
+ buf = f.read()
29
+ reader = io.BytesIO(buf)
30
+
31
+ return torch.load(reader, weights_only=weights_only)
32
+
33
+
34
+ try:
35
+ # upgrade to PT 2.7 once released
36
+ from torch.distributed._serialization import _streaming_load, _streaming_save
37
+ except ImportError:
38
+ _streaming_load = _fallback_load
39
+ _streaming_save = _fallback_save