checkpoint-engine 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.1'
32
- __version_tuple__ = version_tuple = (0, 2, 1)
31
+ __version__ = version = '0.2.3'
32
+ __version_tuple__ = version_tuple = (0, 2, 3)
33
33
 
34
34
  __commit_id__ = commit_id = None
checkpoint_engine/ps.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import argparse
2
2
  import concurrent.futures
3
3
  import ctypes
4
+ import json
4
5
  import os
5
6
  import pickle
6
7
  import random
@@ -18,7 +19,7 @@ import torch.distributed as dist
18
19
  import zmq
19
20
  from loguru import logger
20
21
  from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
21
- from safetensors.torch import safe_open
22
+ from safetensors.torch import _getdtype, safe_open
22
23
  from torch.multiprocessing.reductions import reduce_tensor
23
24
 
24
25
  from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
92
93
  name: str
93
94
  dtype: _TorchDtype
94
95
  shape: _TorchSize
96
+ aligned_size: int
95
97
 
96
98
 
97
99
  class BucketRange(NamedTuple):
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
140
142
  def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
141
143
  ret = []
142
144
  for meta in metas:
143
- size = _align_size(meta.dtype, meta.shape)
145
+ size = meta.aligned_size
144
146
  ret.append(
145
147
  {
146
148
  "name": meta.name,
@@ -422,6 +424,7 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
422
424
  name=parameter_name,
423
425
  shape=meta["shape"],
424
426
  dtype=meta["dtype"],
427
+ aligned_size=_align_size(meta["dtype"], meta["shape"]),
425
428
  )
426
429
  tp_meta = tp_metas[parameter_name]
427
430
  if tp_meta.concat_dim != -1:
@@ -431,7 +434,10 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
431
434
  shape = list(parameter_metas[name].shape)
432
435
  shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
433
436
  parameter_metas[name] = ParameterMeta(
434
- name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
437
+ name=name,
438
+ shape=torch.Size(shape),
439
+ dtype=parameter_metas[name].dtype,
440
+ aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
435
441
  )
436
442
  weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
437
443
  # TODO: here concat is serial, which may be slow
@@ -449,17 +455,85 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
449
455
  return parameters
450
456
 
451
457
 
452
- def _register_checkpoint(
453
- *,
458
+ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459
+ def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460
+ """
461
+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462
+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463
+ The actual tensor data is in the remaining bytes and is naturally aligned.
464
+ We pin the remaining bytes as the buffer, making pinning faster.
465
+ """
466
+
467
+ def _pin(t: torch.Tensor):
468
+ """
469
+ Pin the memory of tensor in-place.
470
+ See: https://github.com/pytorch/pytorch/issues/32167
471
+ """
472
+ cudart = torch.cuda.cudart()
473
+ r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
474
+ assert r == 0, f"pin memory error, error code: {r}"
475
+
476
+ # TODO: should only support /dev/shm? but we found files in disk also work?
477
+ size = os.stat(file_path).st_size
478
+ flag_size = 8
479
+ t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480
+ assert t.nbytes > flag_size, (
481
+ f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482
+ )
483
+ start_pos = (
484
+ int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485
+ + flag_size
486
+ )
487
+ header_tensor = t[flag_size:start_pos]
488
+ header = json.loads(header_tensor.numpy().tobytes())
489
+ if "__metadata__" in header:
490
+ header.pop("__metadata__")
491
+
492
+ metas: list[ParameterMeta] = []
493
+ offset = 0
494
+ try:
495
+ for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496
+ start, end = meta["data_offsets"]
497
+ # safetensors format ensures offsets are aligned
498
+ assert offset == start, f"offset {offset} should be equal to start {start}"
499
+ metas.append(
500
+ ParameterMeta(
501
+ name=name,
502
+ dtype=_getdtype(meta["dtype"]),
503
+ shape=torch.Size(meta["shape"]),
504
+ aligned_size=end - start,
505
+ )
506
+ )
507
+ offset = end
508
+ except Exception as e:
509
+ logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510
+ raise
511
+
512
+ buffer = t[start_pos:]
513
+ assert offset == buffer.nbytes, (
514
+ f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515
+ )
516
+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
517
+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518
+ os.remove(file_path)
519
+ _pin(buffer)
520
+ logger.info(
521
+ f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
+ )
523
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
+
525
+ memory_buffers: list[MemoryBuffer] = []
526
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527
+ memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528
+ return memory_buffers
529
+
530
+
531
+ def _normal_pin_memory(
454
532
  files: list[str],
455
533
  named_tensors: dict[str, torch.Tensor],
456
534
  rank: int | None = None,
535
+ shared_pin_memory: list[MemoryBuffer] | None = None,
457
536
  ) -> list[MemoryBuffer]:
458
- logger.info(
459
- f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
460
- )
461
- if not files and not named_tensors:
462
- return []
463
537
  parameters = _load_checkpoint(files)
464
538
  if named_tensors:
465
539
  parameters.update(named_tensors)
@@ -469,13 +543,16 @@ def _register_checkpoint(
469
543
  size: int
470
544
  metas: list[ParameterMeta]
471
545
 
472
- buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
546
+ buckets: list[MemoryBucket] = []
547
+ buckets.append(MemoryBucket(size=0, metas=[]))
473
548
  for name, tensor in sorted(parameters.items()):
474
549
  size = _align_size(tensor.dtype, tensor.shape)
475
550
  if buckets[-1].size + size > bucket_size:
476
551
  assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
477
552
  buckets.append(MemoryBucket(size=0, metas=[]))
478
- buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
553
+ buckets[-1].metas.append(
554
+ ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
555
+ )
479
556
  buckets[-1].size += size
480
557
 
481
558
  memory_buffers = [
@@ -483,16 +560,34 @@ def _register_checkpoint(
483
560
  for bucket in buckets
484
561
  ]
485
562
 
486
- def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
487
- buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
488
- return idx, buffer
563
+ def register_pin_memory(
564
+ idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565
+ ) -> tuple[int, torch.Tensor]:
566
+ if shared_pin_memory:
567
+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568
+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569
+ assert idx < len(shared_pin_memory), (
570
+ f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571
+ )
572
+ assert shared_pin_memory[idx].size == size, (
573
+ f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574
+ )
575
+ return idx, shared_pin_memory[idx].buffer
576
+ else:
577
+ buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578
+ return idx, buffer
489
579
 
490
580
  def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
491
581
  buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
492
582
 
493
583
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
494
584
  futures = [
495
- executor.submit(register_pin_memory, idx, bucket.size)
585
+ executor.submit(
586
+ register_pin_memory,
587
+ idx,
588
+ bucket.size,
589
+ shared_pin_memory,
590
+ )
496
591
  for idx, bucket in enumerate(buckets)
497
592
  ]
498
593
  new_futures = []
@@ -518,6 +613,45 @@ def _register_checkpoint(
518
613
  offset += size
519
614
  for future in concurrent.futures.as_completed(new_futures):
520
615
  future.result()
616
+ return memory_buffers
617
+
618
+
619
+ def _register_checkpoint(
620
+ *,
621
+ files: list[str],
622
+ named_tensors: dict[str, torch.Tensor],
623
+ rank: int | None = None,
624
+ shared_pin_memory: list[MemoryBuffer] | None = None,
625
+ inplace_pin: bool = False,
626
+ ) -> list[MemoryBuffer]:
627
+ logger.info(
628
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
629
+ )
630
+ if not files and not named_tensors:
631
+ return []
632
+ memory_buffers: list[MemoryBuffer] = []
633
+ if inplace_pin:
634
+ logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
635
+ files_to_inplace_pin = [
636
+ file
637
+ for file in files
638
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
639
+ ]
640
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
641
+ else:
642
+ files_to_normal_pin = files
643
+ files_to_inplace_pin = []
644
+ if files_to_normal_pin or named_tensors:
645
+ memory_buffers.extend(
646
+ _normal_pin_memory(
647
+ files=files_to_normal_pin,
648
+ named_tensors=named_tensors,
649
+ rank=rank,
650
+ shared_pin_memory=shared_pin_memory,
651
+ )
652
+ )
653
+ if files_to_inplace_pin:
654
+ memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
521
655
  return memory_buffers
522
656
 
523
657
 
@@ -566,7 +700,7 @@ def _gen_h2d_buckets(
566
700
  for idx, metas in enumerate(items.memory_buffer_metas_list):
567
701
  start_offset, offset = 0, 0
568
702
  for meta in metas.metas:
569
- s = _align_size(meta.dtype, meta.shape)
703
+ s = meta.aligned_size
570
704
  if buckets[-1][1].size + s > bucket_size:
571
705
  if offset - start_offset > 0:
572
706
  buckets[-1][1].ranges.append(
@@ -747,6 +881,8 @@ class P2PStore:
747
881
 
748
882
 
749
883
  class ParameterServer:
884
+ shared_memory_pool_name = "__shared_memory_pool__"
885
+
750
886
  def __init__(
751
887
  self,
752
888
  *,
@@ -790,7 +926,10 @@ class ParameterServer:
790
926
  self._zmq_ctx = zmq.Context()
791
927
  self._zmq_addr_counter = 0
792
928
 
929
+ # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
930
+ self._current_shared_memory_pool_user: str = ""
793
931
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
932
+ self._memory_pool[self.shared_memory_pool_name] = []
794
933
  # dict key is owner_rank, value is a bucket metas list in owner_rank
795
934
  self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
796
935
  # NPU transfer engine initialization requires prior set_device.
@@ -805,6 +944,17 @@ class ParameterServer:
805
944
  self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
806
945
  self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
807
946
 
947
+ def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
948
+ if checkpoint_name == self._current_shared_memory_pool_user:
949
+ assert self._memory_pool[self.shared_memory_pool_name], (
950
+ f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
951
+ )
952
+ return self._memory_pool[self.shared_memory_pool_name]
953
+ elif checkpoint_name in self._memory_pool:
954
+ return self._memory_pool[checkpoint_name]
955
+ else:
956
+ raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
957
+
808
958
  def _logger_rank0(self, msg: str):
809
959
  if self._local_rank == 0:
810
960
  logger.info(msg)
@@ -828,46 +978,103 @@ class ParameterServer:
828
978
  *,
829
979
  files: list[str] | None = None,
830
980
  named_tensors: dict[str, torch.Tensor] | None = None,
981
+ use_shared_memory_pool: bool = False,
982
+ use_inplace_pin_memory: bool = False,
831
983
  ) -> None:
832
984
  """
833
985
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
986
+ Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
987
+ Please make sure to copy the files to disks if you need to keep them.
834
988
 
835
989
  Args:
836
990
  checkpoint_name: The name of the checkpoint.
837
991
  files: The safetensors files to register.
838
992
  named_tensors: The named tensors to register.
993
+ use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
994
+ Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
995
+ cannot accommodate checkpoints with different memory requirements.
996
+ To free the actual memory of the shared pool or to modify its shape,
997
+ please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
998
+ use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
999
+ Currently, this feature is experimental and may crash.
839
1000
  """
840
1001
  try:
841
- assert checkpoint_name not in self._memory_pool, (
842
- f"checkpoint {checkpoint_name} already registered"
843
- )
844
- self._memory_pool[checkpoint_name] = _register_checkpoint(
845
- files=files or [], named_tensors=named_tensors or {}, rank=self._rank
846
- )
847
- if self._p2p_store is not None:
848
- self._register_parameters_to_p2p_store(checkpoint_name)
1002
+ if use_shared_memory_pool:
1003
+ logger.info(
1004
+ f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
1005
+ )
1006
+ assert self._current_shared_memory_pool_user == "", (
1007
+ f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
1008
+ f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
1009
+ f"This registration may cause unexpected conflicts."
1010
+ )
1011
+ # Since we set the uninitialized shared memory pool to empty list,
1012
+ # we can check whether this is the first time to use shared memory pool
1013
+ _is_first_time = not self._memory_pool[self.shared_memory_pool_name]
1014
+ self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
1015
+ files=files or [],
1016
+ named_tensors=named_tensors or {},
1017
+ rank=self._rank,
1018
+ shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
1019
+ )
1020
+ self._current_shared_memory_pool_user = checkpoint_name
1021
+ if self._p2p_store is not None and _is_first_time:
1022
+ self._register_parameters_to_p2p_store(checkpoint_name)
1023
+ else:
1024
+ assert checkpoint_name not in self._memory_pool, (
1025
+ f"checkpoint {checkpoint_name} already registered"
1026
+ )
1027
+ self._memory_pool[checkpoint_name] = _register_checkpoint(
1028
+ files=files or [],
1029
+ named_tensors=named_tensors or {},
1030
+ rank=self._rank,
1031
+ inplace_pin=use_inplace_pin_memory,
1032
+ )
1033
+ if self._p2p_store is not None:
1034
+ self._register_parameters_to_p2p_store(checkpoint_name)
849
1035
  except Exception:
850
1036
  logger.exception(
851
1037
  f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
852
1038
  )
853
- if self._p2p_store is not None:
1039
+ if self._p2p_store is not None and not use_shared_memory_pool:
854
1040
  self._unregister_parameters_from_p2p_store(checkpoint_name)
855
1041
  self.unregister_checkpoint(checkpoint_name)
856
1042
  raise
857
1043
 
858
- def unregister_checkpoint(self, checkpoint_name: str):
1044
+ def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
859
1045
  """
860
1046
  Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
861
1047
  from p2p store if p2p store is initialized.
1048
+ Args:
1049
+ checkpoint_name: The name of the checkpoint.
1050
+ force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
1051
+ If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
862
1052
  """
863
- if checkpoint_name not in self._memory_pool:
1053
+ if (
1054
+ checkpoint_name not in self._memory_pool
1055
+ and checkpoint_name != self._current_shared_memory_pool_user
1056
+ ):
1057
+ logger.warning(
1058
+ f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
1059
+ )
1060
+ return
1061
+
1062
+ if checkpoint_name == self._current_shared_memory_pool_user and not force:
1063
+ self._current_shared_memory_pool_user = ""
864
1064
  return
1065
+
865
1066
  if self._p2p_store is not None:
866
1067
  num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
867
1068
  logger.info(
868
1069
  f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
869
1070
  )
870
- del self._memory_pool[checkpoint_name]
1071
+
1072
+ if checkpoint_name == self._current_shared_memory_pool_user:
1073
+ self._current_shared_memory_pool_user = ""
1074
+ del self._memory_pool[self.shared_memory_pool_name]
1075
+ self._memory_pool[self.shared_memory_pool_name] = []
1076
+ else:
1077
+ del self._memory_pool[checkpoint_name]
871
1078
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
872
1079
  # this works by using torch>=2.5.0
873
1080
  torch._C._host_emptyCache()
@@ -882,6 +1089,10 @@ class ParameterServer:
882
1089
  self.init_process_group()
883
1090
  assert dist.is_initialized(), "process group is not initialized"
884
1091
  metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
1092
+ try:
1093
+ memory_pool = self._get_memory_pool(checkpoint_name)
1094
+ except RuntimeError:
1095
+ memory_pool = []
885
1096
  metas = DataToGather(
886
1097
  memory_buffer_metas_list=[
887
1098
  MemoryBufferMetas(
@@ -889,7 +1100,7 @@ class ParameterServer:
889
1100
  ptr=x.buffer.data_ptr(),
890
1101
  size=x.size,
891
1102
  )
892
- for x in self._memory_pool.get(checkpoint_name, [])
1103
+ for x in memory_pool
893
1104
  ],
894
1105
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
895
1106
  host_ip=get_ip(),
@@ -1050,7 +1261,7 @@ class ParameterServer:
1050
1261
  for items in self._current_global_parameter_metas.values():
1051
1262
  for metas_list in items.memory_buffer_metas_list:
1052
1263
  for meta in metas_list.metas:
1053
- max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
1264
+ max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
1054
1265
  free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
1055
1266
  if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
1056
1267
  self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
@@ -1095,7 +1306,7 @@ class ParameterServer:
1095
1306
  remote_ptrs.append(ptrs[b.idx][0] + b.offset)
1096
1307
  lens.append(b.size)
1097
1308
  else:
1098
- pool = self._memory_pool[checkpoint_name][b.idx]
1309
+ pool = self._get_memory_pool(checkpoint_name)[b.idx]
1099
1310
  buffer[offset : offset + b.size].data.copy_(
1100
1311
  pool.buffer[b.offset : b.offset + b.size],
1101
1312
  non_blocking=True,
@@ -1158,22 +1369,32 @@ class ParameterServer:
1158
1369
 
1159
1370
  def _register_parameters_to_p2p_store(self, checkpoint_name: str):
1160
1371
  assert self._p2p_store is not None, "p2p store is not initialized"
1161
- pool = self._memory_pool[checkpoint_name]
1372
+ pool = self._get_memory_pool(checkpoint_name)
1162
1373
  if len(pool) == 0:
1163
1374
  return
1164
1375
  named_tensors, tensor_ptrs = {}, []
1376
+ register_name = (
1377
+ checkpoint_name
1378
+ if checkpoint_name != self._current_shared_memory_pool_user
1379
+ else self.shared_memory_pool_name
1380
+ )
1165
1381
  for idx, memory_buffer in enumerate(pool):
1166
- named_tensors[f"memory_pool_{checkpoint_name}_{idx}"] = memory_buffer.buffer
1382
+ named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
1167
1383
  tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
1168
1384
  self._p2p_store.register_named_tensors(named_tensors)
1169
1385
 
1170
1386
  def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
1171
1387
  assert self._p2p_store is not None, "p2p store is not initialized"
1172
- pool = self._memory_pool[checkpoint_name]
1388
+ pool = self._get_memory_pool(checkpoint_name)
1173
1389
  if len(pool) == 0:
1174
1390
  return 0
1391
+ unregister_name = (
1392
+ checkpoint_name
1393
+ if checkpoint_name != self._current_shared_memory_pool_user
1394
+ else self.shared_memory_pool_name
1395
+ )
1175
1396
  return self._p2p_store.unregister_named_tensors(
1176
- [f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
1397
+ [f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
1177
1398
  )
1178
1399
 
1179
1400
  def _update_per_bucket(
@@ -1284,9 +1505,9 @@ class ParameterServer:
1284
1505
  dist.broadcast(buffer_b, src=brank)
1285
1506
  resp = socket.recv()
1286
1507
  if resp != b"":
1287
- exception_obj = pickle.loads(resp)
1508
+ msg = resp.decode("utf-8")
1288
1509
  logger.error(
1289
- f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
1510
+ f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1290
1511
  )
1291
1512
  ret_code.fill_(1)
1292
1513
  dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
@@ -1,4 +1,5 @@
1
1
  import gc
2
+ import traceback
2
3
  from collections.abc import Callable
3
4
  from typing import TypedDict
4
5
 
@@ -63,7 +64,8 @@ def update_weights_from_ipc(
63
64
  assert buffer.dtype == torch.uint8
64
65
  socket.send(b"")
65
66
  except Exception as e:
66
- socket.send_pyobj(e)
67
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
68
+ socket.send_string(msg)
67
69
  socket.recv() # wait for ack
68
70
  raise
69
71
  try:
@@ -83,7 +85,8 @@ def update_weights_from_ipc(
83
85
  except Exception as e: # noqa: BLE001
84
86
  # Send exception back to Parameter Server.
85
87
  # Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
86
- socket.send_pyobj(e)
88
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
89
+ socket.send_string(msg)
87
90
  elif isinstance(
88
91
  payload, Exception
89
92
  ): # error occurred, got force quit signal from Parameter Server
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
99
99
  pip install 'checkpoint-engine[p2p]'
100
100
  ```
101
101
 
102
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
103
-
104
102
  ## Getting Started
105
103
 
106
- Prepare an H800 or H20 machine with 8 GPUs with latest vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights.
104
+ Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
107
105
 
108
106
  ```Bash
109
- cd /opt && git clone https://github.com/vllm-project/vllm && cd vllm
107
+ mkdir -p /opt/vLLM && cd /opt/vLLM
110
108
  uv venv --python 3.12 --seed
111
109
  source .venv/bin/activate
112
- VLLM_USE_PRECOMPILED=1 uv pip install --editable .
110
+ uv pip install vllm==0.10.2
113
111
  ```
114
112
 
115
113
  Install checkpoint-engine
@@ -180,6 +178,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
180
178
  pytest tests/ -m "not gpu"
181
179
  ```
182
180
 
181
+ ### Environment Variables
182
+ - `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
183
+ - `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
184
+ - `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
185
+
183
186
  ## SGLang Integration
184
187
 
185
188
  Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
@@ -0,0 +1,10 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=kBRz0P2plw1eVdIpt70W6m1LMbEIhLY3RyOfVGdubaI,704
3
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
+ checkpoint_engine/ps.py,sha256=JQcSDeq7wGLZMPBdAa-9Lb4SymSP1l_oUEMO-X1LfvQ,68360
5
+ checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
6
+ checkpoint_engine-0.2.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
+ checkpoint_engine-0.2.3.dist-info/METADATA,sha256=qsTp8s8Z6gz2q12x0gZQKlvViKrvlEB36b_Zpe_nhi4,11559
8
+ checkpoint_engine-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ checkpoint_engine-0.2.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
+ checkpoint_engine-0.2.3.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=vYqoJTG51NOUmYyL0xt8asRK8vUT4lGAdal_EZ59mvw,704
3
- checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
- checkpoint_engine/ps.py,sha256=Mgfr_MXYYZ_6JKqD5kIGDBWCYNWAtIPfAppP_cFu604,57781
5
- checkpoint_engine/worker.py,sha256=5TzDgTPew6Ts9sMOzecalLCR1p_ZwfeKPdzzr68kAQg,6564
6
- checkpoint_engine-0.2.1.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
- checkpoint_engine-0.2.1.dist-info/METADATA,sha256=7E7NhehWHpS6QVkLup-oCm350wdbiZX8CY3jmmJP0bU,11315
8
- checkpoint_engine-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- checkpoint_engine-0.2.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
- checkpoint_engine-0.2.1.dist-info/RECORD,,