checkpoint-engine 0.3.0rc0__tar.gz → 0.3.0rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/_version.py +3 -3
  3. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/ps.py +76 -15
  4. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  5. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/SOURCES.txt +2 -1
  6. checkpoint_engine-0.3.0rc1/tests/test_inplace_unpin.py +81 -0
  7. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/tests/test_update.py +1 -2
  8. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/.github/workflows/cpu-tests.yml +0 -0
  9. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/.github/workflows/pre-commit.yaml +0 -0
  10. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/.github/workflows/python-publish.yml +0 -0
  11. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/.gitignore +0 -0
  12. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/.pre-commit-config.yaml +0 -0
  13. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/LICENCE +0 -0
  14. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/README.md +0 -0
  15. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/__init__.py +0 -0
  16. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/device_utils.py +0 -0
  17. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/worker.py +0 -0
  18. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  19. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/requires.txt +0 -0
  20. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/top_level.txt +0 -0
  21. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/docs/npu_start.md +0 -0
  22. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/examples/update.py +0 -0
  23. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/figures/checkpoint-engine.png +0 -0
  24. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/figures/overlap-update-and-copy.png +0 -0
  25. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/figures/pipeline.png +0 -0
  26. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/patches/vllm_fp8.patch +0 -0
  27. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/pyproject.toml +0 -0
  28. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/setup.cfg +0 -0
  29. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/tests/test_assign_receiver_ranks.py +0 -0
  30. {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.0rc1}/tests/test_rdma_parser.py +0 -0
  31. /checkpoint_engine-0.3.0rc0/tests/test_pin_memory.py → /checkpoint_engine-0.3.0rc1/tests/test_reuse_pin_memory.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.0rc0
3
+ Version: 0.3.0rc1
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.0rc0'
32
- __version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
31
+ __version__ = version = '0.3.0rc1'
32
+ __version_tuple__ = version_tuple = (0, 3, 0, 'rc1')
33
33
 
34
- __commit_id__ = commit_id = 'gbaf6f6196'
34
+ __commit_id__ = commit_id = 'g88370e267'
@@ -118,6 +118,7 @@ class MemoryBuffer(BaseModel):
118
118
  buffer: _TorchTensor
119
119
  size: int
120
120
  metas: list[ParameterMeta]
121
+ manually_pinned: bool = False
121
122
 
122
123
 
123
124
  class MemoryBufferMetaList(BaseModel):
@@ -520,7 +521,7 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
520
521
  logger.info(
521
522
  f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
523
  )
523
- return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
524
525
 
525
526
  memory_buffers: list[MemoryBuffer] = []
526
527
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
@@ -622,6 +623,7 @@ def _register_checkpoint(
622
623
  named_tensors: dict[str, torch.Tensor],
623
624
  rank: int | None = None,
624
625
  shared_pin_memory: list[MemoryBuffer] | None = None,
626
+ inplace_pin: bool = False,
625
627
  ) -> list[MemoryBuffer]:
