checkpoint-engine 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.0'
32
- __version_tuple__ = version_tuple = (0, 2, 0)
31
+ __version__ = version = '0.2.1'
32
+ __version_tuple__ = version_tuple = (0, 2, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,86 @@
1
+ import os
2
+ import re
3
+ import socket
4
+ import subprocess
5
+ from functools import lru_cache
6
+
7
+ import torch
8
+ from loguru import logger
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_ip() -> str:
13
+ try:
14
+ # try to get ip from network interface
15
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
16
+ s.connect(("8.8.8.8", 80))
17
+ return s.getsockname()[0]
18
+ except Exception as e: # noqa: BLE001
19
+ # fallback to get ip from hostname
20
+ logger.warning(
21
+ f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
22
+ )
23
+ return socket.gethostbyname(socket.gethostname())
24
+
25
+
26
+ def npu_generate_uuid() -> str:
27
+ str_pid = str(os.getpid())
28
+ npu_num = 8
29
+ try:
30
+ for npu_id in range(npu_num):
31
+ cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
32
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
33
+ str_result = str(result.stdout)
34
+ if str_pid in str_result:
35
+ # In A3 server, one NPU has two chips.
36
+ match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
37
+ chip_count = int(match_chip_count.group(1))
38
+ search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
39
+ match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
40
+ chip_id = int(match_chip_id.group(1))
41
+ return f"{get_ip()}-{npu_id * chip_count + chip_id}"
42
+ raise ValueError("The current process is not running on the npu device")
43
+ except subprocess.CalledProcessError as e:
44
+ raise ValueError("The current process is not running on the npu device") from e
45
+
46
+
47
+ class DeviceManager:
48
+ def __init__(self):
49
+ self.device_type = self._detect_device_type()
50
+ self._setup_device_module()
51
+
52
+ def _is_torch_npu_available(self) -> bool:
53
+ try:
54
+ if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
55
+ return torch.npu.is_available()
56
+ else:
57
+ return False
58
+ except ImportError:
59
+ return False
60
+
61
+ def _detect_device_type(self) -> str:
62
+ if self._is_torch_npu_available():
63
+ return "npu"
64
+ elif torch.cuda.is_available():
65
+ return "cuda"
66
+ else:
67
+ raise TypeError("The current device type is not supported")
68
+
69
+ def _setup_device_module(self):
70
+ if self.device_type == "npu":
71
+ import torch_npu
72
+
73
+ self.device_module = torch_npu.npu
74
+ elif self.device_type == "cuda":
75
+ self.device_module = torch.cuda
76
+ else:
77
+ raise TypeError("The current device type is not supported")
78
+
79
+ @property
80
+ def backend(self) -> str:
81
+ if self.device_type == "npu":
82
+ return "hccl"
83
+ elif self.device_type == "cuda":
84
+ return "nccl"
85
+ else:
86
+ raise TypeError("The current device type is not supported")
checkpoint_engine/ps.py CHANGED
@@ -4,13 +4,11 @@ import ctypes
4
4
  import os
5
5
  import pickle
6
6
  import random
7
- import socket
8
7
  import threading
9
8
  import time
10
9
  from collections import defaultdict
11
10
  from collections.abc import Callable
12
11
  from datetime import timedelta
13
- from functools import lru_cache
14
12
  from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
15
13
 
16
14
  import httpx
@@ -23,6 +21,8 @@ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
23
21
  from safetensors.torch import safe_open
24
22
  from torch.multiprocessing.reductions import reduce_tensor
25
23
 
24
+ from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
25
+
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  from typing import TypeVar
@@ -254,28 +254,16 @@ def _concat_tp_weights(
254
254
  return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
255
255
 
256
256
 
257
- def _get_physical_gpu_id(device_index: int | None = None) -> str:
257
+ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
258
258
  try:
259
- return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
259
+ if device_manager.device_type == "npu":
260
+ return f"NPU-{npu_generate_uuid()}"
261
+ else:
262
+ return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
260
263
  except AssertionError as e:
261
264
  raise ValueError(f"fail to get physical gpu id {device_index}") from e
262
265
 
263
266
 
264
- @lru_cache(maxsize=1)
265
- def _get_ip() -> str:
266
- try:
267
- # try to get ip from network interface
268
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
269
- s.connect(("8.8.8.8", 80))
270
- return s.getsockname()[0]
271
- except Exception as e: # noqa: BLE001
272
- # fallback to get ip from hostname
273
- logger.warning(
274
- f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
275
- )
276
- return socket.gethostbyname(socket.gethostname())
277
-
278
-
279
267
  def _ibv_get_device_list() -> list[str]:
280
268
  lib = ctypes.CDLL("libibverbs.so.1")
281
269
  lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
@@ -317,13 +305,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
317
305
  """
318
306
  if not devices:
319
307
  raise RuntimeError("no rdma devices found")
320
- assert len(devices) <= gpu_count, (
321
- f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
322
- )
323
- assert gpu_count % len(devices) == 0, (
324
- f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
325
- )
326
- return devices[local_rank // (gpu_count // len(devices))]
308
+ try:
309
+ assert len(devices) <= gpu_count, (
310
+ f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
311
+ )
312
+ assert gpu_count % len(devices) == 0, (
313
+ f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
314
+ )
315
+ return devices[local_rank // (gpu_count // len(devices))]
316
+ except AssertionError:
317
+ logger.error(
318
+ "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
319
+ "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
320
+ "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
321
+ )
322
+ raise
327
323
 
328
324
 
329
325
  def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
@@ -677,20 +673,29 @@ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, i
677
673
 
678
674
 
679
675
  class P2PStore:
680
- def __init__(self):
676
+ def __init__(self, device_manager: DeviceManager):
681
677
  from mooncake.engine import TransferEngine
682
678
 
683
679
  self.rank = int(os.getenv("RANK"))
684
- gpu_count = torch.cuda.device_count()
680
+ gpu_count = device_manager.device_module.device_count()
685
681
  local_rank = self.rank % gpu_count
686
- self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
687
- self.ip = _get_ip()
682
+ device_type = device_manager.device_type
683
+ if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
684
+ self.device = ""
685
+ else:
686
+ self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
687
+ self.ip = get_ip()
688
688
 
689
689
  # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
690
690
  retry_count = 8
691
691
  for i in range(retry_count):
692
692
  self.engine = TransferEngine()
693
- ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", self.device)
693
+ ret = self.engine.initialize(
694
+ self.ip,
695
+ "P2PHANDSHAKE",
696
+ "ascend_direct" if device_type == "npu" else "rdma",
697
+ self.device,
698
+ )
694
699
  if ret == 0:
695
700
  break
696
701
  # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
@@ -757,11 +762,12 @@ class ParameterServer:
757
762
  Args:
758
763
  auto_pg: Whether to automatically initialize the process group.
759
764
  Notice that if auto_pg is True, will destroy the process group after update.
760
- mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
765
+ mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
761
766
  """
762
767
  self._rank = rank or int(os.environ.get("RANK", None))
763
768
  self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
764
- self._gpu_count = gpu_count or torch.cuda.device_count()
769
+ self.device_manager = DeviceManager()
770
+ self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
765
771
  self._local_rank = self._rank % self._gpu_count
766
772
  self._auto_pg = auto_pg
767
773
  self._all_hosts = []
@@ -775,7 +781,7 @@ class ParameterServer:
775
781
  assert (
776
782
  self._gpu_count is not None
777
783
  and self._gpu_count > 0
778
- and self._gpu_count <= torch.cuda.device_count()
784
+ and self._gpu_count <= self.device_manager.device_module.device_count()
779
785
  ), self._gpu_count
780
786
  assert (
781
787
  self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
@@ -787,15 +793,16 @@ class ParameterServer:
787
793
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
788
794
  # dict key is owner_rank, value is a bucket metas list in owner_rank
789
795
  self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
796
+ # NPU transfer engine initialization requires prior set_device.
797
+ device_index = self._local_rank
798
+ self.device_manager.device_module.set_device(device_index)
790
799
  try:
791
- self._p2p_store = P2PStore()
800
+ self._p2p_store = P2PStore(self.device_manager)
792
801
  except ImportError as e:
793
802
  logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
794
803
  self._p2p_store = None
795
804
 
796
- device_index = self._local_rank
797
- torch.cuda.set_device(device_index)
798
- self._device_uuid = _get_physical_gpu_id(device_index)
805
+ self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
799
806
  self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
800
807
 
801
808
  def _logger_rank0(self, msg: str):
@@ -885,13 +892,15 @@ class ParameterServer:
885
892
  for x in self._memory_pool.get(checkpoint_name, [])
886
893
  ],
887
894
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
888
- host_ip=_get_ip(),
895
+ host_ip=get_ip(),
889
896
  device_uuid=self._device_uuid,
890
897
  rdma_device=self._rdma_device or "",
891
898
  )
892
899
 
893
900
  dist.all_gather_object(metas_lst, metas)
894
901
 
902
+ self._current_global_parameter_metas = {}
903
+
895
904
  num_parameters = 0
896
905
  all_hosts: list[str] = []
897
906
  global_device_uuids: list[str] = []
@@ -948,7 +957,7 @@ class ParameterServer:
948
957
  is_master=self._rank == 0,
949
958
  )
