torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import pickle
|
|
9
|
+
import time
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from datetime import timedelta
|
|
13
|
+
from typing import Callable, cast, Generator, Optional, TypeVar, Union
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch.distributed import Work
|
|
17
|
+
from torch.distributed.tensor import _DTensorSpec, DTensor
|
|
18
|
+
from torch.utils._pytree import (
|
|
19
|
+
KeyPath,
|
|
20
|
+
tree_flatten_with_path,
|
|
21
|
+
tree_unflatten,
|
|
22
|
+
TreeSpec,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from torchft.checkpointing.transport import CheckpointTransport
|
|
26
|
+
from torchft.process_group import ProcessGroup
|
|
27
|
+
|
|
28
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
T = TypeVar("T")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class _TensorMeta:
|
|
35
|
+
"""
|
|
36
|
+
This is the metadata for a tensor that is used to transfer checkpoints.
|
|
37
|
+
It contains the shape, the dtype, the storage offset and the stride of the
|
|
38
|
+
tensor.
|
|
39
|
+
|
|
40
|
+
This must be pickleable so that it can be sent over the wire.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
shape: torch.Size
|
|
44
|
+
dtype: torch.dtype
|
|
45
|
+
storage_offset: int
|
|
46
|
+
stride: tuple[int, ...]
|
|
47
|
+
nbytes: int
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class _DTensorMeta:
|
|
52
|
+
"""
|
|
53
|
+
This is the metadata for a DTensor that is used to transfer checkpoints.
|
|
54
|
+
It contains the metadata for the local tensor and the spec of the DTensor.
|
|
55
|
+
|
|
56
|
+
This must be pickleable so that it can be sent over the wire.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
local: _TensorMeta
|
|
60
|
+
spec: _DTensorSpec
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class _StateDictMeta:
|
|
65
|
+
"""
|
|
66
|
+
This is the metadata for a state dict that is used to transfer checkpoints.
|
|
67
|
+
It contains the step, the pytree spec of the state dict and the metadata for
|
|
68
|
+
each tensor in the state dict.
|
|
69
|
+
|
|
70
|
+
This must be pickleable so that it can be sent over the wire.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
step: the step of the checkpoint to verify consistency
|
|
74
|
+
treespec: the pytree spec of the state dict
|
|
75
|
+
paths: the path of each leaf in the state dict
|
|
76
|
+
non_tensor_leaves: the metadata for each tensor in the state dict and any
|
|
77
|
+
non-tensor leaves in the state dict
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
step: int
|
|
81
|
+
treespec: TreeSpec
|
|
82
|
+
paths: list[KeyPath]
|
|
83
|
+
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]]
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@contextmanager
|
|
87
|
+
def _timeit(name: str) -> Generator[None, None, None]:
|
|
88
|
+
start = time.perf_counter()
|
|
89
|
+
yield
|
|
90
|
+
dur = time.perf_counter() - start
|
|
91
|
+
logger.info(f"{name} took {dur}s")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _prepare_tensor(tensor: torch.Tensor) -> tuple[torch.Tensor, _TensorMeta]:
|
|
95
|
+
return (
|
|
96
|
+
_cast_tensor(tensor, torch.uint8),
|
|
97
|
+
_TensorMeta(
|
|
98
|
+
shape=tensor.shape,
|
|
99
|
+
dtype=tensor.dtype,
|
|
100
|
+
storage_offset=cast(int, tensor.storage_offset()),
|
|
101
|
+
stride=tensor.stride(),
|
|
102
|
+
nbytes=tensor.untyped_storage().nbytes(),
|
|
103
|
+
),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _prepare_state_dict(
|
|
108
|
+
state_dict: object,
|
|
109
|
+
step: int,
|
|
110
|
+
device: torch.device,
|
|
111
|
+
) -> tuple[_StateDictMeta, list[torch.Tensor]]:
|
|
112
|
+
leaves: list[tuple[KeyPath, object]]
|
|
113
|
+
leaves, treespec = tree_flatten_with_path(state_dict)
|
|
114
|
+
|
|
115
|
+
paths: list[KeyPath] = []
|
|
116
|
+
non_tensor_leaves: list[Union[object, _TensorMeta, _DTensorMeta]] = []
|
|
117
|
+
tensors: list[torch.Tensor] = []
|
|
118
|
+
for key_path, v in leaves:
|
|
119
|
+
paths.append(key_path)
|
|
120
|
+
|
|
121
|
+
if isinstance(v, DTensor):
|
|
122
|
+
tensor, tensor_meta = _prepare_tensor(v._local_tensor)
|
|
123
|
+
|
|
124
|
+
tensors.append(tensor)
|
|
125
|
+
|
|
126
|
+
non_tensor_leaves.append(
|
|
127
|
+
_DTensorMeta(
|
|
128
|
+
local=tensor_meta,
|
|
129
|
+
spec=v._spec,
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
elif isinstance(v, torch.Tensor):
|
|
133
|
+
tensor, tensor_meta = _prepare_tensor(v)
|
|
134
|
+
tensors.append(tensor)
|
|
135
|
+
non_tensor_leaves.append(tensor_meta)
|
|
136
|
+
else:
|
|
137
|
+
non_tensor_leaves.append(v)
|
|
138
|
+
|
|
139
|
+
return (
|
|
140
|
+
_StateDictMeta(
|
|
141
|
+
step=step,
|
|
142
|
+
treespec=treespec,
|
|
143
|
+
paths=paths,
|
|
144
|
+
non_tensor_leaves=non_tensor_leaves,
|
|
145
|
+
),
|
|
146
|
+
tensors,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
|
151
|
+
"""
|
|
152
|
+
Casts the underlying storage to a tensor of the given dtype.
|
|
153
|
+
|
|
154
|
+
The returned tensor will be of size ``storage.nbytes``.
|
|
155
|
+
|
|
156
|
+
This works for all datatypes and supports strided/offset tensors with the
|
|
157
|
+
caveat that the cast tensor may be larger than the original tensor due to
|
|
158
|
+
the differences in striding.
|
|
159
|
+
"""
|
|
160
|
+
assert (
|
|
161
|
+
type(tensor) is torch.Tensor
|
|
162
|
+
), f"can only cast standard tensors not {type(tensor)}"
|
|
163
|
+
storage = tensor.untyped_storage()
|
|
164
|
+
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
|
|
165
|
+
assert ret.untyped_storage() is storage, "storage should be the same"
|
|
166
|
+
return ret
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class PGTransport(CheckpointTransport[T]):
|
|
170
|
+
"""
|
|
171
|
+
This is a checkpoint transport that uses the process group to transfer checkpoints.
|
|
172
|
+
This allows for fast recovery of workers by fetching the current weights
|
|
173
|
+
from an existing worker.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
pg: the process group to use for communication
|
|
177
|
+
timeout: the timeout for communication
|
|
178
|
+
device: the device to use for tensors
|
|
179
|
+
state_dict: if specified this function will be called to do an inplace
|
|
180
|
+
receive into the returned state_dict. This is much faster than
|
|
181
|
+
having to allocate new tensors and transferring them to the CPU.
|
|
182
|
+
"""
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
pg: ProcessGroup,
|
|
187
|
+
timeout: timedelta,
|
|
188
|
+
device: torch.device,
|
|
189
|
+
state_dict: Optional[Callable[[], object]] = None,
|
|
190
|
+
) -> None:
|
|
191
|
+
self._work: list[Work] = []
|
|
192
|
+
self._pg = pg
|
|
193
|
+
self._timeout = timeout
|
|
194
|
+
self._device = device
|
|
195
|
+
self._state_dict = state_dict
|
|
196
|
+
|
|
197
|
+
def metadata(self) -> str:
|
|
198
|
+
return "<n/a>"
|
|
199
|
+
|
|
200
|
+
def disallow_checkpoint(self) -> None:
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
def send_checkpoint(
|
|
204
|
+
self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta
|
|
205
|
+
) -> None:
|
|
206
|
+
with _timeit("preparing state_dict"):
|
|
207
|
+
meta, tensors = _prepare_state_dict(state_dict, step, device=self._device)
|
|
208
|
+
|
|
209
|
+
work = []
|
|
210
|
+
|
|
211
|
+
with _timeit("send pickle"):
|
|
212
|
+
buf = pickle.dumps(meta)
|
|
213
|
+
len_t = torch.tensor([len(buf)], dtype=torch.int64, device=self._device)
|
|
214
|
+
buf_t = torch.frombuffer(buf, dtype=torch.uint8).to(self._device)
|
|
215
|
+
for dst_rank in dst_ranks:
|
|
216
|
+
work.append(self._pg.send([len_t], dst_rank, tag=1))
|
|
217
|
+
work.append(self._pg.send([buf_t], dst_rank, tag=2))
|
|
218
|
+
|
|
219
|
+
with _timeit("send tensors"):
|
|
220
|
+
for i, t in enumerate(tensors):
|
|
221
|
+
original_device = t.device
|
|
222
|
+
t = t.to(self._device)
|
|
223
|
+
for dst_rank in dst_ranks:
|
|
224
|
+
work.append(self._pg.send([t], dst_rank, tag=3 + i))
|
|
225
|
+
|
|
226
|
+
# if we did a copy we should wait for the work to complete so we
|
|
227
|
+
# can free the memory to avoid OOMs
|
|
228
|
+
if original_device == torch.device("cpu"):
|
|
229
|
+
for w in work:
|
|
230
|
+
w.wait(timeout)
|
|
231
|
+
work = []
|
|
232
|
+
|
|
233
|
+
for w in work:
|
|
234
|
+
w.wait(timeout)
|
|
235
|
+
|
|
236
|
+
def recv_checkpoint(
|
|
237
|
+
self, src_rank: int, metadata: str, step: int, timeout: timedelta
|
|
238
|
+
) -> T:
|
|
239
|
+
state_dict = self._state_dict() if self._state_dict else {}
|
|
240
|
+
state_dict_leaves, _ = tree_flatten_with_path(state_dict)
|
|
241
|
+
|
|
242
|
+
dst_tensors: dict[KeyPath, object] = dict(state_dict_leaves)
|
|
243
|
+
|
|
244
|
+
len_t = torch.zeros(1, dtype=torch.int64, device=self._device)
|
|
245
|
+
self._pg.recv([len_t], src_rank, tag=1).wait(timeout)
|
|
246
|
+
length = cast(int, len_t.item())
|
|
247
|
+
|
|
248
|
+
assert length > 0, f"invalid metadata length {length=}"
|
|
249
|
+
|
|
250
|
+
buf = torch.empty(length, dtype=torch.uint8, device=self._device)
|
|
251
|
+
self._pg.recv([buf], src_rank, tag=2).wait(timeout)
|
|
252
|
+
|
|
253
|
+
meta: _StateDictMeta = pickle.loads(buf.cpu().numpy().tobytes())
|
|
254
|
+
assert meta.step == step
|
|
255
|
+
|
|
256
|
+
i: int = 0
|
|
257
|
+
works: list[Work] = []
|
|
258
|
+
|
|
259
|
+
def recv(path: KeyPath, v: _TensorMeta) -> torch.Tensor:
|
|
260
|
+
nonlocal i
|
|
261
|
+
|
|
262
|
+
inplace = dst_tensors.get(path)
|
|
263
|
+
if (
|
|
264
|
+
isinstance(inplace, torch.Tensor)
|
|
265
|
+
and inplace.device.type == self._device.type
|
|
266
|
+
):
|
|
267
|
+
if isinstance(inplace, DTensor):
|
|
268
|
+
inplace = inplace._local_tensor
|
|
269
|
+
t = _cast_tensor(inplace, torch.uint8)
|
|
270
|
+
assert (
|
|
271
|
+
t.nbytes == v.nbytes
|
|
272
|
+
), "inplace tensor storage must be the same size"
|
|
273
|
+
else:
|
|
274
|
+
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
|
|
275
|
+
|
|
276
|
+
work = self._pg.recv([t], src_rank, tag=3 + i)
|
|
277
|
+
i += 1
|
|
278
|
+
|
|
279
|
+
if inplace is None:
|
|
280
|
+
# if not inplace we need to copy it to CPU to avoid OOMing
|
|
281
|
+
work.wait(timeout)
|
|
282
|
+
t = t.cpu()
|
|
283
|
+
else:
|
|
284
|
+
works.append(work)
|
|
285
|
+
|
|
286
|
+
return torch.as_strided(
|
|
287
|
+
t.view(v.dtype),
|
|
288
|
+
size=v.shape,
|
|
289
|
+
stride=v.stride,
|
|
290
|
+
storage_offset=v.storage_offset,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
values = []
|
|
294
|
+
for path, v in zip(meta.paths, meta.non_tensor_leaves):
|
|
295
|
+
if isinstance(v, _TensorMeta):
|
|
296
|
+
values.append(recv(path, v))
|
|
297
|
+
elif isinstance(v, _DTensorMeta):
|
|
298
|
+
tensor = recv(path, v.local)
|
|
299
|
+
values.append(DTensor(tensor, v.spec, requires_grad=False))
|
|
300
|
+
else:
|
|
301
|
+
values.append(v)
|
|
302
|
+
|
|
303
|
+
for work in works:
|
|
304
|
+
work.wait(timeout)
|
|
305
|
+
|
|
306
|
+
return tree_unflatten(values, meta.treespec)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
from datetime import timedelta
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.distributed as dist
|
|
14
|
+
|
|
15
|
+
from torchft.checkpointing.pg_transport import _timeit, PGTransport
|
|
16
|
+
from torchft.process_group import ProcessGroupBabyNCCL
|
|
17
|
+
|
|
18
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def main(argv: list[str]) -> None:
|
|
22
|
+
import argparse
|
|
23
|
+
|
|
24
|
+
logging.basicConfig(level=logging.INFO)
|
|
25
|
+
|
|
26
|
+
parser = argparse.ArgumentParser()
|
|
27
|
+
parser.add_argument("--inplace", action="store_true")
|
|
28
|
+
parser.add_argument("--device", type=str, default="cpu")
|
|
29
|
+
parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB
|
|
30
|
+
parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB
|
|
31
|
+
args = parser.parse_args(argv)
|
|
32
|
+
|
|
33
|
+
CHUNK_SIZE: int = args.chunk_size
|
|
34
|
+
TOTAL_SIZE: int = args.total_size
|
|
35
|
+
INPLACE: bool = args.inplace
|
|
36
|
+
DEVICE: str = args.device
|
|
37
|
+
|
|
38
|
+
timeout: timedelta = timedelta(seconds=10)
|
|
39
|
+
|
|
40
|
+
store = dist.TCPStore(
|
|
41
|
+
"localhost",
|
|
42
|
+
0,
|
|
43
|
+
is_master=True,
|
|
44
|
+
timeout=timeout,
|
|
45
|
+
wait_for_workers=False,
|
|
46
|
+
)
|
|
47
|
+
store_addr: str = f"localhost:{store.port}"
|
|
48
|
+
|
|
49
|
+
def run(rank: int) -> None:
|
|
50
|
+
torch.cuda.set_device(rank)
|
|
51
|
+
|
|
52
|
+
device = torch.device(DEVICE)
|
|
53
|
+
|
|
54
|
+
with _timeit("init_pg"):
|
|
55
|
+
pg = ProcessGroupBabyNCCL(timeout=timeout)
|
|
56
|
+
pg.configure(store_addr=store_addr, replica_id="0", rank=rank, world_size=2)
|
|
57
|
+
|
|
58
|
+
t = torch.zeros(10, device=device, dtype=torch.float32)
|
|
59
|
+
pg.allreduce([t], dist.ReduceOp.SUM).wait(timeout=timeout)
|
|
60
|
+
|
|
61
|
+
with _timeit("create state_dict"):
|
|
62
|
+
state_dict: dict[str, torch.Tensor] = {}
|
|
63
|
+
for i in range(0, TOTAL_SIZE, CHUNK_SIZE):
|
|
64
|
+
state_dict[f"chunk/{i}"] = torch.zeros(
|
|
65
|
+
CHUNK_SIZE // 4, dtype=torch.float32, device=device
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def get_state_dict() -> object:
|
|
69
|
+
return state_dict
|
|
70
|
+
|
|
71
|
+
transport = PGTransport(
|
|
72
|
+
pg=pg,
|
|
73
|
+
timeout=timeout,
|
|
74
|
+
device=device,
|
|
75
|
+
state_dict=get_state_dict if INPLACE else None,
|
|
76
|
+
)
|
|
77
|
+
metadata = transport.metadata()
|
|
78
|
+
|
|
79
|
+
if rank == 0:
|
|
80
|
+
with _timeit("send_checkpoint"):
|
|
81
|
+
transport.send_checkpoint(
|
|
82
|
+
dst_ranks=[1],
|
|
83
|
+
step=1,
|
|
84
|
+
state_dict=state_dict,
|
|
85
|
+
timeout=timedelta(seconds=60),
|
|
86
|
+
)
|
|
87
|
+
elif rank == 1:
|
|
88
|
+
with _timeit("recv_checkpoint"):
|
|
89
|
+
transport.recv_checkpoint(
|
|
90
|
+
src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=60)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
94
|
+
results = executor.map(run, range(2))
|
|
95
|
+
list(results)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
if __name__ == "__main__":
|
|
99
|
+
main(sys.argv[1:])
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import sys
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from unittest import skipIf, skipUnless, TestCase
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch.distributed import TCPStore
|
|
13
|
+
|
|
14
|
+
from torchft.checkpointing.pg_transport import PGTransport
|
|
15
|
+
from torchft.checkpointing.transport import CheckpointTransport
|
|
16
|
+
from torchft.checkpointing.transport_test import (
|
|
17
|
+
make_state_dict,
|
|
18
|
+
run_multi_recovery_test,
|
|
19
|
+
)
|
|
20
|
+
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PGTransportTest(TestCase):
|
|
24
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
25
|
+
@skipIf(sys.platform == "darwin", "not passing on mac")
|
|
26
|
+
def test_pg_transport_gloo(self) -> None:
|
|
27
|
+
store: TCPStore = TCPStore(
|
|
28
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
29
|
+
)
|
|
30
|
+
device: torch.device = torch.device("cpu")
|
|
31
|
+
|
|
32
|
+
def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
|
|
33
|
+
pg = ProcessGroupGloo()
|
|
34
|
+
pg.configure(
|
|
35
|
+
store_addr=f"localhost:{store.port}/prefix",
|
|
36
|
+
replica_id="0",
|
|
37
|
+
rank=rank,
|
|
38
|
+
world_size=world_size,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
return PGTransport[dict[str, object]](
|
|
42
|
+
pg, timeout=timedelta(seconds=10), device=device
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
run_multi_recovery_test(self, init, device=device)
|
|
46
|
+
|
|
47
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
48
|
+
@skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
|
|
49
|
+
def test_pg_transport_baby_nccl(self) -> None:
|
|
50
|
+
store: TCPStore = TCPStore(
|
|
51
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
52
|
+
)
|
|
53
|
+
device: torch.device = torch.device("cuda")
|
|
54
|
+
timeout: timedelta = timedelta(seconds=10)
|
|
55
|
+
|
|
56
|
+
def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
|
|
57
|
+
torch.cuda.set_device(rank)
|
|
58
|
+
|
|
59
|
+
pg = ProcessGroupBabyNCCL(timeout=timeout)
|
|
60
|
+
pg.configure(
|
|
61
|
+
store_addr=f"localhost:{store.port}/prefix",
|
|
62
|
+
replica_id="0",
|
|
63
|
+
rank=rank,
|
|
64
|
+
world_size=world_size,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return PGTransport[dict[str, object]](pg, timeout=timeout, device=device)
|
|
68
|
+
|
|
69
|
+
run_multi_recovery_test(self, init, device=device)
|
|
70
|
+
|
|
71
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
72
|
+
@skipUnless(torch.cuda.device_count() >= 3, "need three CUDA devices")
|
|
73
|
+
def test_pg_transport_baby_nccl_inplace(self) -> None:
|
|
74
|
+
store: TCPStore = TCPStore(
|
|
75
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
76
|
+
)
|
|
77
|
+
device: torch.device = torch.device("cuda")
|
|
78
|
+
timeout: timedelta = timedelta(seconds=10)
|
|
79
|
+
|
|
80
|
+
def state_dict() -> dict[str, object]:
|
|
81
|
+
return make_state_dict(device)
|
|
82
|
+
|
|
83
|
+
def init(rank: int, world_size: int) -> CheckpointTransport[dict[str, object]]:
|
|
84
|
+
torch.cuda.set_device(rank)
|
|
85
|
+
|
|
86
|
+
pg = ProcessGroupBabyNCCL(timeout=timeout)
|
|
87
|
+
pg.configure(
|
|
88
|
+
store_addr=f"localhost:{store.port}/prefix",
|
|
89
|
+
replica_id="0",
|
|
90
|
+
rank=rank,
|
|
91
|
+
world_size=world_size,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return PGTransport[dict[str, object]](
|
|
95
|
+
pg,
|
|
96
|
+
timeout=timeout,
|
|
97
|
+
device=device,
|
|
98
|
+
state_dict=state_dict,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
run_multi_recovery_test(self, init, device=device)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import pytest
|
|
8
|
+
|
|
9
|
+
from torchft.checkpointing._rwlock import RWLock
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_w_locked() -> None:
|
|
13
|
+
lock = RWLock()
|
|
14
|
+
|
|
15
|
+
with lock.w_lock():
|
|
16
|
+
assert lock.w_locked()
|
|
17
|
+
assert not lock.w_locked()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_w_lock_timeout() -> None:
|
|
21
|
+
lock = RWLock(timeout=0.01)
|
|
22
|
+
|
|
23
|
+
lock.r_acquire()
|
|
24
|
+
lock.r_acquire()
|
|
25
|
+
|
|
26
|
+
with pytest.raises(TimeoutError):
|
|
27
|
+
lock.w_acquire()
|
|
28
|
+
|
|
29
|
+
with pytest.raises(TimeoutError):
|
|
30
|
+
with lock.w_lock():
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
lock.r_release()
|
|
34
|
+
with pytest.raises(TimeoutError):
|
|
35
|
+
lock.w_acquire()
|
|
36
|
+
|
|
37
|
+
lock.r_release()
|
|
38
|
+
with lock.w_lock():
|
|
39
|
+
pass
|
|
40
|
+
lock.w_acquire()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_r_lock_timeout() -> None:
|
|
44
|
+
lock = RWLock(timeout=0.01)
|
|
45
|
+
|
|
46
|
+
lock.w_acquire()
|
|
47
|
+
|
|
48
|
+
with pytest.raises(TimeoutError):
|
|
49
|
+
lock.r_acquire()
|
|
50
|
+
|
|
51
|
+
with pytest.raises(TimeoutError):
|
|
52
|
+
with lock.r_lock():
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
lock.w_release()
|
|
56
|
+
with lock.r_lock():
|
|
57
|
+
pass
|
|
58
|
+
lock.r_acquire()
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from typing import Generic, List, TypeVar
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CheckpointTransport(Generic[T], ABC):
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def metadata(self) -> str:
|
|
17
|
+
"""
|
|
18
|
+
Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
|
|
19
|
+
"""
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def send_checkpoint(
|
|
24
|
+
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
|
|
25
|
+
) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Sends the checkpoint, only called when there is a rank that is behind.
|
|
28
|
+
|
|
29
|
+
This may be async.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
dst_ranks: the ranks to send to
|
|
33
|
+
step: the step number to send
|
|
34
|
+
state_dict: the state dict to send
|
|
35
|
+
timeout: the timeout to wait for the checkpoint to be sent
|
|
36
|
+
"""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
def disallow_checkpoint(self) -> None:
|
|
40
|
+
"""
|
|
41
|
+
Called after send_checkpoint to wait for the checkpoint to be sent.
|
|
42
|
+
|
|
43
|
+
Once this returns, the state_dict may be mutated so no further data should be sent.
|
|
44
|
+
"""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def recv_checkpoint(
|
|
49
|
+
self, src_rank: int, metadata: str, step: int, timeout: timedelta
|
|
50
|
+
) -> T:
|
|
51
|
+
"""
|
|
52
|
+
Receives the checkpoint from the given rank.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
src_rank: the rank to receive the checkpoint from
|
|
56
|
+
metadata: the metadata returned by the remote CheckpointTransport
|
|
57
|
+
step: the step number to receive
|
|
58
|
+
timeout: the timeout to wait for the checkpoint
|
|
59
|
+
"""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
def shutdown(self, wait: bool = True) -> None:
|
|
63
|
+
"""
|
|
64
|
+
Called to shutdown the checkpoint transport.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
wait: whether to wait for the transport to shutdown
|
|
68
|
+
"""
|