checkpoint-engine 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.0'
32
- __version_tuple__ = version_tuple = (0, 2, 0)
31
+ __version__ = version = '0.2.2'
32
+ __version_tuple__ = version_tuple = (0, 2, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,86 @@
1
+ import os
2
+ import re
3
+ import socket
4
+ import subprocess
5
+ from functools import lru_cache
6
+
7
+ import torch
8
+ from loguru import logger
9
+
10
+
11
+ @lru_cache(maxsize=1)
12
+ def get_ip() -> str:
13
+ try:
14
+ # try to get ip from network interface
15
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
16
+ s.connect(("8.8.8.8", 80))
17
+ return s.getsockname()[0]
18
+ except Exception as e: # noqa: BLE001
19
+ # fallback to get ip from hostname
20
+ logger.warning(
21
+ f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
22
+ )
23
+ return socket.gethostbyname(socket.gethostname())
24
+
25
+
26
+ def npu_generate_uuid() -> str:
27
+ str_pid = str(os.getpid())
28
+ npu_num = 8
29
+ try:
30
+ for npu_id in range(npu_num):
31
+ cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
32
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
33
+ str_result = str(result.stdout)
34
+ if str_pid in str_result:
35
+ # In A3 server, one NPU has two chips.
36
+ match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
37
+ chip_count = int(match_chip_count.group(1))
38
+ search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
39
+ match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
40
+ chip_id = int(match_chip_id.group(1))
41
+ return f"{get_ip()}-{npu_id * chip_count + chip_id}"
42
+ raise ValueError("The current process is not running on the npu device")
43
+ except subprocess.CalledProcessError as e:
44
+ raise ValueError("The current process is not running on the npu device") from e
45
+
46
+
47
+ class DeviceManager:
48
+ def __init__(self):
49
+ self.device_type = self._detect_device_type()
50
+ self._setup_device_module()
51
+
52
+ def _is_torch_npu_available(self) -> bool:
53
+ try:
54
+ if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
55
+ return torch.npu.is_available()
56
+ else:
57
+ return False
58
+ except ImportError:
59
+ return False
60
+
61
+ def _detect_device_type(self) -> str:
62
+ if self._is_torch_npu_available():
63
+ return "npu"
64
+ elif torch.cuda.is_available():
65
+ return "cuda"
66
+ else:
67
+ raise TypeError("The current device type is not supported")
68
+
69
+ def _setup_device_module(self):
70
+ if self.device_type == "npu":
71
+ import torch_npu
72
+
73
+ self.device_module = torch_npu.npu
74
+ elif self.device_type == "cuda":
75
+ self.device_module = torch.cuda
76
+ else:
77
+ raise TypeError("The current device type is not supported")
78
+
79
+ @property
80
+ def backend(self) -> str:
81
+ if self.device_type == "npu":
82
+ return "hccl"
83
+ elif self.device_type == "cuda":
84
+ return "nccl"
85
+ else:
86
+ raise TypeError("The current device type is not supported")
checkpoint_engine/ps.py CHANGED
@@ -1,16 +1,15 @@
1
1
  import argparse
2
2
  import concurrent.futures
3
3
  import ctypes
4
+ import json
4
5
  import os
5
6
  import pickle
6
7
  import random
7
- import socket
8
8
  import threading
9
9
  import time
10
10
  from collections import defaultdict
11
11
  from collections.abc import Callable
12
12
  from datetime import timedelta
13
- from functools import lru_cache
14
13
  from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
15
14
 
16
15
  import httpx
@@ -20,9 +19,11 @@ import torch.distributed as dist
20
19
  import zmq
21
20
  from loguru import logger
22
21
  from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
23
- from safetensors.torch import safe_open
22
+ from safetensors.torch import _getdtype, safe_open
24
23
  from torch.multiprocessing.reductions import reduce_tensor
25
24
 
25
+ from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
26
+
26
27
 
27
28
  if TYPE_CHECKING:
28
29
  from typing import TypeVar
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
92
93
  name: str
93
94
  dtype: _TorchDtype
94
95
  shape: _TorchSize
96
+ aligned_size: int
95
97
 
96
98
 
97
99
  class BucketRange(NamedTuple):
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
140
142
  def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
141
143
  ret = []
142
144
  for meta in metas:
143
- size = _align_size(meta.dtype, meta.shape)
145
+ size = meta.aligned_size
144
146
  ret.append(
145
147
  {
146
148
  "name": meta.name,
@@ -254,28 +256,16 @@ def _concat_tp_weights(
254
256
  return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
255
257
 
256
258
 
257
- def _get_physical_gpu_id(device_index: int | None = None) -> str:
259
+ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
258
260
  try:
259
- return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
261
+ if device_manager.device_type == "npu":
262
+ return f"NPU-{npu_generate_uuid()}"
263
+ else:
264
+ return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
260
265
  except AssertionError as e:
261
266
  raise ValueError(f"fail to get physical gpu id {device_index}") from e
262
267
 
263
268
 
264
- @lru_cache(maxsize=1)
265
- def _get_ip() -> str:
266
- try:
267
- # try to get ip from network interface
268
- with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
269
- s.connect(("8.8.8.8", 80))
270
- return s.getsockname()[0]
271
- except Exception as e: # noqa: BLE001
272
- # fallback to get ip from hostname
273
- logger.warning(
274
- f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
275
- )
276
- return socket.gethostbyname(socket.gethostname())
277
-
278
-
279
269
  def _ibv_get_device_list() -> list[str]:
280
270
  lib = ctypes.CDLL("libibverbs.so.1")
281
271
  lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
@@ -317,13 +307,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
317
307
  """
318
308
  if not devices:
319
309
  raise RuntimeError("no rdma devices found")
320
- assert len(devices) <= gpu_count, (
321
- f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
322
- )
323
- assert gpu_count % len(devices) == 0, (
324
- f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
325
- )
326
- return devices[local_rank // (gpu_count // len(devices))]
310
+ try:
311
+ assert len(devices) <= gpu_count, (
312
+ f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
313
+ )
314
+ assert gpu_count % len(devices) == 0, (
315
+ f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
316
+ )
317
+ return devices[local_rank // (gpu_count // len(devices))]
318
+ except AssertionError:
319
+ logger.error(
320
+ "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
321
+ "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
322
+ "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
323
+ )
324
+ raise
327
325
 
328
326
 
329
327
  def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
@@ -426,6 +424,7 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
426
424
  name=parameter_name,
427
425
  shape=meta["shape"],
428
426
  dtype=meta["dtype"],
427
+ aligned_size=_align_size(meta["dtype"], meta["shape"]),
429
428
  )
430
429
  tp_meta = tp_metas[parameter_name]
431
430
  if tp_meta.concat_dim != -1:
@@ -435,7 +434,10 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
435
434
  shape = list(parameter_metas[name].shape)
436
435
  shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
437
436
  parameter_metas[name] = ParameterMeta(
438
- name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
437
+ name=name,
438
+ shape=torch.Size(shape),
439
+ dtype=parameter_metas[name].dtype,
440
+ aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
439
441
  )
440
442
  weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
441
443
  # TODO: here concat is serial, which may be slow
@@ -453,17 +455,85 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
453
455
  return parameters
454
456
 
455
457
 
456
- def _register_checkpoint(
457
- *,
458
+ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459
+ def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460
+ """
461
+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462
+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463
+ The actual tensor data is in the remaining bytes and is naturally aligned.
464
+ We pin the remaining bytes as the buffer, making pinning faster.
465
+ """
466
+
467
+ def _pin(t: torch.Tensor):
468
+ """
469
+ Pin the memory of tensor in-place.
470
+ See: https://github.com/pytorch/pytorch/issues/32167
471
+ """
472
+ cudart = torch.cuda.cudart()
473
+ r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
474
+ assert r == 0, f"pin memory error, error code: {r}"
475
+
476
+ # TODO: should only support /dev/shm? but we found files in disk also work?
477
+ size = os.stat(file_path).st_size
478
+ flag_size = 8
479
+ t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480
+ assert t.nbytes > flag_size, (
481
+ f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482
+ )
483
+ start_pos = (
484
+ int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485
+ + flag_size
486
+ )
487
+ header_tensor = t[flag_size:start_pos]
488
+ header = json.loads(header_tensor.numpy().tobytes())
489
+ if "__metadata__" in header:
490
+ header.pop("__metadata__")
491
+
492
+ metas: list[ParameterMeta] = []
493
+ offset = 0
494
+ try:
495
+ for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496
+ start, end = meta["data_offsets"]
497
+ # safetensors format ensures offsets are aligned
498
+ assert offset == start, f"offset {offset} should be equal to start {start}"
499
+ metas.append(
500
+ ParameterMeta(
501
+ name=name,
502
+ dtype=_getdtype(meta["dtype"]),
503
+ shape=torch.Size(meta["shape"]),
504
+ aligned_size=end - start,
505
+ )
506
+ )
507
+ offset = end
508
+ except Exception as e:
509
+ logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510
+ raise
511
+
512
+ buffer = t[start_pos:]
513
+ assert offset == buffer.nbytes, (
514
+ f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515
+ )
516
+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
517
+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518
+ os.remove(file_path)
519
+ _pin(buffer)
520
+ logger.info(
521
+ f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
+ )
523
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
+
525
+ memory_buffers: list[MemoryBuffer] = []
526
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527
+ memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528
+ return memory_buffers
529
+
530
+
531
+ def _normal_pin_memory(
458
532
  files: list[str],
459
533
  named_tensors: dict[str, torch.Tensor],
460
534
  rank: int | None = None,
535
+ shared_pin_memory: list[MemoryBuffer] | None = None,
461
536
  ) -> list[MemoryBuffer]:
462
- logger.info(
463
- f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
464
- )
465
- if not files and not named_tensors:
466
- return []
467
537
  parameters = _load_checkpoint(files)
468
538
  if named_tensors:
469
539
  parameters.update(named_tensors)
@@ -473,13 +543,16 @@ def _register_checkpoint(
473
543
  size: int
474
544
  metas: list[ParameterMeta]
475
545
 
476
- buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
546
+ buckets: list[MemoryBucket] = []
547
+ buckets.append(MemoryBucket(size=0, metas=[]))
477
548
  for name, tensor in sorted(parameters.items()):
478
549
  size = _align_size(tensor.dtype, tensor.shape)
479
550
  if buckets[-1].size + size > bucket_size:
480
551
  assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
481
552
  buckets.append(MemoryBucket(size=0, metas=[]))
482
- buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
553
+ buckets[-1].metas.append(
554
+ ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
555
+ )
483
556
  buckets[-1].size += size
484
557
 
485
558
  memory_buffers = [
@@ -487,16 +560,34 @@ def _register_checkpoint(
487
560
  for bucket in buckets
488
561
  ]
489
562
 
490
- def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
491
- buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
492
- return idx, buffer
563
+ def register_pin_memory(
564
+ idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565
+ ) -> tuple[int, torch.Tensor]:
566
+ if shared_pin_memory:
567
+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568
+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569
+ assert idx < len(shared_pin_memory), (
570
+ f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571
+ )
572
+ assert shared_pin_memory[idx].size == size, (
573
+ f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574
+ )
575
+ return idx, shared_pin_memory[idx].buffer
576
+ else:
577
+ buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578
+ return idx, buffer
493
579
 
494
580
  def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
495
581
  buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
496
582
 
497
583
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
498
584
  futures = [
499
- executor.submit(register_pin_memory, idx, bucket.size)
585
+ executor.submit(
586
+ register_pin_memory,
587
+ idx,
588
+ bucket.size,
589
+ shared_pin_memory,
590
+ )
500
591
  for idx, bucket in enumerate(buckets)
501
592
  ]
502
593
  new_futures = []
@@ -522,6 +613,39 @@ def _register_checkpoint(
522
613
  offset += size
523
614
  for future in concurrent.futures.as_completed(new_futures):
524
615
  future.result()
616
+ return memory_buffers
617
+
618
+
619
+ def _register_checkpoint(
620
+ *,
621
+ files: list[str],
622
+ named_tensors: dict[str, torch.Tensor],
623
+ rank: int | None = None,
624
+ shared_pin_memory: list[MemoryBuffer] | None = None,
625
+ ) -> list[MemoryBuffer]:
626
+ logger.info(
627
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
628
+ )
629
+ if not files and not named_tensors:
630
+ return []
631
+ memory_buffers: list[MemoryBuffer] = []
632
+ files_to_inplace_pin = [
633
+ file
634
+ for file in files
635
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636
+ ]
637
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
638
+ if files_to_normal_pin or named_tensors:
639
+ memory_buffers.extend(
640
+ _normal_pin_memory(
641
+ files=files_to_normal_pin,
642
+ named_tensors=named_tensors,
643
+ rank=rank,
644
+ shared_pin_memory=shared_pin_memory,
645
+ )
646
+ )
647
+ if files_to_inplace_pin:
648
+ memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
525
649
  return memory_buffers
526
650
 
527
651
 
@@ -570,7 +694,7 @@ def _gen_h2d_buckets(
570
694
  for idx, metas in enumerate(items.memory_buffer_metas_list):
571
695
  start_offset, offset = 0, 0
572
696
  for meta in metas.metas:
573
- s = _align_size(meta.dtype, meta.shape)
697
+ s = meta.aligned_size
574
698
  if buckets[-1][1].size + s > bucket_size:
575
699
  if offset - start_offset > 0:
576
700
  buckets[-1][1].ranges.append(
@@ -677,20 +801,29 @@ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, i
677
801
 
678
802
 
679
803
  class P2PStore:
680
- def __init__(self):
804
+ def __init__(self, device_manager: DeviceManager):
681
805
  from mooncake.engine import TransferEngine
682
806
 
683
807
  self.rank = int(os.getenv("RANK"))
684
- gpu_count = torch.cuda.device_count()
808
+ gpu_count = device_manager.device_module.device_count()
685
809
  local_rank = self.rank % gpu_count
686
- self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
687
- self.ip = _get_ip()
810
+ device_type = device_manager.device_type
811
+ if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
812
+ self.device = ""
813
+ else:
814
+ self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
815
+ self.ip = get_ip()
688
816
 
689
817
  # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
690
818
  retry_count = 8
691
819
  for i in range(retry_count):
692
820
  self.engine = TransferEngine()
693
- ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", self.device)
821
+ ret = self.engine.initialize(
822
+ self.ip,
823
+ "P2PHANDSHAKE",
824
+ "ascend_direct" if device_type == "npu" else "rdma",
825
+ self.device,
826
+ )
694
827
  if ret == 0:
695
828
  break
696
829
  # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
@@ -742,6 +875,8 @@ class P2PStore:
742
875
 
743
876
 
744
877
  class ParameterServer:
878
+ shared_memory_pool_name = "__shared_memory_pool__"
879
+
745
880
  def __init__(
746
881
  self,
747
882
  *,
@@ -757,11 +892,12 @@ class ParameterServer:
757
892
  Args:
758
893
  auto_pg: Whether to automatically initialize the process group.
759
894
  Notice that if auto_pg is True, will destroy the process group after update.
760
- mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
895
+ mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
761
896
  """
762
897
  self._rank = rank or int(os.environ.get("RANK", None))
763
898
  self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
764
- self._gpu_count = gpu_count or torch.cuda.device_count()
899
+ self.device_manager = DeviceManager()
900
+ self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
765
901
  self._local_rank = self._rank % self._gpu_count
766
902
  self._auto_pg = auto_pg
767
903
  self._all_hosts = []
@@ -775,7 +911,7 @@ class ParameterServer:
775
911
  assert (
776
912
  self._gpu_count is not None
777
913
  and self._gpu_count > 0
778
- and self._gpu_count <= torch.cuda.device_count()
914
+ and self._gpu_count <= self.device_manager.device_module.device_count()
779
915
  ), self._gpu_count
780
916
  assert (
781
917
  self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
@@ -784,20 +920,35 @@ class ParameterServer:
784
920
  self._zmq_ctx = zmq.Context()
785
921
  self._zmq_addr_counter = 0
786
922
 
923
+ # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
924
+ self._current_shared_memory_pool_user: str = ""
787
925
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
926
+ self._memory_pool[self.shared_memory_pool_name] = []
788
927
  # dict key is owner_rank, value is a bucket metas list in owner_rank
789
928
  self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
929
+ # NPU transfer engine initialization requires prior set_device.
930
+ device_index = self._local_rank
931
+ self.device_manager.device_module.set_device(device_index)
790
932
  try:
791
- self._p2p_store = P2PStore()
933
+ self._p2p_store = P2PStore(self.device_manager)
792
934
  except ImportError as e:
793
935
  logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
794
936
  self._p2p_store = None
795
937
 
796
- device_index = self._local_rank
797
- torch.cuda.set_device(device_index)
798
- self._device_uuid = _get_physical_gpu_id(device_index)
938
+ self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
799
939
  self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
800
940
 
941
+ def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
942
+ if checkpoint_name == self._current_shared_memory_pool_user:
943
+ assert self._memory_pool[self.shared_memory_pool_name], (
944
+ f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
945
+ )
946
+ return self._memory_pool[self.shared_memory_pool_name]
947
+ elif checkpoint_name in self._memory_pool:
948
+ return self._memory_pool[checkpoint_name]
949
+ else:
950
+ raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
951
+
801
952
  def _logger_rank0(self, msg: str):
802
953
  if self._local_rank == 0:
803
954
  logger.info(msg)
@@ -821,46 +972,97 @@ class ParameterServer:
821
972
  *,
822
973
  files: list[str] | None = None,
823
974
  named_tensors: dict[str, torch.Tensor] | None = None,
975
+ use_shared_memory_pool: bool = False,
824
976
  ) -> None:
825
977
  """
826
978
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
979
+ Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
980
+ Please make sure to copy the files to disks if you need to keep them.
827
981
 
828
982
  Args:
829
983
  checkpoint_name: The name of the checkpoint.
830
984
  files: The safetensors files to register.
831
985
  named_tensors: The named tensors to register.
986
+ use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
987
+ Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
988
+ cannot accommodate checkpoints with different memory requirements.
989
+ To free the actual memory of the shared pool or to modify its shape,
990
+ please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
832
991
  """
833
992
  try:
834
- assert checkpoint_name not in self._memory_pool, (
835
- f"checkpoint {checkpoint_name} already registered"
836
- )
837
- self._memory_pool[checkpoint_name] = _register_checkpoint(
838
- files=files or [], named_tensors=named_tensors or {}, rank=self._rank
839
- )
840
- if self._p2p_store is not None:
841
- self._register_parameters_to_p2p_store(checkpoint_name)
993
+ if use_shared_memory_pool:
994
+ logger.info(
995
+ f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
996
+ )
997
+ assert self._current_shared_memory_pool_user == "", (
998
+ f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
999
+ f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
1000
+ f"This registration may cause unexpected conflicts."
1001
+ )
1002
+ # Since we set the uninitialized shared memory pool to empty list,
1003
+ # we can check whether this is the first time to use shared memory pool
1004
+ _is_first_time = not self._memory_pool[self.shared_memory_pool_name]
1005
+ self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
1006
+ files=files or [],
1007
+ named_tensors=named_tensors or {},
1008
+ rank=self._rank,
1009
+ shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
1010
+ )
1011
+ self._current_shared_memory_pool_user = checkpoint_name
1012
+ if self._p2p_store is not None and _is_first_time:
1013
+ self._register_parameters_to_p2p_store(checkpoint_name)
1014
+ else:
1015
+ assert checkpoint_name not in self._memory_pool, (
1016
+ f"checkpoint {checkpoint_name} already registered"
1017
+ )
1018
+ self._memory_pool[checkpoint_name] = _register_checkpoint(
1019
+ files=files or [], named_tensors=named_tensors or {}, rank=self._rank
1020
+ )
1021
+ if self._p2p_store is not None:
1022
+ self._register_parameters_to_p2p_store(checkpoint_name)
842
1023
  except Exception:
843
1024
  logger.exception(
844
1025
  f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
845
1026
  )
846
- if self._p2p_store is not None:
1027
+ if self._p2p_store is not None and not use_shared_memory_pool:
847
1028
  self._unregister_parameters_from_p2p_store(checkpoint_name)
848
1029
  self.unregister_checkpoint(checkpoint_name)
849
1030
  raise
850
1031
 
851
- def unregister_checkpoint(self, checkpoint_name: str):
1032
+ def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
852
1033
  """
853
1034
  Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
854
1035
  from p2p store if p2p store is initialized.
1036
+ Args:
1037
+ checkpoint_name: The name of the checkpoint.
1038
+ force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
1039
+ If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
855
1040
  """
856
- if checkpoint_name not in self._memory_pool:
1041
+ if (
1042
+ checkpoint_name not in self._memory_pool
1043
+ and checkpoint_name != self._current_shared_memory_pool_user
1044
+ ):
1045
+ logger.warning(
1046
+ f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
1047
+ )
857
1048
  return
1049
+
1050
+ if checkpoint_name == self._current_shared_memory_pool_user and not force:
1051
+ self._current_shared_memory_pool_user = ""
1052
+ return
1053
+
858
1054
  if self._p2p_store is not None:
859
1055
  num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
860
1056
  logger.info(
861
1057
  f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
862
1058
  )
863
- del self._memory_pool[checkpoint_name]
1059
+
1060
+ if checkpoint_name == self._current_shared_memory_pool_user:
1061
+ self._current_shared_memory_pool_user = ""
1062
+ del self._memory_pool[self.shared_memory_pool_name]
1063
+ self._memory_pool[self.shared_memory_pool_name] = []
1064
+ else:
1065
+ del self._memory_pool[checkpoint_name]
864
1066
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
865
1067
  # this works by using torch>=2.5.0
866
1068
  torch._C._host_emptyCache()
@@ -875,6 +1077,10 @@ class ParameterServer:
875
1077
  self.init_process_group()
876
1078
  assert dist.is_initialized(), "process group is not initialized"
877
1079
  metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
1080
+ try:
1081
+ memory_pool = self._get_memory_pool(checkpoint_name)
1082
+ except RuntimeError:
1083
+ memory_pool = []
878
1084
  metas = DataToGather(
879
1085
  memory_buffer_metas_list=[
880
1086
  MemoryBufferMetas(
@@ -882,16 +1088,18 @@ class ParameterServer:
882
1088
  ptr=x.buffer.data_ptr(),
883
1089
  size=x.size,
884
1090
  )
885
- for x in self._memory_pool.get(checkpoint_name, [])
1091
+ for x in memory_pool
886
1092
  ],
887
1093
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
888
- host_ip=_get_ip(),
1094
+ host_ip=get_ip(),
889
1095
  device_uuid=self._device_uuid,
890
1096
  rdma_device=self._rdma_device or "",
891
1097
  )
892
1098
 
893
1099
  dist.all_gather_object(metas_lst, metas)
894
1100
 
1101
+ self._current_global_parameter_metas = {}
1102
+
895
1103
  num_parameters = 0
896
1104
  all_hosts: list[str] = []
897
1105
  global_device_uuids: list[str] = []
@@ -948,7 +1156,7 @@ class ParameterServer:
948
1156
  is_master=self._rank == 0,
949
1157
  )
950
1158
  dist.init_process_group(
951
- backend="nccl",
1159
+ backend=self.device_manager.backend,
952
1160
  world_size=self._world_size,
953
1161
  rank=self._rank,
954
1162
  timeout=timeout,
@@ -991,21 +1199,22 @@ class ParameterServer:
991
1199
  if self._rank not in ranks:
992
1200
  return
993
1201
  self._update_per_bucket(checkpoint_name, req_func, ranks)
994
- if self._auto_pg:
995
- dist.destroy_process_group()
996
-
997
- torch.cuda.empty_cache()
998
1202
 
999
- logger.info(
1000
- f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1001
- f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
1002
- f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
1003
- )
1004
1203
  except Exception as e:
1005
1204
  logger.exception(
1006
1205
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
1007
1206
  )
1008
1207
  raise
1208
+ finally:
1209
+ if self._auto_pg and (not ranks or self._rank in ranks):
1210
+ dist.destroy_process_group()
1211
+
1212
+ self.device_manager.device_module.empty_cache()
1213
+ logger.info(
1214
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
1215
+ f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
1216
+ f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
1217
+ )
1009
1218
 
1010
1219
  def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
1011
1220
  def zmq_handle(device_uuid: str) -> str:
@@ -1022,14 +1231,16 @@ class ParameterServer:
1022
1231
  # auto detect bucket size
1023
1232
  tensor = torch.tensor(
1024
1233
  [
1025
- # proportion of current cuda free memory bytes
1026
- int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
1234
+ # proportion of current device free memory bytes
1235
+ int(
1236
+ float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
1237
+ ),
1027
1238
  # we use negative value to reuse allreduce min operation
1028
1239
  # for getting the max value of zmq_addr_counter in all ranks
1029
1240
  -self._zmq_addr_counter,
1030
1241
  ],
1031
1242
  dtype=torch.int64,
1032
- device="cuda",
1243
+ device=self.device_manager.device_type,
1033
1244
  )
1034
1245
  dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
1035
1246
  tensor = tensor.cpu()
@@ -1038,7 +1249,7 @@ class ParameterServer:
1038
1249
  for items in self._current_global_parameter_metas.values():
1039
1250
  for metas_list in items.memory_buffer_metas_list:
1040
1251
  for meta in metas_list.metas:
1041
- max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
1252
+ max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
1042
1253
  free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
1043
1254
  if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
1044
1255
  self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
@@ -1083,7 +1294,7 @@ class ParameterServer:
1083
1294
  remote_ptrs.append(ptrs[b.idx][0] + b.offset)
1084
1295
  lens.append(b.size)
1085
1296
  else:
1086
- pool = self._memory_pool[checkpoint_name][b.idx]
1297
+ pool = self._get_memory_pool(checkpoint_name)[b.idx]
1087
1298
  buffer[offset : offset + b.size].data.copy_(
1088
1299
  pool.buffer[b.offset : b.offset + b.size],
1089
1300
  non_blocking=True,
@@ -1092,7 +1303,7 @@ class ParameterServer:
1092
1303
  assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
1093
1304
  if owner_rank is not None:
1094
1305
  self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
1095
- torch.cuda.synchronize()
1306
+ self.device_manager.device_module.synchronize()
1096
1307
 
1097
1308
  def init_process_group_for_ranks(
1098
1309
  self,
@@ -1132,7 +1343,11 @@ class ParameterServer:
1132
1343
  master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
1133
1344
  )
1134
1345
  dist.init_process_group(
1135
- backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
1346
+ backend=self.device_manager.backend,
1347
+ world_size=len(ranks),
1348
+ rank=rank,
1349
+ timeout=timeout,
1350
+ store=store,
1136
1351
  )
1137
1352
 
1138
1353
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
@@ -1142,22 +1357,32 @@ class ParameterServer:
1142
1357
 
1143
1358
  def _register_parameters_to_p2p_store(self, checkpoint_name: str):
1144
1359
  assert self._p2p_store is not None, "p2p store is not initialized"
1145
- pool = self._memory_pool[checkpoint_name]
1360
+ pool = self._get_memory_pool(checkpoint_name)
1146
1361
  if len(pool) == 0:
1147
1362
  return
1148
1363
  named_tensors, tensor_ptrs = {}, []
1364
+ register_name = (
1365
+ checkpoint_name
1366
+ if checkpoint_name != self._current_shared_memory_pool_user
1367
+ else self.shared_memory_pool_name
1368
+ )
1149
1369
  for idx, memory_buffer in enumerate(pool):
1150
- named_tensors[f"memory_pool_{checkpoint_name}_{idx}"] = memory_buffer.buffer
1370
+ named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
1151
1371
  tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
1152
1372
  self._p2p_store.register_named_tensors(named_tensors)
1153
1373
 
1154
1374
  def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
1155
1375
  assert self._p2p_store is not None, "p2p store is not initialized"
1156
- pool = self._memory_pool[checkpoint_name]
1376
+ pool = self._get_memory_pool(checkpoint_name)
1157
1377
  if len(pool) == 0:
1158
1378
  return 0
1379
+ unregister_name = (
1380
+ checkpoint_name
1381
+ if checkpoint_name != self._current_shared_memory_pool_user
1382
+ else self.shared_memory_pool_name
1383
+ )
1159
1384
  return self._p2p_store.unregister_named_tensors(
1160
- [f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
1385
+ [f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
1161
1386
  )
1162
1387
 
1163
1388
  def _update_per_bucket(
@@ -1184,7 +1409,7 @@ class ParameterServer:
1184
1409
 
1185
1410
  if not need_update:
1186
1411
  return
1187
- # first execute a barrier to avoid subsequent cuda oom
1412
+ # first execute a barrier to avoid subsequent device oom
1188
1413
  dist.barrier()
1189
1414
 
1190
1415
  bucket_size, disable_h2d_buffer = self._detect_bucket_size()
@@ -1199,7 +1424,7 @@ class ParameterServer:
1199
1424
  h2d_buffer: torch.Tensor | None = (
1200
1425
  None
1201
1426
  if disable_h2d_buffer
1202
- else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1427
+ else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
1203
1428
  )
1204
1429
  # p2p store need to register h2d_buffer to let other ranks read
1205
1430
  if ranks:
@@ -1212,7 +1437,9 @@ class ParameterServer:
1212
1437
  continue
1213
1438
  receiver_rank_buckets.append((owner_rank, bucket))
1214
1439
 
1215
- buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
1440
+ buffer = torch.empty(
1441
+ bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
1442
+ )
1216
1443
  handle = reduce_tensor(buffer)
1217
1444
 
1218
1445
  buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
@@ -1231,52 +1458,66 @@ class ParameterServer:
1231
1458
  socket.send_pyobj(handle)
1232
1459
 
1233
1460
  gidx = 0
1461
+ ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1234
1462
  bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1235
- for i in range(max_len):
1236
- if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1237
- self._copy_to_buffer(
1238
- checkpoint_name,
1239
- receiver_rank_buckets[i][1],
1240
- h2d_buffer,
1241
- receiver_rank_buckets[i][0] if ranks else None,
1242
- )
1243
- for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1244
- if i >= len(_buckets):
1245
- continue
1246
- bucket = _buckets[i]
1247
- alloc, reserved = (
1248
- torch.cuda.memory_allocated() / 1024 / 1024,
1249
- torch.cuda.memory_reserved() / 1024 / 1024,
1250
- )
1251
- self._logger_rank0(
1252
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1253
- f"Current CUDA allocated {alloc:.2f} MB, "
1254
- f"reserved {reserved:.2f} MB."
1255
- )
1256
- start = gidx % 2 * bucket_size
1257
- buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1258
- if receiver_rank == self._rank:
1259
- if disable_h2d_buffer:
1260
- self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1261
- else:
1262
- buffer_b.data.copy_(h2d_buffer[: bucket.size])
1263
- brank = bcast_rank_map[receiver_rank]
1264
- dist.broadcast(buffer_b, src=brank)
1265
- socket.recv()
1266
- dist.barrier()
1267
- socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1268
- gidx += 1
1269
-
1270
- socket.recv()
1271
- socket.send_pyobj(None)
1272
- socket.recv()
1273
- req_thread.join()
1274
- dist.barrier()
1275
- socket.close()
1276
- if ranks and h2d_buffer is not None:
1277
- self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1278
-
1279
- torch.cuda.empty_cache()
1463
+ try:
1464
+ for i in range(max_len):
1465
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1466
+ self._copy_to_buffer(
1467
+ checkpoint_name,
1468
+ receiver_rank_buckets[i][1],
1469
+ h2d_buffer,
1470
+ receiver_rank_buckets[i][0] if ranks else None,
1471
+ )
1472
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1473
+ if i >= len(_buckets):
1474
+ continue
1475
+ bucket = _buckets[i]
1476
+ alloc, reserved = (
1477
+ self.device_manager.device_module.memory_allocated() / 1024 / 1024,
1478
+ self.device_manager.device_module.memory_reserved() / 1024 / 1024,
1479
+ )
1480
+ self._logger_rank0(
1481
+ f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1482
+ f"Current device allocated {alloc:.2f} MB, "
1483
+ f"reserved {reserved:.2f} MB."
1484
+ )
1485
+ start = gidx % 2 * bucket_size
1486
+ buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1487
+ if receiver_rank == self._rank:
1488
+ if disable_h2d_buffer:
1489
+ self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1490
+ else:
1491
+ buffer_b.data.copy_(h2d_buffer[: bucket.size])
1492
+ brank = bcast_rank_map[receiver_rank]
1493
+ dist.broadcast(buffer_b, src=brank)
1494
+ resp = socket.recv()
1495
+ if resp != b"":
1496
+ msg = resp.decode("utf-8")
1497
+ logger.error(
1498
+ f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1499
+ )
1500
+ ret_code.fill_(1)
1501
+ dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1502
+ self.device_manager.device_module.synchronize()
1503
+ if ret_code.item() != 0:
1504
+ # quit early if any rank failed
1505
+ socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
1506
+ raise RuntimeError("Failed to update weights due to remote errors")
1507
+ socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1508
+ gidx += 1
1509
+
1510
+ socket.recv()
1511
+ socket.send_pyobj(None)
1512
+ socket.recv()
1513
+ finally:
1514
+ req_thread.join()
1515
+ dist.barrier()
1516
+ socket.close()
1517
+ if ranks and h2d_buffer is not None:
1518
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1519
+
1520
+ self.device_manager.device_module.empty_cache()
1280
1521
 
1281
1522
 
1282
1523
  def _init_api(ps: ParameterServer) -> Any:
@@ -1294,6 +1535,7 @@ def _init_api(ps: ParameterServer) -> Any:
1294
1535
  update_url: str | None = None
1295
1536
  inference_group_ranks: list[int] = []
1296
1537
  timeout: float = 300.0
1538
+ uds: str | None = None
1297
1539
 
1298
1540
  def wrap_exception(func: Callable[[], None]) -> Response:
1299
1541
  try:
@@ -1326,7 +1568,9 @@ def _init_api(ps: ParameterServer) -> Any:
1326
1568
  return
1327
1569
  if req.inference_group_ranks:
1328
1570
  socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
1329
- request_inference_to_update(req.update_url, dict(socket_paths), timeout=req.timeout)
1571
+ request_inference_to_update(
1572
+ req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
1573
+ )
1330
1574
 
1331
1575
  return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
1332
1576
 
@@ -1,10 +1,13 @@
1
1
  import gc
2
+ import traceback
2
3
  from collections.abc import Callable
3
4
  from typing import TypedDict
4
5
 
5
6
  import torch
6
7
  import zmq
7
8
 
9
+ from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
10
+
8
11
 
9
12
  def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
10
13
  func, args = handle
@@ -53,51 +56,107 @@ def update_weights_from_ipc(
53
56
  socket = zmq_ctx.socket(zmq.REP)
54
57
  socket.connect(zmq_handle)
55
58
  buffer: torch.Tensor | None = None
56
- while True:
57
- payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
58
- if payload is None:
59
- # means the update is done
60
- if post_hook is not None:
61
- post_hook()
62
- torch.cuda.synchronize()
63
- socket.send(b"")
64
- break
65
- if isinstance(payload, tuple):
66
- # an ipc handle that vLLM can use `func, args = handle`
67
- # and `func(*args)` to rebuild GPU tensor.
68
- buffer = _rebuild_ipc(payload, device_id)
69
- assert buffer.dtype == torch.uint8
70
- socket.send(b"")
71
- continue
72
- assert isinstance(payload, list)
73
- run(_extract_weights(payload, buffer))
74
- torch.cuda.synchronize()
59
+ device_manager = DeviceManager()
60
+ try:
61
+ ipc_handle: tuple[Callable, tuple] = socket.recv_pyobj()
62
+ assert isinstance(ipc_handle, tuple)
63
+ buffer = _rebuild_ipc(ipc_handle, device_id)
64
+ assert buffer.dtype == torch.uint8
75
65
  socket.send(b"")
66
+ except Exception as e:
67
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
68
+ socket.send_string(msg)
69
+ socket.recv() # wait for ack
70
+ raise
71
+ try:
72
+ while True:
73
+ payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
74
+ if payload is None: # done signal
75
+ if post_hook is not None:
76
+ post_hook()
77
+ device_manager.device_module.synchronize()
78
+ socket.send(b"")
79
+ break
80
+ if isinstance(payload, list): # still updating weights
81
+ try:
82
+ run(_extract_weights(payload, buffer))
83
+ device_manager.device_module.synchronize()
84
+ socket.send(b"")
85
+ except Exception as e: # noqa: BLE001
86
+ # Send exception back to Parameter Server.
87
+ # Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
88
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
89
+ socket.send_string(msg)
90
+ elif isinstance(
91
+ payload, Exception
92
+ ): # error occurred, got force quit signal from Parameter Server
93
+ raise payload
94
+ else:
95
+ raise TypeError(f"Unexpected payload type: {type(payload)}")
76
96
 
77
- socket.close()
78
- del buffer
79
- gc.collect()
80
- torch.cuda.empty_cache()
97
+ finally:
98
+ socket.close()
99
+ del buffer
100
+ gc.collect()
101
+ device_manager.device_module.empty_cache()
81
102
 
82
103
 
83
104
  class VllmColocateWorkerExtension:
84
105
  """
85
- The class for vLLM's worker to inherit from, in the colocate setting.
86
- By defining an extension class, the code can work no matter what is
87
- the underlying worker class. This way, the code can be compatible
88
- with both vLLM V0 and V1.
89
- NOTE: we define this class in a separate module, and the main module
90
- should pass the full qualified name as `worker_extension_cls` argument.
106
+ Worker extension for vLLM to update weights from checkpoint-engine.
107
+
108
+ This class provides a worker extension mechanism that allows vLLM workers to receive
109
+ and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
110
+ The methods in this worker extension will be injected into the vLLM worker class and
111
+ are callable from the `collective_rpc` API, enabling seamless weight updates for both
112
+ vLLM V0 and V1 versions.
113
+
114
+ Note:
115
+ This class is defined in a separate module. The fully qualified name
116
+ `checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
117
+ `worker_extension_cls` argument when initializing the vLLM worker.
91
118
  """
92
119
 
93
120
  def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
121
+ """
122
+ Update model weights from checkpoint-engine via IPC communication.
123
+
124
+ This method establishes a ZMQ connection to the checkpoint-engine and receives
125
+ weight updates through a shared memory buffer. The update process includes:
126
+ 1. Receiving IPC handles to reconstruct shared memory tensors
127
+ 2. Extracting flattened metadata describing tensor weights in the shared memory tensor
128
+ 3. Loading weights into the model
129
+ 4. Post-processing weights after loading
130
+
131
+ Args:
132
+ zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
133
+ The device UUID is platform-specific:
134
+ - For CUDA: UUID from `current_platform.get_device_uuid()`
135
+ - For NPU: Format "NPU-{generated_uuid}"
136
+
137
+ Raises:
138
+ ValueError: If the device type is not supported (not CUDA or NPU).
139
+ AssertionError: If the device is not properly initialized.
140
+
141
+ Note:
142
+ This method is called by vLLM's collective RPC mechanism. The ZMQ context
143
+ is lazily initialized on first call and reused for subsequent updates.
144
+ """
94
145
  from vllm.model_executor.model_loader.utils import process_weights_after_loading
95
146
  from vllm.platforms import current_platform
96
147
 
148
+ # vllm-ascend not init device
149
+ if current_platform.device_type == "npu" and self.device is None:
150
+ self.device = torch.device(f"npu:{self.local_rank}")
97
151
  assert self.device is not None
98
152
  if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
99
153
  self._zmq_ctx = zmq.Context()
100
- device_uuid = current_platform.get_device_uuid(self.device.index)
154
+ if current_platform.device_type == "cuda":
155
+ device_uuid = current_platform.get_device_uuid(self.device.index)
156
+ elif current_platform.device_type == "npu":
157
+ device_uuid = f"NPU-{npu_generate_uuid()}"
158
+ else:
159
+ raise ValueError(f"Unsupported device type: {current_platform.device_type}")
101
160
  update_weights_from_ipc(
102
161
  self._zmq_ctx,
103
162
  zmq_handles[device_uuid],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
99
99
  pip install 'checkpoint-engine[p2p]'
100
100
  ```
101
101
 
102
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
103
-
104
102
  ## Getting Started
105
103
 
106
- Prepare an H800 or H20 machine with 8 GPUs with latest vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights.
104
+ Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
107
105
 
108
106
  ```Bash
109
- cd /opt && git clone https://github.com/vllm-project/vllm && cd vllm
107
+ mkdir -p /opt/vLLM && cd /opt/vLLM
110
108
  uv venv --python 3.12 --seed
111
109
  source .venv/bin/activate
112
- VLLM_USE_PRECOMPILED=1 uv pip install --editable .
110
+ uv pip install vllm==0.10.2
113
111
  ```
114
112
 
115
113
  Install checkpoint-engine
@@ -169,13 +167,68 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
169
167
  Run a simple correctness test for checkpoint_engine
170
168
 
171
169
  ```bash
172
- torchrun --nproc-per-node 8 tests/test_update.py
170
+ pytest tests/test_update.py
173
171
  ```
174
172
 
175
- Other unit tests can be done with pytest.
173
+ `test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
174
+
175
+ Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
176
+
177
+ ```bash
178
+ pytest tests/ -m "not gpu"
179
+ ```
180
+
181
+ ### Environment Variables
182
+ - `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
183
+ - `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
184
+ - `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
185
+
186
+ ## SGLang Integration
187
+
188
+ Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
189
+
190
+ ### Quick Start
191
+
192
+ **1. Install checkpoint-engine:**
193
+ ```bash
194
+ pip install 'checkpoint-engine[p2p]'
195
+ ```
196
+
197
+ **2. Launch SGLang server:**
198
+ ```bash
199
+ python -m sglang.launch_server \
200
+ --model-path $MODEL_PATH \
201
+ --tp 8 \
202
+ --load-format dummy \
203
+ --wait-for-initial-weights
204
+ ```
205
+
206
+ **3. Run checkpoint engine:**
207
+ ```bash
208
+ python -m sglang.srt.checkpoint_engine.update \
209
+ --update-method broadcast \
210
+ --checkpoint-path $MODEL_PATH \
211
+ --inference-parallel-size 8
212
+ ```
213
+
214
+ ### Multi-Node Setup
215
+
216
+ For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
217
+
218
+ ### Key Options
219
+
220
+ **SGLang Server:**
221
+ - `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
222
+ - `--load-format dummy`: Enable overlapping initialization tasks
223
+
224
+ **Checkpoint Engine:**
225
+ - `--update-method`: Choose `broadcast`, `p2p`, or `all`
226
+ - `--inference-parallel-size`: Number of parallel processes
227
+ - `--checkpoint-path`: Model checkpoint directory
228
+
176
229
  ## Limitations and Future Work
177
230
 
178
- - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
231
+ - This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
179
232
  - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
180
233
 
181
234
  ## Acknowledgments
@@ -0,0 +1,10 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=o3ZTescp-19Z9cvBGq9dQnbppljgzdUYUf98Nov0spY,704
3
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
+ checkpoint_engine/ps.py,sha256=cu8Qp5daY1iL30iN69jXP4grlHoAKILblngcKQPA5Bg,67692
5
+ checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
6
+ checkpoint_engine-0.2.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
+ checkpoint_engine-0.2.2.dist-info/METADATA,sha256=_bBxy27d0GMc7KzuIBAdw-Lno3-UrVLUhH63YDbY1YA,11559
8
+ checkpoint_engine-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ checkpoint_engine-0.2.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
+ checkpoint_engine-0.2.2.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=Dg8AmJomLVpjKL6prJylOONZAPRtB86LOce7dorQS_A,704
3
- checkpoint_engine/ps.py,sha256=OpGocqJv0TfGgVC1cPKARfz6qehfCLMzQ5KpDQNxb0o,55291
4
- checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
5
- checkpoint_engine-0.2.0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
6
- checkpoint_engine-0.2.0.dist-info/METADATA,sha256=tbAq45YlRvRAfQHDB0XV8w4ZP0zmVJ3RMTAx_wTm154,9896
7
- checkpoint_engine-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- checkpoint_engine-0.2.0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
- checkpoint_engine-0.2.0.dist-info/RECORD,,