checkpoint-engine 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +3 -3
- checkpoint_engine/ps.py +357 -177
- checkpoint_engine/worker.py +12 -7
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.3.dist-info}/METADATA +11 -10
- checkpoint_engine-0.1.3.dist-info/RECORD +9 -0
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.3.dist-info}/licenses/LICENCE +1 -1
- checkpoint_engine-0.1.1.dist-info/RECORD +0 -9
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.3.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.3.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.3'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 3)
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id =
|
|
34
|
+
__commit_id__ = commit_id = None
|
checkpoint_engine/ps.py
CHANGED
|
@@ -1,30 +1,28 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
import argparse
|
|
4
2
|
import concurrent.futures
|
|
3
|
+
import ctypes
|
|
5
4
|
import os
|
|
6
5
|
import pickle
|
|
7
6
|
import random
|
|
8
7
|
import socket
|
|
9
|
-
import subprocess
|
|
10
8
|
import threading
|
|
11
9
|
import time
|
|
12
|
-
import uuid
|
|
13
10
|
from collections import defaultdict
|
|
11
|
+
from collections.abc import Callable
|
|
14
12
|
from datetime import timedelta
|
|
15
|
-
from functools import
|
|
16
|
-
from typing import
|
|
13
|
+
from functools import lru_cache
|
|
14
|
+
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
|
|
17
15
|
|
|
16
|
+
import httpx
|
|
18
17
|
import numpy as np
|
|
19
|
-
import requests
|
|
20
18
|
import torch
|
|
21
19
|
import torch.distributed as dist
|
|
22
20
|
import zmq
|
|
23
21
|
from loguru import logger
|
|
24
|
-
from pydantic import BaseModel,
|
|
22
|
+
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
25
23
|
from safetensors.torch import safe_open
|
|
26
24
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
27
|
-
|
|
25
|
+
|
|
28
26
|
|
|
29
27
|
if TYPE_CHECKING:
|
|
30
28
|
from typing_extensions import TypedDict
|
|
@@ -37,16 +35,59 @@ if TYPE_CHECKING:
|
|
|
37
35
|
tp_concat_dim: int
|
|
38
36
|
|
|
39
37
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
38
|
+
def _dt_validate(value: Any) -> torch.dtype:
|
|
39
|
+
if isinstance(value, str):
|
|
40
|
+
if not value.startswith("torch."):
|
|
41
|
+
raise ValueError(f"dtype {value} should start with torch.")
|
|
42
|
+
try:
|
|
43
|
+
value = getattr(torch, value.split(".")[1])
|
|
44
|
+
except AttributeError as e:
|
|
45
|
+
raise ValueError(f"unknown dtype: {value}") from e
|
|
46
|
+
if not isinstance(value, torch.dtype):
|
|
47
|
+
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
|
|
48
|
+
return value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
_TorchDtype = Annotated[
|
|
52
|
+
torch.dtype,
|
|
53
|
+
PlainValidator(_dt_validate),
|
|
54
|
+
PlainSerializer(lambda x: str(x), return_type=str),
|
|
55
|
+
WithJsonSchema({"type": "string"}, mode="serialization"),
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _size_validate(value: Any) -> torch.Size:
|
|
60
|
+
if isinstance(value, list | tuple):
|
|
61
|
+
return torch.Size(value)
|
|
62
|
+
if not isinstance(value, torch.Size):
|
|
63
|
+
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
|
|
64
|
+
return value
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
_TorchSize = Annotated[
|
|
68
|
+
torch.Size,
|
|
69
|
+
PlainValidator(_size_validate),
|
|
70
|
+
PlainSerializer(lambda x: tuple(x), return_type=tuple),
|
|
71
|
+
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
|
|
72
|
+
]
|
|
73
|
+
|
|
46
74
|
|
|
75
|
+
def _tensor_validate(value: Any) -> torch.Tensor:
|
|
76
|
+
if isinstance(value, torch.Tensor):
|
|
77
|
+
return value
|
|
78
|
+
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
_TorchTensor = Annotated[
|
|
82
|
+
torch.Tensor,
|
|
83
|
+
PlainValidator(_tensor_validate),
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class ParameterMeta(BaseModel):
|
|
47
88
|
name: str
|
|
48
|
-
dtype:
|
|
49
|
-
shape:
|
|
89
|
+
dtype: _TorchDtype
|
|
90
|
+
shape: _TorchSize
|
|
50
91
|
|
|
51
92
|
|
|
52
93
|
class BucketRange(NamedTuple):
|
|
@@ -68,9 +109,7 @@ class MemoryBufferMetas(BaseModel):
|
|
|
68
109
|
|
|
69
110
|
|
|
70
111
|
class MemoryBuffer(BaseModel):
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
buffer: torch.Tensor
|
|
112
|
+
buffer: _TorchTensor
|
|
74
113
|
size: int
|
|
75
114
|
metas: list[ParameterMeta]
|
|
76
115
|
|
|
@@ -82,7 +121,7 @@ class MemoryBufferMetaList(BaseModel):
|
|
|
82
121
|
|
|
83
122
|
class DataToGather(MemoryBufferMetaList):
|
|
84
123
|
host_ip: str
|
|
85
|
-
|
|
124
|
+
device_uuid: str
|
|
86
125
|
|
|
87
126
|
|
|
88
127
|
# 256 bytes alignment when flatten torch tensors to uint8 buffer
|
|
@@ -93,7 +132,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
|
93
132
|
return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
|
|
94
133
|
|
|
95
134
|
|
|
96
|
-
def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
|
|
135
|
+
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
97
136
|
ret = []
|
|
98
137
|
for meta in metas:
|
|
99
138
|
size = _align_size(meta.dtype, meta.shape)
|
|
@@ -109,11 +148,11 @@ def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
|
|
|
109
148
|
return ret
|
|
110
149
|
|
|
111
150
|
|
|
112
|
-
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
|
|
113
|
-
def _safetensors_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
151
|
+
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
|
|
152
|
+
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
114
153
|
ret = {}
|
|
115
154
|
with safe_open(fn, framework="pt") as f:
|
|
116
|
-
for name in f.keys():
|
|
155
|
+
for name in f.keys(): # noqa: SIM118
|
|
117
156
|
weight = f.get_tensor(name)
|
|
118
157
|
meta = {
|
|
119
158
|
"key": name,
|
|
@@ -126,10 +165,10 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
|
|
|
126
165
|
return ret
|
|
127
166
|
|
|
128
167
|
# deprecated, will be removed in the future
|
|
129
|
-
def _fast_np_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
168
|
+
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
130
169
|
"""load *.np file and return memmap and related tensor meta"""
|
|
131
170
|
|
|
132
|
-
def parse_npy_header(fin):
|
|
171
|
+
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
|
|
133
172
|
start = fin.tell()
|
|
134
173
|
major, minor = np.lib.format.read_magic(fin)
|
|
135
174
|
if major == 1 and minor == 0:
|
|
@@ -137,7 +176,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
|
|
|
137
176
|
elif major == 2 and minor == 0:
|
|
138
177
|
read_header_fn = np.lib.format.read_array_header_2_0
|
|
139
178
|
else:
|
|
140
|
-
raise ValueError(
|
|
179
|
+
raise ValueError(
|
|
180
|
+
f"unknown version {major}.{minor} when parsing npy header from {fn}"
|
|
181
|
+
)
|
|
141
182
|
shape, is_fortran, dtype = read_header_fn(fin)
|
|
142
183
|
return {
|
|
143
184
|
"shape": shape,
|
|
@@ -193,7 +234,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
|
|
|
193
234
|
return tp_rank, ret
|
|
194
235
|
|
|
195
236
|
|
|
196
|
-
def _concat_tp_weights(
|
|
237
|
+
def _concat_tp_weights(
|
|
238
|
+
tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
|
|
239
|
+
) -> torch.Tensor:
|
|
197
240
|
"""Concat tp weights with meta info.
|
|
198
241
|
If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
|
|
199
242
|
Else we will cat weights in concat_dim.
|
|
@@ -206,39 +249,54 @@ def _concat_tp_weights(tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_si
|
|
|
206
249
|
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
207
250
|
|
|
208
251
|
|
|
209
|
-
def _get_physical_gpu_id(
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
for line in lines:
|
|
215
|
-
if f"GPU {rank}" in line:
|
|
216
|
-
uuid = line.split("UUID: ")[1].strip(")")
|
|
217
|
-
return uuid
|
|
218
|
-
raise ValueError(f"not found gpu{rank} uuid")
|
|
252
|
+
def _get_physical_gpu_id(device_index: int | None = None) -> str:
|
|
253
|
+
try:
|
|
254
|
+
return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
|
|
255
|
+
except AssertionError as e:
|
|
256
|
+
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
219
257
|
|
|
220
258
|
|
|
221
259
|
@lru_cache(maxsize=1)
|
|
222
|
-
def _get_ip():
|
|
260
|
+
def _get_ip() -> str:
|
|
223
261
|
try:
|
|
224
262
|
# try to get ip from network interface
|
|
225
263
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
226
264
|
s.connect(("8.8.8.8", 80))
|
|
227
265
|
return s.getsockname()[0]
|
|
228
|
-
except:
|
|
266
|
+
except Exception as e: # noqa: BLE001
|
|
229
267
|
# fallback to get ip from hostname
|
|
230
|
-
logger.warning(
|
|
268
|
+
logger.warning(
|
|
269
|
+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
270
|
+
)
|
|
231
271
|
return socket.gethostbyname(socket.gethostname())
|
|
232
272
|
|
|
233
273
|
|
|
274
|
+
def _ibv_get_device_list() -> list[str]:
|
|
275
|
+
lib = ctypes.CDLL("libibverbs.so.1")
|
|
276
|
+
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
277
|
+
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
|
|
278
|
+
|
|
279
|
+
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
|
280
|
+
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
|
|
281
|
+
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
|
|
282
|
+
|
|
283
|
+
num = ctypes.c_int()
|
|
284
|
+
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
|
|
285
|
+
if not dev_array or num.value <= 0:
|
|
286
|
+
return []
|
|
287
|
+
|
|
288
|
+
devices = []
|
|
289
|
+
for i in range(num.value):
|
|
290
|
+
dev_ptr = dev_array[i] # struct ibv_device *
|
|
291
|
+
name = lib.ibv_get_device_name(dev_ptr) # const char *
|
|
292
|
+
devices.append(name.decode())
|
|
293
|
+
lib.ibv_free_device_list(dev_array)
|
|
294
|
+
return devices
|
|
295
|
+
|
|
296
|
+
|
|
234
297
|
def _get_rdma_devices() -> list[str]:
|
|
235
298
|
"""
|
|
236
|
-
use
|
|
237
|
-
```bash
|
|
238
|
-
pushd /sys/class/infiniband/ > /dev/null;
|
|
239
|
-
for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done;
|
|
240
|
-
popd > /dev/null;
|
|
241
|
-
```
|
|
299
|
+
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
|
|
242
300
|
"""
|
|
243
301
|
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
|
|
244
302
|
if devices_str:
|
|
@@ -246,41 +304,27 @@ def _get_rdma_devices() -> list[str]:
|
|
|
246
304
|
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
247
305
|
hca = os.getenv("NCCL_IB_HCA", None)
|
|
248
306
|
if hca:
|
|
249
|
-
|
|
250
|
-
if len(
|
|
307
|
+
hca_list = hca.split(",")
|
|
308
|
+
if len(hca_list) > 1:
|
|
251
309
|
# if NCCL_IB_HCA has multiple values, just return
|
|
252
|
-
return
|
|
310
|
+
return hca_list
|
|
253
311
|
else:
|
|
254
|
-
hca =
|
|
255
|
-
|
|
256
|
-
port_path = "ports/1/gid_attrs/types"
|
|
257
|
-
devices = []
|
|
258
|
-
for device in sorted(os.listdir(basepath)):
|
|
259
|
-
if hca is not None and hca not in device:
|
|
260
|
-
continue
|
|
261
|
-
path = os.path.join(basepath, device, port_path)
|
|
262
|
-
if not os.path.exists(path) or not os.path.isdir(path):
|
|
263
|
-
continue
|
|
264
|
-
for port in os.listdir(path):
|
|
265
|
-
try:
|
|
266
|
-
content = open(os.path.join(path, port)).read()
|
|
267
|
-
if "v" in content:
|
|
268
|
-
print(f"found rdma device {device} in port {port}: {content.strip()}")
|
|
269
|
-
devices.append(device)
|
|
270
|
-
break
|
|
271
|
-
except Exception:
|
|
272
|
-
pass
|
|
273
|
-
return devices
|
|
312
|
+
hca = hca_list[0]
|
|
313
|
+
return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
|
|
274
314
|
|
|
275
315
|
|
|
276
|
-
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]):
|
|
316
|
+
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
277
317
|
"""
|
|
278
318
|
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
|
|
279
319
|
"""
|
|
280
320
|
if not devices:
|
|
281
321
|
raise RuntimeError("no rdma devices found")
|
|
282
|
-
assert len(devices) <= gpu_count,
|
|
283
|
-
|
|
322
|
+
assert len(devices) <= gpu_count, (
|
|
323
|
+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
324
|
+
)
|
|
325
|
+
assert gpu_count % len(devices) == 0, (
|
|
326
|
+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
327
|
+
)
|
|
284
328
|
return devices[local_rank // (gpu_count // len(devices))]
|
|
285
329
|
|
|
286
330
|
|
|
@@ -305,8 +349,12 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
305
349
|
size=1,
|
|
306
350
|
)
|
|
307
351
|
if parameter_name not in parameter_metas:
|
|
308
|
-
assert isinstance(meta["dtype"], torch.dtype),
|
|
309
|
-
|
|
352
|
+
assert isinstance(meta["dtype"], torch.dtype), (
|
|
353
|
+
f"meta {meta} dtype should be torch.dtype"
|
|
354
|
+
)
|
|
355
|
+
assert isinstance(meta["shape"], torch.Size), (
|
|
356
|
+
f"meta {meta} shape should be torch.Size"
|
|
357
|
+
)
|
|
310
358
|
parameter_metas[parameter_name] = ParameterMeta(
|
|
311
359
|
name=parameter_name,
|
|
312
360
|
shape=meta["shape"],
|
|
@@ -319,7 +367,9 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
319
367
|
if tp_meta.concat_dim != -1:
|
|
320
368
|
shape = list(parameter_metas[name].shape)
|
|
321
369
|
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
322
|
-
parameter_metas[name] = ParameterMeta(
|
|
370
|
+
parameter_metas[name] = ParameterMeta(
|
|
371
|
+
name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
|
|
372
|
+
)
|
|
323
373
|
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
324
374
|
# TODO: here concat is serial, which may be slow
|
|
325
375
|
# but since tp storage is not used in the future
|
|
@@ -338,17 +388,19 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
338
388
|
|
|
339
389
|
def _register_checkpoint(
|
|
340
390
|
*,
|
|
341
|
-
files: list[str]
|
|
342
|
-
named_tensors: dict[str, torch.Tensor]
|
|
391
|
+
files: list[str],
|
|
392
|
+
named_tensors: dict[str, torch.Tensor],
|
|
343
393
|
rank: int | None = None,
|
|
344
394
|
) -> list[MemoryBuffer]:
|
|
345
|
-
logger.info(
|
|
395
|
+
logger.info(
|
|
396
|
+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
397
|
+
)
|
|
346
398
|
if not files and not named_tensors:
|
|
347
399
|
return []
|
|
348
400
|
parameters = _load_checkpoint(files)
|
|
349
401
|
if named_tensors:
|
|
350
402
|
parameters.update(named_tensors)
|
|
351
|
-
bucket_size = max(4 << 30, max(
|
|
403
|
+
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
|
|
352
404
|
|
|
353
405
|
class MemoryBucket(BaseModel):
|
|
354
406
|
size: int
|
|
@@ -363,7 +415,10 @@ def _register_checkpoint(
|
|
|
363
415
|
buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
|
|
364
416
|
buckets[-1].size += size
|
|
365
417
|
|
|
366
|
-
memory_buffers = [
|
|
418
|
+
memory_buffers = [
|
|
419
|
+
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
|
|
420
|
+
for bucket in buckets
|
|
421
|
+
]
|
|
367
422
|
|
|
368
423
|
def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
|
|
369
424
|
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
@@ -373,7 +428,10 @@ def _register_checkpoint(
|
|
|
373
428
|
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
374
429
|
|
|
375
430
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
376
|
-
futures = [
|
|
431
|
+
futures = [
|
|
432
|
+
executor.submit(register_pin_memory, idx, bucket.size)
|
|
433
|
+
for idx, bucket in enumerate(buckets)
|
|
434
|
+
]
|
|
377
435
|
new_futures = []
|
|
378
436
|
for future in concurrent.futures.as_completed(futures):
|
|
379
437
|
idx, buffer = future.result()
|
|
@@ -400,8 +458,26 @@ def _register_checkpoint(
|
|
|
400
458
|
return memory_buffers
|
|
401
459
|
|
|
402
460
|
|
|
403
|
-
def request_inference_to_update(
|
|
404
|
-
|
|
461
|
+
def request_inference_to_update(
|
|
462
|
+
url: str,
|
|
463
|
+
socket_paths: dict[str, str],
|
|
464
|
+
timeout: float = 300.0,
|
|
465
|
+
uds: str | None = None,
|
|
466
|
+
):
|
|
467
|
+
"""Send an inference update request to inference server via HTTP or Unix socket.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
|
|
471
|
+
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
|
|
472
|
+
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
|
|
473
|
+
uds (str, optional): Path to a Unix domain socket. If provided, the request
|
|
474
|
+
will be sent via the Unix socket instead of HTTP. Defaults to None.
|
|
475
|
+
|
|
476
|
+
Raises:
|
|
477
|
+
httpx.HTTPStatusError: If the response contains an HTTP error status.
|
|
478
|
+
httpx.RequestError: If there was an issue while making the request.
|
|
479
|
+
"""
|
|
480
|
+
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
|
|
405
481
|
url,
|
|
406
482
|
json={
|
|
407
483
|
"method": "update_weights_from_ipc",
|
|
@@ -413,7 +489,9 @@ def request_inference_to_update(url: str, socket_paths: dict[str, str], timeout:
|
|
|
413
489
|
resp.raise_for_status()
|
|
414
490
|
|
|
415
491
|
|
|
416
|
-
def _gen_h2d_buckets(
|
|
492
|
+
def _gen_h2d_buckets(
|
|
493
|
+
global_metas: dict[int, MemoryBufferMetaList], bucket_size: int
|
|
494
|
+
) -> list[tuple[int, H2DBucket]]:
|
|
417
495
|
buckets: list[tuple[int, H2DBucket]] = []
|
|
418
496
|
|
|
419
497
|
for owner_rank, items in global_metas.items():
|
|
@@ -424,14 +502,18 @@ def _gen_h2d_buckets(global_metas: dict[int, MemoryBufferMetaList], bucket_size:
|
|
|
424
502
|
s = _align_size(meta.dtype, meta.shape)
|
|
425
503
|
if buckets[-1][1].size + s > bucket_size:
|
|
426
504
|
if offset - start_offset > 0:
|
|
427
|
-
buckets[-1][1].ranges.append(
|
|
505
|
+
buckets[-1][1].ranges.append(
|
|
506
|
+
BucketRange(idx, start_offset, offset - start_offset)
|
|
507
|
+
)
|
|
428
508
|
start_offset = offset
|
|
429
509
|
buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
|
|
430
510
|
offset += s
|
|
431
511
|
buckets[-1][1].size += s
|
|
432
512
|
buckets[-1][1].items.append(meta)
|
|
433
513
|
buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
|
|
434
|
-
assert buckets[-1][1].size > 0,
|
|
514
|
+
assert buckets[-1][1].size > 0, (
|
|
515
|
+
f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
|
|
516
|
+
)
|
|
435
517
|
return buckets
|
|
436
518
|
|
|
437
519
|
|
|
@@ -470,7 +552,9 @@ class P2PStore:
|
|
|
470
552
|
raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
|
|
471
553
|
self.port = self.engine.get_rpc_port()
|
|
472
554
|
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
473
|
-
logger.info(
|
|
555
|
+
logger.info(
|
|
556
|
+
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
|
|
557
|
+
)
|
|
474
558
|
|
|
475
559
|
@property
|
|
476
560
|
def addr(self) -> str:
|
|
@@ -492,37 +576,60 @@ class P2PStore:
|
|
|
492
576
|
num_unregistered = 0
|
|
493
577
|
for i, name in enumerate(names):
|
|
494
578
|
del self.named_tensors[name]
|
|
495
|
-
logger.info(
|
|
579
|
+
logger.info(
|
|
580
|
+
f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
|
|
581
|
+
)
|
|
496
582
|
num_unregistered += 1
|
|
497
583
|
return num_unregistered
|
|
498
584
|
|
|
499
|
-
def batch_transfer_sync_read(
|
|
500
|
-
|
|
585
|
+
def batch_transfer_sync_read(
|
|
586
|
+
self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
|
|
587
|
+
):
|
|
588
|
+
assert (
|
|
589
|
+
self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
|
|
590
|
+
)
|
|
501
591
|
|
|
502
592
|
|
|
503
593
|
class ParameterServer:
|
|
504
|
-
def __init__(
|
|
594
|
+
def __init__(
|
|
595
|
+
self,
|
|
596
|
+
*,
|
|
597
|
+
rank: int | None = None,
|
|
598
|
+
world_size: int | None = None,
|
|
599
|
+
auto_pg: bool = False,
|
|
600
|
+
gpu_count: int | None = None,
|
|
601
|
+
mem_fraction: float | None = None,
|
|
602
|
+
):
|
|
505
603
|
"""
|
|
506
604
|
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
|
|
507
605
|
|
|
508
606
|
Args:
|
|
509
607
|
auto_pg: Whether to automatically initialize the process group.
|
|
510
608
|
Notice that if auto_pg is True, will destroy the process group after update.
|
|
609
|
+
mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
|
|
511
610
|
"""
|
|
512
|
-
self._rank = int(os.environ.get("RANK", None))
|
|
513
|
-
self._world_size = int(os.environ.get("WORLD_SIZE", None))
|
|
514
|
-
self.
|
|
515
|
-
self._gpu_count = torch.cuda.device_count()
|
|
611
|
+
self._rank = rank or int(os.environ.get("RANK", None))
|
|
612
|
+
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
|
|
613
|
+
self._gpu_count = gpu_count or torch.cuda.device_count()
|
|
516
614
|
self._local_rank = self._rank % self._gpu_count
|
|
517
615
|
self._auto_pg = auto_pg
|
|
518
616
|
self._all_hosts = []
|
|
519
|
-
self.
|
|
617
|
+
self._global_device_uuids: list[str] = []
|
|
618
|
+
self._mem_fraction = mem_fraction or 0.9
|
|
520
619
|
|
|
521
620
|
assert self._rank is not None and self._rank >= 0, self._rank
|
|
522
621
|
assert self._world_size and self._world_size > 0, self._world_size
|
|
622
|
+
assert (
|
|
623
|
+
self._gpu_count is not None
|
|
624
|
+
and self._gpu_count > 0
|
|
625
|
+
and self._gpu_count <= torch.cuda.device_count()
|
|
626
|
+
), self._gpu_count
|
|
627
|
+
assert (
|
|
628
|
+
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
|
|
629
|
+
), self._mem_fraction
|
|
523
630
|
|
|
524
|
-
self._device_uuid = _get_physical_gpu_id(self._local_rank)
|
|
525
631
|
self._zmq_ctx = zmq.Context()
|
|
632
|
+
self._zmq_addr_counter = 0
|
|
526
633
|
|
|
527
634
|
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
528
635
|
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
@@ -533,19 +640,27 @@ class ParameterServer:
|
|
|
533
640
|
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
|
|
534
641
|
self._p2p_store = None
|
|
535
642
|
|
|
536
|
-
|
|
643
|
+
device_index = self._local_rank
|
|
644
|
+
torch.cuda.set_device(device_index)
|
|
645
|
+
self._device_uuid = _get_physical_gpu_id(device_index)
|
|
537
646
|
|
|
538
|
-
def _logger_rank0(self, msg):
|
|
647
|
+
def _logger_rank0(self, msg: str):
|
|
539
648
|
if self._local_rank == 0:
|
|
540
649
|
logger.info(msg)
|
|
541
650
|
|
|
542
|
-
def get_metas(self):
|
|
651
|
+
def get_metas(self) -> dict[int, MemoryBufferMetaList]:
|
|
543
652
|
return self._current_global_parameter_metas
|
|
544
653
|
|
|
545
654
|
def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
|
|
546
655
|
self._current_global_parameter_metas = metas
|
|
547
656
|
|
|
548
|
-
def register_checkpoint(
|
|
657
|
+
def register_checkpoint(
|
|
658
|
+
self,
|
|
659
|
+
checkpoint_name: str,
|
|
660
|
+
*,
|
|
661
|
+
files: list[str] | None = None,
|
|
662
|
+
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
663
|
+
) -> None:
|
|
549
664
|
"""
|
|
550
665
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
551
666
|
|
|
@@ -555,12 +670,18 @@ class ParameterServer:
|
|
|
555
670
|
named_tensors: The named tensors to register.
|
|
556
671
|
"""
|
|
557
672
|
try:
|
|
558
|
-
assert checkpoint_name not in self._memory_pool,
|
|
559
|
-
|
|
673
|
+
assert checkpoint_name not in self._memory_pool, (
|
|
674
|
+
f"checkpoint {checkpoint_name} already registered"
|
|
675
|
+
)
|
|
676
|
+
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
677
|
+
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
|
|
678
|
+
)
|
|
560
679
|
if self._p2p_store is not None:
|
|
561
680
|
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
562
681
|
except Exception:
|
|
563
|
-
logger.exception(
|
|
682
|
+
logger.exception(
|
|
683
|
+
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
|
|
684
|
+
)
|
|
564
685
|
if self._p2p_store is not None:
|
|
565
686
|
self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
566
687
|
self.unregister_checkpoint(checkpoint_name)
|
|
@@ -583,10 +704,6 @@ class ParameterServer:
|
|
|
583
704
|
# this works by using torch>=2.5.0
|
|
584
705
|
torch._C._host_emptyCache()
|
|
585
706
|
|
|
586
|
-
@cached_property
|
|
587
|
-
def _zmq_socket_path(self) -> str:
|
|
588
|
-
return f"ipc://@checkpoint-engine-{uuid.uuid4()}.sock"
|
|
589
|
-
|
|
590
707
|
def gather_metas(self, checkpoint_name: str):
|
|
591
708
|
"""
|
|
592
709
|
Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
|
|
@@ -598,19 +715,17 @@ class ParameterServer:
|
|
|
598
715
|
assert dist.is_initialized(), "process group is not initialized"
|
|
599
716
|
metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
|
|
600
717
|
metas = DataToGather(
|
|
601
|
-
memory_buffer_metas_list=
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
),
|
|
610
|
-
),
|
|
718
|
+
memory_buffer_metas_list=[
|
|
719
|
+
MemoryBufferMetas(
|
|
720
|
+
metas=x.metas,
|
|
721
|
+
ptr=x.buffer.data_ptr(),
|
|
722
|
+
size=x.size,
|
|
723
|
+
)
|
|
724
|
+
for x in self._memory_pool.get(checkpoint_name, [])
|
|
725
|
+
],
|
|
611
726
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
612
727
|
host_ip=_get_ip(),
|
|
613
|
-
|
|
728
|
+
device_uuid=self._device_uuid,
|
|
614
729
|
)
|
|
615
730
|
|
|
616
731
|
dist.all_gather_object(metas_lst, metas)
|
|
@@ -618,23 +733,31 @@ class ParameterServer:
|
|
|
618
733
|
self._current_global_parameter_metas = {}
|
|
619
734
|
num_parameters = 0
|
|
620
735
|
all_hosts: list[str] = []
|
|
621
|
-
|
|
736
|
+
global_device_uuids: list[str] = []
|
|
622
737
|
for i, metas_buckets in enumerate(metas_lst):
|
|
623
738
|
assert metas_buckets is not None, f"metas_buckets {i} should not be None"
|
|
624
739
|
if i % self._gpu_count == 0 and not self._all_hosts:
|
|
625
740
|
all_hosts.append(metas_buckets.host_ip)
|
|
626
|
-
if not self.
|
|
627
|
-
|
|
741
|
+
if not self._global_device_uuids:
|
|
742
|
+
global_device_uuids.append(metas_buckets.device_uuid)
|
|
628
743
|
if metas_buckets.memory_buffer_metas_list:
|
|
629
744
|
self._current_global_parameter_metas[i] = metas_buckets
|
|
630
|
-
num_parameters += sum(
|
|
745
|
+
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
|
|
631
746
|
if not self._all_hosts:
|
|
632
747
|
self._all_hosts = all_hosts
|
|
633
|
-
if not self.
|
|
634
|
-
self.
|
|
635
|
-
logger.info(
|
|
748
|
+
if not self._global_device_uuids:
|
|
749
|
+
self._global_device_uuids = global_device_uuids
|
|
750
|
+
logger.info(
|
|
751
|
+
f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
|
|
752
|
+
)
|
|
636
753
|
|
|
637
|
-
def init_process_group(
|
|
754
|
+
def init_process_group(
|
|
755
|
+
self,
|
|
756
|
+
*,
|
|
757
|
+
master_addr: str | None = None,
|
|
758
|
+
master_port: int | None = None,
|
|
759
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
760
|
+
):
|
|
638
761
|
"""
|
|
639
762
|
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
640
763
|
|
|
@@ -642,10 +765,22 @@ class ParameterServer:
|
|
|
642
765
|
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
|
|
643
766
|
timeout: The timeout of the process group.
|
|
644
767
|
"""
|
|
768
|
+
master_addr = master_addr or os.getenv("MASTER_ADDR")
|
|
769
|
+
assert master_addr, "master_addr is required"
|
|
645
770
|
store = dist.TCPStore(
|
|
646
|
-
|
|
771
|
+
master_addr,
|
|
772
|
+
_get_master_port(master_port),
|
|
773
|
+
self._world_size,
|
|
774
|
+
timeout=timeout,
|
|
775
|
+
is_master=self._rank == 0,
|
|
776
|
+
)
|
|
777
|
+
dist.init_process_group(
|
|
778
|
+
backend="nccl",
|
|
779
|
+
world_size=self._world_size,
|
|
780
|
+
rank=self._rank,
|
|
781
|
+
timeout=timeout,
|
|
782
|
+
store=store,
|
|
647
783
|
)
|
|
648
|
-
dist.init_process_group(backend="nccl", world_size=self._world_size, rank=self._rank, timeout=timeout, store=store)
|
|
649
784
|
logger.info(f"[rank{self._rank}] init process group successfully.")
|
|
650
785
|
|
|
651
786
|
def update(
|
|
@@ -653,8 +788,8 @@ class ParameterServer:
|
|
|
653
788
|
checkpoint_name: str,
|
|
654
789
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
655
790
|
*,
|
|
656
|
-
ranks: list[int] =
|
|
657
|
-
):
|
|
791
|
+
ranks: list[int] | None = None,
|
|
792
|
+
) -> None:
|
|
658
793
|
"""
|
|
659
794
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
660
795
|
|
|
@@ -667,18 +802,21 @@ class ParameterServer:
|
|
|
667
802
|
which is useful in disaggregated architecture.
|
|
668
803
|
"""
|
|
669
804
|
try:
|
|
805
|
+
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
670
806
|
if not ranks:
|
|
671
807
|
if self._auto_pg and not dist.is_initialized():
|
|
672
808
|
self.init_process_group()
|
|
673
809
|
self._update_per_bucket(checkpoint_name, req_func)
|
|
674
810
|
else:
|
|
675
|
-
if self._rank not in ranks:
|
|
811
|
+
if not self._auto_pg and self._rank not in ranks:
|
|
676
812
|
return
|
|
677
813
|
if self._auto_pg:
|
|
678
814
|
if dist.is_initialized():
|
|
679
815
|
dist.destroy_process_group()
|
|
680
816
|
# HACK: wait 2s to ensure destroy is finished
|
|
681
817
|
time.sleep(2)
|
|
818
|
+
if self._rank not in ranks:
|
|
819
|
+
return
|
|
682
820
|
self.init_process_group_for_ranks(ranks)
|
|
683
821
|
self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
|
|
684
822
|
if self._auto_pg:
|
|
@@ -692,21 +830,39 @@ class ParameterServer:
|
|
|
692
830
|
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
|
|
693
831
|
)
|
|
694
832
|
except Exception as e:
|
|
695
|
-
logger.exception(
|
|
696
|
-
|
|
833
|
+
logger.exception(
|
|
834
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
835
|
+
)
|
|
836
|
+
raise
|
|
837
|
+
|
|
838
|
+
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
|
|
839
|
+
def zmq_handle(device_uuid: str) -> str:
|
|
840
|
+
return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
|
|
697
841
|
|
|
698
|
-
|
|
699
|
-
|
|
842
|
+
socket_paths = [(uid, zmq_handle(uid)) for uid in self._global_device_uuids]
|
|
843
|
+
socket = self._zmq_ctx.socket(zmq.REQ)
|
|
844
|
+
socket.bind(zmq_handle(self._device_uuid))
|
|
845
|
+
self._zmq_addr_counter += 1
|
|
846
|
+
return socket, socket_paths
|
|
847
|
+
|
|
848
|
+
def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
|
|
849
|
+
GiB = 1 << 30 # noqa: N806
|
|
700
850
|
# auto detect bucket size
|
|
701
|
-
|
|
702
|
-
|
|
851
|
+
tensor = torch.tensor(
|
|
852
|
+
[
|
|
853
|
+
# proportion of current cuda free memory bytes
|
|
854
|
+
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
|
|
855
|
+
# we use negative value to reuse allreduce min operation
|
|
856
|
+
# for getting the max value of zmq_addr_counter in all ranks
|
|
857
|
+
-self._zmq_addr_counter,
|
|
858
|
+
],
|
|
703
859
|
dtype=torch.int64,
|
|
704
860
|
device="cuda",
|
|
705
861
|
)
|
|
706
|
-
dist.all_reduce(
|
|
707
|
-
|
|
862
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
863
|
+
tensor = tensor.cpu()
|
|
864
|
+
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
|
|
708
865
|
max_tensor_bytes = 0
|
|
709
|
-
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB_bytes
|
|
710
866
|
for items in self._current_global_parameter_metas.values():
|
|
711
867
|
for metas_list in items.memory_buffer_metas_list:
|
|
712
868
|
for meta in metas_list.metas:
|
|
@@ -729,18 +885,27 @@ class ParameterServer:
|
|
|
729
885
|
f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
|
|
730
886
|
)
|
|
731
887
|
disable_h2d_buffer = True
|
|
888
|
+
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
|
|
732
889
|
bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
|
|
733
|
-
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size /
|
|
890
|
+
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
|
|
734
891
|
return bucket_size, disable_h2d_buffer
|
|
735
892
|
|
|
736
|
-
def _copy_to_buffer(
|
|
893
|
+
def _copy_to_buffer(
|
|
894
|
+
self,
|
|
895
|
+
checkpoint_name: str,
|
|
896
|
+
bucket: H2DBucket,
|
|
897
|
+
buffer: torch.Tensor,
|
|
898
|
+
owner_rank: int | None = None,
|
|
899
|
+
):
|
|
737
900
|
offset = 0
|
|
738
901
|
if owner_rank is not None:
|
|
739
902
|
buf_ptrs, remote_ptrs, lens = [], [], []
|
|
740
903
|
ptr_base = buffer.data_ptr()
|
|
741
904
|
target_addr, ptrs = self._get_addr_ptrs(owner_rank)
|
|
742
905
|
for b in bucket.ranges:
|
|
743
|
-
assert offset + b.size <= bucket.size,
|
|
906
|
+
assert offset + b.size <= bucket.size, (
|
|
907
|
+
f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
|
|
908
|
+
)
|
|
744
909
|
if owner_rank is not None:
|
|
745
910
|
buf_ptrs.append(ptr_base + offset)
|
|
746
911
|
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
|
|
@@ -758,7 +923,11 @@ class ParameterServer:
|
|
|
758
923
|
torch.cuda.synchronize()
|
|
759
924
|
|
|
760
925
|
def init_process_group_for_ranks(
|
|
761
|
-
self,
|
|
926
|
+
self,
|
|
927
|
+
ranks: list[int],
|
|
928
|
+
*,
|
|
929
|
+
master_port: int | None = None,
|
|
930
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
762
931
|
):
|
|
763
932
|
"""
|
|
764
933
|
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
@@ -787,8 +956,12 @@ class ParameterServer:
|
|
|
787
956
|
# and will not participate in this update. Since they have registered memory addresses
|
|
788
957
|
# to p2p_store at the beginning, update ranks can directly get the memory addresses
|
|
789
958
|
# from other nodes and put the weights into the buffer.
|
|
790
|
-
store = dist.TCPStore(
|
|
791
|
-
|
|
959
|
+
store = dist.TCPStore(
|
|
960
|
+
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
961
|
+
)
|
|
962
|
+
dist.init_process_group(
|
|
963
|
+
backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
|
|
964
|
+
)
|
|
792
965
|
|
|
793
966
|
def _update_per_bucket_p2p(
|
|
794
967
|
self,
|
|
@@ -800,7 +973,9 @@ class ParameterServer:
|
|
|
800
973
|
assert ranks, "ranks should be set"
|
|
801
974
|
if len(self._current_global_parameter_metas) == 0:
|
|
802
975
|
raise ValueError("parameter metas is empty")
|
|
803
|
-
assert dist.is_initialized(),
|
|
976
|
+
assert dist.is_initialized(), (
|
|
977
|
+
"process group is not initialized when update model per bucket p2p"
|
|
978
|
+
)
|
|
804
979
|
|
|
805
980
|
need_update = self._rank in ranks
|
|
806
981
|
logger.info(
|
|
@@ -814,26 +989,24 @@ class ParameterServer:
|
|
|
814
989
|
# first execute a barrier to avoid subsequent cuda oom
|
|
815
990
|
dist.barrier()
|
|
816
991
|
|
|
817
|
-
bucket_size, _ = self.
|
|
992
|
+
bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
|
|
818
993
|
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
|
|
819
|
-
|
|
820
|
-
self._p2p_store.register_named_tensors({
|
|
994
|
+
ipc_buffer_name = "__ipc_buffer___"
|
|
995
|
+
self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
|
|
821
996
|
logger.info(
|
|
822
997
|
f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
|
|
823
998
|
)
|
|
824
999
|
handle = reduce_tensor(buffer)
|
|
825
1000
|
|
|
826
|
-
gidx = 0
|
|
827
1001
|
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
|
|
1002
|
+
socket, socket_paths = self._bind_zmq_socket()
|
|
828
1003
|
req_thread = threading.Thread(
|
|
829
1004
|
target=req_func,
|
|
830
|
-
args=(
|
|
1005
|
+
args=(socket_paths,),
|
|
831
1006
|
)
|
|
832
1007
|
req_thread.start()
|
|
833
|
-
socket = self._zmq_ctx.socket(zmq.REQ)
|
|
834
|
-
socket.bind(self._zmq_socket_path)
|
|
835
1008
|
socket.send_pyobj(handle)
|
|
836
|
-
for owner_rank, bucket in buckets:
|
|
1009
|
+
for gidx, (owner_rank, bucket) in enumerate(buckets):
|
|
837
1010
|
self._logger_rank0(
|
|
838
1011
|
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
839
1012
|
)
|
|
@@ -845,7 +1018,6 @@ class ParameterServer:
|
|
|
845
1018
|
socket.recv()
|
|
846
1019
|
dist.barrier()
|
|
847
1020
|
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
848
|
-
gidx += 1
|
|
849
1021
|
|
|
850
1022
|
socket.recv()
|
|
851
1023
|
socket.send_pyobj(None)
|
|
@@ -853,7 +1025,7 @@ class ParameterServer:
|
|
|
853
1025
|
req_thread.join()
|
|
854
1026
|
dist.barrier()
|
|
855
1027
|
socket.close()
|
|
856
|
-
self._p2p_store.unregister_named_tensors([
|
|
1028
|
+
self._p2p_store.unregister_named_tensors([ipc_buffer_name])
|
|
857
1029
|
torch.cuda.empty_cache()
|
|
858
1030
|
|
|
859
1031
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
@@ -877,7 +1049,9 @@ class ParameterServer:
|
|
|
877
1049
|
pool = self._memory_pool[checkpoint_name]
|
|
878
1050
|
if len(pool) == 0:
|
|
879
1051
|
return 0
|
|
880
|
-
return self._p2p_store.unregister_named_tensors(
|
|
1052
|
+
return self._p2p_store.unregister_named_tensors(
|
|
1053
|
+
[f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
|
|
1054
|
+
)
|
|
881
1055
|
|
|
882
1056
|
def _update_per_bucket(
|
|
883
1057
|
self,
|
|
@@ -891,11 +1065,13 @@ class ParameterServer:
|
|
|
891
1065
|
|
|
892
1066
|
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
893
1067
|
|
|
894
|
-
bucket_size, disable_h2d_buffer = self.
|
|
1068
|
+
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
895
1069
|
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
|
|
896
1070
|
|
|
897
1071
|
h2d_buffer: torch.Tensor | None = (
|
|
898
|
-
None
|
|
1072
|
+
None
|
|
1073
|
+
if disable_h2d_buffer
|
|
1074
|
+
else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
|
|
899
1075
|
)
|
|
900
1076
|
|
|
901
1077
|
owner_rank_buckets: list[H2DBucket] = []
|
|
@@ -914,13 +1090,12 @@ class ParameterServer:
|
|
|
914
1090
|
if len(buckets_by_owner_rank[owner_rank]) > max_len:
|
|
915
1091
|
max_len = len(buckets_by_owner_rank[owner_rank])
|
|
916
1092
|
|
|
1093
|
+
socket, socket_paths = self._bind_zmq_socket()
|
|
917
1094
|
req_thread = threading.Thread(
|
|
918
1095
|
target=req_func,
|
|
919
|
-
args=(
|
|
1096
|
+
args=(socket_paths,),
|
|
920
1097
|
)
|
|
921
1098
|
req_thread.start()
|
|
922
|
-
socket = self._zmq_ctx.socket(zmq.REQ)
|
|
923
|
-
socket.bind(self._zmq_socket_path)
|
|
924
1099
|
socket.send_pyobj(handle)
|
|
925
1100
|
|
|
926
1101
|
gidx = 0
|
|
@@ -932,7 +1107,10 @@ class ParameterServer:
|
|
|
932
1107
|
if i >= len(_buckets):
|
|
933
1108
|
continue
|
|
934
1109
|
bucket = _buckets[i]
|
|
935
|
-
alloc, reserved =
|
|
1110
|
+
alloc, reserved = (
|
|
1111
|
+
torch.cuda.memory_allocated() / 1024 / 1024,
|
|
1112
|
+
torch.cuda.memory_reserved() / 1024 / 1024,
|
|
1113
|
+
)
|
|
936
1114
|
self._logger_rank0(
|
|
937
1115
|
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
938
1116
|
f"Current CUDA allocated {alloc:.2f} MB, "
|
|
@@ -960,7 +1138,7 @@ class ParameterServer:
|
|
|
960
1138
|
torch.cuda.empty_cache()
|
|
961
1139
|
|
|
962
1140
|
|
|
963
|
-
def _init_api(ps: ParameterServer):
|
|
1141
|
+
def _init_api(ps: ParameterServer) -> Any:
|
|
964
1142
|
import fastapi
|
|
965
1143
|
from fastapi import Request
|
|
966
1144
|
from fastapi.responses import JSONResponse, Response
|
|
@@ -976,32 +1154,32 @@ def _init_api(ps: ParameterServer):
|
|
|
976
1154
|
inference_group_ranks: list[int] = []
|
|
977
1155
|
timeout: float = 300.0
|
|
978
1156
|
|
|
979
|
-
def wrap_exception(func):
|
|
1157
|
+
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
980
1158
|
try:
|
|
981
1159
|
func()
|
|
982
|
-
except Exception as e:
|
|
1160
|
+
except Exception as e: # noqa: BLE001
|
|
983
1161
|
logger.exception(f"wrap exception {func} failed")
|
|
984
1162
|
return JSONResponse(content=str(e), status_code=500)
|
|
985
1163
|
return Response(status_code=200)
|
|
986
1164
|
|
|
987
1165
|
@app.post("/v1/checkpoints/{checkpoint_name}/files")
|
|
988
|
-
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request):
|
|
1166
|
+
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
|
|
989
1167
|
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
|
|
990
1168
|
|
|
991
1169
|
@app.delete("/v1/checkpoints/{checkpoint_name}")
|
|
992
|
-
async def unregister_checkpoint(checkpoint_name: str):
|
|
1170
|
+
async def unregister_checkpoint(checkpoint_name: str) -> Response:
|
|
993
1171
|
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
|
|
994
1172
|
|
|
995
1173
|
@app.get("/v1/healthz")
|
|
996
|
-
async def healthz():
|
|
1174
|
+
async def healthz() -> Response:
|
|
997
1175
|
return Response(status_code=200)
|
|
998
1176
|
|
|
999
1177
|
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
|
|
1000
|
-
async def gather_metas(checkpoint_name: str):
|
|
1178
|
+
async def gather_metas(checkpoint_name: str) -> Response:
|
|
1001
1179
|
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
|
|
1002
1180
|
|
|
1003
1181
|
@app.post("/v1/checkpoints/{checkpoint_name}/update")
|
|
1004
|
-
async def update(checkpoint_name: str, req: UpdateRequest):
|
|
1182
|
+
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
|
|
1005
1183
|
def update_func(socket_paths: list[tuple[str, str]]):
|
|
1006
1184
|
if req.update_url is None:
|
|
1007
1185
|
return
|
|
@@ -1018,11 +1196,13 @@ def _init_api(ps: ParameterServer):
|
|
|
1018
1196
|
def run_from_cli():
|
|
1019
1197
|
import uvicorn
|
|
1020
1198
|
|
|
1021
|
-
parser = argparse.ArgumentParser(description="
|
|
1199
|
+
parser = argparse.ArgumentParser(description="Parameter Server")
|
|
1022
1200
|
parser.add_argument("--uds", type=str)
|
|
1023
1201
|
|
|
1024
1202
|
args = parser.parse_args()
|
|
1025
|
-
logger.info(
|
|
1203
|
+
logger.info(
|
|
1204
|
+
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
|
|
1205
|
+
)
|
|
1026
1206
|
|
|
1027
1207
|
assert args.uds and len(args.uds) > 0, args.uds
|
|
1028
1208
|
ps = ParameterServer(auto_pg=True)
|
checkpoint_engine/worker.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TypedDict
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
import zmq
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id:
|
|
9
|
+
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
9
10
|
func, args = handle
|
|
10
11
|
list_args = list(args)
|
|
11
12
|
if device_id is not None:
|
|
@@ -24,12 +25,14 @@ class FlattenedTensorMetadata(TypedDict):
|
|
|
24
25
|
offset: int
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
def _extract_weights(
|
|
28
|
+
def _extract_weights(
|
|
29
|
+
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
|
|
30
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
28
31
|
assert buffer is not None
|
|
29
32
|
weights: list[tuple[str, torch.Tensor]] = []
|
|
30
33
|
for item in payload:
|
|
31
34
|
shape = item["shape"]
|
|
32
|
-
if isinstance(shape,
|
|
35
|
+
if isinstance(shape, list | tuple):
|
|
33
36
|
shape = torch.Size(shape)
|
|
34
37
|
assert isinstance(shape, torch.Size)
|
|
35
38
|
dtype, offset = item["dtype"], item["offset"]
|
|
@@ -45,11 +48,11 @@ def update_weights_from_ipc(
|
|
|
45
48
|
device_id: int,
|
|
46
49
|
*,
|
|
47
50
|
run: Callable[[list[tuple[str, torch.Tensor]]], None],
|
|
48
|
-
post_hook: Callable[[], None] = None,
|
|
51
|
+
post_hook: Callable[[], None] | None = None,
|
|
49
52
|
):
|
|
50
53
|
socket = zmq_ctx.socket(zmq.REP)
|
|
51
54
|
socket.connect(zmq_handle)
|
|
52
|
-
buffer:
|
|
55
|
+
buffer: torch.Tensor | None = None
|
|
53
56
|
while True:
|
|
54
57
|
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
|
|
55
58
|
if payload is None:
|
|
@@ -100,5 +103,7 @@ class VllmColocateWorkerExtension:
|
|
|
100
103
|
zmq_handles[device_uuid],
|
|
101
104
|
device_id=self.device.index,
|
|
102
105
|
run=self.model_runner.model.load_weights,
|
|
103
|
-
post_hook=lambda: process_weights_after_loading(
|
|
106
|
+
post_hook=lambda: process_weights_after_loading(
|
|
107
|
+
self.model_runner.model, self.model_config, self.device
|
|
108
|
+
),
|
|
104
109
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -11,14 +11,15 @@ Description-Content-Type: text/markdown
|
|
|
11
11
|
License-File: LICENCE
|
|
12
12
|
Requires-Dist: torch>=2.5.0
|
|
13
13
|
Requires-Dist: fastapi
|
|
14
|
-
Requires-Dist: pydantic
|
|
14
|
+
Requires-Dist: pydantic>=2.0.0
|
|
15
15
|
Requires-Dist: safetensors
|
|
16
16
|
Requires-Dist: pyzmq
|
|
17
17
|
Requires-Dist: uvicorn
|
|
18
18
|
Requires-Dist: loguru
|
|
19
19
|
Requires-Dist: numpy
|
|
20
|
+
Requires-Dist: httpx
|
|
20
21
|
Provides-Extra: p2p
|
|
21
|
-
Requires-Dist: mooncake-transfer-engine; extra == "p2p"
|
|
22
|
+
Requires-Dist: mooncake-transfer-engine>=0.3.5; extra == "p2p"
|
|
22
23
|
Dynamic: license-file
|
|
23
24
|
|
|
24
25
|
# Checkpoint Engine
|
|
@@ -41,7 +42,7 @@ The core weight update logic is in `ParameterServer` class, a service colocated
|
|
|
41
42
|
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
|
|
42
43
|
|
|
43
44
|
### Optimized Weight Broadcast
|
|
44
|
-
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
|
|
45
|
+
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
|
|
45
46
|
We arrange the data transfer into 3 stages:
|
|
46
47
|
1. H2D: moving weights to GPU memory. These weights may come from disk or the training engine.
|
|
47
48
|
2. broadcast: broadcast among checkpoint engine workers; the data results in a CUDA IPC buffer shared with inference engine.
|
|
@@ -73,9 +74,9 @@ Pipelining naturally requires more GPU memory. When memory is not enough, checkp
|
|
|
73
74
|
All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
|
|
74
75
|
|
|
75
76
|
* FP8 test needs additional vLLM patches, see [FP8 quantization](#fp8-quantization).
|
|
76
|
-
* Device Info: we tested various combination of devices and
|
|
77
|
+
* Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
|
|
77
78
|
* Since update duration is related to IPC bucket size, we provide the bucket size in the table.
|
|
78
|
-
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
|
|
79
|
+
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
|
|
79
80
|
|
|
80
81
|
## Installation
|
|
81
82
|
|
|
@@ -88,7 +89,7 @@ pip install checkpoint-engine
|
|
|
88
89
|
Use the flexible P2P implementation, notice this will install `mooncake-transfer-engine` to support RDMA transfer between different ranks.
|
|
89
90
|
|
|
90
91
|
```Bash
|
|
91
|
-
pip install checkpoint-engine[p2p]
|
|
92
|
+
pip install 'checkpoint-engine[p2p]'
|
|
92
93
|
```
|
|
93
94
|
|
|
94
95
|
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
@@ -107,7 +108,7 @@ VLLM_USE_PRECOMPILED=1 uv pip install --editable .
|
|
|
107
108
|
Install checkpoint-engine
|
|
108
109
|
|
|
109
110
|
```Bash
|
|
110
|
-
uv pip install checkpoint-engine[p2p]
|
|
111
|
+
uv pip install 'checkpoint-engine[p2p]'
|
|
111
112
|
```
|
|
112
113
|
|
|
113
114
|
We use `Qwen/Qwen3-235B-A22B-Instruct-2507` (BF16) as the test model
|
|
@@ -133,7 +134,7 @@ torchrun --nproc-per-node 8 examples/update.py --update-method all --checkpoint-
|
|
|
133
134
|
|
|
134
135
|
### Reuse weights from existing instances
|
|
135
136
|
|
|
136
|
-
New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
|
|
137
|
+
New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
|
|
137
138
|
|
|
138
139
|
First, start the existing instances with `--save-metas-file global_metas.pkl` to save global metas to a file and use `--sleep-time 300` to make sure they stay alive.
|
|
139
140
|
|
|
@@ -150,7 +151,7 @@ torchrun --nproc-per-node 8 examples/update.py --load-metas-file global_metas.pk
|
|
|
150
151
|
|
|
151
152
|
### FP8 quantization
|
|
152
153
|
|
|
153
|
-
FP8 quantization currently do not natively work in vLLM when updating weights.
|
|
154
|
+
FP8 quantization currently do not natively work in vLLM when updating weights.
|
|
154
155
|
We provide a simple patch in [`patches/vllm_fp8.patch`](./patches/vllm_fp8.patch) to handle the correct weight update.
|
|
155
156
|
Notice this patch is only tested in DeepSeek-V3.1 and Kimi-K2. Other models may meet some compatible issues.
|
|
156
157
|
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
|
|
3
|
+
checkpoint_engine/ps.py,sha256=9dXRXi0QDPoRYrgGKAYvdmDFBXejgusjR0ltbii5_B0,49134
|
|
4
|
+
checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
|
|
5
|
+
checkpoint_engine-0.1.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
6
|
+
checkpoint_engine-0.1.3.dist-info/METADATA,sha256=y96dMjEOKWaO_PA0h5BX8G3Ku7Tt1jCU09uIf8iYgic,9322
|
|
7
|
+
checkpoint_engine-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
+
checkpoint_engine-0.1.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
+
checkpoint_engine-0.1.3.dist-info/RECORD,,
|
|
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
|
18
18
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
19
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
20
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
-
SOFTWARE
|
|
21
|
+
SOFTWARE
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
-
checkpoint_engine/_version.py,sha256=yfx_VE-4lpqM4jnWOSq-8rihMWIwMaX9CQ7tNEpA4T0,712
|
|
3
|
-
checkpoint_engine/ps.py,sha256=9u2rLOj-oQrXsnpYhdXFjv7ak2-f4BRUXB6KYlG3ah0,44422
|
|
4
|
-
checkpoint_engine/worker.py,sha256=OrSeknjtECnO88I-YMdfkZj70TIRhjvTEeZkNyZk21M,3695
|
|
5
|
-
checkpoint_engine-0.1.1.dist-info/licenses/LICENCE,sha256=0jqA0jrA_i9VUqd7FTVoI1KnN1ZRENwG_tlMRCjC63k,1066
|
|
6
|
-
checkpoint_engine-0.1.1.dist-info/METADATA,sha256=WDc5tg3RQiCthzbEzI4D3t3I0AQhGUPcDpJn7TMhYbI,9286
|
|
7
|
-
checkpoint_engine-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
-
checkpoint_engine-0.1.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
-
checkpoint_engine-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|