torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/_test_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def any_nan(ts: list[torch.Tensor]) -> bool:
|
|
11
|
+
"""
|
|
12
|
+
Check if any tensor in the list contains NaN values.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
ts: List of tensors to check for NaN values
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
True if any tensor contains NaN values, False otherwise
|
|
19
|
+
"""
|
|
20
|
+
for t in ts:
|
|
21
|
+
if torch.isnan(t).any():
|
|
22
|
+
return True
|
|
23
|
+
return False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def combine_views(
|
|
27
|
+
views: list[list[tuple[int, ...]]],
|
|
28
|
+
combinations: list[list[tuple[int, ...]]],
|
|
29
|
+
tmp: list[tuple[int, ...]],
|
|
30
|
+
i: int,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""
|
|
33
|
+
Recursively generate all possible combinations of views from a list of
|
|
34
|
+
lists of views.
|
|
35
|
+
|
|
36
|
+
This function uses backtracking to generate all possible combinations by
|
|
37
|
+
selecting each list in the input. The results are stored in the
|
|
38
|
+
combinations list.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
views: A list of lists, where each inner list contains possible view
|
|
42
|
+
shapes (tuples)
|
|
43
|
+
combinations: Output list where all combinations will be stored
|
|
44
|
+
tmp: Temporary list to build the current combination
|
|
45
|
+
i: Current index in the views list being processed
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
None. Results are stored in the combinations list passed as
|
|
49
|
+
an argument.
|
|
50
|
+
"""
|
|
51
|
+
if i == len(views):
|
|
52
|
+
combinations.append(tmp.copy())
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
for j in range(len(views[i])):
|
|
56
|
+
tmp.append(views[i][j])
|
|
57
|
+
combine_views(views, combinations, tmp, i + 1)
|
|
58
|
+
tmp.pop()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def gen_views(inp: torch.Tensor) -> list[tuple[int, ...]]:
|
|
62
|
+
"""
|
|
63
|
+
Generate all possible 2D views (shapes) for a tensor with a given number
|
|
64
|
+
of elements.
|
|
65
|
+
|
|
66
|
+
This function finds all pairs of integers (m, n) such that m * n equals the
|
|
67
|
+
total number of elements in the input tensor. These pairs represent possible
|
|
68
|
+
2D shapes that the tensor can be reshaped into.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
inp: Input tensor
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
A list of tuples, where each tuple (m, n) represents a possible 2D shape
|
|
75
|
+
such that m * n equals the total number of elements in the input tensor
|
|
76
|
+
"""
|
|
77
|
+
size = inp.numel()
|
|
78
|
+
|
|
79
|
+
views = []
|
|
80
|
+
for m in range(1 if size % 2 == 0 else 2, size):
|
|
81
|
+
if size % m == 0:
|
|
82
|
+
views.append((m, size // m))
|
|
83
|
+
|
|
84
|
+
return views
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def gen_splits(inp: torch.Tensor, split_size: int) -> list[list[tuple[int, ...]]]:
|
|
88
|
+
"""
|
|
89
|
+
Split a tensor into chunks and generate all possible combinations of views.
|
|
90
|
+
|
|
91
|
+
This function first splits the input tensor into chunks of the specified size,
|
|
92
|
+
then generates all possible 2D views for each chunk, and finally computes all
|
|
93
|
+
possible combinations of these views across all chunks.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
inp: Input tensor to be split
|
|
97
|
+
split_size: Size of each chunk
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
A list of lists, where each inner list contains a combination of view
|
|
101
|
+
shapes, one for each chunk of the input tensor
|
|
102
|
+
"""
|
|
103
|
+
views = []
|
|
104
|
+
|
|
105
|
+
for split in torch.split(inp, split_size):
|
|
106
|
+
views.append(gen_views(split))
|
|
107
|
+
|
|
108
|
+
combinations = []
|
|
109
|
+
combine_views(views, combinations, [], 0)
|
|
110
|
+
|
|
111
|
+
return combinations
|
|
Binary file
|
torchft/_torchft.pyi
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from typing import Hashable, List, Optional
|
|
10
|
+
|
|
11
|
+
class ManagerClient:
|
|
12
|
+
def __init__(self, addr: str, connect_timeout: timedelta) -> None: ...
|
|
13
|
+
def _quorum(
|
|
14
|
+
self,
|
|
15
|
+
group_rank: int,
|
|
16
|
+
step: int,
|
|
17
|
+
checkpoint_metadata: str,
|
|
18
|
+
shrink_only: bool,
|
|
19
|
+
timeout: timedelta,
|
|
20
|
+
commit_failures: int,
|
|
21
|
+
init_sync: bool = True,
|
|
22
|
+
) -> QuorumResult: ...
|
|
23
|
+
def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
|
|
24
|
+
def should_commit(
|
|
25
|
+
self,
|
|
26
|
+
group_rank: int,
|
|
27
|
+
step: int,
|
|
28
|
+
should_commit: bool,
|
|
29
|
+
timeout: timedelta,
|
|
30
|
+
) -> bool: ...
|
|
31
|
+
|
|
32
|
+
class QuorumResult:
|
|
33
|
+
quorum_id: int
|
|
34
|
+
replica_rank: int
|
|
35
|
+
replica_world_size: int
|
|
36
|
+
recover_src_manager_address: str
|
|
37
|
+
recover_src_replica_rank: Optional[int]
|
|
38
|
+
recover_dst_replica_ranks: List[int]
|
|
39
|
+
store_address: str
|
|
40
|
+
max_step: int
|
|
41
|
+
max_replica_rank: Optional[int]
|
|
42
|
+
max_world_size: int
|
|
43
|
+
heal: bool
|
|
44
|
+
commit_failures: int
|
|
45
|
+
replica_ids: list[str]
|
|
46
|
+
|
|
47
|
+
class ManagerServer:
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
replica_id: str,
|
|
51
|
+
lighthouse_addr: str,
|
|
52
|
+
hostname: str,
|
|
53
|
+
bind: str,
|
|
54
|
+
store_addr: str,
|
|
55
|
+
world_size: int,
|
|
56
|
+
heartbeat_interval: timedelta,
|
|
57
|
+
connect_timeout: timedelta,
|
|
58
|
+
quorum_retries: int,
|
|
59
|
+
) -> None: ...
|
|
60
|
+
def address(self) -> str: ...
|
|
61
|
+
def shutdown(self) -> None: ...
|
|
62
|
+
|
|
63
|
+
class LighthouseServer:
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
bind: str,
|
|
67
|
+
min_replicas: int,
|
|
68
|
+
join_timeout_ms: Optional[int] = None,
|
|
69
|
+
quorum_tick_ms: Optional[int] = None,
|
|
70
|
+
heartbeat_timeout_ms: Optional[int] = None,
|
|
71
|
+
) -> None: ...
|
|
72
|
+
def address(self) -> str: ...
|
|
73
|
+
def shutdown(self) -> None: ...
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class QuorumMember:
|
|
77
|
+
replica_id: str
|
|
78
|
+
address: str
|
|
79
|
+
store_address: str
|
|
80
|
+
step: int
|
|
81
|
+
world_size: int
|
|
82
|
+
shrink_only: bool
|
|
83
|
+
data: Optional[dict[Hashable, object]] = None
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class Timestamp:
|
|
87
|
+
seconds: int
|
|
88
|
+
nanos: int
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class Quorum:
|
|
92
|
+
quorum_id: str
|
|
93
|
+
participants: List[QuorumMember]
|
|
94
|
+
created: Timestamp
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class LighthouseClient:
|
|
98
|
+
addr: str
|
|
99
|
+
connect_timeout: timedelta
|
|
100
|
+
|
|
101
|
+
def quorum(
|
|
102
|
+
self,
|
|
103
|
+
replica_id: str,
|
|
104
|
+
timeout: timedelta,
|
|
105
|
+
address: Optional[str] = None,
|
|
106
|
+
store_address: Optional[str] = None,
|
|
107
|
+
step: Optional[int] = None,
|
|
108
|
+
world_size: Optional[int] = None,
|
|
109
|
+
shrink_only: Optional[bool] = None,
|
|
110
|
+
data: Optional[dict[Hashable, object]] = None,
|
|
111
|
+
) -> Quorum: ...
|
|
112
|
+
def heartbeat(
|
|
113
|
+
self,
|
|
114
|
+
replica_id: str,
|
|
115
|
+
timeout: timedelta = timedelta(seconds=5),
|
|
116
|
+
) -> None: ...
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Checkpointing
|
|
9
|
+
==============
|
|
10
|
+
|
|
11
|
+
This module implements methods for checkpointing and resuming training from a checkpoint.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from torchft.checkpointing.http_transport import HTTPTransport
|
|
15
|
+
from torchft.checkpointing.transport import CheckpointTransport
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"HTTPTransport",
|
|
19
|
+
"CheckpointTransport",
|
|
20
|
+
]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""rwlock.py
|
|
8
|
+
|
|
9
|
+
Adapted from: https://github.com/tylerneylon/rwlock/blob/main/rwlock.py
|
|
10
|
+
|
|
11
|
+
A class to implement read-write locks on top of the standard threading
|
|
12
|
+
library.
|
|
13
|
+
|
|
14
|
+
This is implemented with two mutexes (threading.Lock instances) as per this
|
|
15
|
+
wikipedia pseudocode:
|
|
16
|
+
|
|
17
|
+
https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_two_mutexes
|
|
18
|
+
|
|
19
|
+
__________________________
|
|
20
|
+
License info (MIT):
|
|
21
|
+
|
|
22
|
+
*******
|
|
23
|
+
|
|
24
|
+
Copyright 2023 Tyler Neylon and contributors
|
|
25
|
+
|
|
26
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
|
|
27
|
+
documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
|
|
28
|
+
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
|
|
29
|
+
persons to whom the Software is furnished to do so, subject to the following conditions:
|
|
30
|
+
|
|
31
|
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
|
32
|
+
|
|
33
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
|
|
34
|
+
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
|
35
|
+
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
|
36
|
+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
37
|
+
|
|
38
|
+
*******
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
from contextlib import contextmanager
|
|
42
|
+
from threading import Lock
|
|
43
|
+
from typing import Generator
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class RWLock(object):
|
|
47
|
+
"""RWLock class; this is meant to allow an object to be read from by
|
|
48
|
+
multiple threads, but only written to by a single thread at a time. See:
|
|
49
|
+
https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock
|
|
50
|
+
|
|
51
|
+
All operations are timed and will throw TimeoutError if the timeout is
|
|
52
|
+
exceeded.
|
|
53
|
+
|
|
54
|
+
Usage:
|
|
55
|
+
|
|
56
|
+
from rwlock import RWLock
|
|
57
|
+
|
|
58
|
+
my_obj_rwlock = RWLock(timeout=60.0)
|
|
59
|
+
|
|
60
|
+
# When reading from my_obj:
|
|
61
|
+
with my_obj_rwlock.r_lock():
|
|
62
|
+
do_read_only_things_with(my_obj)
|
|
63
|
+
|
|
64
|
+
# When writing to my_obj:
|
|
65
|
+
with my_obj_rwlock.w_lock():
|
|
66
|
+
mutate(my_obj)
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self, timeout: float = -1) -> None:
|
|
70
|
+
self.timeout = timeout
|
|
71
|
+
|
|
72
|
+
self._w_lock = Lock()
|
|
73
|
+
self._num_r_lock = Lock()
|
|
74
|
+
self._num_r = 0
|
|
75
|
+
|
|
76
|
+
# ___________________________________________________________________
|
|
77
|
+
# Reading methods.
|
|
78
|
+
|
|
79
|
+
def r_acquire(self) -> None:
|
|
80
|
+
if not self._num_r_lock.acquire(timeout=self.timeout):
|
|
81
|
+
raise TimeoutError(
|
|
82
|
+
f"Timed out waiting for rlock after {self.timeout} seconds"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
self._num_r += 1
|
|
86
|
+
if self._num_r == 1:
|
|
87
|
+
if not self._w_lock.acquire(timeout=self.timeout):
|
|
88
|
+
self._num_r -= 1
|
|
89
|
+
self._num_r_lock.release()
|
|
90
|
+
raise TimeoutError(
|
|
91
|
+
f"Timed out waiting for wlock after {self.timeout} seconds"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self._num_r_lock.release()
|
|
95
|
+
|
|
96
|
+
def r_release(self) -> None:
|
|
97
|
+
assert self._num_r > 0
|
|
98
|
+
self._num_r_lock.acquire()
|
|
99
|
+
self._num_r -= 1
|
|
100
|
+
if self._num_r == 0:
|
|
101
|
+
self._w_lock.release()
|
|
102
|
+
self._num_r_lock.release()
|
|
103
|
+
|
|
104
|
+
@contextmanager
|
|
105
|
+
def r_lock(self) -> Generator[None, None, None]:
|
|
106
|
+
"""This method is designed to be used via the `with` statement."""
|
|
107
|
+
self.r_acquire()
|
|
108
|
+
try:
|
|
109
|
+
yield
|
|
110
|
+
finally:
|
|
111
|
+
self.r_release()
|
|
112
|
+
|
|
113
|
+
# ___________________________________________________________________
|
|
114
|
+
# Writing methods.
|
|
115
|
+
|
|
116
|
+
def w_acquire(self) -> None:
|
|
117
|
+
if not self._w_lock.acquire(timeout=self.timeout):
|
|
118
|
+
raise TimeoutError(
|
|
119
|
+
f"Timed out waiting for wlock after {self.timeout} seconds"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def w_release(self) -> None:
|
|
123
|
+
self._w_lock.release()
|
|
124
|
+
|
|
125
|
+
@contextmanager
|
|
126
|
+
def w_lock(self) -> Generator[None, None, None]:
|
|
127
|
+
"""This method is designed to be used via the `with` statement."""
|
|
128
|
+
self.w_acquire()
|
|
129
|
+
try:
|
|
130
|
+
yield
|
|
131
|
+
finally:
|
|
132
|
+
self.w_release()
|
|
133
|
+
|
|
134
|
+
def w_locked(self) -> bool:
|
|
135
|
+
"""Returns True if the lock is currently locked for reading."""
|
|
136
|
+
return self._w_lock.locked()
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import io
|
|
8
|
+
import warnings
|
|
9
|
+
from typing import IO
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _fallback_save(obj: object, f: IO[bytes]) -> None:
|
|
15
|
+
warnings.warn(
|
|
16
|
+
"using slow fallback torch.save implementation, please upgrade to PT 2.7+ for fast streaming saves"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
torch.save(obj, f)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _fallback_load(f: IO[bytes], weights_only: bool = True) -> object:
|
|
23
|
+
warnings.warn(
|
|
24
|
+
"using slow fallback torch.load implementation, please upgrade to PT 2.7+ for fast streaming loads"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# torch.load requires a seekable file object
|
|
28
|
+
buf = f.read()
|
|
29
|
+
reader = io.BytesIO(buf)
|
|
30
|
+
|
|
31
|
+
return torch.load(reader, weights_only=weights_only)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
# upgrade to PT 2.7 once released
|
|
36
|
+
from torch.distributed._serialization import _streaming_load, _streaming_save
|
|
37
|
+
except ImportError:
|
|
38
|
+
_streaming_load = _fallback_load
|
|
39
|
+
_streaming_save = _fallback_save
|