checkpoint-engine 0.1.3__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/device_utils.py +86 -0
- checkpoint_engine/ps.py +368 -192
- checkpoint_engine/worker.py +86 -30
- {checkpoint_engine-0.1.3.dist-info → checkpoint_engine-0.2.1.dist-info}/METADATA +70 -13
- checkpoint_engine-0.2.1.dist-info/RECORD +10 -0
- checkpoint_engine-0.1.3.dist-info/RECORD +0 -9
- {checkpoint_engine-0.1.3.dist-info → checkpoint_engine-0.2.1.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.1.3.dist-info → checkpoint_engine-0.2.1.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.1.3.dist-info → checkpoint_engine-0.2.1.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.1
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.2.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 1)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import socket
|
|
4
|
+
import subprocess
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=1)
|
|
12
|
+
def get_ip() -> str:
|
|
13
|
+
try:
|
|
14
|
+
# try to get ip from network interface
|
|
15
|
+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
16
|
+
s.connect(("8.8.8.8", 80))
|
|
17
|
+
return s.getsockname()[0]
|
|
18
|
+
except Exception as e: # noqa: BLE001
|
|
19
|
+
# fallback to get ip from hostname
|
|
20
|
+
logger.warning(
|
|
21
|
+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
22
|
+
)
|
|
23
|
+
return socket.gethostbyname(socket.gethostname())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def npu_generate_uuid() -> str:
|
|
27
|
+
str_pid = str(os.getpid())
|
|
28
|
+
npu_num = 8
|
|
29
|
+
try:
|
|
30
|
+
for npu_id in range(npu_num):
|
|
31
|
+
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
|
|
32
|
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
|
|
33
|
+
str_result = str(result.stdout)
|
|
34
|
+
if str_pid in str_result:
|
|
35
|
+
# In A3 server, one NPU has two chips.
|
|
36
|
+
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
|
|
37
|
+
chip_count = int(match_chip_count.group(1))
|
|
38
|
+
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
|
|
39
|
+
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
|
|
40
|
+
chip_id = int(match_chip_id.group(1))
|
|
41
|
+
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
|
|
42
|
+
raise ValueError("The current process is not running on the npu device")
|
|
43
|
+
except subprocess.CalledProcessError as e:
|
|
44
|
+
raise ValueError("The current process is not running on the npu device") from e
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DeviceManager:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.device_type = self._detect_device_type()
|
|
50
|
+
self._setup_device_module()
|
|
51
|
+
|
|
52
|
+
def _is_torch_npu_available(self) -> bool:
|
|
53
|
+
try:
|
|
54
|
+
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
|
|
55
|
+
return torch.npu.is_available()
|
|
56
|
+
else:
|
|
57
|
+
return False
|
|
58
|
+
except ImportError:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def _detect_device_type(self) -> str:
|
|
62
|
+
if self._is_torch_npu_available():
|
|
63
|
+
return "npu"
|
|
64
|
+
elif torch.cuda.is_available():
|
|
65
|
+
return "cuda"
|
|
66
|
+
else:
|
|
67
|
+
raise TypeError("The current device type is not supported")
|
|
68
|
+
|
|
69
|
+
def _setup_device_module(self):
|
|
70
|
+
if self.device_type == "npu":
|
|
71
|
+
import torch_npu
|
|
72
|
+
|
|
73
|
+
self.device_module = torch_npu.npu
|
|
74
|
+
elif self.device_type == "cuda":
|
|
75
|
+
self.device_module = torch.cuda
|
|
76
|
+
else:
|
|
77
|
+
raise TypeError("The current device type is not supported")
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def backend(self) -> str:
|
|
81
|
+
if self.device_type == "npu":
|
|
82
|
+
return "hccl"
|
|
83
|
+
elif self.device_type == "cuda":
|
|
84
|
+
return "nccl"
|
|
85
|
+
else:
|
|
86
|
+
raise TypeError("The current device type is not supported")
|
checkpoint_engine/ps.py
CHANGED
|
@@ -4,13 +4,11 @@ import ctypes
|
|
|
4
4
|
import os
|
|
5
5
|
import pickle
|
|
6
6
|
import random
|
|
7
|
-
import socket
|
|
8
7
|
import threading
|
|
9
8
|
import time
|
|
10
9
|
from collections import defaultdict
|
|
11
10
|
from collections.abc import Callable
|
|
12
11
|
from datetime import timedelta
|
|
13
|
-
from functools import lru_cache
|
|
14
12
|
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
|
|
15
13
|
|
|
16
14
|
import httpx
|
|
@@ -23,8 +21,12 @@ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
|
23
21
|
from safetensors.torch import safe_open
|
|
24
22
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
25
23
|
|
|
24
|
+
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
25
|
+
|
|
26
26
|
|
|
27
27
|
if TYPE_CHECKING:
|
|
28
|
+
from typing import TypeVar
|
|
29
|
+
|
|
28
30
|
from typing_extensions import TypedDict
|
|
29
31
|
|
|
30
32
|
class FileMeta(TypedDict):
|
|
@@ -34,6 +36,8 @@ if TYPE_CHECKING:
|
|
|
34
36
|
type: type
|
|
35
37
|
tp_concat_dim: int
|
|
36
38
|
|
|
39
|
+
T = TypeVar("T")
|
|
40
|
+
|
|
37
41
|
|
|
38
42
|
def _dt_validate(value: Any) -> torch.dtype:
|
|
39
43
|
if isinstance(value, str):
|
|
@@ -117,6 +121,7 @@ class MemoryBuffer(BaseModel):
|
|
|
117
121
|
class MemoryBufferMetaList(BaseModel):
|
|
118
122
|
p2p_store_addr: str | None
|
|
119
123
|
memory_buffer_metas_list: list[MemoryBufferMetas]
|
|
124
|
+
rdma_device: str
|
|
120
125
|
|
|
121
126
|
|
|
122
127
|
class DataToGather(MemoryBufferMetaList):
|
|
@@ -249,28 +254,16 @@ def _concat_tp_weights(
|
|
|
249
254
|
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
250
255
|
|
|
251
256
|
|
|
252
|
-
def _get_physical_gpu_id(device_index: int | None = None) -> str:
|
|
257
|
+
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
|
|
253
258
|
try:
|
|
254
|
-
|
|
259
|
+
if device_manager.device_type == "npu":
|
|
260
|
+
return f"NPU-{npu_generate_uuid()}"
|
|
261
|
+
else:
|
|
262
|
+
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
|
|
255
263
|
except AssertionError as e:
|
|
256
264
|
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
257
265
|
|
|
258
266
|
|
|
259
|
-
@lru_cache(maxsize=1)
|
|
260
|
-
def _get_ip() -> str:
|
|
261
|
-
try:
|
|
262
|
-
# try to get ip from network interface
|
|
263
|
-
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
264
|
-
s.connect(("8.8.8.8", 80))
|
|
265
|
-
return s.getsockname()[0]
|
|
266
|
-
except Exception as e: # noqa: BLE001
|
|
267
|
-
# fallback to get ip from hostname
|
|
268
|
-
logger.warning(
|
|
269
|
-
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
270
|
-
)
|
|
271
|
-
return socket.gethostbyname(socket.gethostname())
|
|
272
|
-
|
|
273
|
-
|
|
274
267
|
def _ibv_get_device_list() -> list[str]:
|
|
275
268
|
lib = ctypes.CDLL("libibverbs.so.1")
|
|
276
269
|
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
@@ -303,14 +296,7 @@ def _get_rdma_devices() -> list[str]:
|
|
|
303
296
|
return devices_str.split(",")
|
|
304
297
|
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
305
298
|
hca = os.getenv("NCCL_IB_HCA", None)
|
|
306
|
-
|
|
307
|
-
hca_list = hca.split(",")
|
|
308
|
-
if len(hca_list) > 1:
|
|
309
|
-
# if NCCL_IB_HCA has multiple values, just return
|
|
310
|
-
return hca_list
|
|
311
|
-
else:
|
|
312
|
-
hca = hca_list[0]
|
|
313
|
-
return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
|
|
299
|
+
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
|
|
314
300
|
|
|
315
301
|
|
|
316
302
|
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
@@ -319,13 +305,90 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
|
|
|
319
305
|
"""
|
|
320
306
|
if not devices:
|
|
321
307
|
raise RuntimeError("no rdma devices found")
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
308
|
+
try:
|
|
309
|
+
assert len(devices) <= gpu_count, (
|
|
310
|
+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
311
|
+
)
|
|
312
|
+
assert gpu_count % len(devices) == 0, (
|
|
313
|
+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
314
|
+
)
|
|
315
|
+
return devices[local_rank // (gpu_count // len(devices))]
|
|
316
|
+
except AssertionError:
|
|
317
|
+
logger.error(
|
|
318
|
+
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
|
|
319
|
+
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
|
|
320
|
+
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
|
|
321
|
+
)
|
|
322
|
+
raise
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
326
|
+
"""
|
|
327
|
+
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
|
|
328
|
+
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
|
|
329
|
+
|
|
330
|
+
The list is comma-separated; port numbers are NOT supported yet.
|
|
331
|
+
An optional prefix '^' indicates the list is an exclude list.
|
|
332
|
+
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
|
|
333
|
+
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
|
|
334
|
+
|
|
335
|
+
Examples:
|
|
336
|
+
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
|
|
337
|
+
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
|
|
338
|
+
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
|
|
339
|
+
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
|
|
340
|
+
"""
|
|
341
|
+
max_hcas = 32
|
|
342
|
+
if not value or value.strip() == "":
|
|
343
|
+
return available_devices[:max_hcas]
|
|
344
|
+
|
|
345
|
+
value = value.strip()
|
|
346
|
+
result = []
|
|
347
|
+
is_exclude = value.startswith("^")
|
|
348
|
+
if is_exclude:
|
|
349
|
+
value = value.removeprefix("^")
|
|
350
|
+
is_exact_match = value.startswith("=")
|
|
351
|
+
if is_exact_match:
|
|
352
|
+
value = value.removeprefix("=")
|
|
353
|
+
|
|
354
|
+
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
|
|
355
|
+
|
|
356
|
+
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
|
|
357
|
+
if is_exclude:
|
|
358
|
+
result = [dev for dev in available_devices if dev not in result]
|
|
359
|
+
if len(result) > max_hcas:
|
|
360
|
+
result = result[:max_hcas]
|
|
361
|
+
|
|
362
|
+
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
|
|
363
|
+
|
|
364
|
+
return result
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _resolve_device_specs(
|
|
368
|
+
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
|
|
369
|
+
) -> list[str]:
|
|
370
|
+
devices = set()
|
|
371
|
+
for spec in device_specs:
|
|
372
|
+
parts = spec.split(":", 1)
|
|
373
|
+
device_name = parts[0].strip()
|
|
374
|
+
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
|
|
375
|
+
# port = parts[1].strip() if len(parts) > 1 else None
|
|
376
|
+
base_devices = (
|
|
377
|
+
[device_name]
|
|
378
|
+
if device_name in available_devices
|
|
379
|
+
else []
|
|
380
|
+
if is_exact_match
|
|
381
|
+
else [dev for dev in available_devices if dev.startswith(device_name)]
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
if not base_devices:
|
|
385
|
+
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
|
|
386
|
+
continue
|
|
387
|
+
|
|
388
|
+
for base_dev in base_devices:
|
|
389
|
+
devices.add(base_dev)
|
|
390
|
+
|
|
391
|
+
return sorted(devices)
|
|
329
392
|
|
|
330
393
|
|
|
331
394
|
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
@@ -490,8 +553,12 @@ def request_inference_to_update(
|
|
|
490
553
|
|
|
491
554
|
|
|
492
555
|
def _gen_h2d_buckets(
|
|
493
|
-
global_metas: dict[int, MemoryBufferMetaList],
|
|
494
|
-
|
|
556
|
+
global_metas: dict[int, MemoryBufferMetaList],
|
|
557
|
+
bucket_size: int,
|
|
558
|
+
local_topo: dict[str, set[int]],
|
|
559
|
+
remote_topo: dict[str, set[int]],
|
|
560
|
+
ranks: list[int] | None = None,
|
|
561
|
+
) -> list[tuple[int, int, H2DBucket]]:
|
|
495
562
|
buckets: list[tuple[int, H2DBucket]] = []
|
|
496
563
|
|
|
497
564
|
for owner_rank, items in global_metas.items():
|
|
@@ -514,7 +581,73 @@ def _gen_h2d_buckets(
|
|
|
514
581
|
assert buckets[-1][1].size > 0, (
|
|
515
582
|
f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
|
|
516
583
|
)
|
|
517
|
-
|
|
584
|
+
ranks_set = set(ranks) if ranks else set()
|
|
585
|
+
actual_local_topo = (
|
|
586
|
+
{k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo
|
|
587
|
+
)
|
|
588
|
+
# if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
|
|
589
|
+
if not ranks:
|
|
590
|
+
return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
|
|
591
|
+
else:
|
|
592
|
+
return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def _assign_receiver_ranks(
|
|
596
|
+
buckets: list[tuple[int, "T"]],
|
|
597
|
+
local_topo: dict[str, set[int]],
|
|
598
|
+
remote_topo: dict[str, set[int]],
|
|
599
|
+
) -> list[tuple[int, int, "T"]]:
|
|
600
|
+
"""
|
|
601
|
+
(owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
|
|
602
|
+
|
|
603
|
+
Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
|
|
604
|
+
GPU-rdma_device topology will be considered to make full use of the bandwidth.
|
|
605
|
+
"""
|
|
606
|
+
if not buckets:
|
|
607
|
+
logger.warning("bucket list is empty, no need to assign receiver ranks")
|
|
608
|
+
return []
|
|
609
|
+
rank_to_rdma_device = {
|
|
610
|
+
rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
# group buckets by owner RDMA devices
|
|
614
|
+
buckets_by_rdma_device = defaultdict(list)
|
|
615
|
+
for owner_rank, bucket in buckets:
|
|
616
|
+
owner_rdma_device = rank_to_rdma_device[owner_rank]
|
|
617
|
+
buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
|
|
618
|
+
|
|
619
|
+
buckets_matrix = list(buckets_by_rdma_device.values())
|
|
620
|
+
assert buckets_matrix, "buckets_matrix should not be empty"
|
|
621
|
+
|
|
622
|
+
# Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
|
|
623
|
+
num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
|
|
624
|
+
receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
|
|
625
|
+
|
|
626
|
+
flattened_buckets = [
|
|
627
|
+
buckets_matrix[row][col]
|
|
628
|
+
for col in range(
|
|
629
|
+
max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0
|
|
630
|
+
)
|
|
631
|
+
for row in range(len(buckets_matrix))
|
|
632
|
+
if col < len(buckets_matrix[row])
|
|
633
|
+
]
|
|
634
|
+
|
|
635
|
+
buckets_with_receiver = []
|
|
636
|
+
assigned_cnt = 0
|
|
637
|
+
while assigned_cnt < len(flattened_buckets):
|
|
638
|
+
occupied_devices = set()
|
|
639
|
+
for receiver_rank in receiver_list:
|
|
640
|
+
if assigned_cnt >= len(flattened_buckets):
|
|
641
|
+
break
|
|
642
|
+
owner_rank, bucket = flattened_buckets[assigned_cnt]
|
|
643
|
+
rdma_device = rank_to_rdma_device[owner_rank]
|
|
644
|
+
if rdma_device in occupied_devices:
|
|
645
|
+
break
|
|
646
|
+
buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
|
|
647
|
+
occupied_devices.add(rdma_device)
|
|
648
|
+
assigned_cnt += 1
|
|
649
|
+
|
|
650
|
+
return buckets_with_receiver
|
|
518
651
|
|
|
519
652
|
|
|
520
653
|
def _get_master_port(master_port: int | None = None) -> int:
|
|
@@ -525,21 +658,44 @@ def _get_master_port(master_port: int | None = None) -> int:
|
|
|
525
658
|
return master_port
|
|
526
659
|
|
|
527
660
|
|
|
661
|
+
def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
|
|
662
|
+
"""
|
|
663
|
+
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
|
|
664
|
+
which are generated in self.init_process_group_for_ranks
|
|
665
|
+
"""
|
|
666
|
+
bcast_rank_map: dict[int, int] = {}
|
|
667
|
+
if not ranks:
|
|
668
|
+
bcast_rank_map = {r: r for r in range(world_size)}
|
|
669
|
+
else:
|
|
670
|
+
for i, r in enumerate(ranks):
|
|
671
|
+
bcast_rank_map[r] = i
|
|
672
|
+
return bcast_rank_map
|
|
673
|
+
|
|
674
|
+
|
|
528
675
|
class P2PStore:
|
|
529
|
-
def __init__(self):
|
|
676
|
+
def __init__(self, device_manager: DeviceManager):
|
|
530
677
|
from mooncake.engine import TransferEngine
|
|
531
678
|
|
|
532
679
|
self.rank = int(os.getenv("RANK"))
|
|
533
|
-
gpu_count =
|
|
680
|
+
gpu_count = device_manager.device_module.device_count()
|
|
534
681
|
local_rank = self.rank % gpu_count
|
|
535
|
-
|
|
536
|
-
|
|
682
|
+
device_type = device_manager.device_type
|
|
683
|
+
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
|
|
684
|
+
self.device = ""
|
|
685
|
+
else:
|
|
686
|
+
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
687
|
+
self.ip = get_ip()
|
|
537
688
|
|
|
538
689
|
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
539
690
|
retry_count = 8
|
|
540
691
|
for i in range(retry_count):
|
|
541
692
|
self.engine = TransferEngine()
|
|
542
|
-
ret = self.engine.initialize(
|
|
693
|
+
ret = self.engine.initialize(
|
|
694
|
+
self.ip,
|
|
695
|
+
"P2PHANDSHAKE",
|
|
696
|
+
"ascend_direct" if device_type == "npu" else "rdma",
|
|
697
|
+
self.device,
|
|
698
|
+
)
|
|
543
699
|
if ret == 0:
|
|
544
700
|
break
|
|
545
701
|
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
@@ -553,7 +709,7 @@ class P2PStore:
|
|
|
553
709
|
self.port = self.engine.get_rpc_port()
|
|
554
710
|
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
555
711
|
logger.info(
|
|
556
|
-
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
|
|
712
|
+
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
|
|
557
713
|
)
|
|
558
714
|
|
|
559
715
|
@property
|
|
@@ -606,15 +762,18 @@ class ParameterServer:
|
|
|
606
762
|
Args:
|
|
607
763
|
auto_pg: Whether to automatically initialize the process group.
|
|
608
764
|
Notice that if auto_pg is True, will destroy the process group after update.
|
|
609
|
-
mem_fraction: The proportion (as a fraction) of the current free
|
|
765
|
+
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
610
766
|
"""
|
|
611
767
|
self._rank = rank or int(os.environ.get("RANK", None))
|
|
612
768
|
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
|
|
613
|
-
self.
|
|
769
|
+
self.device_manager = DeviceManager()
|
|
770
|
+
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
|
|
614
771
|
self._local_rank = self._rank % self._gpu_count
|
|
615
772
|
self._auto_pg = auto_pg
|
|
616
773
|
self._all_hosts = []
|
|
617
774
|
self._global_device_uuids: list[str] = []
|
|
775
|
+
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
776
|
+
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
618
777
|
self._mem_fraction = mem_fraction or 0.9
|
|
619
778
|
|
|
620
779
|
assert self._rank is not None and self._rank >= 0, self._rank
|
|
@@ -622,7 +781,7 @@ class ParameterServer:
|
|
|
622
781
|
assert (
|
|
623
782
|
self._gpu_count is not None
|
|
624
783
|
and self._gpu_count > 0
|
|
625
|
-
and self._gpu_count <=
|
|
784
|
+
and self._gpu_count <= self.device_manager.device_module.device_count()
|
|
626
785
|
), self._gpu_count
|
|
627
786
|
assert (
|
|
628
787
|
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
|
|
@@ -634,15 +793,17 @@ class ParameterServer:
|
|
|
634
793
|
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
635
794
|
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
636
795
|
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
|
|
796
|
+
# NPU transfer engine initialization requires prior set_device.
|
|
797
|
+
device_index = self._local_rank
|
|
798
|
+
self.device_manager.device_module.set_device(device_index)
|
|
637
799
|
try:
|
|
638
|
-
self._p2p_store = P2PStore()
|
|
800
|
+
self._p2p_store = P2PStore(self.device_manager)
|
|
639
801
|
except ImportError as e:
|
|
640
802
|
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
|
|
641
803
|
self._p2p_store = None
|
|
642
804
|
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
self._device_uuid = _get_physical_gpu_id(device_index)
|
|
805
|
+
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
|
|
806
|
+
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
646
807
|
|
|
647
808
|
def _logger_rank0(self, msg: str):
|
|
648
809
|
if self._local_rank == 0:
|
|
@@ -653,6 +814,13 @@ class ParameterServer:
|
|
|
653
814
|
|
|
654
815
|
def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
|
|
655
816
|
self._current_global_parameter_metas = metas
|
|
817
|
+
self._remote_rdma_devices = defaultdict(set)
|
|
818
|
+
for i, meta in self._current_global_parameter_metas.items():
|
|
819
|
+
assert meta.rdma_device is not None, "meta.rdma_device should not be None"
|
|
820
|
+
assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None"
|
|
821
|
+
self._remote_rdma_devices[
|
|
822
|
+
meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
|
|
823
|
+
].add(i)
|
|
656
824
|
|
|
657
825
|
def register_checkpoint(
|
|
658
826
|
self,
|
|
@@ -724,13 +892,15 @@ class ParameterServer:
|
|
|
724
892
|
for x in self._memory_pool.get(checkpoint_name, [])
|
|
725
893
|
],
|
|
726
894
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
727
|
-
host_ip=
|
|
895
|
+
host_ip=get_ip(),
|
|
728
896
|
device_uuid=self._device_uuid,
|
|
897
|
+
rdma_device=self._rdma_device or "",
|
|
729
898
|
)
|
|
730
899
|
|
|
731
900
|
dist.all_gather_object(metas_lst, metas)
|
|
732
901
|
|
|
733
902
|
self._current_global_parameter_metas = {}
|
|
903
|
+
|
|
734
904
|
num_parameters = 0
|
|
735
905
|
all_hosts: list[str] = []
|
|
736
906
|
global_device_uuids: list[str] = []
|
|
@@ -741,12 +911,24 @@ class ParameterServer:
|
|
|
741
911
|
if not self._global_device_uuids:
|
|
742
912
|
global_device_uuids.append(metas_buckets.device_uuid)
|
|
743
913
|
if metas_buckets.memory_buffer_metas_list:
|
|
744
|
-
self._current_global_parameter_metas[i] =
|
|
914
|
+
self._current_global_parameter_metas[i] = MemoryBufferMetaList(
|
|
915
|
+
memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
|
|
916
|
+
p2p_store_addr=metas_buckets.p2p_store_addr,
|
|
917
|
+
rdma_device=metas_buckets.rdma_device,
|
|
918
|
+
)
|
|
745
919
|
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
|
|
920
|
+
self._local_rdma_devices[
|
|
921
|
+
metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0]
|
|
922
|
+
if metas_buckets.p2p_store_addr
|
|
923
|
+
else metas_buckets.host_ip
|
|
924
|
+
].add(i)
|
|
746
925
|
if not self._all_hosts:
|
|
747
926
|
self._all_hosts = all_hosts
|
|
748
927
|
if not self._global_device_uuids:
|
|
749
928
|
self._global_device_uuids = global_device_uuids
|
|
929
|
+
# Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
|
|
930
|
+
# Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
|
|
931
|
+
self._remote_rdma_devices = self._local_rdma_devices.copy()
|
|
750
932
|
logger.info(
|
|
751
933
|
f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
|
|
752
934
|
)
|
|
@@ -775,7 +957,7 @@ class ParameterServer:
|
|
|
775
957
|
is_master=self._rank == 0,
|
|
776
958
|
)
|
|
777
959
|
dist.init_process_group(
|
|
778
|
-
backend=
|
|
960
|
+
backend=self.device_manager.backend,
|
|
779
961
|
world_size=self._world_size,
|
|
780
962
|
rank=self._rank,
|
|
781
963
|
timeout=timeout,
|
|
@@ -801,6 +983,7 @@ class ParameterServer:
|
|
|
801
983
|
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
|
|
802
984
|
which is useful in disaggregated architecture.
|
|
803
985
|
"""
|
|
986
|
+
assert req_func is not None, "req_func is required"
|
|
804
987
|
try:
|
|
805
988
|
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
806
989
|
if not ranks:
|
|
@@ -808,32 +991,31 @@ class ParameterServer:
|
|
|
808
991
|
self.init_process_group()
|
|
809
992
|
self._update_per_bucket(checkpoint_name, req_func)
|
|
810
993
|
else:
|
|
811
|
-
if not self._auto_pg and self._rank not in ranks:
|
|
812
|
-
return
|
|
813
994
|
if self._auto_pg:
|
|
814
995
|
if dist.is_initialized():
|
|
815
996
|
dist.destroy_process_group()
|
|
816
997
|
# HACK: wait 2s to ensure destroy is finished
|
|
817
998
|
time.sleep(2)
|
|
818
|
-
if self._rank not in ranks:
|
|
819
|
-
return
|
|
820
999
|
self.init_process_group_for_ranks(ranks)
|
|
821
|
-
self.
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
torch.cuda.empty_cache()
|
|
1000
|
+
if self._rank not in ranks:
|
|
1001
|
+
return
|
|
1002
|
+
self._update_per_bucket(checkpoint_name, req_func, ranks)
|
|
826
1003
|
|
|
827
|
-
logger.info(
|
|
828
|
-
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
829
|
-
f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
|
|
830
|
-
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
|
|
831
|
-
)
|
|
832
1004
|
except Exception as e:
|
|
833
1005
|
logger.exception(
|
|
834
1006
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
835
1007
|
)
|
|
836
1008
|
raise
|
|
1009
|
+
finally:
|
|
1010
|
+
if self._auto_pg and (not ranks or self._rank in ranks):
|
|
1011
|
+
dist.destroy_process_group()
|
|
1012
|
+
|
|
1013
|
+
self.device_manager.device_module.empty_cache()
|
|
1014
|
+
logger.info(
|
|
1015
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
1016
|
+
f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
|
|
1017
|
+
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
|
|
1018
|
+
)
|
|
837
1019
|
|
|
838
1020
|
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
|
|
839
1021
|
def zmq_handle(device_uuid: str) -> str:
|
|
@@ -850,14 +1032,16 @@ class ParameterServer:
|
|
|
850
1032
|
# auto detect bucket size
|
|
851
1033
|
tensor = torch.tensor(
|
|
852
1034
|
[
|
|
853
|
-
# proportion of current
|
|
854
|
-
int(
|
|
1035
|
+
# proportion of current device free memory bytes
|
|
1036
|
+
int(
|
|
1037
|
+
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
|
|
1038
|
+
),
|
|
855
1039
|
# we use negative value to reuse allreduce min operation
|
|
856
1040
|
# for getting the max value of zmq_addr_counter in all ranks
|
|
857
1041
|
-self._zmq_addr_counter,
|
|
858
1042
|
],
|
|
859
1043
|
dtype=torch.int64,
|
|
860
|
-
device=
|
|
1044
|
+
device=self.device_manager.device_type,
|
|
861
1045
|
)
|
|
862
1046
|
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
863
1047
|
tensor = tensor.cpu()
|
|
@@ -920,7 +1104,7 @@ class ParameterServer:
|
|
|
920
1104
|
assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
|
|
921
1105
|
if owner_rank is not None:
|
|
922
1106
|
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
|
|
923
|
-
|
|
1107
|
+
self.device_manager.device_module.synchronize()
|
|
924
1108
|
|
|
925
1109
|
def init_process_group_for_ranks(
|
|
926
1110
|
self,
|
|
@@ -960,73 +1144,12 @@ class ParameterServer:
|
|
|
960
1144
|
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
961
1145
|
)
|
|
962
1146
|
dist.init_process_group(
|
|
963
|
-
backend=
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
checkpoint_name: str,
|
|
969
|
-
req_func: Callable[[list[tuple[str, str]]], None],
|
|
970
|
-
ranks: list[int],
|
|
971
|
-
):
|
|
972
|
-
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
973
|
-
assert ranks, "ranks should be set"
|
|
974
|
-
if len(self._current_global_parameter_metas) == 0:
|
|
975
|
-
raise ValueError("parameter metas is empty")
|
|
976
|
-
assert dist.is_initialized(), (
|
|
977
|
-
"process group is not initialized when update model per bucket p2p"
|
|
978
|
-
)
|
|
979
|
-
|
|
980
|
-
need_update = self._rank in ranks
|
|
981
|
-
logger.info(
|
|
982
|
-
f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
|
|
983
|
-
f"gpu_count {self._gpu_count}, world_size {self._world_size}"
|
|
984
|
-
)
|
|
985
|
-
|
|
986
|
-
if not need_update:
|
|
987
|
-
return
|
|
988
|
-
|
|
989
|
-
# first execute a barrier to avoid subsequent cuda oom
|
|
990
|
-
dist.barrier()
|
|
991
|
-
|
|
992
|
-
bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
|
|
993
|
-
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
|
|
994
|
-
ipc_buffer_name = "__ipc_buffer___"
|
|
995
|
-
self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
|
|
996
|
-
logger.info(
|
|
997
|
-
f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
|
|
1147
|
+
backend=self.device_manager.backend,
|
|
1148
|
+
world_size=len(ranks),
|
|
1149
|
+
rank=rank,
|
|
1150
|
+
timeout=timeout,
|
|
1151
|
+
store=store,
|
|
998
1152
|
)
|
|
999
|
-
handle = reduce_tensor(buffer)
|
|
1000
|
-
|
|
1001
|
-
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
|
|
1002
|
-
socket, socket_paths = self._bind_zmq_socket()
|
|
1003
|
-
req_thread = threading.Thread(
|
|
1004
|
-
target=req_func,
|
|
1005
|
-
args=(socket_paths,),
|
|
1006
|
-
)
|
|
1007
|
-
req_thread.start()
|
|
1008
|
-
socket.send_pyobj(handle)
|
|
1009
|
-
for gidx, (owner_rank, bucket) in enumerate(buckets):
|
|
1010
|
-
self._logger_rank0(
|
|
1011
|
-
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
1012
|
-
)
|
|
1013
|
-
_buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size]
|
|
1014
|
-
if dist.get_rank() == 0:
|
|
1015
|
-
self._copy_to_buffer(checkpoint_name, bucket, _buffer, owner_rank)
|
|
1016
|
-
# broadcast the collected data to all ranks
|
|
1017
|
-
dist.broadcast(_buffer, src=0)
|
|
1018
|
-
socket.recv()
|
|
1019
|
-
dist.barrier()
|
|
1020
|
-
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
1021
|
-
|
|
1022
|
-
socket.recv()
|
|
1023
|
-
socket.send_pyobj(None)
|
|
1024
|
-
socket.recv()
|
|
1025
|
-
req_thread.join()
|
|
1026
|
-
dist.barrier()
|
|
1027
|
-
socket.close()
|
|
1028
|
-
self._p2p_store.unregister_named_tensors([ipc_buffer_name])
|
|
1029
|
-
torch.cuda.empty_cache()
|
|
1030
1153
|
|
|
1031
1154
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
1032
1155
|
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
|
|
@@ -1057,38 +1180,65 @@ class ParameterServer:
|
|
|
1057
1180
|
self,
|
|
1058
1181
|
checkpoint_name: str,
|
|
1059
1182
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1183
|
+
ranks: list[int] | None = None,
|
|
1060
1184
|
):
|
|
1061
|
-
|
|
1062
|
-
raise ValueError("parameter metas is empty")
|
|
1063
|
-
|
|
1185
|
+
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
1064
1186
|
assert dist.is_initialized(), "process group is not initialized"
|
|
1187
|
+
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1188
|
+
if not ranks:
|
|
1189
|
+
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
1190
|
+
# if ranks is set, it will use p2p to update to the ranks
|
|
1191
|
+
else:
|
|
1192
|
+
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1193
|
+
assert ranks, "ranks should be set"
|
|
1065
1194
|
|
|
1066
|
-
|
|
1195
|
+
need_update = self._rank in ranks
|
|
1196
|
+
logger.info(
|
|
1197
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
|
|
1198
|
+
f"gpu_count {self._gpu_count}, world_size {self._world_size}"
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
if not need_update:
|
|
1202
|
+
return
|
|
1203
|
+
# first execute a barrier to avoid subsequent device oom
|
|
1204
|
+
dist.barrier()
|
|
1067
1205
|
|
|
1068
1206
|
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
1069
|
-
buckets = _gen_h2d_buckets(
|
|
1207
|
+
buckets = _gen_h2d_buckets(
|
|
1208
|
+
self._current_global_parameter_metas,
|
|
1209
|
+
bucket_size,
|
|
1210
|
+
self._local_rdma_devices,
|
|
1211
|
+
self._remote_rdma_devices,
|
|
1212
|
+
ranks,
|
|
1213
|
+
)
|
|
1070
1214
|
|
|
1071
1215
|
h2d_buffer: torch.Tensor | None = (
|
|
1072
1216
|
None
|
|
1073
1217
|
if disable_h2d_buffer
|
|
1074
|
-
else torch.empty(bucket_size, dtype=torch.uint8, device=
|
|
1218
|
+
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
|
|
1075
1219
|
)
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
if
|
|
1220
|
+
# p2p store need to register h2d_buffer to let other ranks read
|
|
1221
|
+
if ranks:
|
|
1222
|
+
h2d_buffer_name = "__h2d_buffer__"
|
|
1223
|
+
if h2d_buffer is not None and self._p2p_store is not None:
|
|
1224
|
+
self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
|
|
1225
|
+
receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
|
|
1226
|
+
for receiver_rank, owner_rank, bucket in buckets:
|
|
1227
|
+
if receiver_rank != self._rank:
|
|
1080
1228
|
continue
|
|
1081
|
-
|
|
1229
|
+
receiver_rank_buckets.append((owner_rank, bucket))
|
|
1082
1230
|
|
|
1083
|
-
buffer = torch.empty(
|
|
1231
|
+
buffer = torch.empty(
|
|
1232
|
+
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
|
|
1233
|
+
)
|
|
1084
1234
|
handle = reduce_tensor(buffer)
|
|
1085
1235
|
|
|
1086
|
-
|
|
1236
|
+
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
|
|
1087
1237
|
max_len = 0
|
|
1088
|
-
for
|
|
1089
|
-
|
|
1090
|
-
if len(
|
|
1091
|
-
max_len = len(
|
|
1238
|
+
for receiver_rank, _, bucket in buckets:
|
|
1239
|
+
buckets_by_receiver_rank[receiver_rank].append(bucket)
|
|
1240
|
+
if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
|
|
1241
|
+
max_len = len(buckets_by_receiver_rank[receiver_rank])
|
|
1092
1242
|
|
|
1093
1243
|
socket, socket_paths = self._bind_zmq_socket()
|
|
1094
1244
|
req_thread = threading.Thread(
|
|
@@ -1099,43 +1249,66 @@ class ParameterServer:
|
|
|
1099
1249
|
socket.send_pyobj(handle)
|
|
1100
1250
|
|
|
1101
1251
|
gidx = 0
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
)
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1252
|
+
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
|
|
1253
|
+
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
|
|
1254
|
+
try:
|
|
1255
|
+
for i in range(max_len):
|
|
1256
|
+
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
1257
|
+
self._copy_to_buffer(
|
|
1258
|
+
checkpoint_name,
|
|
1259
|
+
receiver_rank_buckets[i][1],
|
|
1260
|
+
h2d_buffer,
|
|
1261
|
+
receiver_rank_buckets[i][0] if ranks else None,
|
|
1262
|
+
)
|
|
1263
|
+
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
|
|
1264
|
+
if i >= len(_buckets):
|
|
1265
|
+
continue
|
|
1266
|
+
bucket = _buckets[i]
|
|
1267
|
+
alloc, reserved = (
|
|
1268
|
+
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
|
|
1269
|
+
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
|
|
1270
|
+
)
|
|
1271
|
+
self._logger_rank0(
|
|
1272
|
+
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
1273
|
+
f"Current device allocated {alloc:.2f} MB, "
|
|
1274
|
+
f"reserved {reserved:.2f} MB."
|
|
1275
|
+
)
|
|
1276
|
+
start = gidx % 2 * bucket_size
|
|
1277
|
+
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
|
|
1278
|
+
if receiver_rank == self._rank:
|
|
1279
|
+
if disable_h2d_buffer:
|
|
1280
|
+
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1281
|
+
else:
|
|
1282
|
+
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1283
|
+
brank = bcast_rank_map[receiver_rank]
|
|
1284
|
+
dist.broadcast(buffer_b, src=brank)
|
|
1285
|
+
resp = socket.recv()
|
|
1286
|
+
if resp != b"":
|
|
1287
|
+
exception_obj = pickle.loads(resp)
|
|
1288
|
+
logger.error(
|
|
1289
|
+
f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
|
|
1290
|
+
)
|
|
1291
|
+
ret_code.fill_(1)
|
|
1292
|
+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
|
1293
|
+
self.device_manager.device_module.synchronize()
|
|
1294
|
+
if ret_code.item() != 0:
|
|
1295
|
+
# quit early if any rank failed
|
|
1296
|
+
socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
|
|
1297
|
+
raise RuntimeError("Failed to update weights due to remote errors")
|
|
1298
|
+
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
1299
|
+
gidx += 1
|
|
1300
|
+
|
|
1301
|
+
socket.recv()
|
|
1302
|
+
socket.send_pyobj(None)
|
|
1303
|
+
socket.recv()
|
|
1304
|
+
finally:
|
|
1305
|
+
req_thread.join()
|
|
1306
|
+
dist.barrier()
|
|
1307
|
+
socket.close()
|
|
1308
|
+
if ranks and h2d_buffer is not None:
|
|
1309
|
+
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
1310
|
+
|
|
1311
|
+
self.device_manager.device_module.empty_cache()
|
|
1139
1312
|
|
|
1140
1313
|
|
|
1141
1314
|
def _init_api(ps: ParameterServer) -> Any:
|
|
@@ -1153,6 +1326,7 @@ def _init_api(ps: ParameterServer) -> Any:
|
|
|
1153
1326
|
update_url: str | None = None
|
|
1154
1327
|
inference_group_ranks: list[int] = []
|
|
1155
1328
|
timeout: float = 300.0
|
|
1329
|
+
uds: str | None = None
|
|
1156
1330
|
|
|
1157
1331
|
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
1158
1332
|
try:
|
|
@@ -1185,7 +1359,9 @@ def _init_api(ps: ParameterServer) -> Any:
|
|
|
1185
1359
|
return
|
|
1186
1360
|
if req.inference_group_ranks:
|
|
1187
1361
|
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
|
|
1188
|
-
request_inference_to_update(
|
|
1362
|
+
request_inference_to_update(
|
|
1363
|
+
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
|
|
1364
|
+
)
|
|
1189
1365
|
|
|
1190
1366
|
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
|
|
1191
1367
|
|
checkpoint_engine/worker.py
CHANGED
|
@@ -5,6 +5,8 @@ from typing import TypedDict
|
|
|
5
5
|
import torch
|
|
6
6
|
import zmq
|
|
7
7
|
|
|
8
|
+
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
|
|
9
|
+
|
|
8
10
|
|
|
9
11
|
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
10
12
|
func, args = handle
|
|
@@ -53,51 +55,105 @@ def update_weights_from_ipc(
|
|
|
53
55
|
socket = zmq_ctx.socket(zmq.REP)
|
|
54
56
|
socket.connect(zmq_handle)
|
|
55
57
|
buffer: torch.Tensor | None = None
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
torch.cuda.synchronize()
|
|
63
|
-
socket.send(b"")
|
|
64
|
-
break
|
|
65
|
-
if isinstance(payload, tuple):
|
|
66
|
-
# an ipc handle that vLLM can use `func, args = handle`
|
|
67
|
-
# and `func(*args)` to rebuild GPU tensor.
|
|
68
|
-
buffer = _rebuild_ipc(payload, device_id)
|
|
69
|
-
assert buffer.dtype == torch.uint8
|
|
70
|
-
socket.send(b"")
|
|
71
|
-
continue
|
|
72
|
-
assert isinstance(payload, list)
|
|
73
|
-
run(_extract_weights(payload, buffer))
|
|
74
|
-
torch.cuda.synchronize()
|
|
58
|
+
device_manager = DeviceManager()
|
|
59
|
+
try:
|
|
60
|
+
ipc_handle: tuple[Callable, tuple] = socket.recv_pyobj()
|
|
61
|
+
assert isinstance(ipc_handle, tuple)
|
|
62
|
+
buffer = _rebuild_ipc(ipc_handle, device_id)
|
|
63
|
+
assert buffer.dtype == torch.uint8
|
|
75
64
|
socket.send(b"")
|
|
65
|
+
except Exception as e:
|
|
66
|
+
socket.send_pyobj(e)
|
|
67
|
+
socket.recv() # wait for ack
|
|
68
|
+
raise
|
|
69
|
+
try:
|
|
70
|
+
while True:
|
|
71
|
+
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
|
|
72
|
+
if payload is None: # done signal
|
|
73
|
+
if post_hook is not None:
|
|
74
|
+
post_hook()
|
|
75
|
+
device_manager.device_module.synchronize()
|
|
76
|
+
socket.send(b"")
|
|
77
|
+
break
|
|
78
|
+
if isinstance(payload, list): # still updating weights
|
|
79
|
+
try:
|
|
80
|
+
run(_extract_weights(payload, buffer))
|
|
81
|
+
device_manager.device_module.synchronize()
|
|
82
|
+
socket.send(b"")
|
|
83
|
+
except Exception as e: # noqa: BLE001
|
|
84
|
+
# Send exception back to Parameter Server.
|
|
85
|
+
# Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
|
|
86
|
+
socket.send_pyobj(e)
|
|
87
|
+
elif isinstance(
|
|
88
|
+
payload, Exception
|
|
89
|
+
): # error occurred, got force quit signal from Parameter Server
|
|
90
|
+
raise payload
|
|
91
|
+
else:
|
|
92
|
+
raise TypeError(f"Unexpected payload type: {type(payload)}")
|
|
76
93
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
94
|
+
finally:
|
|
95
|
+
socket.close()
|
|
96
|
+
del buffer
|
|
97
|
+
gc.collect()
|
|
98
|
+
device_manager.device_module.empty_cache()
|
|
81
99
|
|
|
82
100
|
|
|
83
101
|
class VllmColocateWorkerExtension:
|
|
84
102
|
"""
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
103
|
+
Worker extension for vLLM to update weights from checkpoint-engine.
|
|
104
|
+
|
|
105
|
+
This class provides a worker extension mechanism that allows vLLM workers to receive
|
|
106
|
+
and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
|
|
107
|
+
The methods in this worker extension will be injected into the vLLM worker class and
|
|
108
|
+
are callable from the `collective_rpc` API, enabling seamless weight updates for both
|
|
109
|
+
vLLM V0 and V1 versions.
|
|
110
|
+
|
|
111
|
+
Note:
|
|
112
|
+
This class is defined in a separate module. The fully qualified name
|
|
113
|
+
`checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
|
|
114
|
+
`worker_extension_cls` argument when initializing the vLLM worker.
|
|
91
115
|
"""
|
|
92
116
|
|
|
93
117
|
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
|
|
118
|
+
"""
|
|
119
|
+
Update model weights from checkpoint-engine via IPC communication.
|
|
120
|
+
|
|
121
|
+
This method establishes a ZMQ connection to the checkpoint-engine and receives
|
|
122
|
+
weight updates through a shared memory buffer. The update process includes:
|
|
123
|
+
1. Receiving IPC handles to reconstruct shared memory tensors
|
|
124
|
+
2. Extracting flattened metadata describing tensor weights in the shared memory tensor
|
|
125
|
+
3. Loading weights into the model
|
|
126
|
+
4. Post-processing weights after loading
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
|
|
130
|
+
The device UUID is platform-specific:
|
|
131
|
+
- For CUDA: UUID from `current_platform.get_device_uuid()`
|
|
132
|
+
- For NPU: Format "NPU-{generated_uuid}"
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
ValueError: If the device type is not supported (not CUDA or NPU).
|
|
136
|
+
AssertionError: If the device is not properly initialized.
|
|
137
|
+
|
|
138
|
+
Note:
|
|
139
|
+
This method is called by vLLM's collective RPC mechanism. The ZMQ context
|
|
140
|
+
is lazily initialized on first call and reused for subsequent updates.
|
|
141
|
+
"""
|
|
94
142
|
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
|
95
143
|
from vllm.platforms import current_platform
|
|
96
144
|
|
|
145
|
+
# vllm-ascend not init device
|
|
146
|
+
if current_platform.device_type == "npu" and self.device is None:
|
|
147
|
+
self.device = torch.device(f"npu:{self.local_rank}")
|
|
97
148
|
assert self.device is not None
|
|
98
149
|
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
|
99
150
|
self._zmq_ctx = zmq.Context()
|
|
100
|
-
|
|
151
|
+
if current_platform.device_type == "cuda":
|
|
152
|
+
device_uuid = current_platform.get_device_uuid(self.device.index)
|
|
153
|
+
elif current_platform.device_type == "npu":
|
|
154
|
+
device_uuid = f"NPU-{npu_generate_uuid()}"
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(f"Unsupported device type: {current_platform.device_type}")
|
|
101
157
|
update_weights_from_ipc(
|
|
102
158
|
self._zmq_ctx,
|
|
103
159
|
zmq_handles[device_uuid],
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -38,8 +38,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
|
|
|
38
38
|
|
|
39
39
|
The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
|
|
40
40
|
|
|
41
|
-
- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
|
|
42
|
-
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `
|
|
41
|
+
- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
|
|
42
|
+
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
|
|
43
43
|
|
|
44
44
|
### Optimized Weight Broadcast
|
|
45
45
|
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
|
|
@@ -60,16 +60,22 @@ It then executes the transfer, where it controls the inference engine through a
|
|
|
60
60
|
|
|
61
61
|
Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
|
|
62
62
|
|
|
63
|
+
### Optimized P2P Bucket Assignment
|
|
64
|
+
In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
|
|
65
|
+
To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
|
|
66
|
+
The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
|
|
67
|
+
See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
|
|
68
|
+
|
|
63
69
|
## Benchmark
|
|
64
70
|
|
|
65
71
|
| Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
|
|
66
72
|
| :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
|
|
67
|
-
| GLM-4.5-Air (BF16) | 8xH800 TP8
|
|
68
|
-
| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8
|
|
69
|
-
| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.
|
|
70
|
-
| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.
|
|
71
|
-
| DeepSeek-V3.1 (FP8) | 256xH20 TP16 |
|
|
72
|
-
| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.
|
|
73
|
+
| GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
|
|
74
|
+
| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
|
|
75
|
+
| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
|
|
76
|
+
| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
|
|
77
|
+
| DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
|
|
78
|
+
| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
|
|
73
79
|
|
|
74
80
|
All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
|
|
75
81
|
|
|
@@ -77,6 +83,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and
|
|
|
77
83
|
* Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
|
|
78
84
|
* Since update duration is related to IPC bucket size, we provide the bucket size in the table.
|
|
79
85
|
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
|
|
86
|
+
* We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
|
|
80
87
|
|
|
81
88
|
## Installation
|
|
82
89
|
|
|
@@ -92,7 +99,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
|
|
|
92
99
|
pip install 'checkpoint-engine[p2p]'
|
|
93
100
|
```
|
|
94
101
|
|
|
95
|
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
102
|
+
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
96
103
|
|
|
97
104
|
## Getting Started
|
|
98
105
|
|
|
@@ -162,14 +169,64 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
|
|
|
162
169
|
Run a simple correctness test for checkpoint_engine
|
|
163
170
|
|
|
164
171
|
```bash
|
|
165
|
-
|
|
172
|
+
pytest tests/test_update.py
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
`test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
|
|
176
|
+
|
|
177
|
+
Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
|
|
178
|
+
|
|
179
|
+
```bash
|
|
180
|
+
pytest tests/ -m "not gpu"
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
## SGLang Integration
|
|
184
|
+
|
|
185
|
+
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
186
|
+
|
|
187
|
+
### Quick Start
|
|
188
|
+
|
|
189
|
+
**1. Install checkpoint-engine:**
|
|
190
|
+
```bash
|
|
191
|
+
pip install 'checkpoint-engine[p2p]'
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
**2. Launch SGLang server:**
|
|
195
|
+
```bash
|
|
196
|
+
python -m sglang.launch_server \
|
|
197
|
+
--model-path $MODEL_PATH \
|
|
198
|
+
--tp 8 \
|
|
199
|
+
--load-format dummy \
|
|
200
|
+
--wait-for-initial-weights
|
|
201
|
+
```
|
|
202
|
+
|
|
203
|
+
**3. Run checkpoint engine:**
|
|
204
|
+
```bash
|
|
205
|
+
python -m sglang.srt.checkpoint_engine.update \
|
|
206
|
+
--update-method broadcast \
|
|
207
|
+
--checkpoint-path $MODEL_PATH \
|
|
208
|
+
--inference-parallel-size 8
|
|
166
209
|
```
|
|
167
210
|
|
|
211
|
+
### Multi-Node Setup
|
|
212
|
+
|
|
213
|
+
For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
|
|
214
|
+
|
|
215
|
+
### Key Options
|
|
216
|
+
|
|
217
|
+
**SGLang Server:**
|
|
218
|
+
- `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
|
|
219
|
+
- `--load-format dummy`: Enable overlapping initialization tasks
|
|
220
|
+
|
|
221
|
+
**Checkpoint Engine:**
|
|
222
|
+
- `--update-method`: Choose `broadcast`, `p2p`, or `all`
|
|
223
|
+
- `--inference-parallel-size`: Number of parallel processes
|
|
224
|
+
- `--checkpoint-path`: Model checkpoint directory
|
|
225
|
+
|
|
168
226
|
## Limitations and Future Work
|
|
169
227
|
|
|
170
|
-
- This project is currently
|
|
228
|
+
- This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
|
|
171
229
|
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
|
|
172
|
-
- The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
|
|
173
230
|
|
|
174
231
|
## Acknowledgments
|
|
175
232
|
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=vYqoJTG51NOUmYyL0xt8asRK8vUT4lGAdal_EZ59mvw,704
|
|
3
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
4
|
+
checkpoint_engine/ps.py,sha256=Mgfr_MXYYZ_6JKqD5kIGDBWCYNWAtIPfAppP_cFu604,57781
|
|
5
|
+
checkpoint_engine/worker.py,sha256=5TzDgTPew6Ts9sMOzecalLCR1p_ZwfeKPdzzr68kAQg,6564
|
|
6
|
+
checkpoint_engine-0.2.1.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
7
|
+
checkpoint_engine-0.2.1.dist-info/METADATA,sha256=7E7NhehWHpS6QVkLup-oCm350wdbiZX8CY3jmmJP0bU,11315
|
|
8
|
+
checkpoint_engine-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
checkpoint_engine-0.2.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
10
|
+
checkpoint_engine-0.2.1.dist-info/RECORD,,
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
-
checkpoint_engine/_version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
|
|
3
|
-
checkpoint_engine/ps.py,sha256=9dXRXi0QDPoRYrgGKAYvdmDFBXejgusjR0ltbii5_B0,49134
|
|
4
|
-
checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
|
|
5
|
-
checkpoint_engine-0.1.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
6
|
-
checkpoint_engine-0.1.3.dist-info/METADATA,sha256=y96dMjEOKWaO_PA0h5BX8G3Ku7Tt1jCU09uIf8iYgic,9322
|
|
7
|
-
checkpoint_engine-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
-
checkpoint_engine-0.1.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
-
checkpoint_engine-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|