checkpoint-engine 0.3.0rc0__py3-none-any.whl → 0.3.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/ps.py +76 -15
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.0rc1.dist-info}/METADATA +1 -1
- checkpoint_engine-0.3.0rc1.dist-info/RECORD +10 -0
- checkpoint_engine-0.3.0rc0.dist-info/RECORD +0 -10
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.0rc1.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.0rc1.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.0rc1.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3, 0, '
|
|
31
|
+
__version__ = version = '0.3.0rc1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0, 'rc1')
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
checkpoint_engine/ps.py
CHANGED
|
@@ -118,6 +118,7 @@ class MemoryBuffer(BaseModel):
|
|
|
118
118
|
buffer: _TorchTensor
|
|
119
119
|
size: int
|
|
120
120
|
metas: list[ParameterMeta]
|
|
121
|
+
manually_pinned: bool = False
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
class MemoryBufferMetaList(BaseModel):
|
|
@@ -520,7 +521,7 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
|
|
|
520
521
|
logger.info(
|
|
521
522
|
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
522
523
|
)
|
|
523
|
-
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
|
|
524
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
|
|
524
525
|
|
|
525
526
|
memory_buffers: list[MemoryBuffer] = []
|
|
526
527
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
@@ -622,6 +623,7 @@ def _register_checkpoint(
|
|
|
622
623
|
named_tensors: dict[str, torch.Tensor],
|
|
623
624
|
rank: int | None = None,
|
|
624
625
|
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
626
|
+
inplace_pin: bool = False,
|
|
625
627
|
) -> list[MemoryBuffer]:
|
|
626
628
|
logger.info(
|
|
627
629
|
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
@@ -629,12 +631,17 @@ def _register_checkpoint(
|
|
|
629
631
|
if not files and not named_tensors:
|
|
630
632
|
return []
|
|
631
633
|
memory_buffers: list[MemoryBuffer] = []
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
634
|
+
if inplace_pin:
|
|
635
|
+
logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
|
|
636
|
+
files_to_inplace_pin = [
|
|
637
|
+
file
|
|
638
|
+
for file in files
|
|
639
|
+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
640
|
+
]
|
|
641
|
+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
642
|
+
else:
|
|
643
|
+
files_to_normal_pin = files
|
|
644
|
+
files_to_inplace_pin = []
|
|
638
645
|
if files_to_normal_pin or named_tensors:
|
|
639
646
|
memory_buffers.extend(
|
|
640
647
|
_normal_pin_memory(
|
|
@@ -868,7 +875,7 @@ class ParameterServer:
|
|
|
868
875
|
*,
|
|
869
876
|
rank: int | None = None,
|
|
870
877
|
world_size: int | None = None,
|
|
871
|
-
auto_pg: bool =
|
|
878
|
+
auto_pg: bool = True,
|
|
872
879
|
gpu_count: int | None = None,
|
|
873
880
|
mem_fraction: float | None = None,
|
|
874
881
|
):
|
|
@@ -877,7 +884,7 @@ class ParameterServer:
|
|
|
877
884
|
|
|
878
885
|
Args:
|
|
879
886
|
auto_pg: Whether to automatically initialize the process group.
|
|
880
|
-
Notice that if auto_pg is True, will destroy the process group after update.
|
|
887
|
+
Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
|
|
881
888
|
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
882
889
|
"""
|
|
883
890
|
self._rank = rank or int(os.environ.get("RANK", None))
|
|
@@ -959,11 +966,12 @@ class ParameterServer:
|
|
|
959
966
|
files: list[str] | None = None,
|
|
960
967
|
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
961
968
|
use_shared_memory_pool: bool = False,
|
|
969
|
+
use_inplace_pin_memory: bool = True,
|
|
962
970
|
) -> None:
|
|
963
971
|
"""
|
|
964
972
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
965
|
-
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
966
|
-
Please make sure to copy the files to disks if you need to keep them.
|
|
973
|
+
Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
974
|
+
Please make sure to copy the files to disks if you need to keep them. NPU does not support inplace pin memory.
|
|
967
975
|
|
|
968
976
|
Args:
|
|
969
977
|
checkpoint_name: The name of the checkpoint.
|
|
@@ -974,7 +982,14 @@ class ParameterServer:
|
|
|
974
982
|
cannot accommodate checkpoints with different memory requirements.
|
|
975
983
|
To free the actual memory of the shared pool or to modify its shape,
|
|
976
984
|
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
985
|
+
use_inplace_pin_memory: If True (default), allows inplace pin memory for /dev/shm/ safetensors files.
|
|
986
|
+
This option is ignored when ``use_shared_memory_pool`` is True.
|
|
977
987
|
"""
|
|
988
|
+
if self.device_manager.device_type != "cuda" and use_inplace_pin_memory:
|
|
989
|
+
logger.warning(
|
|
990
|
+
f"[rank{self._rank}] Only cuda devices support in-place pin memory, set use_inplace_pin_memory to False"
|
|
991
|
+
)
|
|
992
|
+
use_inplace_pin_memory = False
|
|
978
993
|
try:
|
|
979
994
|
if use_shared_memory_pool:
|
|
980
995
|
logger.info(
|
|
@@ -993,6 +1008,7 @@ class ParameterServer:
|
|
|
993
1008
|
named_tensors=named_tensors or {},
|
|
994
1009
|
rank=self._rank,
|
|
995
1010
|
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
|
|
1011
|
+
inplace_pin=False, # inplace pin memory is not compatible with shared memory pool
|
|
996
1012
|
)
|
|
997
1013
|
self._current_shared_memory_pool_user = checkpoint_name
|
|
998
1014
|
if self._p2p_store is not None and _is_first_time:
|
|
@@ -1002,7 +1018,10 @@ class ParameterServer:
|
|
|
1002
1018
|
f"checkpoint {checkpoint_name} already registered"
|
|
1003
1019
|
)
|
|
1004
1020
|
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
1005
|
-
files=files or [],
|
|
1021
|
+
files=files or [],
|
|
1022
|
+
named_tensors=named_tensors or {},
|
|
1023
|
+
rank=self._rank,
|
|
1024
|
+
inplace_pin=use_inplace_pin_memory,
|
|
1006
1025
|
)
|
|
1007
1026
|
if self._p2p_store is not None:
|
|
1008
1027
|
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
@@ -1048,6 +1067,46 @@ class ParameterServer:
|
|
|
1048
1067
|
del self._memory_pool[self.shared_memory_pool_name]
|
|
1049
1068
|
self._memory_pool[self.shared_memory_pool_name] = []
|
|
1050
1069
|
else:
|
|
1070
|
+
|
|
1071
|
+
def _unpin(t: torch.Tensor):
|
|
1072
|
+
"""
|
|
1073
|
+
Un-pin the pinned memory.
|
|
1074
|
+
"""
|
|
1075
|
+
p_flags = ctypes.c_uint()
|
|
1076
|
+
try:
|
|
1077
|
+
libc = ctypes.CDLL(None) # get all symbols from the current process
|
|
1078
|
+
cuda_host_get_flags = libc.cudaHostGetFlags
|
|
1079
|
+
cuda_host_get_flags.argtypes = [ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p]
|
|
1080
|
+
cuda_host_get_flags.restype = ctypes.c_int
|
|
1081
|
+
except AttributeError:
|
|
1082
|
+
logger.error("cudaHostGetFlags not found in libc, cannot unpin memory manually")
|
|
1083
|
+
raise
|
|
1084
|
+
r = cuda_host_get_flags(ctypes.byref(p_flags), ctypes.c_void_p(t.data_ptr()))
|
|
1085
|
+
assert r == 0, f"get pin flags error, error code: {r}"
|
|
1086
|
+
# p_flags value meaning from cuda/include/driver_types.h
|
|
1087
|
+
# cudaHostRegisterDefault 0x00 /**< Default host memory registration flag */
|
|
1088
|
+
# cudaHostRegisterPortable 0x01 /**< Pinned memory accessible by all CUDA contexts */
|
|
1089
|
+
# cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */
|
|
1090
|
+
# cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */
|
|
1091
|
+
# cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */
|
|
1092
|
+
assert p_flags.value == 0x02, (
|
|
1093
|
+
f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
|
|
1094
|
+
)
|
|
1095
|
+
cudart = torch.cuda.cudart()
|
|
1096
|
+
r = cudart.cudaHostUnregister(t.data_ptr())
|
|
1097
|
+
assert r == 0, f"unpin memory error, error code: {r}"
|
|
1098
|
+
|
|
1099
|
+
# if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
|
|
1100
|
+
try:
|
|
1101
|
+
for memory_buffer in self._memory_pool.get(checkpoint_name, []):
|
|
1102
|
+
if memory_buffer.manually_pinned:
|
|
1103
|
+
_unpin(memory_buffer.buffer)
|
|
1104
|
+
except Exception as e:
|
|
1105
|
+
logger.error(
|
|
1106
|
+
f"[rank{self._rank}] fail to unpin memory for checkpoint {checkpoint_name}: {e}"
|
|
1107
|
+
)
|
|
1108
|
+
raise
|
|
1109
|
+
# we won't delete the memory pool if unpinning fails.
|
|
1051
1110
|
del self._memory_pool[checkpoint_name]
|
|
1052
1111
|
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
1053
1112
|
# this works by using torch>=2.5.0
|
|
@@ -1183,6 +1242,8 @@ class ParameterServer:
|
|
|
1183
1242
|
) -> None:
|
|
1184
1243
|
"""
|
|
1185
1244
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
1245
|
+
Warning: if _auto_pg is False when initializing ParameterServer, please make sure ALL ranks in the WORLD_SIZE call `update` function,
|
|
1246
|
+
otherwise, it will hang.
|
|
1186
1247
|
|
|
1187
1248
|
Args:
|
|
1188
1249
|
checkpoint_name: The name of the checkpoint.
|
|
@@ -1217,7 +1278,7 @@ class ParameterServer:
|
|
|
1217
1278
|
is_master=self._rank == 0,
|
|
1218
1279
|
)
|
|
1219
1280
|
# if ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1220
|
-
ranks_group = dist.new_group(ranks if ranks else None
|
|
1281
|
+
ranks_group = dist.new_group(ranks) if ranks else None
|
|
1221
1282
|
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
|
|
1222
1283
|
self.store_based_barrier(manager_store)
|
|
1223
1284
|
except Exception as e:
|
|
@@ -1248,7 +1309,7 @@ class ParameterServer:
|
|
|
1248
1309
|
return socket, socket_paths
|
|
1249
1310
|
|
|
1250
1311
|
def _detect_bucket_size(
|
|
1251
|
-
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
|
|
1312
|
+
self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
|
|
1252
1313
|
) -> tuple[int, bool]:
|
|
1253
1314
|
GiB = 1 << 30 # noqa: N806
|
|
1254
1315
|
# auto detect bucket size
|
|
@@ -1367,7 +1428,7 @@ class ParameterServer:
|
|
|
1367
1428
|
self,
|
|
1368
1429
|
checkpoint_name: str,
|
|
1369
1430
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1370
|
-
ranks_group: dist.ProcessGroup,
|
|
1431
|
+
ranks_group: dist.ProcessGroup | None,
|
|
1371
1432
|
ranks: list[int] | None = None,
|
|
1372
1433
|
):
|
|
1373
1434
|
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.0rc1
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=Ctme-brbITV9k9eCj361Q_klPsndHOTci7ZqCb_3Wk8,714
|
|
3
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
4
|
+
checkpoint_engine/ps.py,sha256=xGoiy4bfRl_USj9ws9g7yUos0Gw513oouV0QbChQ3rk,70668
|
|
5
|
+
checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
|
|
6
|
+
checkpoint_engine-0.3.0rc1.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
7
|
+
checkpoint_engine-0.3.0rc1.dist-info/METADATA,sha256=1KjhSfes8NyRV7mF6bLmr1uGgNDqTUakS3QduK95OJY,11562
|
|
8
|
+
checkpoint_engine-0.3.0rc1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
checkpoint_engine-0.3.0rc1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
10
|
+
checkpoint_engine-0.3.0rc1.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
-
checkpoint_engine/_version.py,sha256=v0iyeXv9HxMc4JmYu_bJTIGKXRQVfpijACyjq2P_sk0,714
|
|
3
|
-
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
4
|
-
checkpoint_engine/ps.py,sha256=eIvg_eI7HMedacoQQer62NRnGDjANtxsHVxgM93ccXQ,66977
|
|
5
|
-
checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
|
|
6
|
-
checkpoint_engine-0.3.0rc0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
7
|
-
checkpoint_engine-0.3.0rc0.dist-info/METADATA,sha256=iVd2qPdNyTPPX3XIEiuM0ASk8As72zSGfFIYicpZG3E,11562
|
|
8
|
-
checkpoint_engine-0.3.0rc0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
-
checkpoint_engine-0.3.0rc0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
10
|
-
checkpoint_engine-0.3.0rc0.dist-info/RECORD,,
|
|
File without changes
|
{checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.0rc1.dist-info}/licenses/LICENCE
RENAMED
|
File without changes
|
|
File without changes
|