checkpoint-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +3 -3
- checkpoint_engine/ps.py +337 -172
- checkpoint_engine/worker.py +12 -7
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.2.dist-info}/METADATA +11 -10
- checkpoint_engine-0.1.2.dist-info/RECORD +9 -0
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.2.dist-info}/licenses/LICENCE +1 -1
- checkpoint_engine-0.1.1.dist-info/RECORD +0 -9
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.2.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.1.1.dist-info → checkpoint_engine-0.1.2.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 1,
|
|
31
|
+
__version__ = version = '0.1.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 1, 2)
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id =
|
|
34
|
+
__commit_id__ = commit_id = None
|
checkpoint_engine/ps.py
CHANGED
|
@@ -2,31 +2,32 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import argparse
|
|
4
4
|
import concurrent.futures
|
|
5
|
+
import ctypes
|
|
5
6
|
import os
|
|
6
7
|
import pickle
|
|
7
8
|
import random
|
|
8
9
|
import socket
|
|
9
|
-
import subprocess
|
|
10
10
|
import threading
|
|
11
11
|
import time
|
|
12
|
-
import uuid
|
|
13
12
|
from collections import defaultdict
|
|
14
13
|
from datetime import timedelta
|
|
15
|
-
from functools import
|
|
16
|
-
from typing import
|
|
14
|
+
from functools import lru_cache
|
|
15
|
+
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
|
|
17
16
|
|
|
17
|
+
import httpx
|
|
18
18
|
import numpy as np
|
|
19
|
-
import requests
|
|
20
19
|
import torch
|
|
21
20
|
import torch.distributed as dist
|
|
22
21
|
import zmq
|
|
23
22
|
from loguru import logger
|
|
24
|
-
from pydantic import BaseModel,
|
|
23
|
+
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
25
24
|
from safetensors.torch import safe_open
|
|
26
25
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
27
|
-
|
|
26
|
+
|
|
28
27
|
|
|
29
28
|
if TYPE_CHECKING:
|
|
29
|
+
from collections.abc import Callable
|
|
30
|
+
|
|
30
31
|
from typing_extensions import TypedDict
|
|
31
32
|
|
|
32
33
|
class FileMeta(TypedDict):
|
|
@@ -37,16 +38,59 @@ if TYPE_CHECKING:
|
|
|
37
38
|
tp_concat_dim: int
|
|
38
39
|
|
|
39
40
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
41
|
+
def _dt_validate(value: Any) -> torch.dtype:
|
|
42
|
+
if isinstance(value, str):
|
|
43
|
+
if not value.startswith("torch."):
|
|
44
|
+
raise ValueError(f"dtype {value} should start with torch.")
|
|
45
|
+
try:
|
|
46
|
+
value = getattr(torch, value.split(".")[1])
|
|
47
|
+
except AttributeError as e:
|
|
48
|
+
raise ValueError(f"unknown dtype: {value}") from e
|
|
49
|
+
if not isinstance(value, torch.dtype):
|
|
50
|
+
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
|
|
51
|
+
return value
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
_TorchDtype = Annotated[
|
|
55
|
+
torch.dtype,
|
|
56
|
+
PlainValidator(_dt_validate),
|
|
57
|
+
PlainSerializer(lambda x: str(x), return_type=str),
|
|
58
|
+
WithJsonSchema({"type": "string"}, mode="serialization"),
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _size_validate(value: Any) -> torch.Size:
|
|
63
|
+
if isinstance(value, list | tuple):
|
|
64
|
+
return torch.Size(value)
|
|
65
|
+
if not isinstance(value, torch.Size):
|
|
66
|
+
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
|
|
67
|
+
return value
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_TorchSize = Annotated[
|
|
71
|
+
torch.Size,
|
|
72
|
+
PlainValidator(_size_validate),
|
|
73
|
+
PlainSerializer(lambda x: tuple(x), return_type=tuple),
|
|
74
|
+
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
|
|
75
|
+
]
|
|
76
|
+
|
|
46
77
|
|
|
78
|
+
def _tensor_validate(value: Any) -> torch.Tensor:
|
|
79
|
+
if isinstance(value, torch.Tensor):
|
|
80
|
+
return value
|
|
81
|
+
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
_TorchTensor = Annotated[
|
|
85
|
+
torch.Tensor,
|
|
86
|
+
PlainValidator(_tensor_validate),
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class ParameterMeta(BaseModel):
|
|
47
91
|
name: str
|
|
48
|
-
dtype:
|
|
49
|
-
shape:
|
|
92
|
+
dtype: _TorchDtype
|
|
93
|
+
shape: _TorchSize
|
|
50
94
|
|
|
51
95
|
|
|
52
96
|
class BucketRange(NamedTuple):
|
|
@@ -68,9 +112,7 @@ class MemoryBufferMetas(BaseModel):
|
|
|
68
112
|
|
|
69
113
|
|
|
70
114
|
class MemoryBuffer(BaseModel):
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
buffer: torch.Tensor
|
|
115
|
+
buffer: _TorchTensor
|
|
74
116
|
size: int
|
|
75
117
|
metas: list[ParameterMeta]
|
|
76
118
|
|
|
@@ -82,7 +124,7 @@ class MemoryBufferMetaList(BaseModel):
|
|
|
82
124
|
|
|
83
125
|
class DataToGather(MemoryBufferMetaList):
|
|
84
126
|
host_ip: str
|
|
85
|
-
|
|
127
|
+
device_uuid: str
|
|
86
128
|
|
|
87
129
|
|
|
88
130
|
# 256 bytes alignment when flatten torch tensors to uint8 buffer
|
|
@@ -93,7 +135,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
|
93
135
|
return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
|
|
94
136
|
|
|
95
137
|
|
|
96
|
-
def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
|
|
138
|
+
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
97
139
|
ret = []
|
|
98
140
|
for meta in metas:
|
|
99
141
|
size = _align_size(meta.dtype, meta.shape)
|
|
@@ -110,10 +152,10 @@ def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
|
|
|
110
152
|
|
|
111
153
|
|
|
112
154
|
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
|
|
113
|
-
def _safetensors_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
155
|
+
def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
114
156
|
ret = {}
|
|
115
157
|
with safe_open(fn, framework="pt") as f:
|
|
116
|
-
for name in f.keys():
|
|
158
|
+
for name in f.keys(): # noqa: SIM118
|
|
117
159
|
weight = f.get_tensor(name)
|
|
118
160
|
meta = {
|
|
119
161
|
"key": name,
|
|
@@ -126,10 +168,10 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
|
|
|
126
168
|
return ret
|
|
127
169
|
|
|
128
170
|
# deprecated, will be removed in the future
|
|
129
|
-
def _fast_np_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
171
|
+
def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
130
172
|
"""load *.np file and return memmap and related tensor meta"""
|
|
131
173
|
|
|
132
|
-
def parse_npy_header(fin):
|
|
174
|
+
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
|
|
133
175
|
start = fin.tell()
|
|
134
176
|
major, minor = np.lib.format.read_magic(fin)
|
|
135
177
|
if major == 1 and minor == 0:
|
|
@@ -137,7 +179,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
|
|
|
137
179
|
elif major == 2 and minor == 0:
|
|
138
180
|
read_header_fn = np.lib.format.read_array_header_2_0
|
|
139
181
|
else:
|
|
140
|
-
raise ValueError(
|
|
182
|
+
raise ValueError(
|
|
183
|
+
f"unknown version {major}.{minor} when parsing npy header from {fn}"
|
|
184
|
+
)
|
|
141
185
|
shape, is_fortran, dtype = read_header_fn(fin)
|
|
142
186
|
return {
|
|
143
187
|
"shape": shape,
|
|
@@ -193,7 +237,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
|
|
|
193
237
|
return tp_rank, ret
|
|
194
238
|
|
|
195
239
|
|
|
196
|
-
def _concat_tp_weights(
|
|
240
|
+
def _concat_tp_weights(
|
|
241
|
+
tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
|
|
242
|
+
) -> torch.Tensor:
|
|
197
243
|
"""Concat tp weights with meta info.
|
|
198
244
|
If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
|
|
199
245
|
Else we will cat weights in concat_dim.
|
|
@@ -206,39 +252,54 @@ def _concat_tp_weights(tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_si
|
|
|
206
252
|
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
207
253
|
|
|
208
254
|
|
|
209
|
-
def _get_physical_gpu_id(
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
for line in lines:
|
|
215
|
-
if f"GPU {rank}" in line:
|
|
216
|
-
uuid = line.split("UUID: ")[1].strip(")")
|
|
217
|
-
return uuid
|
|
218
|
-
raise ValueError(f"not found gpu{rank} uuid")
|
|
255
|
+
def _get_physical_gpu_id(device_index: int | None = None) -> str:
|
|
256
|
+
try:
|
|
257
|
+
return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
|
|
258
|
+
except AssertionError as e:
|
|
259
|
+
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
219
260
|
|
|
220
261
|
|
|
221
262
|
@lru_cache(maxsize=1)
|
|
222
|
-
def _get_ip():
|
|
263
|
+
def _get_ip() -> str:
|
|
223
264
|
try:
|
|
224
265
|
# try to get ip from network interface
|
|
225
266
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
226
267
|
s.connect(("8.8.8.8", 80))
|
|
227
268
|
return s.getsockname()[0]
|
|
228
|
-
except:
|
|
269
|
+
except Exception as e: # noqa: BLE001
|
|
229
270
|
# fallback to get ip from hostname
|
|
230
|
-
logger.warning(
|
|
271
|
+
logger.warning(
|
|
272
|
+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
273
|
+
)
|
|
231
274
|
return socket.gethostbyname(socket.gethostname())
|
|
232
275
|
|
|
233
276
|
|
|
277
|
+
def _ibv_get_device_list() -> list[str]:
|
|
278
|
+
lib = ctypes.CDLL("libibverbs.so.1")
|
|
279
|
+
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
280
|
+
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
|
|
281
|
+
|
|
282
|
+
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
|
283
|
+
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
|
|
284
|
+
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
|
|
285
|
+
|
|
286
|
+
num = ctypes.c_int()
|
|
287
|
+
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
|
|
288
|
+
if not dev_array or num.value <= 0:
|
|
289
|
+
return []
|
|
290
|
+
|
|
291
|
+
devices = []
|
|
292
|
+
for i in range(num.value):
|
|
293
|
+
dev_ptr = dev_array[i] # struct ibv_device *
|
|
294
|
+
name = lib.ibv_get_device_name(dev_ptr) # const char *
|
|
295
|
+
devices.append(name.decode())
|
|
296
|
+
lib.ibv_free_device_list(dev_array)
|
|
297
|
+
return devices
|
|
298
|
+
|
|
299
|
+
|
|
234
300
|
def _get_rdma_devices() -> list[str]:
|
|
235
301
|
"""
|
|
236
|
-
use
|
|
237
|
-
```bash
|
|
238
|
-
pushd /sys/class/infiniband/ > /dev/null;
|
|
239
|
-
for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done;
|
|
240
|
-
popd > /dev/null;
|
|
241
|
-
```
|
|
302
|
+
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
|
|
242
303
|
"""
|
|
243
304
|
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
|
|
244
305
|
if devices_str:
|
|
@@ -246,41 +307,27 @@ def _get_rdma_devices() -> list[str]:
|
|
|
246
307
|
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
247
308
|
hca = os.getenv("NCCL_IB_HCA", None)
|
|
248
309
|
if hca:
|
|
249
|
-
|
|
250
|
-
if len(
|
|
310
|
+
hca_list = hca.split(",")
|
|
311
|
+
if len(hca_list) > 1:
|
|
251
312
|
# if NCCL_IB_HCA has multiple values, just return
|
|
252
|
-
return
|
|
313
|
+
return hca_list
|
|
253
314
|
else:
|
|
254
|
-
hca =
|
|
255
|
-
|
|
256
|
-
port_path = "ports/1/gid_attrs/types"
|
|
257
|
-
devices = []
|
|
258
|
-
for device in sorted(os.listdir(basepath)):
|
|
259
|
-
if hca is not None and hca not in device:
|
|
260
|
-
continue
|
|
261
|
-
path = os.path.join(basepath, device, port_path)
|
|
262
|
-
if not os.path.exists(path) or not os.path.isdir(path):
|
|
263
|
-
continue
|
|
264
|
-
for port in os.listdir(path):
|
|
265
|
-
try:
|
|
266
|
-
content = open(os.path.join(path, port)).read()
|
|
267
|
-
if "v" in content:
|
|
268
|
-
print(f"found rdma device {device} in port {port}: {content.strip()}")
|
|
269
|
-
devices.append(device)
|
|
270
|
-
break
|
|
271
|
-
except Exception:
|
|
272
|
-
pass
|
|
273
|
-
return devices
|
|
315
|
+
hca = hca_list[0]
|
|
316
|
+
return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
|
|
274
317
|
|
|
275
318
|
|
|
276
|
-
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]):
|
|
319
|
+
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
277
320
|
"""
|
|
278
321
|
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
|
|
279
322
|
"""
|
|
280
323
|
if not devices:
|
|
281
324
|
raise RuntimeError("no rdma devices found")
|
|
282
|
-
assert len(devices) <= gpu_count,
|
|
283
|
-
|
|
325
|
+
assert len(devices) <= gpu_count, (
|
|
326
|
+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
327
|
+
)
|
|
328
|
+
assert gpu_count % len(devices) == 0, (
|
|
329
|
+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
330
|
+
)
|
|
284
331
|
return devices[local_rank // (gpu_count // len(devices))]
|
|
285
332
|
|
|
286
333
|
|
|
@@ -305,8 +352,12 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
305
352
|
size=1,
|
|
306
353
|
)
|
|
307
354
|
if parameter_name not in parameter_metas:
|
|
308
|
-
assert isinstance(meta["dtype"], torch.dtype),
|
|
309
|
-
|
|
355
|
+
assert isinstance(meta["dtype"], torch.dtype), (
|
|
356
|
+
f"meta {meta} dtype should be torch.dtype"
|
|
357
|
+
)
|
|
358
|
+
assert isinstance(meta["shape"], torch.Size), (
|
|
359
|
+
f"meta {meta} shape should be torch.Size"
|
|
360
|
+
)
|
|
310
361
|
parameter_metas[parameter_name] = ParameterMeta(
|
|
311
362
|
name=parameter_name,
|
|
312
363
|
shape=meta["shape"],
|
|
@@ -319,7 +370,9 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
319
370
|
if tp_meta.concat_dim != -1:
|
|
320
371
|
shape = list(parameter_metas[name].shape)
|
|
321
372
|
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
322
|
-
parameter_metas[name] = ParameterMeta(
|
|
373
|
+
parameter_metas[name] = ParameterMeta(
|
|
374
|
+
name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
|
|
375
|
+
)
|
|
323
376
|
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
324
377
|
# TODO: here concat is serial, which may be slow
|
|
325
378
|
# but since tp storage is not used in the future
|
|
@@ -338,17 +391,19 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
338
391
|
|
|
339
392
|
def _register_checkpoint(
|
|
340
393
|
*,
|
|
341
|
-
files: list[str]
|
|
342
|
-
named_tensors: dict[str, torch.Tensor]
|
|
394
|
+
files: list[str],
|
|
395
|
+
named_tensors: dict[str, torch.Tensor],
|
|
343
396
|
rank: int | None = None,
|
|
344
397
|
) -> list[MemoryBuffer]:
|
|
345
|
-
logger.info(
|
|
398
|
+
logger.info(
|
|
399
|
+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
400
|
+
)
|
|
346
401
|
if not files and not named_tensors:
|
|
347
402
|
return []
|
|
348
403
|
parameters = _load_checkpoint(files)
|
|
349
404
|
if named_tensors:
|
|
350
405
|
parameters.update(named_tensors)
|
|
351
|
-
bucket_size = max(4 << 30, max(
|
|
406
|
+
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
|
|
352
407
|
|
|
353
408
|
class MemoryBucket(BaseModel):
|
|
354
409
|
size: int
|
|
@@ -363,7 +418,10 @@ def _register_checkpoint(
|
|
|
363
418
|
buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
|
|
364
419
|
buckets[-1].size += size
|
|
365
420
|
|
|
366
|
-
memory_buffers = [
|
|
421
|
+
memory_buffers = [
|
|
422
|
+
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
|
|
423
|
+
for bucket in buckets
|
|
424
|
+
]
|
|
367
425
|
|
|
368
426
|
def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
|
|
369
427
|
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
@@ -373,7 +431,10 @@ def _register_checkpoint(
|
|
|
373
431
|
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
374
432
|
|
|
375
433
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
376
|
-
futures = [
|
|
434
|
+
futures = [
|
|
435
|
+
executor.submit(register_pin_memory, idx, bucket.size)
|
|
436
|
+
for idx, bucket in enumerate(buckets)
|
|
437
|
+
]
|
|
377
438
|
new_futures = []
|
|
378
439
|
for future in concurrent.futures.as_completed(futures):
|
|
379
440
|
idx, buffer = future.result()
|
|
@@ -400,8 +461,26 @@ def _register_checkpoint(
|
|
|
400
461
|
return memory_buffers
|
|
401
462
|
|
|
402
463
|
|
|
403
|
-
def request_inference_to_update(
|
|
404
|
-
|
|
464
|
+
def request_inference_to_update(
|
|
465
|
+
url: str,
|
|
466
|
+
socket_paths: dict[str, str],
|
|
467
|
+
timeout: float = 300.0,
|
|
468
|
+
uds: str | None = None,
|
|
469
|
+
):
|
|
470
|
+
"""Send an inference update request to inference server via HTTP or Unix socket.
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
|
|
474
|
+
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
|
|
475
|
+
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
|
|
476
|
+
uds (str, optional): Path to a Unix domain socket. If provided, the request
|
|
477
|
+
will be sent via the Unix socket instead of HTTP. Defaults to None.
|
|
478
|
+
|
|
479
|
+
Raises:
|
|
480
|
+
httpx.HTTPStatusError: If the response contains an HTTP error status.
|
|
481
|
+
httpx.RequestError: If there was an issue while making the request.
|
|
482
|
+
"""
|
|
483
|
+
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
|
|
405
484
|
url,
|
|
406
485
|
json={
|
|
407
486
|
"method": "update_weights_from_ipc",
|
|
@@ -413,7 +492,9 @@ def request_inference_to_update(url: str, socket_paths: dict[str, str], timeout:
|
|
|
413
492
|
resp.raise_for_status()
|
|
414
493
|
|
|
415
494
|
|
|
416
|
-
def _gen_h2d_buckets(
|
|
495
|
+
def _gen_h2d_buckets(
|
|
496
|
+
global_metas: dict[int, MemoryBufferMetaList], bucket_size: int
|
|
497
|
+
) -> list[tuple[int, H2DBucket]]:
|
|
417
498
|
buckets: list[tuple[int, H2DBucket]] = []
|
|
418
499
|
|
|
419
500
|
for owner_rank, items in global_metas.items():
|
|
@@ -424,14 +505,18 @@ def _gen_h2d_buckets(global_metas: dict[int, MemoryBufferMetaList], bucket_size:
|
|
|
424
505
|
s = _align_size(meta.dtype, meta.shape)
|
|
425
506
|
if buckets[-1][1].size + s > bucket_size:
|
|
426
507
|
if offset - start_offset > 0:
|
|
427
|
-
buckets[-1][1].ranges.append(
|
|
508
|
+
buckets[-1][1].ranges.append(
|
|
509
|
+
BucketRange(idx, start_offset, offset - start_offset)
|
|
510
|
+
)
|
|
428
511
|
start_offset = offset
|
|
429
512
|
buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
|
|
430
513
|
offset += s
|
|
431
514
|
buckets[-1][1].size += s
|
|
432
515
|
buckets[-1][1].items.append(meta)
|
|
433
516
|
buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
|
|
434
|
-
assert buckets[-1][1].size > 0,
|
|
517
|
+
assert buckets[-1][1].size > 0, (
|
|
518
|
+
f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
|
|
519
|
+
)
|
|
435
520
|
return buckets
|
|
436
521
|
|
|
437
522
|
|
|
@@ -470,7 +555,9 @@ class P2PStore:
|
|
|
470
555
|
raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
|
|
471
556
|
self.port = self.engine.get_rpc_port()
|
|
472
557
|
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
473
|
-
logger.info(
|
|
558
|
+
logger.info(
|
|
559
|
+
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
|
|
560
|
+
)
|
|
474
561
|
|
|
475
562
|
@property
|
|
476
563
|
def addr(self) -> str:
|
|
@@ -492,16 +579,24 @@ class P2PStore:
|
|
|
492
579
|
num_unregistered = 0
|
|
493
580
|
for i, name in enumerate(names):
|
|
494
581
|
del self.named_tensors[name]
|
|
495
|
-
logger.info(
|
|
582
|
+
logger.info(
|
|
583
|
+
f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
|
|
584
|
+
)
|
|
496
585
|
num_unregistered += 1
|
|
497
586
|
return num_unregistered
|
|
498
587
|
|
|
499
|
-
def batch_transfer_sync_read(
|
|
500
|
-
|
|
588
|
+
def batch_transfer_sync_read(
|
|
589
|
+
self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
|
|
590
|
+
):
|
|
591
|
+
assert (
|
|
592
|
+
self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
|
|
593
|
+
)
|
|
501
594
|
|
|
502
595
|
|
|
503
596
|
class ParameterServer:
|
|
504
|
-
def __init__(
|
|
597
|
+
def __init__(
|
|
598
|
+
self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
|
|
599
|
+
):
|
|
505
600
|
"""
|
|
506
601
|
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
|
|
507
602
|
|
|
@@ -509,20 +604,19 @@ class ParameterServer:
|
|
|
509
604
|
auto_pg: Whether to automatically initialize the process group.
|
|
510
605
|
Notice that if auto_pg is True, will destroy the process group after update.
|
|
511
606
|
"""
|
|
512
|
-
self._rank = int(os.environ.get("RANK", None))
|
|
513
|
-
self._world_size = int(os.environ.get("WORLD_SIZE", None))
|
|
514
|
-
self._master_addr = os.getenv("MASTER_ADDR")
|
|
607
|
+
self._rank = rank or int(os.environ.get("RANK", None))
|
|
608
|
+
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
|
|
515
609
|
self._gpu_count = torch.cuda.device_count()
|
|
516
610
|
self._local_rank = self._rank % self._gpu_count
|
|
517
611
|
self._auto_pg = auto_pg
|
|
518
612
|
self._all_hosts = []
|
|
519
|
-
self.
|
|
613
|
+
self._global_device_uuids: list[str] = []
|
|
520
614
|
|
|
521
615
|
assert self._rank is not None and self._rank >= 0, self._rank
|
|
522
616
|
assert self._world_size and self._world_size > 0, self._world_size
|
|
523
617
|
|
|
524
|
-
self._device_uuid = _get_physical_gpu_id(self._local_rank)
|
|
525
618
|
self._zmq_ctx = zmq.Context()
|
|
619
|
+
self._zmq_addr_counter = 0
|
|
526
620
|
|
|
527
621
|
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
528
622
|
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
@@ -533,19 +627,27 @@ class ParameterServer:
|
|
|
533
627
|
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
|
|
534
628
|
self._p2p_store = None
|
|
535
629
|
|
|
536
|
-
|
|
630
|
+
device_index = self._local_rank
|
|
631
|
+
torch.cuda.set_device(device_index)
|
|
632
|
+
self._device_uuid = _get_physical_gpu_id(device_index)
|
|
537
633
|
|
|
538
|
-
def _logger_rank0(self, msg):
|
|
634
|
+
def _logger_rank0(self, msg: str):
|
|
539
635
|
if self._local_rank == 0:
|
|
540
636
|
logger.info(msg)
|
|
541
637
|
|
|
542
|
-
def get_metas(self):
|
|
638
|
+
def get_metas(self) -> dict[int, MemoryBufferMetaList]:
|
|
543
639
|
return self._current_global_parameter_metas
|
|
544
640
|
|
|
545
641
|
def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
|
|
546
642
|
self._current_global_parameter_metas = metas
|
|
547
643
|
|
|
548
|
-
def register_checkpoint(
|
|
644
|
+
def register_checkpoint(
|
|
645
|
+
self,
|
|
646
|
+
checkpoint_name: str,
|
|
647
|
+
*,
|
|
648
|
+
files: list[str] | None = None,
|
|
649
|
+
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
650
|
+
) -> None:
|
|
549
651
|
"""
|
|
550
652
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
551
653
|
|
|
@@ -555,12 +657,18 @@ class ParameterServer:
|
|
|
555
657
|
named_tensors: The named tensors to register.
|
|
556
658
|
"""
|
|
557
659
|
try:
|
|
558
|
-
assert checkpoint_name not in self._memory_pool,
|
|
559
|
-
|
|
660
|
+
assert checkpoint_name not in self._memory_pool, (
|
|
661
|
+
f"checkpoint {checkpoint_name} already registered"
|
|
662
|
+
)
|
|
663
|
+
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
664
|
+
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
|
|
665
|
+
)
|
|
560
666
|
if self._p2p_store is not None:
|
|
561
667
|
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
562
668
|
except Exception:
|
|
563
|
-
logger.exception(
|
|
669
|
+
logger.exception(
|
|
670
|
+
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
|
|
671
|
+
)
|
|
564
672
|
if self._p2p_store is not None:
|
|
565
673
|
self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
566
674
|
self.unregister_checkpoint(checkpoint_name)
|
|
@@ -583,10 +691,6 @@ class ParameterServer:
|
|
|
583
691
|
# this works by using torch>=2.5.0
|
|
584
692
|
torch._C._host_emptyCache()
|
|
585
693
|
|
|
586
|
-
@cached_property
|
|
587
|
-
def _zmq_socket_path(self) -> str:
|
|
588
|
-
return f"ipc://@checkpoint-engine-{uuid.uuid4()}.sock"
|
|
589
|
-
|
|
590
694
|
def gather_metas(self, checkpoint_name: str):
|
|
591
695
|
"""
|
|
592
696
|
Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
|
|
@@ -598,19 +702,17 @@ class ParameterServer:
|
|
|
598
702
|
assert dist.is_initialized(), "process group is not initialized"
|
|
599
703
|
metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
|
|
600
704
|
metas = DataToGather(
|
|
601
|
-
memory_buffer_metas_list=
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
),
|
|
610
|
-
),
|
|
705
|
+
memory_buffer_metas_list=[
|
|
706
|
+
MemoryBufferMetas(
|
|
707
|
+
metas=x.metas,
|
|
708
|
+
ptr=x.buffer.data_ptr(),
|
|
709
|
+
size=x.size,
|
|
710
|
+
)
|
|
711
|
+
for x in self._memory_pool.get(checkpoint_name, [])
|
|
712
|
+
],
|
|
611
713
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
612
714
|
host_ip=_get_ip(),
|
|
613
|
-
|
|
715
|
+
device_uuid=self._device_uuid,
|
|
614
716
|
)
|
|
615
717
|
|
|
616
718
|
dist.all_gather_object(metas_lst, metas)
|
|
@@ -618,23 +720,31 @@ class ParameterServer:
|
|
|
618
720
|
self._current_global_parameter_metas = {}
|
|
619
721
|
num_parameters = 0
|
|
620
722
|
all_hosts: list[str] = []
|
|
621
|
-
|
|
723
|
+
global_device_uuids: list[str] = []
|
|
622
724
|
for i, metas_buckets in enumerate(metas_lst):
|
|
623
725
|
assert metas_buckets is not None, f"metas_buckets {i} should not be None"
|
|
624
726
|
if i % self._gpu_count == 0 and not self._all_hosts:
|
|
625
727
|
all_hosts.append(metas_buckets.host_ip)
|
|
626
|
-
if not self.
|
|
627
|
-
|
|
728
|
+
if not self._global_device_uuids:
|
|
729
|
+
global_device_uuids.append(metas_buckets.device_uuid)
|
|
628
730
|
if metas_buckets.memory_buffer_metas_list:
|
|
629
731
|
self._current_global_parameter_metas[i] = metas_buckets
|
|
630
|
-
num_parameters += sum(
|
|
732
|
+
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
|
|
631
733
|
if not self._all_hosts:
|
|
632
734
|
self._all_hosts = all_hosts
|
|
633
|
-
if not self.
|
|
634
|
-
self.
|
|
635
|
-
logger.info(
|
|
735
|
+
if not self._global_device_uuids:
|
|
736
|
+
self._global_device_uuids = global_device_uuids
|
|
737
|
+
logger.info(
|
|
738
|
+
f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
|
|
739
|
+
)
|
|
636
740
|
|
|
637
|
-
def init_process_group(
|
|
741
|
+
def init_process_group(
|
|
742
|
+
self,
|
|
743
|
+
*,
|
|
744
|
+
master_addr: str | None = None,
|
|
745
|
+
master_port: int | None = None,
|
|
746
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
747
|
+
):
|
|
638
748
|
"""
|
|
639
749
|
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
640
750
|
|
|
@@ -642,10 +752,22 @@ class ParameterServer:
|
|
|
642
752
|
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
|
|
643
753
|
timeout: The timeout of the process group.
|
|
644
754
|
"""
|
|
755
|
+
master_addr = master_addr or os.getenv("MASTER_ADDR")
|
|
756
|
+
assert master_addr, "master_addr is required"
|
|
645
757
|
store = dist.TCPStore(
|
|
646
|
-
|
|
758
|
+
master_addr,
|
|
759
|
+
_get_master_port(master_port),
|
|
760
|
+
self._world_size,
|
|
761
|
+
timeout=timeout,
|
|
762
|
+
is_master=self._rank == 0,
|
|
763
|
+
)
|
|
764
|
+
dist.init_process_group(
|
|
765
|
+
backend="nccl",
|
|
766
|
+
world_size=self._world_size,
|
|
767
|
+
rank=self._rank,
|
|
768
|
+
timeout=timeout,
|
|
769
|
+
store=store,
|
|
647
770
|
)
|
|
648
|
-
dist.init_process_group(backend="nccl", world_size=self._world_size, rank=self._rank, timeout=timeout, store=store)
|
|
649
771
|
logger.info(f"[rank{self._rank}] init process group successfully.")
|
|
650
772
|
|
|
651
773
|
def update(
|
|
@@ -653,8 +775,8 @@ class ParameterServer:
|
|
|
653
775
|
checkpoint_name: str,
|
|
654
776
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
655
777
|
*,
|
|
656
|
-
ranks: list[int] =
|
|
657
|
-
):
|
|
778
|
+
ranks: list[int] | None = None,
|
|
779
|
+
) -> None:
|
|
658
780
|
"""
|
|
659
781
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
660
782
|
|
|
@@ -667,6 +789,7 @@ class ParameterServer:
|
|
|
667
789
|
which is useful in disaggregated architecture.
|
|
668
790
|
"""
|
|
669
791
|
try:
|
|
792
|
+
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
670
793
|
if not ranks:
|
|
671
794
|
if self._auto_pg and not dist.is_initialized():
|
|
672
795
|
self.init_process_group()
|
|
@@ -692,21 +815,39 @@ class ParameterServer:
|
|
|
692
815
|
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
|
|
693
816
|
)
|
|
694
817
|
except Exception as e:
|
|
695
|
-
logger.exception(
|
|
696
|
-
|
|
818
|
+
logger.exception(
|
|
819
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
820
|
+
)
|
|
821
|
+
raise
|
|
822
|
+
|
|
823
|
+
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
|
|
824
|
+
def zmq_handle(device_uuid: str) -> str:
|
|
825
|
+
return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
|
|
697
826
|
|
|
698
|
-
|
|
699
|
-
|
|
827
|
+
socket_paths = [(uid, zmq_handle(uid)) for uid in self._global_device_uuids]
|
|
828
|
+
socket = self._zmq_ctx.socket(zmq.REQ)
|
|
829
|
+
socket.bind(zmq_handle(self._device_uuid))
|
|
830
|
+
self._zmq_addr_counter += 1
|
|
831
|
+
return socket, socket_paths
|
|
832
|
+
|
|
833
|
+
def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
|
|
834
|
+
GiB = 1 << 30 # noqa: N806
|
|
700
835
|
# auto detect bucket size
|
|
701
|
-
|
|
702
|
-
|
|
836
|
+
tensor = torch.tensor(
|
|
837
|
+
[
|
|
838
|
+
# 90% of current cuda free memory bytes
|
|
839
|
+
int(float(torch.cuda.mem_get_info()[0]) * 0.9),
|
|
840
|
+
# we use negative value to reuse allreduce min operation
|
|
841
|
+
# for getting the max value of zmq_addr_counter in all ranks
|
|
842
|
+
-self._zmq_addr_counter,
|
|
843
|
+
],
|
|
703
844
|
dtype=torch.int64,
|
|
704
845
|
device="cuda",
|
|
705
846
|
)
|
|
706
|
-
dist.all_reduce(
|
|
707
|
-
|
|
847
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
848
|
+
tensor = tensor.cpu()
|
|
849
|
+
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
|
|
708
850
|
max_tensor_bytes = 0
|
|
709
|
-
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB_bytes
|
|
710
851
|
for items in self._current_global_parameter_metas.values():
|
|
711
852
|
for metas_list in items.memory_buffer_metas_list:
|
|
712
853
|
for meta in metas_list.metas:
|
|
@@ -729,18 +870,27 @@ class ParameterServer:
|
|
|
729
870
|
f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
|
|
730
871
|
)
|
|
731
872
|
disable_h2d_buffer = True
|
|
873
|
+
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
|
|
732
874
|
bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
|
|
733
|
-
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size /
|
|
875
|
+
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
|
|
734
876
|
return bucket_size, disable_h2d_buffer
|
|
735
877
|
|
|
736
|
-
def _copy_to_buffer(
|
|
878
|
+
def _copy_to_buffer(
|
|
879
|
+
self,
|
|
880
|
+
checkpoint_name: str,
|
|
881
|
+
bucket: H2DBucket,
|
|
882
|
+
buffer: torch.Tensor,
|
|
883
|
+
owner_rank: int | None = None,
|
|
884
|
+
):
|
|
737
885
|
offset = 0
|
|
738
886
|
if owner_rank is not None:
|
|
739
887
|
buf_ptrs, remote_ptrs, lens = [], [], []
|
|
740
888
|
ptr_base = buffer.data_ptr()
|
|
741
889
|
target_addr, ptrs = self._get_addr_ptrs(owner_rank)
|
|
742
890
|
for b in bucket.ranges:
|
|
743
|
-
assert offset + b.size <= bucket.size,
|
|
891
|
+
assert offset + b.size <= bucket.size, (
|
|
892
|
+
f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
|
|
893
|
+
)
|
|
744
894
|
if owner_rank is not None:
|
|
745
895
|
buf_ptrs.append(ptr_base + offset)
|
|
746
896
|
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
|
|
@@ -758,7 +908,11 @@ class ParameterServer:
|
|
|
758
908
|
torch.cuda.synchronize()
|
|
759
909
|
|
|
760
910
|
def init_process_group_for_ranks(
|
|
761
|
-
self,
|
|
911
|
+
self,
|
|
912
|
+
ranks: list[int],
|
|
913
|
+
*,
|
|
914
|
+
master_port: int | None = None,
|
|
915
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
762
916
|
):
|
|
763
917
|
"""
|
|
764
918
|
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
@@ -787,8 +941,12 @@ class ParameterServer:
|
|
|
787
941
|
# and will not participate in this update. Since they have registered memory addresses
|
|
788
942
|
# to p2p_store at the beginning, update ranks can directly get the memory addresses
|
|
789
943
|
# from other nodes and put the weights into the buffer.
|
|
790
|
-
store = dist.TCPStore(
|
|
791
|
-
|
|
944
|
+
store = dist.TCPStore(
|
|
945
|
+
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
946
|
+
)
|
|
947
|
+
dist.init_process_group(
|
|
948
|
+
backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
|
|
949
|
+
)
|
|
792
950
|
|
|
793
951
|
def _update_per_bucket_p2p(
|
|
794
952
|
self,
|
|
@@ -800,7 +958,9 @@ class ParameterServer:
|
|
|
800
958
|
assert ranks, "ranks should be set"
|
|
801
959
|
if len(self._current_global_parameter_metas) == 0:
|
|
802
960
|
raise ValueError("parameter metas is empty")
|
|
803
|
-
assert dist.is_initialized(),
|
|
961
|
+
assert dist.is_initialized(), (
|
|
962
|
+
"process group is not initialized when update model per bucket p2p"
|
|
963
|
+
)
|
|
804
964
|
|
|
805
965
|
need_update = self._rank in ranks
|
|
806
966
|
logger.info(
|
|
@@ -814,26 +974,24 @@ class ParameterServer:
|
|
|
814
974
|
# first execute a barrier to avoid subsequent cuda oom
|
|
815
975
|
dist.barrier()
|
|
816
976
|
|
|
817
|
-
bucket_size, _ = self.
|
|
977
|
+
bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
|
|
818
978
|
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
|
|
819
|
-
|
|
820
|
-
self._p2p_store.register_named_tensors({
|
|
979
|
+
ipc_buffer_name = "__ipc_buffer___"
|
|
980
|
+
self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
|
|
821
981
|
logger.info(
|
|
822
982
|
f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
|
|
823
983
|
)
|
|
824
984
|
handle = reduce_tensor(buffer)
|
|
825
985
|
|
|
826
|
-
gidx = 0
|
|
827
986
|
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
|
|
987
|
+
socket, socket_paths = self._bind_zmq_socket()
|
|
828
988
|
req_thread = threading.Thread(
|
|
829
989
|
target=req_func,
|
|
830
|
-
args=(
|
|
990
|
+
args=(socket_paths,),
|
|
831
991
|
)
|
|
832
992
|
req_thread.start()
|
|
833
|
-
socket = self._zmq_ctx.socket(zmq.REQ)
|
|
834
|
-
socket.bind(self._zmq_socket_path)
|
|
835
993
|
socket.send_pyobj(handle)
|
|
836
|
-
for owner_rank, bucket in buckets:
|
|
994
|
+
for gidx, (owner_rank, bucket) in enumerate(buckets):
|
|
837
995
|
self._logger_rank0(
|
|
838
996
|
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
839
997
|
)
|
|
@@ -845,7 +1003,6 @@ class ParameterServer:
|
|
|
845
1003
|
socket.recv()
|
|
846
1004
|
dist.barrier()
|
|
847
1005
|
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
848
|
-
gidx += 1
|
|
849
1006
|
|
|
850
1007
|
socket.recv()
|
|
851
1008
|
socket.send_pyobj(None)
|
|
@@ -853,7 +1010,7 @@ class ParameterServer:
|
|
|
853
1010
|
req_thread.join()
|
|
854
1011
|
dist.barrier()
|
|
855
1012
|
socket.close()
|
|
856
|
-
self._p2p_store.unregister_named_tensors([
|
|
1013
|
+
self._p2p_store.unregister_named_tensors([ipc_buffer_name])
|
|
857
1014
|
torch.cuda.empty_cache()
|
|
858
1015
|
|
|
859
1016
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
@@ -877,7 +1034,9 @@ class ParameterServer:
|
|
|
877
1034
|
pool = self._memory_pool[checkpoint_name]
|
|
878
1035
|
if len(pool) == 0:
|
|
879
1036
|
return 0
|
|
880
|
-
return self._p2p_store.unregister_named_tensors(
|
|
1037
|
+
return self._p2p_store.unregister_named_tensors(
|
|
1038
|
+
[f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
|
|
1039
|
+
)
|
|
881
1040
|
|
|
882
1041
|
def _update_per_bucket(
|
|
883
1042
|
self,
|
|
@@ -891,11 +1050,13 @@ class ParameterServer:
|
|
|
891
1050
|
|
|
892
1051
|
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
893
1052
|
|
|
894
|
-
bucket_size, disable_h2d_buffer = self.
|
|
1053
|
+
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
895
1054
|
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
|
|
896
1055
|
|
|
897
1056
|
h2d_buffer: torch.Tensor | None = (
|
|
898
|
-
None
|
|
1057
|
+
None
|
|
1058
|
+
if disable_h2d_buffer
|
|
1059
|
+
else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
|
|
899
1060
|
)
|
|
900
1061
|
|
|
901
1062
|
owner_rank_buckets: list[H2DBucket] = []
|
|
@@ -914,13 +1075,12 @@ class ParameterServer:
|
|
|
914
1075
|
if len(buckets_by_owner_rank[owner_rank]) > max_len:
|
|
915
1076
|
max_len = len(buckets_by_owner_rank[owner_rank])
|
|
916
1077
|
|
|
1078
|
+
socket, socket_paths = self._bind_zmq_socket()
|
|
917
1079
|
req_thread = threading.Thread(
|
|
918
1080
|
target=req_func,
|
|
919
|
-
args=(
|
|
1081
|
+
args=(socket_paths,),
|
|
920
1082
|
)
|
|
921
1083
|
req_thread.start()
|
|
922
|
-
socket = self._zmq_ctx.socket(zmq.REQ)
|
|
923
|
-
socket.bind(self._zmq_socket_path)
|
|
924
1084
|
socket.send_pyobj(handle)
|
|
925
1085
|
|
|
926
1086
|
gidx = 0
|
|
@@ -932,7 +1092,10 @@ class ParameterServer:
|
|
|
932
1092
|
if i >= len(_buckets):
|
|
933
1093
|
continue
|
|
934
1094
|
bucket = _buckets[i]
|
|
935
|
-
alloc, reserved =
|
|
1095
|
+
alloc, reserved = (
|
|
1096
|
+
torch.cuda.memory_allocated() / 1024 / 1024,
|
|
1097
|
+
torch.cuda.memory_reserved() / 1024 / 1024,
|
|
1098
|
+
)
|
|
936
1099
|
self._logger_rank0(
|
|
937
1100
|
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
938
1101
|
f"Current CUDA allocated {alloc:.2f} MB, "
|
|
@@ -960,7 +1123,7 @@ class ParameterServer:
|
|
|
960
1123
|
torch.cuda.empty_cache()
|
|
961
1124
|
|
|
962
1125
|
|
|
963
|
-
def _init_api(ps: ParameterServer):
|
|
1126
|
+
def _init_api(ps: ParameterServer) -> Any:
|
|
964
1127
|
import fastapi
|
|
965
1128
|
from fastapi import Request
|
|
966
1129
|
from fastapi.responses import JSONResponse, Response
|
|
@@ -976,32 +1139,32 @@ def _init_api(ps: ParameterServer):
|
|
|
976
1139
|
inference_group_ranks: list[int] = []
|
|
977
1140
|
timeout: float = 300.0
|
|
978
1141
|
|
|
979
|
-
def wrap_exception(func):
|
|
1142
|
+
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
980
1143
|
try:
|
|
981
1144
|
func()
|
|
982
|
-
except Exception as e:
|
|
1145
|
+
except Exception as e: # noqa: BLE001
|
|
983
1146
|
logger.exception(f"wrap exception {func} failed")
|
|
984
1147
|
return JSONResponse(content=str(e), status_code=500)
|
|
985
1148
|
return Response(status_code=200)
|
|
986
1149
|
|
|
987
1150
|
@app.post("/v1/checkpoints/{checkpoint_name}/files")
|
|
988
|
-
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request):
|
|
1151
|
+
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
|
|
989
1152
|
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
|
|
990
1153
|
|
|
991
1154
|
@app.delete("/v1/checkpoints/{checkpoint_name}")
|
|
992
|
-
async def unregister_checkpoint(checkpoint_name: str):
|
|
1155
|
+
async def unregister_checkpoint(checkpoint_name: str) -> Response:
|
|
993
1156
|
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
|
|
994
1157
|
|
|
995
1158
|
@app.get("/v1/healthz")
|
|
996
|
-
async def healthz():
|
|
1159
|
+
async def healthz() -> Response:
|
|
997
1160
|
return Response(status_code=200)
|
|
998
1161
|
|
|
999
1162
|
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
|
|
1000
|
-
async def gather_metas(checkpoint_name: str):
|
|
1163
|
+
async def gather_metas(checkpoint_name: str) -> Response:
|
|
1001
1164
|
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
|
|
1002
1165
|
|
|
1003
1166
|
@app.post("/v1/checkpoints/{checkpoint_name}/update")
|
|
1004
|
-
async def update(checkpoint_name: str, req: UpdateRequest):
|
|
1167
|
+
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
|
|
1005
1168
|
def update_func(socket_paths: list[tuple[str, str]]):
|
|
1006
1169
|
if req.update_url is None:
|
|
1007
1170
|
return
|
|
@@ -1018,11 +1181,13 @@ def _init_api(ps: ParameterServer):
|
|
|
1018
1181
|
def run_from_cli():
|
|
1019
1182
|
import uvicorn
|
|
1020
1183
|
|
|
1021
|
-
parser = argparse.ArgumentParser(description="
|
|
1184
|
+
parser = argparse.ArgumentParser(description="Parameter Server")
|
|
1022
1185
|
parser.add_argument("--uds", type=str)
|
|
1023
1186
|
|
|
1024
1187
|
args = parser.parse_args()
|
|
1025
|
-
logger.info(
|
|
1188
|
+
logger.info(
|
|
1189
|
+
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
|
|
1190
|
+
)
|
|
1026
1191
|
|
|
1027
1192
|
assert args.uds and len(args.uds) > 0, args.uds
|
|
1028
1193
|
ps = ParameterServer(auto_pg=True)
|
checkpoint_engine/worker.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import TypedDict
|
|
3
4
|
|
|
4
5
|
import torch
|
|
5
6
|
import zmq
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id:
|
|
9
|
+
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
9
10
|
func, args = handle
|
|
10
11
|
list_args = list(args)
|
|
11
12
|
if device_id is not None:
|
|
@@ -24,12 +25,14 @@ class FlattenedTensorMetadata(TypedDict):
|
|
|
24
25
|
offset: int
|
|
25
26
|
|
|
26
27
|
|
|
27
|
-
def _extract_weights(
|
|
28
|
+
def _extract_weights(
|
|
29
|
+
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
|
|
30
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
28
31
|
assert buffer is not None
|
|
29
32
|
weights: list[tuple[str, torch.Tensor]] = []
|
|
30
33
|
for item in payload:
|
|
31
34
|
shape = item["shape"]
|
|
32
|
-
if isinstance(shape,
|
|
35
|
+
if isinstance(shape, list | tuple):
|
|
33
36
|
shape = torch.Size(shape)
|
|
34
37
|
assert isinstance(shape, torch.Size)
|
|
35
38
|
dtype, offset = item["dtype"], item["offset"]
|
|
@@ -45,11 +48,11 @@ def update_weights_from_ipc(
|
|
|
45
48
|
device_id: int,
|
|
46
49
|
*,
|
|
47
50
|
run: Callable[[list[tuple[str, torch.Tensor]]], None],
|
|
48
|
-
post_hook: Callable[[], None] = None,
|
|
51
|
+
post_hook: Callable[[], None] | None = None,
|
|
49
52
|
):
|
|
50
53
|
socket = zmq_ctx.socket(zmq.REP)
|
|
51
54
|
socket.connect(zmq_handle)
|
|
52
|
-
buffer:
|
|
55
|
+
buffer: torch.Tensor | None = None
|
|
53
56
|
while True:
|
|
54
57
|
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
|
|
55
58
|
if payload is None:
|
|
@@ -100,5 +103,7 @@ class VllmColocateWorkerExtension:
|
|
|
100
103
|
zmq_handles[device_uuid],
|
|
101
104
|
device_id=self.device.index,
|
|
102
105
|
run=self.model_runner.model.load_weights,
|
|
103
|
-
post_hook=lambda: process_weights_after_loading(
|
|
106
|
+
post_hook=lambda: process_weights_after_loading(
|
|
107
|
+
self.model_runner.model, self.model_config, self.device
|
|
108
|
+
),
|
|
104
109
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -11,14 +11,15 @@ Description-Content-Type: text/markdown
|
|
|
11
11
|
License-File: LICENCE
|
|
12
12
|
Requires-Dist: torch>=2.5.0
|
|
13
13
|
Requires-Dist: fastapi
|
|
14
|
-
Requires-Dist: pydantic
|
|
14
|
+
Requires-Dist: pydantic>=2.0.0
|
|
15
15
|
Requires-Dist: safetensors
|
|
16
16
|
Requires-Dist: pyzmq
|
|
17
17
|
Requires-Dist: uvicorn
|
|
18
18
|
Requires-Dist: loguru
|
|
19
19
|
Requires-Dist: numpy
|
|
20
|
+
Requires-Dist: httpx
|
|
20
21
|
Provides-Extra: p2p
|
|
21
|
-
Requires-Dist: mooncake-transfer-engine; extra == "p2p"
|
|
22
|
+
Requires-Dist: mooncake-transfer-engine>=0.3.5; extra == "p2p"
|
|
22
23
|
Dynamic: license-file
|
|
23
24
|
|
|
24
25
|
# Checkpoint Engine
|
|
@@ -41,7 +42,7 @@ The core weight update logic is in `ParameterServer` class, a service colocated
|
|
|
41
42
|
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
|
|
42
43
|
|
|
43
44
|
### Optimized Weight Broadcast
|
|
44
|
-
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
|
|
45
|
+
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
|
|
45
46
|
We arrange the data transfer into 3 stages:
|
|
46
47
|
1. H2D: moving weights to GPU memory. These weights may come from disk or the training engine.
|
|
47
48
|
2. broadcast: broadcast among checkpoint engine workers; the data results in a CUDA IPC buffer shared with inference engine.
|
|
@@ -73,9 +74,9 @@ Pipelining naturally requires more GPU memory. When memory is not enough, checkp
|
|
|
73
74
|
All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
|
|
74
75
|
|
|
75
76
|
* FP8 test needs additional vLLM patches, see [FP8 quantization](#fp8-quantization).
|
|
76
|
-
* Device Info: we tested various combination of devices and
|
|
77
|
+
* Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
|
|
77
78
|
* Since update duration is related to IPC bucket size, we provide the bucket size in the table.
|
|
78
|
-
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
|
|
79
|
+
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
|
|
79
80
|
|
|
80
81
|
## Installation
|
|
81
82
|
|
|
@@ -88,7 +89,7 @@ pip install checkpoint-engine
|
|
|
88
89
|
Use the flexible P2P implementation, notice this will install `mooncake-transfer-engine` to support RDMA transfer between different ranks.
|
|
89
90
|
|
|
90
91
|
```Bash
|
|
91
|
-
pip install checkpoint-engine[p2p]
|
|
92
|
+
pip install 'checkpoint-engine[p2p]'
|
|
92
93
|
```
|
|
93
94
|
|
|
94
95
|
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
@@ -107,7 +108,7 @@ VLLM_USE_PRECOMPILED=1 uv pip install --editable .
|
|
|
107
108
|
Install checkpoint-engine
|
|
108
109
|
|
|
109
110
|
```Bash
|
|
110
|
-
uv pip install checkpoint-engine[p2p]
|
|
111
|
+
uv pip install 'checkpoint-engine[p2p]'
|
|
111
112
|
```
|
|
112
113
|
|
|
113
114
|
We use `Qwen/Qwen3-235B-A22B-Instruct-2507` (BF16) as the test model
|
|
@@ -133,7 +134,7 @@ torchrun --nproc-per-node 8 examples/update.py --update-method all --checkpoint-
|
|
|
133
134
|
|
|
134
135
|
### Reuse weights from existing instances
|
|
135
136
|
|
|
136
|
-
New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
|
|
137
|
+
New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
|
|
137
138
|
|
|
138
139
|
First, start the existing instances with `--save-metas-file global_metas.pkl` to save global metas to a file and use `--sleep-time 300` to make sure they stay alive.
|
|
139
140
|
|
|
@@ -150,7 +151,7 @@ torchrun --nproc-per-node 8 examples/update.py --load-metas-file global_metas.pk
|
|
|
150
151
|
|
|
151
152
|
### FP8 quantization
|
|
152
153
|
|
|
153
|
-
FP8 quantization currently do not natively work in vLLM when updating weights.
|
|
154
|
+
FP8 quantization currently do not natively work in vLLM when updating weights.
|
|
154
155
|
We provide a simple patch in [`patches/vllm_fp8.patch`](./patches/vllm_fp8.patch) to handle the correct weight update.
|
|
155
156
|
Notice this patch is only tested in DeepSeek-V3.1 and Kimi-K2. Other models may meet some compatible issues.
|
|
156
157
|
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=Ok5oAXdWgR9aghaFXTafTeDW6sYO3uVe6d2Nket57R4,704
|
|
3
|
+
checkpoint_engine/ps.py,sha256=ckM2vdLg3aeOKmM_vTcbIPKcT-r-E4s73yPaCESKdwg,48439
|
|
4
|
+
checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
|
|
5
|
+
checkpoint_engine-0.1.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
6
|
+
checkpoint_engine-0.1.2.dist-info/METADATA,sha256=9FUb4s1KMSzrDOOV-C18q3gqcVWD_qZ4t_DaLuai_M4,9322
|
|
7
|
+
checkpoint_engine-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
+
checkpoint_engine-0.1.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
+
checkpoint_engine-0.1.2.dist-info/RECORD,,
|
|
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
|
18
18
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
19
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
20
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
-
SOFTWARE
|
|
21
|
+
SOFTWARE
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
-
checkpoint_engine/_version.py,sha256=yfx_VE-4lpqM4jnWOSq-8rihMWIwMaX9CQ7tNEpA4T0,712
|
|
3
|
-
checkpoint_engine/ps.py,sha256=9u2rLOj-oQrXsnpYhdXFjv7ak2-f4BRUXB6KYlG3ah0,44422
|
|
4
|
-
checkpoint_engine/worker.py,sha256=OrSeknjtECnO88I-YMdfkZj70TIRhjvTEeZkNyZk21M,3695
|
|
5
|
-
checkpoint_engine-0.1.1.dist-info/licenses/LICENCE,sha256=0jqA0jrA_i9VUqd7FTVoI1KnN1ZRENwG_tlMRCjC63k,1066
|
|
6
|
-
checkpoint_engine-0.1.1.dist-info/METADATA,sha256=WDc5tg3RQiCthzbEzI4D3t3I0AQhGUPcDpJn7TMhYbI,9286
|
|
7
|
-
checkpoint_engine-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
-
checkpoint_engine-0.1.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
-
checkpoint_engine-0.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|