626
628
  logger.info(
627
629
  f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
@@ -629,12 +631,17 @@ def _register_checkpoint(
629
631
  if not files and not named_tensors:
630
632
  return []
631
633
  memory_buffers: list[MemoryBuffer] = []
632
- files_to_inplace_pin = [
633
- file
634
- for file in files
635
- if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636
- ]
637
- files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
634
+ if inplace_pin:
635
+ logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
636
+ files_to_inplace_pin = [
637
+ file
638
+ for file in files
639
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
640
+ ]
641
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
642
+ else:
643
+ files_to_normal_pin = files
644
+ files_to_inplace_pin = []
638
645
  if files_to_normal_pin or named_tensors:
639
646
  memory_buffers.extend(
640
647
  _normal_pin_memory(
@@ -868,7 +875,7 @@ class ParameterServer:
868
875
  *,
869
876
  rank: int | None = None,
870
877
  world_size: int | None = None,
871
- auto_pg: bool = False,
878
+ auto_pg: bool = True,
872
879
  gpu_count: int | None = None,
873
880
  mem_fraction: float | None = None,
874
881
  ):
@@ -877,7 +884,7 @@ class ParameterServer:
877
884
 
878
885
  Args:
879
886
  auto_pg: Whether to automatically initialize the process group.
880
- Notice that if auto_pg is True, will destroy the process group after update.
887
+ Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
881
888
  mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
882
889
  """
883
890
  self._rank = rank or int(os.environ.get("RANK", None))
@@ -959,11 +966,12 @@ class ParameterServer:
959
966
  files: list[str] | None = None,
960
967
  named_tensors: dict[str, torch.Tensor] | None = None,
961
968
  use_shared_memory_pool: bool = False,
969
+ use_inplace_pin_memory: bool = True,
962
970
  ) -> None:
963
971
  """
964
972
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
965
- Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
966
- Please make sure to copy the files to disks if you need to keep them.
973
+ Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
974
+ Please make sure to copy the files to disks if you need to keep them. NPU does not support inplace pin memory.
967
975
 
968
976
  Args:
969
977
  checkpoint_name: The name of the checkpoint.
@@ -974,7 +982,14 @@ class ParameterServer:
974
982
  cannot accommodate checkpoints with different memory requirements.
975
983
  To free the actual memory of the shared pool or to modify its shape,
976
984
  please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
985
+ use_inplace_pin_memory: If True (default), allows inplace pin memory for /dev/shm/ safetensors files.
986
+ This option is ignored when ``use_shared_memory_pool`` is True.
977
987
  """
988
+ if self.device_manager.device_type != "cuda" and use_inplace_pin_memory:
989
+ logger.warning(
990
+ f"[rank{self._rank}] Only cuda devices support in-place pin memory, set use_inplace_pin_memory to False"
991
+ )
992
+ use_inplace_pin_memory = False
978
993
  try:
979
994
  if use_shared_memory_pool:
980
995
  logger.info(
@@ -993,6 +1008,7 @@ class ParameterServer:
993
1008
  named_tensors=named_tensors or {},
994
1009
  rank=self._rank,
995
1010
  shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
1011
+ inplace_pin=False, # inplace pin memory is not compatible with shared memory pool
996
1012
  )
997
1013
  self._current_shared_memory_pool_user = checkpoint_name
998
1014
  if self._p2p_store is not None and _is_first_time:
@@ -1002,7 +1018,10 @@ class ParameterServer:
1002
1018
  f"checkpoint {checkpoint_name} already registered"
1003
1019
  )
1004
1020
  self._memory_pool[checkpoint_name] = _register_checkpoint(
1005
- files=files or [], named_tensors=named_tensors or {}, rank=self._rank
1021
+ files=files or [],
1022
+ named_tensors=named_tensors or {},
1023
+ rank=self._rank,
1024
+ inplace_pin=use_inplace_pin_memory,
1006
1025
  )
1007
1026
  if self._p2p_store is not None:
1008
1027
  self._register_parameters_to_p2p_store(checkpoint_name)
@@ -1048,6 +1067,46 @@ class ParameterServer:
1048
1067
  del self._memory_pool[self.shared_memory_pool_name]
1049
1068
  self._memory_pool[self.shared_memory_pool_name] = []
1050
1069
  else:
1070
+
1071
+ def _unpin(t: torch.Tensor):
1072
+ """
1073
+ Un-pin the pinned memory.
1074
+ """
1075
+ p_flags = ctypes.c_uint()
1076
+ try:
1077
+ libc = ctypes.CDLL(None) # get all symbols from the current process
1078
+ cuda_host_get_flags = libc.cudaHostGetFlags
1079
+ cuda_host_get_flags.argtypes = [ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p]
1080
+ cuda_host_get_flags.restype = ctypes.c_int
1081
+ except AttributeError:
1082
+ logger.error("cudaHostGetFlags not found in libc, cannot unpin memory manually")
1083
+ raise
1084
+ r = cuda_host_get_flags(ctypes.byref(p_flags), ctypes.c_void_p(t.data_ptr()))
1085
+ assert r == 0, f"get pin flags error, error code: {r}"
1086
+ # p_flags value meaning from cuda/include/driver_types.h
1087
+ # cudaHostRegisterDefault 0x00 /**< Default host memory registration flag */
1088
+ # cudaHostRegisterPortable 0x01 /**< Pinned memory accessible by all CUDA contexts */
1089
+ # cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */
1090
+ # cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */
1091
+ # cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */
1092
+ assert p_flags.value == 0x02, (
1093
+ f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
1094
+ )
1095
+ cudart = torch.cuda.cudart()
1096
+ r = cudart.cudaHostUnregister(t.data_ptr())
1097
+ assert r == 0, f"unpin memory error, error code: {r}"
1098
+
1099
+ # if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
1100
+ try:
1101
+ for memory_buffer in self._memory_pool.get(checkpoint_name, []):
1102
+ if memory_buffer.manually_pinned:
1103
+ _unpin(memory_buffer.buffer)
1104
+ except Exception as e:
1105
+ logger.error(
1106
+ f"[rank{self._rank}] fail to unpin memory for checkpoint {checkpoint_name}: {e}"
1107
+ )
1108
+ raise
1109
+ # we won't delete the memory pool if unpinning fails.
1051
1110
  del self._memory_pool[checkpoint_name]
1052
1111
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
1053
1112
  # this works by using torch>=2.5.0
@@ -1183,6 +1242,8 @@ class ParameterServer:
1183
1242
  ) -> None:
1184
1243
  """
1185
1244
  Update the checkpoint to inference engine. This function should be called after gather_metas.
1245
+ Warning: if _auto_pg is False when initializing ParameterServer, please make sure ALL ranks in the WORLD_SIZE call `update` function,
1246
+ otherwise, it will hang.
1186
1247
 
1187
1248
  Args:
1188
1249
  checkpoint_name: The name of the checkpoint.
@@ -1217,7 +1278,7 @@ class ParameterServer:
1217
1278
  is_master=self._rank == 0,
1218
1279
  )
1219
1280
  # if ranks is None or [], it will use fully broadcast to update to all ranks
1220
- ranks_group = dist.new_group(ranks if ranks else None)
1281
+ ranks_group = dist.new_group(ranks) if ranks else None
1221
1282
  self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1222
1283
  self.store_based_barrier(manager_store)
1223
1284
  except Exception as e:
@@ -1248,7 +1309,7 @@ class ParameterServer:
1248
1309
  return socket, socket_paths
1249
1310
 
1250
1311
  def _detect_bucket_size(
1251
- self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
1312
+ self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
1252
1313
  ) -> tuple[int, bool]:
1253
1314
  GiB = 1 << 30 # noqa: N806
1254
1315
  # auto detect bucket size
@@ -1367,7 +1428,7 @@ class ParameterServer:
1367
1428
  self,
1368
1429
  checkpoint_name: str,
1369
1430
  req_func: Callable[[list[tuple[str, str]]], None],
1370
- ranks_group: dist.ProcessGroup,
1431
+ ranks_group: dist.ProcessGroup | None,
1371
1432
  ranks: list[int] | None = None,
1372
1433
  ):
1373
1434
  assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.0rc0
3
+ Version: 0.3.0rc1
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -23,6 +23,7 @@ figures/overlap-update-and-copy.png
23
23
  figures/pipeline.png
24
24
  patches/vllm_fp8.patch
25
25
  tests/test_assign_receiver_ranks.py
26
- tests/test_pin_memory.py
26
+ tests/test_inplace_unpin.py
27
27
  tests/test_rdma_parser.py
28
+ tests/test_reuse_pin_memory.py
28
29
  tests/test_update.py
@@ -0,0 +1,81 @@
1
+ import os
2
+ import subprocess
3
+ import time
4
+
5
+ import pytest
6
+ import torch.distributed as dist
7
+ from test_update import device_manager, gen_test_tensors, get_world_size
8
+
9
+ from checkpoint_engine.ps import ParameterServer
10
+
11
+
12
+ dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
13
+
14
+
15
+ def get_files() -> list[str]:
16
+ rank = int(os.getenv("RANK"))
17
+ named_tensors = dict(gen_test_tensors(rank))
18
+ import safetensors.torch
19
+
20
+ files = []
21
+ os.makedirs(dev_shm_dir, exist_ok=True)
22
+ tensors_in_dev_shm = named_tensors
23
+ time.sleep(1)
24
+ dev_shm_files = [
25
+ os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
26
+ for _ in range(get_world_size())
27
+ ]
28
+ safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
29
+ time.sleep(1)
30
+ files.append(dev_shm_files[rank])
31
+ return files
32
+
33
+
34
+ def run_pin_and_unpin(num_runs: int):
35
+ ps = ParameterServer(auto_pg=True)
36
+ checkpoint_name = "test_with_files"
37
+ for _ in range(num_runs):
38
+ files = get_files()
39
+ ps.register_checkpoint(checkpoint_name, files=files)
40
+ ps.gather_metas(checkpoint_name)
41
+ dist.barrier()
42
+ ps.unregister_checkpoint(checkpoint_name)
43
+ if ps._rank == 0:
44
+ import shutil
45
+
46
+ shutil.rmtree(dev_shm_dir)
47
+
48
+ dist.destroy_process_group()
49
+
50
+
51
+ @pytest.mark.gpu
52
+ def test_unpin_files():
53
+ world_size = device_manager.device_module.device_count()
54
+ assert world_size >= 2, "This test requires at least 2 GPUs."
55
+ master_addr = "localhost"
56
+ master_port = 25400
57
+ cmd = [
58
+ "torchrun",
59
+ "--nproc_per_node",
60
+ str(world_size),
61
+ "--master_addr",
62
+ master_addr,
63
+ "--master_port",
64
+ str(master_port),
65
+ __file__,
66
+ ]
67
+
68
+ result = subprocess.run( # noqa: S603
69
+ cmd,
70
+ capture_output=False,
71
+ text=True,
72
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
73
+ shell=False,
74
+ check=False,
75
+ )
76
+
77
+ assert result.returncode == 0
78
+
79
+
80
+ if __name__ == "__main__":
81
+ run_pin_and_unpin(3)
@@ -185,7 +185,6 @@ def run_with_files(
185
185
  os.makedirs(dev_shm_dir, exist_ok=True)
186
186
  os.makedirs(disk_dir, exist_ok=True)
187
187
  tensors_items = list(named_tensors.items())
188
- tensors_in_dev_shm = named_tensors
189
188
  tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
190
189
  tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
191
190
  tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
@@ -218,7 +217,7 @@ def run_with_files(
218
217
  if rank == 0:
219
218
  import shutil
220
219
 
221
- os.removedirs(dev_shm_dir)
220
+ shutil.rmtree(dev_shm_dir)
222
221
  shutil.rmtree(disk_dir)
223
222
  assert proc.exitcode == 0
224
223