checkpoint-engine 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.2'
32
- __version_tuple__ = version_tuple = (0, 2, 2)
31
+ __version__ = version = '0.2.3'
32
+ __version_tuple__ = version_tuple = (0, 2, 3)
33
33
 
34
34
  __commit_id__ = commit_id = None
checkpoint_engine/ps.py CHANGED
@@ -622,6 +622,7 @@ def _register_checkpoint(
622
622
  named_tensors: dict[str, torch.Tensor],
623
623
  rank: int | None = None,
624
624
  shared_pin_memory: list[MemoryBuffer] | None = None,
625
+ inplace_pin: bool = False,
625
626
  ) -> list[MemoryBuffer]:
626
627
  logger.info(
627
628
  f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
@@ -629,12 +630,17 @@ def _register_checkpoint(
629
630
  if not files and not named_tensors:
630
631
  return []
631
632
  memory_buffers: list[MemoryBuffer] = []
632
- files_to_inplace_pin = [
633
- file
634
- for file in files
635
- if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636
- ]
637
- files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
633
+ if inplace_pin:
634
+ logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
635
+ files_to_inplace_pin = [
636
+ file
637
+ for file in files
638
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
639
+ ]
640
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
641
+ else:
642
+ files_to_normal_pin = files
643
+ files_to_inplace_pin = []
638
644
  if files_to_normal_pin or named_tensors:
639
645
  memory_buffers.extend(
640
646
  _normal_pin_memory(
@@ -973,10 +979,11 @@ class ParameterServer:
973
979
  files: list[str] | None = None,
974
980
  named_tensors: dict[str, torch.Tensor] | None = None,
975
981
  use_shared_memory_pool: bool = False,
982
+ use_inplace_pin_memory: bool = False,
976
983
  ) -> None:
977
984
  """
978
985
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
979
- Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
986
+ Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
980
987
  Please make sure to copy the files to disks if you need to keep them.
981
988
 
982
989
  Args:
@@ -988,6 +995,8 @@ class ParameterServer:
988
995
  cannot accommodate checkpoints with different memory requirements.
989
996
  To free the actual memory of the shared pool or to modify its shape,
990
997
  please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
998
+ use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
999
+ Currently, this feature is experimental and may crash.
991
1000
  """
992
1001
  try:
993
1002
  if use_shared_memory_pool:
@@ -1016,7 +1025,10 @@ class ParameterServer:
1016
1025
  f"checkpoint {checkpoint_name} already registered"
1017
1026
  )
1018
1027
  self._memory_pool[checkpoint_name] = _register_checkpoint(
1019
- files=files or [], named_tensors=named_tensors or {}, rank=self._rank
1028
+ files=files or [],
1029
+ named_tensors=named_tensors or {},
1030
+ rank=self._rank,
1031
+ inplace_pin=use_inplace_pin_memory,
1020
1032
  )
1021
1033
  if self._p2p_store is not None:
1022
1034
  self._register_parameters_to_p2p_store(checkpoint_name)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -0,0 +1,10 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=kBRz0P2plw1eVdIpt70W6m1LMbEIhLY3RyOfVGdubaI,704
3
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
+ checkpoint_engine/ps.py,sha256=JQcSDeq7wGLZMPBdAa-9Lb4SymSP1l_oUEMO-X1LfvQ,68360
5
+ checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
6
+ checkpoint_engine-0.2.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
+ checkpoint_engine-0.2.3.dist-info/METADATA,sha256=qsTp8s8Z6gz2q12x0gZQKlvViKrvlEB36b_Zpe_nhi4,11559
8
+ checkpoint_engine-0.2.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ checkpoint_engine-0.2.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
+ checkpoint_engine-0.2.3.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=o3ZTescp-19Z9cvBGq9dQnbppljgzdUYUf98Nov0spY,704
3
- checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
4
- checkpoint_engine/ps.py,sha256=cu8Qp5daY1iL30iN69jXP4grlHoAKILblngcKQPA5Bg,67692
5
- checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
6
- checkpoint_engine-0.2.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
7
- checkpoint_engine-0.2.2.dist-info/METADATA,sha256=_bBxy27d0GMc7KzuIBAdw-Lno3-UrVLUhH63YDbY1YA,11559
8
- checkpoint_engine-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- checkpoint_engine-0.2.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
10
- checkpoint_engine-0.2.2.dist-info/RECORD,,