torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import socket
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
import urllib.request
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
from contextlib import contextmanager, nullcontext
|
|
14
|
+
from datetime import timedelta
|
|
15
|
+
from http.server import BaseHTTPRequestHandler
|
|
16
|
+
from typing import cast, Generator, List, Optional, TypeVar
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
|
|
20
|
+
|
|
21
|
+
from torchft.checkpointing._rwlock import RWLock
|
|
22
|
+
from torchft.checkpointing._serialization import _streaming_load, _streaming_save
|
|
23
|
+
from torchft.checkpointing.transport import CheckpointTransport
|
|
24
|
+
from torchft.http import _IPv6HTTPServer
|
|
25
|
+
|
|
26
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
T = TypeVar("T")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@contextmanager
|
|
32
|
+
def _time(desc: str) -> Generator[None, None, None]:
|
|
33
|
+
start = time.perf_counter()
|
|
34
|
+
yield
|
|
35
|
+
end = time.perf_counter()
|
|
36
|
+
logger.info(f"{desc} took {end - start}s")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class HTTPTransport(CheckpointTransport[T]):
|
|
40
|
+
"""
|
|
41
|
+
This is an HTTP server that can be used to transfer checkpoints
|
|
42
|
+
between workers.
|
|
43
|
+
|
|
44
|
+
This allows for fast recovery of workers by fetching the current weights
|
|
45
|
+
from an existing worker.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
timeout: the timeout for HTTP requests
|
|
49
|
+
num_chunks: the number of chunks to split the checkpoint into (0 for no chunking)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, timeout: timedelta, num_chunks: int) -> None:
|
|
53
|
+
self._checkpoint_lock = RWLock(timeout=timeout.total_seconds())
|
|
54
|
+
self._disallowed = False
|
|
55
|
+
self._step = -1
|
|
56
|
+
self._timeout = timeout
|
|
57
|
+
self._state_dict: Optional[T] = None
|
|
58
|
+
self._num_chunks = num_chunks
|
|
59
|
+
self._stream: Optional[torch.cuda.Stream] = (
|
|
60
|
+
torch.cuda.Stream() if torch.cuda.is_available() else None
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# staged checkpoint information
|
|
64
|
+
self._spec: Optional[TreeSpec] = None
|
|
65
|
+
self._chunks: Optional[List[List[object]]] = None
|
|
66
|
+
|
|
67
|
+
# We don't allow checkpoints until the first send_checkpoint to avoid
|
|
68
|
+
# serving the default step=-1 invalid checkpoint.
|
|
69
|
+
self.disallow_checkpoint()
|
|
70
|
+
|
|
71
|
+
ckpt_server = self
|
|
72
|
+
|
|
73
|
+
class RequestHandler(BaseHTTPRequestHandler):
|
|
74
|
+
# set request socket timeout to avoid hanging forever
|
|
75
|
+
timeout = self._timeout.total_seconds()
|
|
76
|
+
|
|
77
|
+
def do_GET(self):
|
|
78
|
+
try:
|
|
79
|
+
# validate socket timeout is actually set
|
|
80
|
+
assert self.connection.gettimeout() == self.timeout
|
|
81
|
+
|
|
82
|
+
with ckpt_server._checkpoint_lock.r_lock():
|
|
83
|
+
step = ckpt_server._step
|
|
84
|
+
|
|
85
|
+
parts = self.path.split("/")
|
|
86
|
+
assert len(parts) == 4
|
|
87
|
+
if parts[1] != "checkpoint":
|
|
88
|
+
self.send_error(
|
|
89
|
+
400,
|
|
90
|
+
f"invalid url format, expected /checkpoint/step/key but got {self.path}",
|
|
91
|
+
)
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
step = int(parts[2])
|
|
95
|
+
if step != ckpt_server._step:
|
|
96
|
+
self.send_error(
|
|
97
|
+
400,
|
|
98
|
+
f"invalid checkpoint requested, serving {ckpt_server._step} but got {step=}",
|
|
99
|
+
)
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
key = parts[3]
|
|
103
|
+
if key == "full":
|
|
104
|
+
self.send_response(200)
|
|
105
|
+
self.send_header("Content-type", "application/octet-stream")
|
|
106
|
+
self.end_headers()
|
|
107
|
+
|
|
108
|
+
state_dict = ckpt_server._state_dict
|
|
109
|
+
|
|
110
|
+
_streaming_save(state_dict, self.wfile)
|
|
111
|
+
return
|
|
112
|
+
|
|
113
|
+
if key == "metadata":
|
|
114
|
+
self.send_response(200)
|
|
115
|
+
self.send_header("Content-type", "application/octet-stream")
|
|
116
|
+
self.end_headers()
|
|
117
|
+
|
|
118
|
+
_streaming_save(ckpt_server._spec, self.wfile)
|
|
119
|
+
else:
|
|
120
|
+
chunk = ckpt_server._chunks[int(key)]
|
|
121
|
+
|
|
122
|
+
self.send_response(200)
|
|
123
|
+
self.send_header("Content-type", "application/octet-stream")
|
|
124
|
+
self.end_headers()
|
|
125
|
+
|
|
126
|
+
_streaming_save(chunk, self.wfile)
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.exception(
|
|
129
|
+
f"Exception in checkpoint server when handling {self.path=}: {e}",
|
|
130
|
+
)
|
|
131
|
+
self.send_error(500, str(e))
|
|
132
|
+
|
|
133
|
+
server_address = ("", 0)
|
|
134
|
+
self._server = _IPv6HTTPServer(server_address, RequestHandler)
|
|
135
|
+
logger.info(f"Started CheckpointServer on {self.address()}...")
|
|
136
|
+
|
|
137
|
+
self._thread = threading.Thread(
|
|
138
|
+
target=self._serve,
|
|
139
|
+
args=(),
|
|
140
|
+
daemon=True,
|
|
141
|
+
)
|
|
142
|
+
self._thread.start()
|
|
143
|
+
|
|
144
|
+
@classmethod
|
|
145
|
+
def _load_from_address(cls, address: str, timeout: timedelta) -> object:
|
|
146
|
+
"""
|
|
147
|
+
Loads a checkpoint from the given address.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
address: the HTTP address to load the checkpoint from
|
|
151
|
+
"""
|
|
152
|
+
msg = f"fetching checkpoint from {address}"
|
|
153
|
+
logger.info(msg)
|
|
154
|
+
|
|
155
|
+
with (
|
|
156
|
+
_time(msg),
|
|
157
|
+
urllib.request.urlopen(address, timeout=timeout.total_seconds()) as f,
|
|
158
|
+
):
|
|
159
|
+
# We have to set weights_only to False as there are some non-tensor
|
|
160
|
+
# states like lr_scheduler.
|
|
161
|
+
# pyre-fixme[16]: needs torch>=2.7
|
|
162
|
+
return cast(T, _streaming_load(f, weights_only=False))
|
|
163
|
+
|
|
164
|
+
def address(self) -> str:
|
|
165
|
+
"""
|
|
166
|
+
Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address.
|
|
167
|
+
|
|
168
|
+
Format: http://host:port/checkpoint/1234
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
an HTTP address
|
|
172
|
+
"""
|
|
173
|
+
port = self._server.socket.getsockname()[1]
|
|
174
|
+
return f"http://{socket.gethostname()}:{port}/checkpoint/"
|
|
175
|
+
|
|
176
|
+
def _serve(self) -> None:
|
|
177
|
+
try:
|
|
178
|
+
self._server.serve_forever()
|
|
179
|
+
except Exception as e:
|
|
180
|
+
logger.exception("got exception in checkpoint server")
|
|
181
|
+
|
|
182
|
+
def disallow_checkpoint(self) -> None:
|
|
183
|
+
"""
|
|
184
|
+
Disallows serving the checkpoint.
|
|
185
|
+
|
|
186
|
+
All requests will block until allow_checkpoint is called.
|
|
187
|
+
"""
|
|
188
|
+
if not self._disallowed:
|
|
189
|
+
self._disallowed = True
|
|
190
|
+
self._checkpoint_lock.w_acquire()
|
|
191
|
+
|
|
192
|
+
def allow_checkpoint(self, step: int) -> None:
|
|
193
|
+
"""
|
|
194
|
+
Allows serving the checkpoint with the specified step number.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
step: the step number to serve
|
|
198
|
+
"""
|
|
199
|
+
self._step = step
|
|
200
|
+
|
|
201
|
+
if self._disallowed:
|
|
202
|
+
self._disallowed = False
|
|
203
|
+
self._checkpoint_lock.w_release()
|
|
204
|
+
|
|
205
|
+
def shutdown(self, wait: bool = True) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Shutdown the server.
|
|
208
|
+
"""
|
|
209
|
+
if not wait:
|
|
210
|
+
# hack for nonblocking shutdown of socketserver threads
|
|
211
|
+
# pyre-fixme[16]: no attribute `__shutdown_request`.
|
|
212
|
+
self._server.__shutdown_request = True
|
|
213
|
+
if wait:
|
|
214
|
+
self._server.shutdown()
|
|
215
|
+
self._thread.join()
|
|
216
|
+
|
|
217
|
+
def metadata(self) -> str:
|
|
218
|
+
return self.address()
|
|
219
|
+
|
|
220
|
+
def send_checkpoint(
|
|
221
|
+
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
|
|
222
|
+
) -> None:
|
|
223
|
+
values, spec = tree_flatten(state_dict)
|
|
224
|
+
|
|
225
|
+
with (
|
|
226
|
+
torch.cuda.stream(self._stream)
|
|
227
|
+
if self._stream is not None
|
|
228
|
+
else nullcontext()
|
|
229
|
+
):
|
|
230
|
+
with _time("transferring state_dict to CPU"):
|
|
231
|
+
values = _to_cpu(values, pin_memory=False)
|
|
232
|
+
if self._stream is not None:
|
|
233
|
+
self._stream.synchronize()
|
|
234
|
+
|
|
235
|
+
# Unflatten so non-chunked transfer uses CPU tensors
|
|
236
|
+
self._state_dict = tree_unflatten(values, spec)
|
|
237
|
+
|
|
238
|
+
# Save spec for chunked
|
|
239
|
+
self._spec = spec
|
|
240
|
+
self._chunks = _split_chunks(values, self._num_chunks)
|
|
241
|
+
|
|
242
|
+
self.allow_checkpoint(step)
|
|
243
|
+
|
|
244
|
+
def recv_checkpoint(
|
|
245
|
+
self, src_rank: int, metadata: str, step: int, timeout: timedelta
|
|
246
|
+
) -> T:
|
|
247
|
+
base_url = f"{metadata}{step}"
|
|
248
|
+
if self._num_chunks == 0:
|
|
249
|
+
return cast(T, self._load_from_address(f"{base_url}/full", timeout))
|
|
250
|
+
else:
|
|
251
|
+
urls = [f"{base_url}/metadata"] + [
|
|
252
|
+
f"{base_url}/{i}" for i in range(self._num_chunks)
|
|
253
|
+
]
|
|
254
|
+
|
|
255
|
+
with ThreadPoolExecutor(max_workers=len(urls)) as executor:
|
|
256
|
+
futures = [
|
|
257
|
+
executor.submit(self._load_from_address, url, timeout)
|
|
258
|
+
for url in urls
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
spec, *chunks = [future.result() for future in futures]
|
|
262
|
+
spec = cast(TreeSpec, spec)
|
|
263
|
+
chunks = cast(List[List[object]], chunks)
|
|
264
|
+
|
|
265
|
+
values = _merge_chunks(chunks, self._num_chunks)
|
|
266
|
+
|
|
267
|
+
return tree_unflatten(values, spec)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
|
|
271
|
+
out = []
|
|
272
|
+
for v in values:
|
|
273
|
+
if isinstance(v, torch.Tensor):
|
|
274
|
+
if v.device.type == "cuda":
|
|
275
|
+
if pin_memory:
|
|
276
|
+
cpu = torch.empty(*tuple(v.size()), dtype=v.dtype, pin_memory=True)
|
|
277
|
+
cpu.copy_(v, non_blocking=True)
|
|
278
|
+
out.append(cpu)
|
|
279
|
+
else:
|
|
280
|
+
out.append(v.cpu())
|
|
281
|
+
else:
|
|
282
|
+
out.append(v)
|
|
283
|
+
else:
|
|
284
|
+
out.append(v)
|
|
285
|
+
return out
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _split_chunks(values: List[T], num_chunks: int) -> List[List[T]]:
|
|
289
|
+
return [values[i::num_chunks] for i in range(num_chunks)]
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _merge_chunks(chunks: List[List[T]], num_chunks: int) -> List[T]:
|
|
293
|
+
max_len = max(len(lst) for lst in chunks)
|
|
294
|
+
output_list = []
|
|
295
|
+
for i in range(max_len):
|
|
296
|
+
for lst in chunks:
|
|
297
|
+
if i < len(lst):
|
|
298
|
+
output_list.append(lst[i])
|
|
299
|
+
return output_list
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import sys
|
|
9
|
+
from datetime import timedelta
|
|
10
|
+
from typing import List
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from torchft.checkpointing.http_transport import _time, HTTPTransport
|
|
15
|
+
|
|
16
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def main(argv: List[str]) -> None:
|
|
20
|
+
import argparse
|
|
21
|
+
|
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
|
23
|
+
|
|
24
|
+
parser = argparse.ArgumentParser()
|
|
25
|
+
parser.add_argument("--num-chunks", type=int, default=0)
|
|
26
|
+
parser.add_argument("--device", type=str, default="cpu")
|
|
27
|
+
parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB
|
|
28
|
+
parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB
|
|
29
|
+
args = parser.parse_args(argv)
|
|
30
|
+
|
|
31
|
+
device = torch.device(args.device)
|
|
32
|
+
num_chunks: int = args.num_chunks
|
|
33
|
+
CHUNK_SIZE = args.chunk_size
|
|
34
|
+
TOTAL_SIZE = args.total_size
|
|
35
|
+
|
|
36
|
+
transport = HTTPTransport(timedelta(seconds=60), num_chunks=num_chunks)
|
|
37
|
+
metadata = transport.metadata()
|
|
38
|
+
|
|
39
|
+
logger.info(f"creating state_dict... {CHUNK_SIZE=} {TOTAL_SIZE=}")
|
|
40
|
+
|
|
41
|
+
with _time("create state_dict"):
|
|
42
|
+
state_dict = {}
|
|
43
|
+
for i in range(0, TOTAL_SIZE, CHUNK_SIZE):
|
|
44
|
+
state_dict[f"chunk/{i}"] = torch.zeros(
|
|
45
|
+
CHUNK_SIZE // 4, dtype=torch.float32, device=device
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
logger.info(f"fetching from {metadata=} {device=} {num_chunks=} {len(state_dict)=}")
|
|
49
|
+
|
|
50
|
+
transport.send_checkpoint(
|
|
51
|
+
dst_ranks=[0], step=1, state_dict=state_dict, timeout=timedelta(seconds=60)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
with _time("fetching checkpoint"):
|
|
55
|
+
transport.recv_checkpoint(
|
|
56
|
+
src_rank=1, metadata=metadata, step=1, timeout=timedelta(seconds=60)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
if __name__ == "__main__":
|
|
61
|
+
main(sys.argv[1:])
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import urllib.error
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from typing import Dict
|
|
10
|
+
from unittest import skipUnless, TestCase
|
|
11
|
+
from unittest.mock import MagicMock
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from parameterized import parameterized
|
|
15
|
+
|
|
16
|
+
from torchft.checkpointing.http_transport import HTTPTransport
|
|
17
|
+
from torchft.checkpointing.http_transport_bench import main as bench_main
|
|
18
|
+
from torchft.checkpointing.transport import CheckpointTransport
|
|
19
|
+
from torchft.checkpointing.transport_test import (
|
|
20
|
+
assertStateDictEqual,
|
|
21
|
+
run_multi_recovery_test,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestHTTPTransport(TestCase):
|
|
26
|
+
@parameterized.expand(
|
|
27
|
+
[
|
|
28
|
+
("no chunks", 0),
|
|
29
|
+
("chunked", 3),
|
|
30
|
+
]
|
|
31
|
+
)
|
|
32
|
+
def test_checkpoint_server(self, name: str, num_chunks: int) -> None:
|
|
33
|
+
expected: Dict[str, object] = {
|
|
34
|
+
"state": "dict",
|
|
35
|
+
"tensor": torch.rand(5, 2),
|
|
36
|
+
"cuda": torch.rand(
|
|
37
|
+
2, 3, device="cuda" if torch.cuda.is_available() else "cpu"
|
|
38
|
+
),
|
|
39
|
+
}
|
|
40
|
+
state_dict_fn = MagicMock()
|
|
41
|
+
state_dict_fn.return_value = expected
|
|
42
|
+
server = HTTPTransport(
|
|
43
|
+
timeout=timedelta(seconds=10),
|
|
44
|
+
num_chunks=num_chunks,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
server.send_checkpoint(
|
|
48
|
+
dst_ranks=[],
|
|
49
|
+
step=1234,
|
|
50
|
+
state_dict=expected,
|
|
51
|
+
timeout=timedelta(seconds=10),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
metadata = server.metadata()
|
|
55
|
+
|
|
56
|
+
out = server.recv_checkpoint(
|
|
57
|
+
src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
|
|
58
|
+
)
|
|
59
|
+
assertStateDictEqual(self, out, expected)
|
|
60
|
+
|
|
61
|
+
# test timeout
|
|
62
|
+
with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"):
|
|
63
|
+
server.recv_checkpoint(
|
|
64
|
+
src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=0.0)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# test mismatch case
|
|
68
|
+
server.send_checkpoint(
|
|
69
|
+
dst_ranks=[],
|
|
70
|
+
step=2345,
|
|
71
|
+
state_dict=expected,
|
|
72
|
+
timeout=timedelta(seconds=10),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
with self.assertRaisesRegex(
|
|
76
|
+
urllib.error.HTTPError, r"Error 400.*serving 2345 but got step=1234"
|
|
77
|
+
):
|
|
78
|
+
server.recv_checkpoint(
|
|
79
|
+
src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
server.shutdown()
|
|
83
|
+
|
|
84
|
+
def test_checkpoint_server_locking(self) -> None:
|
|
85
|
+
server = HTTPTransport(
|
|
86
|
+
timeout=timedelta(seconds=10),
|
|
87
|
+
num_chunks=0,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# server should start up in a disallowed state this will block incoming
|
|
91
|
+
# requests until allow_checkpoint is called
|
|
92
|
+
self.assertTrue(server._checkpoint_lock.w_locked())
|
|
93
|
+
self.assertTrue(server._disallowed)
|
|
94
|
+
self.assertEqual(server._step, -1)
|
|
95
|
+
|
|
96
|
+
# allow requests
|
|
97
|
+
server.allow_checkpoint(1)
|
|
98
|
+
|
|
99
|
+
self.assertFalse(server._checkpoint_lock.w_locked())
|
|
100
|
+
self.assertFalse(server._disallowed)
|
|
101
|
+
self.assertEqual(server._step, 1)
|
|
102
|
+
|
|
103
|
+
# duplicate allow/disallow is fine
|
|
104
|
+
server.allow_checkpoint(2)
|
|
105
|
+
self.assertEqual(server._step, 2)
|
|
106
|
+
|
|
107
|
+
server.disallow_checkpoint()
|
|
108
|
+
server.disallow_checkpoint()
|
|
109
|
+
self.assertTrue(server._checkpoint_lock.w_locked())
|
|
110
|
+
self.assertTrue(server._disallowed)
|
|
111
|
+
|
|
112
|
+
server.shutdown()
|
|
113
|
+
|
|
114
|
+
def test_multi_http_transport_cpu(self) -> None:
|
|
115
|
+
device = torch.device("cpu")
|
|
116
|
+
|
|
117
|
+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
|
|
118
|
+
return HTTPTransport(
|
|
119
|
+
timeout=timedelta(seconds=10),
|
|
120
|
+
num_chunks=0,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
run_multi_recovery_test(self, init, device=device)
|
|
124
|
+
|
|
125
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
|
|
126
|
+
@skipUnless(torch.cuda.is_available(), "CUDA is not available")
|
|
127
|
+
def test_multi_http_transport_cuda(self) -> None:
|
|
128
|
+
device = torch.device("cuda")
|
|
129
|
+
|
|
130
|
+
def init(rank: int, world_size: int) -> CheckpointTransport[Dict[str, object]]:
|
|
131
|
+
return HTTPTransport(
|
|
132
|
+
timeout=timedelta(seconds=10),
|
|
133
|
+
num_chunks=0,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
run_multi_recovery_test(self, init, device=device)
|
|
137
|
+
|
|
138
|
+
def test_benchmark(self) -> None:
|
|
139
|
+
bench_main(
|
|
140
|
+
[
|
|
141
|
+
"--chunk-size=10",
|
|
142
|
+
"--num-chunks=0",
|
|
143
|
+
"--total-size=100",
|
|
144
|
+
"--device=cpu",
|
|
145
|
+
]
|
|
146
|
+
)
|