checkpoint-engine 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.1'
32
- __version_tuple__ = version_tuple = (0, 2, 1)
31
+ __version__ = version = '0.2.2'
32
+ __version_tuple__ = version_tuple = (0, 2, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
checkpoint_engine/ps.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import argparse
2
2
  import concurrent.futures
3
3
  import ctypes
4
+ import json
4
5
  import os
5
6
  import pickle
6
7
  import random
@@ -18,7 +19,7 @@ import torch.distributed as dist
18
19
  import zmq
19
20
  from loguru import logger
20
21
  from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
21
- from safetensors.torch import safe_open
22
+ from safetensors.torch import _getdtype, safe_open
22
23
  from torch.multiprocessing.reductions import reduce_tensor
23
24
 
24
25
  from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
92
93
  name: str
93
94
  dtype: _TorchDtype
94
95
  shape: _TorchSize
96
+ aligned_size: int
95
97
 
96
98
 
97
99
  class BucketRange(NamedTuple):
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
140
142
  def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
141
143
  ret = []
142
144
  for meta in metas:
143
- size = _align_size(meta.dtype, meta.shape)
145
+ size = meta.aligned_size
144
146
  ret.append(
145
147
  {
146
148
  "name": meta.name,
@@ -422,6 +424,7 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
422
424
  name=parameter_name,
423
425
  shape=meta["shape"],
424
426
  dtype=meta["dtype"],
427
+ aligned_size=_align_size(meta["dtype"], meta["shape"]),
425
428
  )
426
429
  tp_meta = tp_metas[parameter_name]
427
430
  if tp_meta.concat_dim != -1:
@@ -431,7 +434,10 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
431
434
  shape = list(parameter_metas[name].shape)
432
435
  shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
433
436
  parameter_metas[name] = ParameterMeta(
434
- name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
437
+ name=name,
438
+ shape=torch.Size(shape),
439
+ dtype=parameter_metas[name].dtype,
440
+ aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
435
441
  )
436
442
  weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
437
443
  # TODO: here concat is serial, which may be slow
@@ -449,17 +455,85 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
449
455
  return parameters
450
456
 
451
457
 
452
- def _register_checkpoint(
453
- *,
458
+ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459
+ def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460
+ """
461
+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462
+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463
+ The actual tensor data is in the remaining bytes and is naturally aligned.
464
+ We pin the remaining bytes as the buffer, making pinning faster.
465
+ """
466
+
467
+ def _pin(t: torch.Tensor):
468
+ """
469
+ Pin the memory of tensor in-place.
470
+ See: https://github.com/pytorch/pytorch/issues/32167
471
+ """
472
+ cudart = torch.cuda.cudart()
473
+ r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
474
+ assert r == 0, f"pin memory error, error code: {r}"
475
+
476
+ # TODO: should only support /dev/shm? but we found files in disk also work?
477
+ size = os.stat(file_path).st_size
478
+ flag_size = 8
479
+ t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480
+ assert t.nbytes > flag_size, (
481
+ f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482
+ )
483
+ start_pos = (
484
+ int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485
+ + flag_size
486
+ )
487
+ header_tensor = t[flag_size:start_pos]
488
+ header = json.loads(header_tensor.numpy().tobytes())
489
+ if "__metadata__" in header:
490
+ header.pop("__metadata__")
491
+
492
+ metas: list[ParameterMeta] = []
493
+ offset = 0
494
+ try:
495
+ for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496
+ start, end = meta["data_offsets"]
497
+ # safetensors format ensures offsets are aligned
498
+ assert offset == start, f"offset {offset} should be equal to start {start}"
499
+ metas.append(
500
+ ParameterMeta(
501
+ name=name,
502
+ dtype=_getdtype(meta["dtype"]),
503
+ shape=torch.Size(meta["shape"]),
504
+ aligned_size=end - start,
505
+ )
506
+ )
507
+ offset = end
508
+ except Exception as e:
509
+ logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510
+ raise
511
+
512
+ buffer = t[start_pos:]
513
+ assert offset == buffer.nbytes, (
514
+ f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515
+ )
516
+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
517
+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518
+ os.remove(file_path)
519
+ _pin(buffer)
520
+ logger.info(
521
+ f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
+ )
523
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
+
525
+ memory_buffers: list[MemoryBuffer] = []
526
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527
+ memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528
+ return memory_buffers
529
+
530
+
531
+ def _normal_pin_memory(
454
532
  files: list[str],
455
533
  named_tensors: dict[str, torch.Tensor],
456
534
  rank: int | None = None,
535
+ shared_pin_memory: list[MemoryBuffer] | None = None,
457
536
  ) -> list[MemoryBuffer]:
458
- logger.info(
459
- f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
460
- )
461
- if not files and not named_tensors:
462
- return []
463
537
  parameters = _load_checkpoint(files)
464
538
  if named_tensors:
465
539
  parameters.update(named_tensors)
@@ -469,13 +543,16 @@ def _register_checkpoint(
469
543
  size: int
470
544
  metas: list[ParameterMeta]
471
545
 
472
- buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
546
+ buckets: list[MemoryBucket] = []
547
+ buckets.append(MemoryBucket(size=0, metas=[]))
473
548
  for name, tensor in sorted(parameters.items()):
474
549
  size = _align_size(tensor.dtype, tensor.shape)
475
550
  if buckets[-1].size + size > bucket_size:
476
551
  assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
477
552
  buckets.append(MemoryBucket(size=0, metas=[]))
478
- buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
553
+ buckets[-1].metas.append(
554
+ ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
555
+ )
479
556
  buckets[-1].size += size
480
557
 
481
558
  memory_buffers = [
@@ -483,16 +560,34 @@ def _register_checkpoint(
483
560
  for bucket in buckets
484
561
  ]
485
562
 
486
- def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
487
- buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
488
- return idx, buffer
563
+ def register_pin_memory(
564
+ idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565
+ ) -> tuple[int, torch.Tensor]:
566
+ if shared_pin_memory:
567
+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568
+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569
+ assert idx < len(shared_pin_memory), (
570
+ f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571
+ )
572
+ assert shared_pin_memory[idx].size == size, (
573
+ f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574
+ )
575
+ return idx, shared_pin_memory[idx].buffer
576
+ else:
577
+ buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578
+ return idx, buffer
489
579
 
490
580
  def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
491
581
  buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
492
582
 
493
583
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
494
584
  futures = [
495
- executor.submit(register_pin_memory, idx, bucket.size)
585
+ executor.submit(
586
+ register_pin_memory,
587
+ idx,
588
+ bucket.size,
589
+ shared_pin_memory,
590
+ )
496
591
  for idx, bucket in enumerate(buckets)
497
592
  ]
498
593
  new_futures = []
@@ -518,6 +613,39 @@ def _register_checkpoint(
518
613
  offset += size
519
614
  for future in concurrent.futures.as_completed(new_futures):
520
615
  future.result()
616
+ return memory_buffers
617
+
618
+
619
+ def _register_checkpoint(
620
+ *,
621
+ files: list[str],
622
+ named_tensors: dict[str, torch.Tensor],
623
+ rank: int | None = None,
624
+ shared_pin_memory: list[MemoryBuffer] | None = None,
625
+ ) -> list[MemoryBuffer]:
626
+ logger.info(
627
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
628
+ )
629
+ if not files and not named_tensors:
630
+ return []
631
+ memory_buffers: list[MemoryBuffer] = []
632
+ files_to_inplace_pin = [
633
+ file
634
+ for file in files
635
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636
+ ]
637
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
638
+ if files_to_normal_pin or named_tensors:
639
+ memory_buffers.extend(
640
+ _normal_pin_memory(
641
+ files=files_to_normal_pin,
642
+ named_tensors=named_tensors,
643
+ rank=rank,
644
+ shared_pin_memory=shared_pin_memory,
645
+ )
646
+ )
647
+ if files_to_inplace_pin:
648
+ memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
521
649
  return memory_buffers
522
650
 
523
651
 
@@ -566,7 +694,7 @@ def _gen_h2d_buckets(
566
694
  for idx, metas in enumerate(items.memory_buffer_metas_list):
567
695
  start_offset, offset = 0, 0
568
696
  for meta in metas.metas:
569
- s = _align_size(meta.dtype, meta.shape)
697
+ s = meta.aligned_size
570
698
  if buckets[-1][1].size + s > bucket_size:
571
699
  if offset - start_offset > 0:
572
700
  buckets[-1][1].ranges.append(
@@ -747,6 +875,8 @@ class P2PStore:
747
875
 
748
876
 
749
877
  class ParameterServer:
878
+ shared_memory_pool_name = "__shared_memory_pool__"
879
+
750
880
  def __init__(
751
881
  self,
752
882
  *,
@@ -790,7 +920,10 @@ class ParameterServer:
790
920
  self._zmq_ctx = zmq.Context()
791
921
  self._zmq_addr_counter = 0
792
922
 
923
+ # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
924
+ self._current_shared_memory_pool_user: str = ""
793
925
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
926
+ self._memory_pool[self.shared_memory_pool_name] = []
794
927
  # dict key is owner_rank, value is a bucket metas list in owner_rank
795
928
  self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
796
929
  # NPU transfer engine initialization requires prior set_device.
@@ -805,6 +938,17 @@ class ParameterServer:
805
938
  self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
806
939
  self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
807
940
 
941
+ def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
942
+ if checkpoint_name == self._current_shared_memory_pool_user:
943
+ assert self._memory_pool[self.shared_memory_pool_name], (
944
+ f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
945
+ )
946
+ return self._memory_pool[self.shared_memory_pool_name]
947
+ elif checkpoint_name in self._memory_pool:
948
+ return self._memory_pool[checkpoint_name]
949
+ else:
950
+ raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
951
+
808
952
  def _logger_rank0(self, msg: str):
809
953
  if self._local_rank == 0:
810
954
  logger.info(msg)
@@ -828,46 +972,97 @@ class ParameterServer:
828
972
  *,
829
973
  files: list[str] | None = None,
830
974
  named_tensors: dict[str, torch.Tensor] | None = None,
975
+ use_shared_memory_pool: bool = False,
831
976
  ) -> None:
832
977
  """
833
978
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
979
+ Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
980
+ Please make sure to copy the files to disks if you need to keep them.
834
981
 
835
982
  Args:
836
983
  checkpoint_name: The name of the checkpoint.
837
984
  files: The safetensors files to register.
838
985
  named_tensors: The named tensors to register.
986
+ use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
987
+ Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
988
+ cannot accommodate checkpoints with different memory requirements.
989
+ To free the actual memory of the shared pool or to modify its shape,
990
+ please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
839
991
  """
840
992
  try:
841
- assert checkpoint_name not in self._memory_pool, (
842
- f"checkpoint {checkpoint_name} already registered"
843
- )
844
- self._memory_pool[checkpoint_name] = _register_checkpoint(
845
- files=files or [], named_tensors=named_tensors or {}, rank=self._rank
846
- )
847
- if self._p2p_store is not None:
848
- self._register_parameters_to_p2p_store(checkpoint_name)
993
+ if use_shared_memory_pool:
994
+ logger.info(
995
+ f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
996
+ )
997
+ assert self._current_shared_memory_pool_user == "", (
998
+ f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
999
+ f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
1000
+ f"This registration may cause unexpected conflicts."
1001
+ )
1002
+ # Since we set the uninitialized shared memory pool to empty list,
1003
+ # we can check whether this is the first time to use shared memory pool
1004
+ _is_first_time = not self._memory_pool[self.shared_memory_pool_name]
1005
+ self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
1006
+ files=files or [],
1007
+ named_tensors=named_tensors or {},
1008
+ rank=self._rank,
1009
+ shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
1010
+ )
1011
+ self._current_shared_memory_pool_user = checkpoint_name
1012
+ if self._p2p_store is not None and _is_first_time:
1013
+ self._register_parameters_to_p2p_store(checkpoint_name)
1014
+ else:
1015
+ assert checkpoint_name not in self._memory_pool, (
1016
+ f"checkpoint {checkpoint_name} already registered"
1017
+ )
1018
+ self._memory_pool[checkpoint_name] = _register_checkpoint(
1019
+ files=files or [], named_tensors=named_tensors or {}, rank=self._rank
1020
+ )
1021
+ if self._p2p_store is not None:
1022
+ self._register_parameters_to_p2p_store(checkpoint_name)
849
1023
  except Exception:
850
1024
  logger.exception(
851
1025
  f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
852
1026
  )
853
- if self._p2p_store is not None:
1027
+ if self._p2p_store is not None and not use_shared_memory_pool:
854
1028
  self._unregister_parameters_from_p2p_store(checkpoint_name)
855
1029
  self.unregister_checkpoint(checkpoint_name)
856
1030
  raise
857
1031
 
858
- def unregister_checkpoint(self, checkpoint_name: str):
1032
+ def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
859
1033
  """
860
1034
  Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
861
1035
  from p2p store if p2p store is initialized.
1036
+ Args:
1037
+ checkpoint_name: The name of the checkpoint.
1038
+ force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
1039
+ If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
862
1040
  """
863
- if checkpoint_name not in self._memory_pool:
1041
+ if (
1042
+ checkpoint_name not in self._memory_pool
1043
+ and checkpoint_name != self._current_shared_memory_pool_user
1044
+ ):
1045
+ logger.warning(
1046
+ f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
1047
+ )
1048
+ return
1049
+
1050
+ if checkpoint_name == self._current_shared_memory_pool_user and not force:
1051
+ self._current_shared_memory_pool_user = ""
864
1052
  return
1053
+
865
1054
  if self._p2p_store is not None:
866
1055
  num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
867
1056
  logger.info(
868
1057
  f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
869
1058
  )
870
- del self._memory_pool[checkpoint_name]
1059
+
1060
+ if checkpoint_name == self._current_shared_memory_pool_user:
1061
+ self._current_shared_memory_pool_user = ""
1062
+ del self._memory_pool[self.shared_memory_pool_name]
1063
+ self._memory_pool[self.shared_memory_pool_name] = []
1064
+ else:
1065
+ del self._memory_pool[checkpoint_name]
871
1066
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
872
1067
  # this works by using torch>=2.5.0
873
1068
  torch._C._host_emptyCache()
@@ -882,6 +1077,10 @@ class ParameterServer:
882
1077
  self.init_process_group()
883
1078
  assert dist.is_initialized(), "process group is not initialized"
884
1079
  metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
1080
+ try:
1081
+ memory_pool = self._get_memory_pool(checkpoint_name)
1082
+ except RuntimeError:
1083
+ memory_pool = []
885
1084
  metas = DataToGather(
886
1085
  memory_buffer_metas_list=[
887
1086
  MemoryBufferMetas(
@@ -889,7 +1088,7 @@ class ParameterServer:
889
1088
  ptr=x.buffer.data_ptr(),
890
1089
  size=x.size,
891
1090
  )
892
- for x in self._memory_pool.get(checkpoint_name, [])
1091
+ for x in memory_pool
893
1092
  ],
894
1093
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
895
1094
  host_ip=get_ip(),
@@ -1050,7 +1249,7 @@ class ParameterServer:
1050
1249
  for items in self._current_global_parameter_metas.values():
1051
1250
  for metas_list in items.memory_buffer_metas_list:
1052
1251
  for meta in metas_list.metas:
1053
- max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
1252
+ max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
1054
1253
  free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
1055
1254
  if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
1056
1255
  self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
@@ -1095,7 +1294,7 @@ class ParameterServer:
1095
1294
  remote_ptrs.append(ptrs[b.idx][0] + b.offset)
1096
1295
  lens.append(b.size)
1097
1296
  else:
1098
- pool = self._memory_pool[checkpoint_name][b.idx]
1297
+ pool = self._get_memory_pool(checkpoint_name)[b.idx]
1099
1298
  buffer[offset : offset + b.size].data.copy_(
1100
1299
  pool.buffer[b.offset : b.offset + b.size],
1101
1300
  non_blocking=True,
@@ -1158,22 +1357,32 @@ class ParameterServer:
1158
1357
 
1159
1358
  def _register_parameters_to_p2p_store(self, checkpoint_name: str):
1160
1359
  assert self._p2p_store is not None, "p2p store is not initialized"
1161
- pool = self._memory_pool[checkpoint_name]
1360
+ pool = self._get_memory_pool(checkpoint_name)
1162
1361
  if len(pool) == 0:
1163
1362
  return
1164
1363
  named_tensors, tensor_ptrs = {}, []
1364
+ register_name = (
1365
+ checkpoint_name
1366
+ if checkpoint_name != self._current_shared_memory_pool_user
1367
+ else self.shared_memory_pool_name
1368
+ )
1165
1369
  for idx, memory_buffer in enumerate(pool):
1166
- named_tensors[f"memory_pool_{checkpoint_name}_{idx}"] = memory_buffer.buffer
1370
+ named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
1167
1371
  tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
1168
1372
  self._p2p_store.register_named_tensors(named_tensors)
1169
1373
 
1170
1374
  def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
1171
1375
  assert self._p2p_store is not None, "p2p store is not initialized"
1172
- pool = self._memory_pool[checkpoint_name]
1376
+ pool = self._get_memory_pool(checkpoint_name)
1173
1377
  if len(pool) == 0:
1174
1378
  return 0
1379
+ unregister_name = (
1380
+ checkpoint_name
1381
+ if checkpoint_name != self._current_shared_memory_pool_user
1382
+ else self.shared_memory_pool_name
1383
+ )
1175
1384
  return self._p2p_store.unregister_named_tensors(
1176
- [f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
1385
+ [f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
1177
1386
  )
1178
1387
 
1179
1388
  def _update_per_bucket(
@@ -1284,9 +1493,9 @@ class ParameterServer:
1284
1493
  dist.broadcast(buffer_b, src=brank)
1285
1494
  resp = socket.recv()
1286
1495
  if resp != b"":
1287
- exception_obj = pickle.loads(resp)
1496
+ msg = resp.decode("utf-8")
1288
1497
  logger.error(
1289
- f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
1498
+ f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1290
1499
  )
1291
1500
  ret_code.fill_(1)
1292
1501
  dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
@@ -1,4 +1,5 @@
1
1
  import gc
2
+ import traceback
2
3
  from collections.abc import Callable
3
4
  from typing import TypedDict
4
5
 
@@ -63,7 +64,8 @@ def update_weights_from_ipc(
63
64
  assert buffer.dtype == torch.uint8
64
65
  socket.send(b"")
65
66
  except Exception as e:
66
- socket.send_pyobj(e)
67
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
68
+ socket.send_string(msg)
67
69
  socket.recv() # wait for ack
68
70
  raise
69
71
  try:
@@ -83,7 +85,8 @@ def update_weights_from_ipc(
83
85
  except Exception as e: # noqa: BLE001
84
86
  # Send exception back to Parameter Server.
85
87
  # Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
86
- socket.send_pyobj(e)
88
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
89
+ socket.send_string(msg)
87
90
  elif isinstance(
88
91
  payload, Exception
89
92
  ): # error occurred, got force quit signal from Parameter Server
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
99
99
  pip install 'checkpoint-engine[p2p]'
100
100
  ```
101
101
 
102
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
103
-
104
102
  ## Getting Started
105
103
 
106
- Prepare an H800 or H20 machine with 8 GPUs with latest vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights.
104
+ Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
107
105
 
108
106
  ```Bash
109
- cd /opt && git clone https://github.com/vllm-project/vllm && cd vllm
107
+ mkdir -p /opt/vLLM && cd /opt/vLLM
110
108
  uv venv --python 3.12 --seed
111
109
  source .venv/bin/activate
112
- VLLM_USE_PRECOMPILED=1 uv pip install --editable .
110
+ uv pip install vllm==0.10.2
113
111
  ```
114
112
 
115
113
  Install checkpoint-engine
@@ -180,6 +178,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
180
178
  pytest tests/ -m "not gpu"
181
179
  ```
182
180
 
181
+ ### Environment Variables
182
+ - `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
183
+ - `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
184
+ - `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
185
+
183
186
  ## SGLang Integration
184
187
 
185
188
  Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
@@ -0,0 +1,10 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=o3ZTescp-19Z9cvBGq9dQnbppljgzdUYUf98Nov0spY,704
3
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
+ checkpoint_engine/ps.py,sha256=cu8Qp5daY1iL30iN69jXP4grlHoAKILblngcKQPA5Bg,67692
5
+ checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
6
+ checkpoint_engine-0.2.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
+ checkpoint_engine-0.2.2.dist-info/METADATA,sha256=_bBxy27d0GMc7KzuIBAdw-Lno3-UrVLUhH63YDbY1YA,11559
8
+ checkpoint_engine-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ checkpoint_engine-0.2.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
+ checkpoint_engine-0.2.2.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=vYqoJTG51NOUmYyL0xt8asRK8vUT4lGAdal_EZ59mvw,704
3
- checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
- checkpoint_engine/ps.py,sha256=Mgfr_MXYYZ_6JKqD5kIGDBWCYNWAtIPfAppP_cFu604,57781
5
- checkpoint_engine/worker.py,sha256=5TzDgTPew6Ts9sMOzecalLCR1p_ZwfeKPdzzr68kAQg,6564
6
- checkpoint_engine-0.2.1.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
- checkpoint_engine-0.2.1.dist-info/METADATA,sha256=7E7NhehWHpS6QVkLup-oCm350wdbiZX8CY3jmmJP0bU,11315
8
- checkpoint_engine-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- checkpoint_engine-0.2.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
- checkpoint_engine-0.2.1.dist-info/RECORD,,