checkpoint-engine 0.1.3__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.3'
32
- __version_tuple__ = version_tuple = (0, 1, 3)
31
+ __version__ = version = '0.2.1'
32
+ __version_tuple__ = version_tuple = (0, 2, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,86 @@
1
+ import os
2
+ import re
3
+ import socket
4
+ import subprocess
5
+ from functools import lru_cache
6
+
7
+ import torch
8
+ from loguru import logger
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_ip() -> str:
13
+ try:
14
+ # try to get ip from network interface
15
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
16
+ s.connect(("8.8.8.8", 80))
17
+ return s.getsockname()[0]
18
+ except Exception as e: # noqa: BLE001
19
+ # fallback to get ip from hostname
20
+ logger.warning(
21
+ f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
22
+ )
23
+ return socket.gethostbyname(socket.gethostname())
24
+
25
+
26
+ def npu_generate_uuid() -> str:
27
+ str_pid = str(os.getpid())
28
+ npu_num = 8
29
+ try:
30
+ for npu_id in range(npu_num):
31
+ cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
32
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
33
+ str_result = str(result.stdout)
34
+ if str_pid in str_result:
35
+ # In A3 server, one NPU has two chips.
36
+ match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
37
+ chip_count = int(match_chip_count.group(1))
38
+ search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
39
+ match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
40
+ chip_id = int(match_chip_id.group(1))
41
+ return f"{get_ip()}-{npu_id * chip_count + chip_id}"
42
+ raise ValueError("The current process is not running on the npu device")
43
+ except subprocess.CalledProcessError as e:
44
+ raise ValueError("The current process is not running on the npu device") from e
45
+
46
+
47
+ class DeviceManager:
48
+ def __init__(self):
49
+ self.device_type = self._detect_device_type()
50
+ self._setup_device_module()
51
+
52
+ def _is_torch_npu_available(self) -> bool:
53
+ try:
54
+ if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
55
+ return torch.npu.is_available()
56
+ else:
57
+ return False
58
+ except ImportError:
59
+ return False
60
+
61
+ def _detect_device_type(self) -> str:
62
+ if self._is_torch_npu_available():
63
+ return "npu"
64
+ elif torch.cuda.is_available():
65
+ return "cuda"
66
+ else:
67
+ raise TypeError("The current device type is not supported")
68
+
69
+ def _setup_device_module(self):
70
+ if self.device_type == "npu":
71
+ import torch_npu
72
+
73
+ self.device_module = torch_npu.npu
74
+ elif self.device_type == "cuda":
75
+ self.device_module = torch.cuda
76
+ else:
77
+ raise TypeError("The current device type is not supported")
78
+
79
+ @property
80
+ def backend(self) -> str:
81
+ if self.device_type == "npu":
82
+ return "hccl"
83
+ elif self.device_type == "cuda":
84
+ return "nccl"
85
+ else:
86
+ raise TypeError("The current device type is not supported")
checkpoint_engine/ps.py CHANGED
@@ -4,13 +4,11 @@ import ctypes
4
4
  import os
5
5
  import pickle
6
6
  import random
7
- import socket
8
7
  import threading
9
8
  import time
10
9
  from collections import defaultdict
11
10
  from collections.abc import Callable
12
11
  from datetime import timedelta
13
- from functools import lru_cache
14
12
  from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
15
13
 
16
14
  import httpx
@@ -23,8 +21,12 @@ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
23
21
  from safetensors.torch import safe_open
24
22
  from torch.multiprocessing.reductions import reduce_tensor
25
23
 
24
+ from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
25
+
26
26
 
27
27
  if TYPE_CHECKING:
28
+ from typing import TypeVar
29
+
28
30
  from typing_extensions import TypedDict
29
31
 
30
32
  class FileMeta(TypedDict):
@@ -34,6 +36,8 @@ if TYPE_CHECKING:
34
36
  type: type
35
37
  tp_concat_dim: int
36
38
 
39
+ T = TypeVar("T")
40
+
37
41
 
38
42
  def _dt_validate(value: Any) -> torch.dtype:
39
43
  if isinstance(value, str):
@@ -117,6 +121,7 @@ class MemoryBuffer(BaseModel):
117
121
  class MemoryBufferMetaList(BaseModel):
118
122
  p2p_store_addr: str | None
119
123
  memory_buffer_metas_list: list[MemoryBufferMetas]
124
+ rdma_device: str
120
125
 
121
126
 
122
127
  class DataToGather(MemoryBufferMetaList):
@@ -249,28 +254,16 @@ def _concat_tp_weights(
249
254
  return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
250
255
 
251
256
 
252
- def _get_physical_gpu_id(device_index: int | None = None) -> str:
257
+ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
253
258
  try:
254
- return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
259
+ if device_manager.device_type == "npu":
260
+ return f"NPU-{npu_generate_uuid()}"
261
+ else:
262
+ return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
255
263
  except AssertionError as e:
256
264
  raise ValueError(f"fail to get physical gpu id {device_index}") from e
257
265
 
258
266
 
259
- @lru_cache(maxsize=1)
260
- def _get_ip() -> str:
261
- try:
262
- # try to get ip from network interface
263
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
264
- s.connect(("8.8.8.8", 80))
265
- return s.getsockname()[0]
266
- except Exception as e: # noqa: BLE001
267
- # fallback to get ip from hostname
268
- logger.warning(
269
- f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
270
- )
271
- return socket.gethostbyname(socket.gethostname())
272
-
273
-
274
267
  def _ibv_get_device_list() -> list[str]:
275
268
  lib = ctypes.CDLL("libibverbs.so.1")
276
269
  lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
@@ -303,14 +296,7 @@ def _get_rdma_devices() -> list[str]:
303
296
  return devices_str.split(",")
304
297
  # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
305
298
  hca = os.getenv("NCCL_IB_HCA", None)
306
- if hca:
307
- hca_list = hca.split(",")
308
- if len(hca_list) > 1:
309
- # if NCCL_IB_HCA has multiple values, just return
310
- return hca_list
311
- else:
312
- hca = hca_list[0]
313
- return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
299
+ return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
314
300
 
315
301
 
316
302
  def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
@@ -319,13 +305,90 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
319
305
  """
320
306
  if not devices:
321
307
  raise RuntimeError("no rdma devices found")
322
- assert len(devices) <= gpu_count, (
323
- f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
324
- )
325
- assert gpu_count % len(devices) == 0, (
326
- f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
327
- )
328
- return devices[local_rank // (gpu_count // len(devices))]
308
+ try:
309
+ assert len(devices) <= gpu_count, (
310
+ f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
311
+ )
312
+ assert gpu_count % len(devices) == 0, (
313
+ f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
314
+ )
315
+ return devices[local_rank // (gpu_count // len(devices))]
316
+ except AssertionError:
317
+ logger.error(
318
+ "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
319
+ "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
320
+ "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
321
+ )
322
+ raise
323
+
324
+
325
+ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
326
+ """
327
+ The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
328
+ The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
329
+
330
+ The list is comma-separated; port numbers are NOT supported yet.
331
+ An optional prefix '^' indicates the list is an exclude list.
332
+ A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
333
+ Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
334
+
335
+ Examples:
336
+ - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
337
+ - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
338
+ - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
339
+ - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
340
+ """
341
+ max_hcas = 32
342
+ if not value or value.strip() == "":
343
+ return available_devices[:max_hcas]
344
+
345
+ value = value.strip()
346
+ result = []
347
+ is_exclude = value.startswith("^")
348
+ if is_exclude:
349
+ value = value.removeprefix("^")
350
+ is_exact_match = value.startswith("=")
351
+ if is_exact_match:
352
+ value = value.removeprefix("=")
353
+
354
+ device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
355
+
356
+ result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
357
+ if is_exclude:
358
+ result = [dev for dev in available_devices if dev not in result]
359
+ if len(result) > max_hcas:
360
+ result = result[:max_hcas]
361
+
362
+ logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
363
+
364
+ return result
365
+
366
+
367
+ def _resolve_device_specs(
368
+ device_specs: list[str], is_exact_match: bool, available_devices: list[str]
369
+ ) -> list[str]:
370
+ devices = set()
371
+ for spec in device_specs:
372
+ parts = spec.split(":", 1)
373
+ device_name = parts[0].strip()
374
+ # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
375
+ # port = parts[1].strip() if len(parts) > 1 else None
376
+ base_devices = (
377
+ [device_name]
378
+ if device_name in available_devices
379
+ else []
380
+ if is_exact_match
381
+ else [dev for dev in available_devices if dev.startswith(device_name)]
382
+ )
383
+
384
+ if not base_devices:
385
+ logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
386
+ continue
387
+
388
+ for base_dev in base_devices:
389
+ devices.add(base_dev)
390
+
391
+ return sorted(devices)
329
392
 
330
393
 
331
394
  def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
@@ -490,8 +553,12 @@ def request_inference_to_update(
490
553
 
491
554
 
492
555
  def _gen_h2d_buckets(
493
- global_metas: dict[int, MemoryBufferMetaList], bucket_size: int
494
- ) -> list[tuple[int, H2DBucket]]:
556
+ global_metas: dict[int, MemoryBufferMetaList],
557
+ bucket_size: int,
558
+ local_topo: dict[str, set[int]],
559
+ remote_topo: dict[str, set[int]],
560
+ ranks: list[int] | None = None,
561
+ ) -> list[tuple[int, int, H2DBucket]]:
495
562
  buckets: list[tuple[int, H2DBucket]] = []
496
563
 
497
564
  for owner_rank, items in global_metas.items():
@@ -514,7 +581,73 @@ def _gen_h2d_buckets(
514
581
  assert buckets[-1][1].size > 0, (
515
582
  f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
516
583
  )
517
- return buckets
584
+ ranks_set = set(ranks) if ranks else set()
585
+ actual_local_topo = (
586
+ {k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo
587
+ )
588
+ # if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
589
+ if not ranks:
590
+ return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
591
+ else:
592
+ return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
593
+
594
+
595
+ def _assign_receiver_ranks(
596
+ buckets: list[tuple[int, "T"]],
597
+ local_topo: dict[str, set[int]],
598
+ remote_topo: dict[str, set[int]],
599
+ ) -> list[tuple[int, int, "T"]]:
600
+ """
601
+ (owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
602
+
603
+ Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
604
+ GPU-rdma_device topology will be considered to make full use of the bandwidth.
605
+ """
606
+ if not buckets:
607
+ logger.warning("bucket list is empty, no need to assign receiver ranks")
608
+ return []
609
+ rank_to_rdma_device = {
610
+ rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
611
+ }
612
+
613
+ # group buckets by owner RDMA devices
614
+ buckets_by_rdma_device = defaultdict(list)
615
+ for owner_rank, bucket in buckets:
616
+ owner_rdma_device = rank_to_rdma_device[owner_rank]
617
+ buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
618
+
619
+ buckets_matrix = list(buckets_by_rdma_device.values())
620
+ assert buckets_matrix, "buckets_matrix should not be empty"
621
+
622
+ # Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
623
+ num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
624
+ receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
625
+
626
+ flattened_buckets = [
627
+ buckets_matrix[row][col]
628
+ for col in range(
629
+ max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0
630
+ )
631
+ for row in range(len(buckets_matrix))
632
+ if col < len(buckets_matrix[row])
633
+ ]
634
+
635
+ buckets_with_receiver = []
636
+ assigned_cnt = 0
637
+ while assigned_cnt < len(flattened_buckets):
638
+ occupied_devices = set()
639
+ for receiver_rank in receiver_list:
640
+ if assigned_cnt >= len(flattened_buckets):
641
+ break
642
+ owner_rank, bucket = flattened_buckets[assigned_cnt]
643
+ rdma_device = rank_to_rdma_device[owner_rank]
644
+ if rdma_device in occupied_devices:
645
+ break
646
+ buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
647
+ occupied_devices.add(rdma_device)
648
+ assigned_cnt += 1
649
+
650
+ return buckets_with_receiver
518
651
 
519
652
 
520
653
  def _get_master_port(master_port: int | None = None) -> int:
@@ -525,21 +658,44 @@ def _get_master_port(master_port: int | None = None) -> int:
525
658
  return master_port
526
659
 
527
660
 
661
+ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
662
+ """
663
+ map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
664
+ which are generated in self.init_process_group_for_ranks
665
+ """
666
+ bcast_rank_map: dict[int, int] = {}
667
+ if not ranks:
668
+ bcast_rank_map = {r: r for r in range(world_size)}
669
+ else:
670
+ for i, r in enumerate(ranks):
671
+ bcast_rank_map[r] = i
672
+ return bcast_rank_map
673
+
674
+
528
675
  class P2PStore:
529
- def __init__(self):
676
+ def __init__(self, device_manager: DeviceManager):
530
677
  from mooncake.engine import TransferEngine
531
678
 
532
679
  self.rank = int(os.getenv("RANK"))
533
- gpu_count = torch.cuda.device_count()
680
+ gpu_count = device_manager.device_module.device_count()
534
681
  local_rank = self.rank % gpu_count
535
- device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
536
- self.ip = _get_ip()
682
+ device_type = device_manager.device_type
683
+ if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
684
+ self.device = ""
685
+ else:
686
+ self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
687
+ self.ip = get_ip()
537
688
 
538
689
  # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
539
690
  retry_count = 8
540
691
  for i in range(retry_count):
541
692
  self.engine = TransferEngine()
542
- ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", device)
693
+ ret = self.engine.initialize(
694
+ self.ip,
695
+ "P2PHANDSHAKE",
696
+ "ascend_direct" if device_type == "npu" else "rdma",
697
+ self.device,
698
+ )
543
699
  if ret == 0:
544
700
  break
545
701
  # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
@@ -553,7 +709,7 @@ class P2PStore:
553
709
  self.port = self.engine.get_rpc_port()
554
710
  self.named_tensors: dict[str, torch.Tensor] = {}
555
711
  logger.info(
556
- f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
712
+ f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
557
713
  )
558
714
 
559
715
  @property
@@ -606,15 +762,18 @@ class ParameterServer:
606
762
  Args:
607
763
  auto_pg: Whether to automatically initialize the process group.
608
764
  Notice that if auto_pg is True, will destroy the process group after update.
609
- mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
765
+ mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
610
766
  """
611
767
  self._rank = rank or int(os.environ.get("RANK", None))
612
768
  self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
613
- self._gpu_count = gpu_count or torch.cuda.device_count()
769
+ self.device_manager = DeviceManager()
770
+ self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
614
771
  self._local_rank = self._rank % self._gpu_count
615
772
  self._auto_pg = auto_pg
616
773
  self._all_hosts = []
617
774
  self._global_device_uuids: list[str] = []
775
+ self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
776
+ self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
618
777
  self._mem_fraction = mem_fraction or 0.9
619
778
 
620
779
  assert self._rank is not None and self._rank >= 0, self._rank
@@ -622,7 +781,7 @@ class ParameterServer:
622
781
  assert (
623
782
  self._gpu_count is not None
624
783
  and self._gpu_count > 0
625
- and self._gpu_count <= torch.cuda.device_count()
784
+ and self._gpu_count <= self.device_manager.device_module.device_count()
626
785
  ), self._gpu_count
627
786
  assert (
628
787
  self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
@@ -634,15 +793,17 @@ class ParameterServer:
634
793
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
635
794
  # dict key is owner_rank, value is a bucket metas list in owner_rank
636
795
  self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
796
+ # NPU transfer engine initialization requires prior set_device.
797
+ device_index = self._local_rank
798
+ self.device_manager.device_module.set_device(device_index)
637
799
  try:
638
- self._p2p_store = P2PStore()
800
+ self._p2p_store = P2PStore(self.device_manager)
639
801
  except ImportError as e:
640
802
  logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
641
803
  self._p2p_store = None
642
804
 
643
- device_index = self._local_rank
644
- torch.cuda.set_device(device_index)
645
- self._device_uuid = _get_physical_gpu_id(device_index)
805
+ self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
806
+ self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
646
807
 
647
808
  def _logger_rank0(self, msg: str):
648
809
  if self._local_rank == 0:
@@ -653,6 +814,13 @@ class ParameterServer:
653
814
 
654
815
  def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
655
816
  self._current_global_parameter_metas = metas
817
+ self._remote_rdma_devices = defaultdict(set)
818
+ for i, meta in self._current_global_parameter_metas.items():
819
+ assert meta.rdma_device is not None, "meta.rdma_device should not be None"
820
+ assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None"
821
+ self._remote_rdma_devices[
822
+ meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
823
+ ].add(i)
656
824
 
657
825
  def register_checkpoint(
658
826
  self,
@@ -724,13 +892,15 @@ class ParameterServer:
724
892
  for x in self._memory_pool.get(checkpoint_name, [])
725
893
  ],
726
894
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
727
- host_ip=_get_ip(),
895
+ host_ip=get_ip(),
728
896
  device_uuid=self._device_uuid,
897
+ rdma_device=self._rdma_device or "",
729
898
  )
730
899
 
731
900
  dist.all_gather_object(metas_lst, metas)
732
901
 
733
902
  self._current_global_parameter_metas = {}
903
+
734
904
  num_parameters = 0
735
905
  all_hosts: list[str] = []
736
906
  global_device_uuids: list[str] = []
@@ -741,12 +911,24 @@ class ParameterServer:
741
911
  if not self._global_device_uuids:
742
912
  global_device_uuids.append(metas_buckets.device_uuid)
743
913
  if metas_buckets.memory_buffer_metas_list:
744
- self._current_global_parameter_metas[i] = metas_buckets
914
+ self._current_global_parameter_metas[i] = MemoryBufferMetaList(
915
+ memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
916
+ p2p_store_addr=metas_buckets.p2p_store_addr,
917
+ rdma_device=metas_buckets.rdma_device,
918
+ )
745
919
  num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
920
+ self._local_rdma_devices[
921
+ metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0]
922
+ if metas_buckets.p2p_store_addr
923
+ else metas_buckets.host_ip
924
+ ].add(i)
746
925
  if not self._all_hosts:
747
926
  self._all_hosts = all_hosts
748
927
  if not self._global_device_uuids:
749
928
  self._global_device_uuids = global_device_uuids
929
+ # Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
930
+ # Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
931
+ self._remote_rdma_devices = self._local_rdma_devices.copy()
750
932
  logger.info(
751
933
  f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
752
934
  )
@@ -775,7 +957,7 @@ class ParameterServer:
775
957
  is_master=self._rank == 0,
776
958
  )
777
959
  dist.init_process_group(
778
- backend="nccl",
960
+ backend=self.device_manager.backend,
779
961
  world_size=self._world_size,
780
962
  rank=self._rank,
781
963
  timeout=timeout,
@@ -801,6 +983,7 @@ class ParameterServer:
801
983
  If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
802
984
  which is useful in disaggregated architecture.
803
985
  """
986
+ assert req_func is not None, "req_func is required"
804
987
  try:
805
988
  # if both ranks is None or [], it will use fully broadcast to update to all ranks
806
989
  if not ranks:
@@ -808,32 +991,31 @@ class ParameterServer:
808
991
  self.init_process_group()
809
992
  self._update_per_bucket(checkpoint_name, req_func)
810
993
  else:
811
- if not self._auto_pg and self._rank not in ranks:
812
- return
813
994
  if self._auto_pg:
814
995
  if dist.is_initialized():
815
996
  dist.destroy_process_group()
816
997
  # HACK: wait 2s to ensure destroy is finished
817
998
  time.sleep(2)
818
- if self._rank not in ranks:
819
- return
820
999
  self.init_process_group_for_ranks(ranks)
821
- self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
822
- if self._auto_pg:
823
- dist.destroy_process_group()
824
-
825
- torch.cuda.empty_cache()
1000
+ if self._rank not in ranks:
1001
+ return
1002
+ self._update_per_bucket(checkpoint_name, req_func, ranks)
826
1003
 
827
- logger.info(
828
- f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
829
- f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
830
- f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
831
- )
832
1004
  except Exception as e:
833
1005
  logger.exception(
834
1006
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
835
1007
  )
836
1008
  raise
1009
+ finally:
1010
+ if self._auto_pg and (not ranks or self._rank in ranks):
1011
+ dist.destroy_process_group()
1012
+
1013
+ self.device_manager.device_module.empty_cache()
1014
+ logger.info(
1015
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1016
+ f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
1017
+ f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
1018
+ )
837
1019
 
838
1020
  def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
839
1021
  def zmq_handle(device_uuid: str) -> str:
@@ -850,14 +1032,16 @@ class ParameterServer:
850
1032
  # auto detect bucket size
851
1033
  tensor = torch.tensor(
852
1034
  [
853
- # proportion of current cuda free memory bytes
854
- int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
1035
+ # proportion of current device free memory bytes
1036
+ int(
1037
+ float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
1038
+ ),
855
1039
  # we use negative value to reuse allreduce min operation
856
1040
  # for getting the max value of zmq_addr_counter in all ranks
857
1041
  -self._zmq_addr_counter,
858
1042
  ],
859
1043
  dtype=torch.int64,
860
- device="cuda",
1044
+ device=self.device_manager.device_type,
861
1045
  )
862
1046
  dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
863
1047
  tensor = tensor.cpu()
@@ -920,7 +1104,7 @@ class ParameterServer:
920
1104
  assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
921
1105
  if owner_rank is not None:
922
1106
  self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
923
- torch.cuda.synchronize()
1107
+ self.device_manager.device_module.synchronize()
924
1108
 
925
1109
  def init_process_group_for_ranks(
926
1110
  self,
@@ -960,73 +1144,12 @@ class ParameterServer:
960
1144
  master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
961
1145
  )
962
1146
  dist.init_process_group(
963
- backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
964
- )
965
-
966
- def _update_per_bucket_p2p(
967
- self,
968
- checkpoint_name: str,
969
- req_func: Callable[[list[tuple[str, str]]], None],
970
- ranks: list[int],
971
- ):
972
- assert self._p2p_store is not None, "p2p store is not initialized"
973
- assert ranks, "ranks should be set"
974
- if len(self._current_global_parameter_metas) == 0:
975
- raise ValueError("parameter metas is empty")
976
- assert dist.is_initialized(), (
977
- "process group is not initialized when update model per bucket p2p"
978
- )
979
-
980
- need_update = self._rank in ranks
981
- logger.info(
982
- f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
983
- f"gpu_count {self._gpu_count}, world_size {self._world_size}"
984
- )
985
-
986
- if not need_update:
987
- return
988
-
989
- # first execute a barrier to avoid subsequent cuda oom
990
- dist.barrier()
991
-
992
- bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
993
- buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
994
- ipc_buffer_name = "__ipc_buffer___"
995
- self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
996
- logger.info(
997
- f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
1147
+ backend=self.device_manager.backend,
1148
+ world_size=len(ranks),
1149
+ rank=rank,
1150
+ timeout=timeout,
1151
+ store=store,
998
1152
  )
999
- handle = reduce_tensor(buffer)
1000
-
1001
- buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
1002
- socket, socket_paths = self._bind_zmq_socket()
1003
- req_thread = threading.Thread(
1004
- target=req_func,
1005
- args=(socket_paths,),
1006
- )
1007
- req_thread.start()
1008
- socket.send_pyobj(handle)
1009
- for gidx, (owner_rank, bucket) in enumerate(buckets):
1010
- self._logger_rank0(
1011
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1012
- )
1013
- _buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size]
1014
- if dist.get_rank() == 0:
1015
- self._copy_to_buffer(checkpoint_name, bucket, _buffer, owner_rank)
1016
- # broadcast the collected data to all ranks
1017
- dist.broadcast(_buffer, src=0)
1018
- socket.recv()
1019
- dist.barrier()
1020
- socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1021
-
1022
- socket.recv()
1023
- socket.send_pyobj(None)
1024
- socket.recv()
1025
- req_thread.join()
1026
- dist.barrier()
1027
- socket.close()
1028
- self._p2p_store.unregister_named_tensors([ipc_buffer_name])
1029
- torch.cuda.empty_cache()
1030
1153
 
1031
1154
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
1032
1155
  addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
@@ -1057,38 +1180,65 @@ class ParameterServer:
1057
1180
  self,
1058
1181
  checkpoint_name: str,
1059
1182
  req_func: Callable[[list[tuple[str, str]]], None],
1183
+ ranks: list[int] | None = None,
1060
1184
  ):
1061
- if len(self._current_global_parameter_metas) == 0:
1062
- raise ValueError("parameter metas is empty")
1063
-
1185
+ assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1064
1186
  assert dist.is_initialized(), "process group is not initialized"
1187
+ # if both ranks is None or [], it will use fully broadcast to update to all ranks
1188
+ if not ranks:
1189
+ logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
1190
+ # if ranks is set, it will use p2p to update to the ranks
1191
+ else:
1192
+ assert self._p2p_store is not None, "p2p store is not initialized"
1193
+ assert ranks, "ranks should be set"
1065
1194
 
1066
- logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
1195
+ need_update = self._rank in ranks
1196
+ logger.info(
1197
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
1198
+ f"gpu_count {self._gpu_count}, world_size {self._world_size}"
1199
+ )
1200
+
1201
+ if not need_update:
1202
+ return
1203
+ # first execute a barrier to avoid subsequent device oom
1204
+ dist.barrier()
1067
1205
 
1068
1206
  bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1069
- buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
1207
+ buckets = _gen_h2d_buckets(
1208
+ self._current_global_parameter_metas,
1209
+ bucket_size,
1210
+ self._local_rdma_devices,
1211
+ self._remote_rdma_devices,
1212
+ ranks,
1213
+ )
1070
1214
 
1071
1215
  h2d_buffer: torch.Tensor | None = (
1072
1216
  None
1073
1217
  if disable_h2d_buffer
1074
- else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1218
+ else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
1075
1219
  )
1076
-
1077
- owner_rank_buckets: list[H2DBucket] = []
1078
- for owner_rank, bucket in buckets:
1079
- if owner_rank != self._rank:
1220
+ # p2p store need to register h2d_buffer to let other ranks read
1221
+ if ranks:
1222
+ h2d_buffer_name = "__h2d_buffer__"
1223
+ if h2d_buffer is not None and self._p2p_store is not None:
1224
+ self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
1225
+ receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
1226
+ for receiver_rank, owner_rank, bucket in buckets:
1227
+ if receiver_rank != self._rank:
1080
1228
  continue
1081
- owner_rank_buckets.append(bucket)
1229
+ receiver_rank_buckets.append((owner_rank, bucket))
1082
1230
 
1083
- buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
1231
+ buffer = torch.empty(
1232
+ bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
1233
+ )
1084
1234
  handle = reduce_tensor(buffer)
1085
1235
 
1086
- buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list)
1236
+ buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
1087
1237
  max_len = 0
1088
- for owner_rank, bucket in buckets:
1089
- buckets_by_owner_rank[owner_rank].append(bucket)
1090
- if len(buckets_by_owner_rank[owner_rank]) > max_len:
1091
- max_len = len(buckets_by_owner_rank[owner_rank])
1238
+ for receiver_rank, _, bucket in buckets:
1239
+ buckets_by_receiver_rank[receiver_rank].append(bucket)
1240
+ if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
1241
+ max_len = len(buckets_by_receiver_rank[receiver_rank])
1092
1242
 
1093
1243
  socket, socket_paths = self._bind_zmq_socket()
1094
1244
  req_thread = threading.Thread(
@@ -1099,43 +1249,66 @@ class ParameterServer:
1099
1249
  socket.send_pyobj(handle)
1100
1250
 
1101
1251
  gidx = 0
1102
- for i in range(max_len):
1103
- if i < len(owner_rank_buckets) and not disable_h2d_buffer:
1104
- self._copy_to_buffer(checkpoint_name, owner_rank_buckets[i], h2d_buffer)
1105
-
1106
- for owner_rank, _buckets in buckets_by_owner_rank.items():
1107
- if i >= len(_buckets):
1108
- continue
1109
- bucket = _buckets[i]
1110
- alloc, reserved = (
1111
- torch.cuda.memory_allocated() / 1024 / 1024,
1112
- torch.cuda.memory_reserved() / 1024 / 1024,
1113
- )
1114
- self._logger_rank0(
1115
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1116
- f"Current CUDA allocated {alloc:.2f} MB, "
1117
- f"reserved {reserved:.2f} MB."
1118
- )
1119
- start = gidx % 2 * bucket_size
1120
- buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1121
- if owner_rank == self._rank:
1122
- if disable_h2d_buffer:
1123
- self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1124
- else:
1125
- buffer_b.data.copy_(h2d_buffer[: bucket.size])
1126
- dist.broadcast(buffer_b, src=owner_rank)
1127
- socket.recv()
1128
- dist.barrier()
1129
- socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1130
- gidx += 1
1131
-
1132
- socket.recv()
1133
- socket.send_pyobj(None)
1134
- socket.recv()
1135
- req_thread.join()
1136
- dist.barrier()
1137
- socket.close()
1138
- torch.cuda.empty_cache()
1252
+ ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1253
+ bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1254
+ try:
1255
+ for i in range(max_len):
1256
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1257
+ self._copy_to_buffer(
1258
+ checkpoint_name,
1259
+ receiver_rank_buckets[i][1],
1260
+ h2d_buffer,
1261
+ receiver_rank_buckets[i][0] if ranks else None,
1262
+ )
1263
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1264
+ if i >= len(_buckets):
1265
+ continue
1266
+ bucket = _buckets[i]
1267
+ alloc, reserved = (
1268
+ self.device_manager.device_module.memory_allocated() / 1024 / 1024,
1269
+ self.device_manager.device_module.memory_reserved() / 1024 / 1024,
1270
+ )
1271
+ self._logger_rank0(
1272
+ f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1273
+ f"Current device allocated {alloc:.2f} MB, "
1274
+ f"reserved {reserved:.2f} MB."
1275
+ )
1276
+ start = gidx % 2 * bucket_size
1277
+ buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1278
+ if receiver_rank == self._rank:
1279
+ if disable_h2d_buffer:
1280
+ self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1281
+ else:
1282
+ buffer_b.data.copy_(h2d_buffer[: bucket.size])
1283
+ brank = bcast_rank_map[receiver_rank]
1284
+ dist.broadcast(buffer_b, src=brank)
1285
+ resp = socket.recv()
1286
+ if resp != b"":
1287
+ exception_obj = pickle.loads(resp)
1288
+ logger.error(
1289
+ f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
1290
+ )
1291
+ ret_code.fill_(1)
1292
+ dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1293
+ self.device_manager.device_module.synchronize()
1294
+ if ret_code.item() != 0:
1295
+ # quit early if any rank failed
1296
+ socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
1297
+ raise RuntimeError("Failed to update weights due to remote errors")
1298
+ socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1299
+ gidx += 1
1300
+
1301
+ socket.recv()
1302
+ socket.send_pyobj(None)
1303
+ socket.recv()
1304
+ finally:
1305
+ req_thread.join()
1306
+ dist.barrier()
1307
+ socket.close()
1308
+ if ranks and h2d_buffer is not None:
1309
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1310
+
1311
+ self.device_manager.device_module.empty_cache()
1139
1312
 
1140
1313
 
1141
1314
  def _init_api(ps: ParameterServer) -> Any:
@@ -1153,6 +1326,7 @@ def _init_api(ps: ParameterServer) -> Any:
1153
1326
  update_url: str | None = None
1154
1327
  inference_group_ranks: list[int] = []
1155
1328
  timeout: float = 300.0
1329
+ uds: str | None = None
1156
1330
 
1157
1331
  def wrap_exception(func: Callable[[], None]) -> Response:
1158
1332
  try:
@@ -1185,7 +1359,9 @@ def _init_api(ps: ParameterServer) -> Any:
1185
1359
  return
1186
1360
  if req.inference_group_ranks:
1187
1361
  socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
1188
- request_inference_to_update(req.update_url, dict(socket_paths), timeout=req.timeout)
1362
+ request_inference_to_update(
1363
+ req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
1364
+ )
1189
1365
 
1190
1366
  return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
1191
1367
 
@@ -5,6 +5,8 @@ from typing import TypedDict
5
5
  import torch
6
6
  import zmq
7
7
 
8
+ from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
9
+
8
10
 
9
11
  def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
10
12
  func, args = handle
@@ -53,51 +55,105 @@ def update_weights_from_ipc(
53
55
  socket = zmq_ctx.socket(zmq.REP)
54
56
  socket.connect(zmq_handle)
55
57
  buffer: torch.Tensor | None = None
56
- while True:
57
- payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
58
- if payload is None:
59
- # means the update is done
60
- if post_hook is not None:
61
- post_hook()
62
- torch.cuda.synchronize()
63
- socket.send(b"")
64
- break
65
- if isinstance(payload, tuple):
66
- # an ipc handle that vLLM can use `func, args = handle`
67
- # and `func(*args)` to rebuild GPU tensor.
68
- buffer = _rebuild_ipc(payload, device_id)
69
- assert buffer.dtype == torch.uint8
70
- socket.send(b"")
71
- continue
72
- assert isinstance(payload, list)
73
- run(_extract_weights(payload, buffer))
74
- torch.cuda.synchronize()
58
+ device_manager = DeviceManager()
59
+ try:
60
+ ipc_handle: tuple[Callable, tuple] = socket.recv_pyobj()
61
+ assert isinstance(ipc_handle, tuple)
62
+ buffer = _rebuild_ipc(ipc_handle, device_id)
63
+ assert buffer.dtype == torch.uint8
75
64
  socket.send(b"")
65
+ except Exception as e:
66
+ socket.send_pyobj(e)
67
+ socket.recv() # wait for ack
68
+ raise
69
+ try:
70
+ while True:
71
+ payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
72
+ if payload is None: # done signal
73
+ if post_hook is not None:
74
+ post_hook()
75
+ device_manager.device_module.synchronize()
76
+ socket.send(b"")
77
+ break
78
+ if isinstance(payload, list): # still updating weights
79
+ try:
80
+ run(_extract_weights(payload, buffer))
81
+ device_manager.device_module.synchronize()
82
+ socket.send(b"")
83
+ except Exception as e: # noqa: BLE001
84
+ # Send exception back to Parameter Server.
85
+ # Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
86
+ socket.send_pyobj(e)
87
+ elif isinstance(
88
+ payload, Exception
89
+ ): # error occurred, got force quit signal from Parameter Server
90
+ raise payload
91
+ else:
92
+ raise TypeError(f"Unexpected payload type: {type(payload)}")
76
93
 
77
- socket.close()
78
- del buffer
79
- gc.collect()
80
- torch.cuda.empty_cache()
94
+ finally:
95
+ socket.close()
96
+ del buffer
97
+ gc.collect()
98
+ device_manager.device_module.empty_cache()
81
99
 
82
100
 
83
101
  class VllmColocateWorkerExtension:
84
102
  """
85
- The class for vLLM's worker to inherit from, in the colocate setting.
86
- By defining an extension class, the code can work no matter what is
87
- the underlying worker class. This way, the code can be compatible
88
- with both vLLM V0 and V1.
89
- NOTE: we define this class in a separate module, and the main module
90
- should pass the full qualified name as `worker_extension_cls` argument.
103
+ Worker extension for vLLM to update weights from checkpoint-engine.
104
+
105
+ This class provides a worker extension mechanism that allows vLLM workers to receive
106
+ and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
107
+ The methods in this worker extension will be injected into the vLLM worker class and
108
+ are callable from the `collective_rpc` API, enabling seamless weight updates for both
109
+ vLLM V0 and V1 versions.
110
+
111
+ Note:
112
+ This class is defined in a separate module. The fully qualified name
113
+ `checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
114
+ `worker_extension_cls` argument when initializing the vLLM worker.
91
115
  """
92
116
 
93
117
  def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
118
+ """
119
+ Update model weights from checkpoint-engine via IPC communication.
120
+
121
+ This method establishes a ZMQ connection to the checkpoint-engine and receives
122
+ weight updates through a shared memory buffer. The update process includes:
123
+ 1. Receiving IPC handles to reconstruct shared memory tensors
124
+ 2. Extracting flattened metadata describing tensor weights in the shared memory tensor
125
+ 3. Loading weights into the model
126
+ 4. Post-processing weights after loading
127
+
128
+ Args:
129
+ zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
130
+ The device UUID is platform-specific:
131
+ - For CUDA: UUID from `current_platform.get_device_uuid()`
132
+ - For NPU: Format "NPU-{generated_uuid}"
133
+
134
+ Raises:
135
+ ValueError: If the device type is not supported (not CUDA or NPU).
136
+ AssertionError: If the device is not properly initialized.
137
+
138
+ Note:
139
+ This method is called by vLLM's collective RPC mechanism. The ZMQ context
140
+ is lazily initialized on first call and reused for subsequent updates.
141
+ """
94
142
  from vllm.model_executor.model_loader.utils import process_weights_after_loading
95
143
  from vllm.platforms import current_platform
96
144
 
145
+ # vllm-ascend not init device
146
+ if current_platform.device_type == "npu" and self.device is None:
147
+ self.device = torch.device(f"npu:{self.local_rank}")
97
148
  assert self.device is not None
98
149
  if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
99
150
  self._zmq_ctx = zmq.Context()
100
- device_uuid = current_platform.get_device_uuid(self.device.index)
151
+ if current_platform.device_type == "cuda":
152
+ device_uuid = current_platform.get_device_uuid(self.device.index)
153
+ elif current_platform.device_type == "npu":
154
+ device_uuid = f"NPU-{npu_generate_uuid()}"
155
+ else:
156
+ raise ValueError(f"Unsupported device type: {current_platform.device_type}")
101
157
  update_weights_from_ipc(
102
158
  self._zmq_ctx,
103
159
  zmq_handles[device_uuid],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.3
3
+ Version: 0.2.1
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -38,8 +38,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
38
38
 
39
39
  The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
40
40
 
41
- - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
42
- - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
41
+ - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
42
+ - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
43
43
 
44
44
  ### Optimized Weight Broadcast
45
45
  In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
@@ -60,16 +60,22 @@ It then executes the transfer, where it controls the inference engine through a
60
60
 
61
61
  Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
62
62
 
63
+ ### Optimized P2P Bucket Assignment
64
+ In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
65
+ To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
66
+ The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
67
+ See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
68
+
63
69
  ## Benchmark
64
70
 
65
71
  | Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
66
72
  | :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
67
- | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.17s | 3.94s (1.42GiB) | 8.83s (4.77GiB) |
68
- | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.46s | 6.75s (2.69GiB) | 16.47s (4.05GiB) |
69
- | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.44s | 12.22s (2.38GiB) | 25.77s (3.61GiB) |
70
- | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.81s | 15.45s (2.93GiB) | 36.24s (4.46GiB) |
71
- | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 1.40s | 13.88s (2.54GiB) | 33.30s (3.86 GiB) |
72
- | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.88s | 21.50s (2.99GiB) | 34.49s (4.57 GiB) |
73
+ | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
74
+ | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
75
+ | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
76
+ | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
77
+ | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
78
+ | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
73
79
 
74
80
  All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
75
81
 
@@ -77,6 +83,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and
77
83
  * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
78
84
  * Since update duration is related to IPC bucket size, we provide the bucket size in the table.
79
85
  * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
86
+ * We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
80
87
 
81
88
  ## Installation
82
89
 
@@ -92,7 +99,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
92
99
  pip install 'checkpoint-engine[p2p]'
93
100
  ```
94
101
 
95
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
102
+ If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
96
103
 
97
104
  ## Getting Started
98
105
 
@@ -162,14 +169,64 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
162
169
  Run a simple correctness test for checkpoint_engine
163
170
 
164
171
  ```bash
165
- torchrun --nproc-per-node 8 tests/test_update.py
172
+ pytest tests/test_update.py
173
+ ```
174
+
175
+ `test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
176
+
177
+ Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
178
+
179
+ ```bash
180
+ pytest tests/ -m "not gpu"
181
+ ```
182
+
183
+ ## SGLang Integration
184
+
185
+ Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
186
+
187
+ ### Quick Start
188
+
189
+ **1. Install checkpoint-engine:**
190
+ ```bash
191
+ pip install 'checkpoint-engine[p2p]'
192
+ ```
193
+
194
+ **2. Launch SGLang server:**
195
+ ```bash
196
+ python -m sglang.launch_server \
197
+ --model-path $MODEL_PATH \
198
+ --tp 8 \
199
+ --load-format dummy \
200
+ --wait-for-initial-weights
201
+ ```
202
+
203
+ **3. Run checkpoint engine:**
204
+ ```bash
205
+ python -m sglang.srt.checkpoint_engine.update \
206
+ --update-method broadcast \
207
+ --checkpoint-path $MODEL_PATH \
208
+ --inference-parallel-size 8
166
209
  ```
167
210
 
211
+ ### Multi-Node Setup
212
+
213
+ For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
214
+
215
+ ### Key Options
216
+
217
+ **SGLang Server:**
218
+ - `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
219
+ - `--load-format dummy`: Enable overlapping initialization tasks
220
+
221
+ **Checkpoint Engine:**
222
+ - `--update-method`: Choose `broadcast`, `p2p`, or `all`
223
+ - `--inference-parallel-size`: Number of parallel processes
224
+ - `--checkpoint-path`: Model checkpoint directory
225
+
168
226
  ## Limitations and Future Work
169
227
 
170
- - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
228
+ - This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
171
229
  - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
172
- - The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
173
230
 
174
231
  ## Acknowledgments
175
232
 
@@ -0,0 +1,10 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=vYqoJTG51NOUmYyL0xt8asRK8vUT4lGAdal_EZ59mvw,704
3
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
+ checkpoint_engine/ps.py,sha256=Mgfr_MXYYZ_6JKqD5kIGDBWCYNWAtIPfAppP_cFu604,57781
5
+ checkpoint_engine/worker.py,sha256=5TzDgTPew6Ts9sMOzecalLCR1p_ZwfeKPdzzr68kAQg,6564
6
+ checkpoint_engine-0.2.1.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
+ checkpoint_engine-0.2.1.dist-info/METADATA,sha256=7E7NhehWHpS6QVkLup-oCm350wdbiZX8CY3jmmJP0bU,11315
8
+ checkpoint_engine-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ checkpoint_engine-0.2.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
+ checkpoint_engine-0.2.1.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
3
- checkpoint_engine/ps.py,sha256=9dXRXi0QDPoRYrgGKAYvdmDFBXejgusjR0ltbii5_B0,49134
4
- checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
5
- checkpoint_engine-0.1.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
6
- checkpoint_engine-0.1.3.dist-info/METADATA,sha256=y96dMjEOKWaO_PA0h5BX8G3Ku7Tt1jCU09uIf8iYgic,9322
7
- checkpoint_engine-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- checkpoint_engine-0.1.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
- checkpoint_engine-0.1.3.dist-info/RECORD,,