950
959
  dist.init_process_group(
951
- backend="nccl",
960
+ backend=self.device_manager.backend,
952
961
  world_size=self._world_size,
953
962
  rank=self._rank,
954
963
  timeout=timeout,
@@ -991,21 +1000,22 @@ class ParameterServer:
991
1000
  if self._rank not in ranks:
992
1001
  return
993
1002
  self._update_per_bucket(checkpoint_name, req_func, ranks)
994
- if self._auto_pg:
995
- dist.destroy_process_group()
996
-
997
- torch.cuda.empty_cache()
998
1003
 
999
- logger.info(
1000
- f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1001
- f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
1002
- f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
1003
- )
1004
1004
  except Exception as e:
1005
1005
  logger.exception(
1006
1006
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
1007
1007
  )
1008
1008
  raise
1009
+ finally:
1010
+ if self._auto_pg and (not ranks or self._rank in ranks):
1011
+ dist.destroy_process_group()
1012
+
1013
+ self.device_manager.device_module.empty_cache()
1014
+ logger.info(
1015
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1016
+ f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
1017
+ f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
1018
+ )
1009
1019
 
1010
1020
  def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
1011
1021
  def zmq_handle(device_uuid: str) -> str:
@@ -1022,14 +1032,16 @@ class ParameterServer:
1022
1032
  # auto detect bucket size
1023
1033
  tensor = torch.tensor(
1024
1034
  [
1025
- # proportion of current cuda free memory bytes
1026
- int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
1035
+ # proportion of current device free memory bytes
1036
+ int(
1037
+ float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
1038
+ ),
1027
1039
  # we use negative value to reuse allreduce min operation
1028
1040
  # for getting the max value of zmq_addr_counter in all ranks
1029
1041
  -self._zmq_addr_counter,
1030
1042
  ],
1031
1043
  dtype=torch.int64,
1032
- device="cuda",
1044
+ device=self.device_manager.device_type,
1033
1045
  )
1034
1046
  dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
1035
1047
  tensor = tensor.cpu()
@@ -1092,7 +1104,7 @@ class ParameterServer:
1092
1104
  assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
1093
1105
  if owner_rank is not None:
1094
1106
  self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
1095
- torch.cuda.synchronize()
1107
+ self.device_manager.device_module.synchronize()
1096
1108
 
1097
1109
  def init_process_group_for_ranks(
1098
1110
  self,
@@ -1132,7 +1144,11 @@ class ParameterServer:
1132
1144
  master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
1133
1145
  )
1134
1146
  dist.init_process_group(
1135
- backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
1147
+ backend=self.device_manager.backend,
1148
+ world_size=len(ranks),
1149
+ rank=rank,
1150
+ timeout=timeout,
1151
+ store=store,
1136
1152
  )
1137
1153
 
1138
1154
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
@@ -1184,7 +1200,7 @@ class ParameterServer:
1184
1200
 
1185
1201
  if not need_update:
1186
1202
  return
1187
- # first execute a barrier to avoid subsequent cuda oom
1203
+ # first execute a barrier to avoid subsequent device oom
1188
1204
  dist.barrier()
1189
1205
 
1190
1206
  bucket_size, disable_h2d_buffer = self._detect_bucket_size()
@@ -1199,7 +1215,7 @@ class ParameterServer:
1199
1215
  h2d_buffer: torch.Tensor | None = (
1200
1216
  None
1201
1217
  if disable_h2d_buffer
1202
- else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1218
+ else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
1203
1219
  )
1204
1220
  # p2p store need to register h2d_buffer to let other ranks read
1205
1221
  if ranks:
@@ -1212,7 +1228,9 @@ class ParameterServer:
1212
1228
  continue
1213
1229
  receiver_rank_buckets.append((owner_rank, bucket))
1214
1230
 
1215
- buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
1231
+ buffer = torch.empty(
1232
+ bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
1233
+ )
1216
1234
  handle = reduce_tensor(buffer)
1217
1235
 
1218
1236
  buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
@@ -1231,52 +1249,66 @@ class ParameterServer:
1231
1249
  socket.send_pyobj(handle)
1232
1250
 
1233
1251
  gidx = 0
1252
+ ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1234
1253
  bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1235
- for i in range(max_len):
1236
- if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1237
- self._copy_to_buffer(
1238
- checkpoint_name,
1239
- receiver_rank_buckets[i][1],
1240
- h2d_buffer,
1241
- receiver_rank_buckets[i][0] if ranks else None,
1242
- )
1243
- for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1244
- if i >= len(_buckets):
1245
- continue
1246
- bucket = _buckets[i]
1247
- alloc, reserved = (
1248
- torch.cuda.memory_allocated() / 1024 / 1024,
1249
- torch.cuda.memory_reserved() / 1024 / 1024,
1250
- )
1251
- self._logger_rank0(
1252
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1253
- f"Current CUDA allocated {alloc:.2f} MB, "
1254
- f"reserved {reserved:.2f} MB."
1255
- )
1256
- start = gidx % 2 * bucket_size
1257
- buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1258
- if receiver_rank == self._rank:
1259
- if disable_h2d_buffer:
1260
- self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1261
- else:
1262
- buffer_b.data.copy_(h2d_buffer[: bucket.size])
1263
- brank = bcast_rank_map[receiver_rank]
1264
- dist.broadcast(buffer_b, src=brank)
1265
- socket.recv()
1266
- dist.barrier()
1267
- socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1268
- gidx += 1
1269
-
1270
- socket.recv()
1271
- socket.send_pyobj(None)
1272
- socket.recv()
1273
- req_thread.join()
1274
- dist.barrier()
1275
- socket.close()
1276
- if ranks and h2d_buffer is not None:
1277
- self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1278
-
1279
- torch.cuda.empty_cache()
1254
+ try:
1255
+ for i in range(max_len):
1256
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1257
+ self._copy_to_buffer(
1258
+ checkpoint_name,
1259
+ receiver_rank_buckets[i][1],
1260
+ h2d_buffer,
1261
+ receiver_rank_buckets[i][0] if ranks else None,
1262
+ )
1263
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1264
+ if i >= len(_buckets):
1265
+ continue
1266
+ bucket = _buckets[i]
1267
+ alloc, reserved = (
1268
+ self.device_manager.device_module.memory_allocated() / 1024 / 1024,
1269
+ self.device_manager.device_module.memory_reserved() / 1024 / 1024,
1270
+ )
1271
+ self._logger_rank0(
1272
+ f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1273
+ f"Current device allocated {alloc:.2f} MB, "
1274
+ f"reserved {reserved:.2f} MB."
1275
+ )
1276
+ start = gidx % 2 * bucket_size
1277
+ buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1278
+ if receiver_rank == self._rank:
1279
+ if disable_h2d_buffer:
1280
+ self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1281
+ else:
1282
+ buffer_b.data.copy_(h2d_buffer[: bucket.size])
1283
+ brank = bcast_rank_map[receiver_rank]
1284
+ dist.broadcast(buffer_b, src=brank)
1285
+ resp = socket.recv()
1286
+ if resp != b"":
1287
+ exception_obj = pickle.loads(resp)
1288
+ logger.error(
1289
+ f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
1290
+ )
1291
+ ret_code.fill_(1)
1292
+ dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1293
+ self.device_manager.device_module.synchronize()
1294
+ if ret_code.item() != 0:
1295
+ # quit early if any rank failed
1296
+ socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
1297
+ raise RuntimeError("Failed to update weights due to remote errors")
1298
+ socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1299
+ gidx += 1
1300
+
1301
+ socket.recv()
1302
+ socket.send_pyobj(None)
1303
+ socket.recv()
1304
+ finally:
1305
+ req_thread.join()
1306
+ dist.barrier()
1307
+ socket.close()
1308
+ if ranks and h2d_buffer is not None:
1309
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1310
+
1311
+ self.device_manager.device_module.empty_cache()
1280
1312
 
1281
1313
 
1282
1314
  def _init_api(ps: ParameterServer) -> Any:
@@ -1294,6 +1326,7 @@ def _init_api(ps: ParameterServer) -> Any:
1294
1326
  update_url: str | None = None
1295
1327
  inference_group_ranks: list[int] = []
1296
1328
  timeout: float = 300.0
1329
+ uds: str | None = None
1297
1330
 
1298
1331
  def wrap_exception(func: Callable[[], None]) -> Response:
1299
1332
  try:
@@ -1326,7 +1359,9 @@ def _init_api(ps: ParameterServer) -> Any:
1326
1359
  return
1327
1360
  if req.inference_group_ranks:
1328
1361
  socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
1329
- request_inference_to_update(req.update_url, dict(socket_paths), timeout=req.timeout)
1362
+ request_inference_to_update(
1363
+ req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
1364
+ )
1330
1365
 
1331
1366
  return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
1332
1367
 
@@ -5,6 +5,8 @@ from typing import TypedDict
5
5
  import torch
6
6
  import zmq
7
7
 
8
+ from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
9
+
8
10
 
9
11
  def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
10
12
  func, args = handle
@@ -53,51 +55,105 @@ def update_weights_from_ipc(
53
55
  socket = zmq_ctx.socket(zmq.REP)
54
56
  socket.connect(zmq_handle)
55
57
  buffer: torch.Tensor | None = None
56
- while True:
57
- payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
58
- if payload is None:
59
- # means the update is done
60
- if post_hook is not None:
61
- post_hook()
62
- torch.cuda.synchronize()
63
- socket.send(b"")
64
- break
65
- if isinstance(payload, tuple):
66
- # an ipc handle that vLLM can use `func, args = handle`
67
- # and `func(*args)` to rebuild GPU tensor.
68
- buffer = _rebuild_ipc(payload, device_id)
69
- assert buffer.dtype == torch.uint8
70
- socket.send(b"")
71
- continue
72
- assert isinstance(payload, list)
73
- run(_extract_weights(payload, buffer))
74
- torch.cuda.synchronize()
58
+ device_manager = DeviceManager()
59
+ try:
60
+ ipc_handle: tuple[Callable, tuple] = socket.recv_pyobj()
61
+ assert isinstance(ipc_handle, tuple)
62
+ buffer = _rebuild_ipc(ipc_handle, device_id)
63
+ assert buffer.dtype == torch.uint8
75
64
  socket.send(b"")
65
+ except Exception as e:
66
+ socket.send_pyobj(e)
67
+ socket.recv() # wait for ack
68
+ raise
69
+ try:
70
+ while True:
71
+ payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
72
+ if payload is None: # done signal
73
+ if post_hook is not None:
74
+ post_hook()
75
+ device_manager.device_module.synchronize()
76
+ socket.send(b"")
77
+ break
78
+ if isinstance(payload, list): # still updating weights
79
+ try:
80
+ run(_extract_weights(payload, buffer))
81
+ device_manager.device_module.synchronize()
82
+ socket.send(b"")
83
+ except Exception as e: # noqa: BLE001
84
+ # Send exception back to Parameter Server.
85
+ # Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
86
+ socket.send_pyobj(e)
87
+ elif isinstance(
88
+ payload, Exception
89
+ ): # error occurred, got force quit signal from Parameter Server
90
+ raise payload
91
+ else:
92
+ raise TypeError(f"Unexpected payload type: {type(payload)}")
76
93
 
77
- socket.close()
78
- del buffer
79
- gc.collect()
80
- torch.cuda.empty_cache()
94
+ finally:
95
+ socket.close()
96
+ del buffer
97
+ gc.collect()
98
+ device_manager.device_module.empty_cache()
81
99
 
82
100
 
83
101
  class VllmColocateWorkerExtension:
84
102
  """
85
- The class for vLLM's worker to inherit from, in the colocate setting.
86
- By defining an extension class, the code can work no matter what is
87
- the underlying worker class. This way, the code can be compatible
88
- with both vLLM V0 and V1.
89
- NOTE: we define this class in a separate module, and the main module
90
- should pass the full qualified name as `worker_extension_cls` argument.
103
+ Worker extension for vLLM to update weights from checkpoint-engine.
104
+
105
+ This class provides a worker extension mechanism that allows vLLM workers to receive
106
+ and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
107
+ The methods in this worker extension will be injected into the vLLM worker class and
108
+ are callable from the `collective_rpc` API, enabling seamless weight updates for both
109
+ vLLM V0 and V1 versions.
110
+
111
+ Note:
112
+ This class is defined in a separate module. The fully qualified name
113
+ `checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
114
+ `worker_extension_cls` argument when initializing the vLLM worker.
91
115
  """
92
116
 
93
117
  def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
118
+ """
119
+ Update model weights from checkpoint-engine via IPC communication.
120
+
121
+ This method establishes a ZMQ connection to the checkpoint-engine and receives
122
+ weight updates through a shared memory buffer. The update process includes:
123
+ 1. Receiving IPC handles to reconstruct shared memory tensors
124
+ 2. Extracting flattened metadata describing tensor weights in the shared memory tensor
125
+ 3. Loading weights into the model
126
+ 4. Post-processing weights after loading
127
+
128
+ Args:
129
+ zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
130
+ The device UUID is platform-specific:
131
+ - For CUDA: UUID from `current_platform.get_device_uuid()`
132
+ - For NPU: Format "NPU-{generated_uuid}"
133
+
134
+ Raises:
135
+ ValueError: If the device type is not supported (not CUDA or NPU).
136
+ AssertionError: If the device is not properly initialized.
137
+
138
+ Note:
139
+ This method is called by vLLM's collective RPC mechanism. The ZMQ context
140
+ is lazily initialized on first call and reused for subsequent updates.
141
+ """
94
142
  from vllm.model_executor.model_loader.utils import process_weights_after_loading
95
143
  from vllm.platforms import current_platform
96
144
 
145
+ # vllm-ascend not init device
146
+ if current_platform.device_type == "npu" and self.device is None:
147
+ self.device = torch.device(f"npu:{self.local_rank}")
97
148
  assert self.device is not None
98
149
  if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
99
150
  self._zmq_ctx = zmq.Context()
100
- device_uuid = current_platform.get_device_uuid(self.device.index)
151
+ if current_platform.device_type == "cuda":
152
+ device_uuid = current_platform.get_device_uuid(self.device.index)
153
+ elif current_platform.device_type == "npu":
154
+ device_uuid = f"NPU-{npu_generate_uuid()}"
155
+ else:
156
+ raise ValueError(f"Unsupported device type: {current_platform.device_type}")
101
157
  update_weights_from_ipc(
102
158
  self._zmq_ctx,
103
159
  zmq_handles[device_uuid],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -169,13 +169,63 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
169
169
  Run a simple correctness test for checkpoint_engine
170
170
 
171
171
  ```bash
172
- torchrun --nproc-per-node 8 tests/test_update.py
172
+ pytest tests/test_update.py
173
173
  ```
174
174
 
175
- Other unit tests can be done with pytest.
175
+ `test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
176
+
177
+ Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
178
+
179
+ ```bash
180
+ pytest tests/ -m "not gpu"
181
+ ```
182
+
183
+ ## SGLang Integration
184
+
185
+ Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
186
+
187
+ ### Quick Start
188
+
189
+ **1. Install checkpoint-engine:**
190
+ ```bash
191
+ pip install 'checkpoint-engine[p2p]'
192
+ ```
193
+
194
+ **2. Launch SGLang server:**
195
+ ```bash
196
+ python -m sglang.launch_server \
197
+ --model-path $MODEL_PATH \
198
+ --tp 8 \
199
+ --load-format dummy \
200
+ --wait-for-initial-weights
201
+ ```
202
+
203
+ **3. Run checkpoint engine:**
204
+ ```bash
205
+ python -m sglang.srt.checkpoint_engine.update \
206
+ --update-method broadcast \
207
+ --checkpoint-path $MODEL_PATH \
208
+ --inference-parallel-size 8
209
+ ```
210
+
211
+ ### Multi-Node Setup
212
+
213
+ For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
214
+
215
+ ### Key Options
216
+
217
+ **SGLang Server:**
218
+ - `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
219
+ - `--load-format dummy`: Enable overlapping initialization tasks
220
+
221
+ **Checkpoint Engine:**
222
+ - `--update-method`: Choose `broadcast`, `p2p`, or `all`
223
+ - `--inference-parallel-size`: Number of parallel processes
224
+ - `--checkpoint-path`: Model checkpoint directory
225
+
176
226
  ## Limitations and Future Work
177
227
 
178
- - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
228
+ - This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
179
229
  - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
180
230
 
181
231
  ## Acknowledgments
@@ -0,0 +1,10 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=vYqoJTG51NOUmYyL0xt8asRK8vUT4lGAdal_EZ59mvw,704
3
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
+ checkpoint_engine/ps.py,sha256=Mgfr_MXYYZ_6JKqD5kIGDBWCYNWAtIPfAppP_cFu604,57781
5
+ checkpoint_engine/worker.py,sha256=5TzDgTPew6Ts9sMOzecalLCR1p_ZwfeKPdzzr68kAQg,6564
6
+ checkpoint_engine-0.2.1.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
+ checkpoint_engine-0.2.1.dist-info/METADATA,sha256=7E7NhehWHpS6QVkLup-oCm350wdbiZX8CY3jmmJP0bU,11315
8
+ checkpoint_engine-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ checkpoint_engine-0.2.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
+ checkpoint_engine-0.2.1.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=Dg8AmJomLVpjKL6prJylOONZAPRtB86LOce7dorQS_A,704
3
- checkpoint_engine/ps.py,sha256=OpGocqJv0TfGgVC1cPKARfz6qehfCLMzQ5KpDQNxb0o,55291
4
- checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
5
- checkpoint_engine-0.2.0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
6
- checkpoint_engine-0.2.0.dist-info/METADATA,sha256=tbAq45YlRvRAfQHDB0XV8w4ZP0zmVJ3RMTAx_wTm154,9896
7
- checkpoint_engine-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- checkpoint_engine-0.2.0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
- checkpoint_engine-0.2.0.dist-info/RECORD,,