checkpoint-engine 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/pin_memory.py +9 -1
- checkpoint_engine/ps.py +12 -2
- checkpoint_engine/worker.py +28 -8
- {checkpoint_engine-0.3.2.dist-info → checkpoint_engine-0.3.4.dist-info}/METADATA +1 -1
- checkpoint_engine-0.3.4.dist-info/RECORD +15 -0
- {checkpoint_engine-0.3.2.dist-info → checkpoint_engine-0.3.4.dist-info}/WHEEL +1 -1
- checkpoint_engine-0.3.2.dist-info/RECORD +0 -15
- {checkpoint_engine-0.3.2.dist-info → checkpoint_engine-0.3.4.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.2.dist-info → checkpoint_engine-0.3.4.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.4'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 4)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
checkpoint_engine/pin_memory.py
CHANGED
|
@@ -209,7 +209,9 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
|
|
|
209
209
|
torch.cuda.set_device(device_index)
|
|
210
210
|
cudart = torch.cuda.cudart()
|
|
211
211
|
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
212
|
-
|
|
212
|
+
if r != 0:
|
|
213
|
+
error_msg = cudart.cudaGetErrorString(r)
|
|
214
|
+
raise RuntimeError(f"pin memory error, error code: {r}, error message: {error_msg}")
|
|
213
215
|
|
|
214
216
|
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
215
217
|
size = os.stat(file_path).st_size
|
|
@@ -254,6 +256,12 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
|
|
|
254
256
|
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
255
257
|
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
256
258
|
os.remove(file_path)
|
|
259
|
+
if not metas:
|
|
260
|
+
# TODO: should we still return this buffer?
|
|
261
|
+
assert buffer.nbytes == 0, f"buffer nbytes {buffer.nbytes} should be 0"
|
|
262
|
+
logger.warning(f"[rank{rank}] no metas found in {file_path}, skip pin memory")
|
|
263
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=[], manually_pinned=False)
|
|
264
|
+
|
|
257
265
|
_pin(buffer)
|
|
258
266
|
logger.info(
|
|
259
267
|
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
checkpoint_engine/ps.py
CHANGED
|
@@ -391,7 +391,11 @@ class ParameterServer:
|
|
|
391
391
|
)
|
|
392
392
|
cudart = torch.cuda.cudart()
|
|
393
393
|
r = cudart.cudaHostUnregister(t.data_ptr())
|
|
394
|
-
|
|
394
|
+
if r != 0:
|
|
395
|
+
error_msg = cudart.cudaGetErrorString(r)
|
|
396
|
+
raise RuntimeError(
|
|
397
|
+
f"unpin memory error, error code: {r}, error message: {error_msg}"
|
|
398
|
+
)
|
|
395
399
|
|
|
396
400
|
# if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
|
|
397
401
|
try:
|
|
@@ -407,7 +411,13 @@ class ParameterServer:
|
|
|
407
411
|
del self._memory_pool[checkpoint_name]
|
|
408
412
|
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
409
413
|
# this works by using torch>=2.5.0
|
|
410
|
-
|
|
414
|
+
if self.device_manager.device_type == "cuda":
|
|
415
|
+
torch._C._host_emptyCache()
|
|
416
|
+
else:
|
|
417
|
+
# torch._C._host_emptyCache() is not supported on NPU, so we call gc.collect() to empty host cache.
|
|
418
|
+
import gc
|
|
419
|
+
|
|
420
|
+
gc.collect()
|
|
411
421
|
|
|
412
422
|
def gather_metas(self, checkpoint_name: str):
|
|
413
423
|
"""
|
checkpoint_engine/worker.py
CHANGED
|
@@ -10,6 +10,9 @@ import zmq
|
|
|
10
10
|
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
|
|
14
|
+
|
|
15
|
+
|
|
13
16
|
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
14
17
|
func, args = handle
|
|
15
18
|
list_args = list(args)
|
|
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
|
|
|
29
32
|
offset: int
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
def _extract_weights(
|
|
33
|
-
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
|
|
34
|
-
) -> list[tuple[str, torch.Tensor]]:
|
|
35
|
+
def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
|
|
35
36
|
assert buffer is not None
|
|
36
|
-
weights:
|
|
37
|
+
weights: _WEIGHTS_TYPE = []
|
|
37
38
|
for item in payload:
|
|
38
39
|
shape = item["shape"]
|
|
39
40
|
if isinstance(shape, list | tuple):
|
|
@@ -166,12 +167,31 @@ class VllmColocateWorkerExtension:
|
|
|
166
167
|
self.device = torch.device(f"npu:{self.local_rank}")
|
|
167
168
|
assert self.device is not None
|
|
168
169
|
|
|
170
|
+
def _load_weights(weights: _WEIGHTS_TYPE):
|
|
171
|
+
# Load main model weights
|
|
172
|
+
self.model_runner.model.load_weights(weights)
|
|
173
|
+
# Load drafter model weights if MTP/speculative decoding is enabled
|
|
174
|
+
if (
|
|
175
|
+
getattr(self.model_runner, "drafter", None) is not None
|
|
176
|
+
and getattr(self.model_runner.drafter, "model", None) is not None
|
|
177
|
+
):
|
|
178
|
+
self.model_runner.drafter.model.load_weights(weights=weights)
|
|
179
|
+
|
|
180
|
+
def _post_hook():
|
|
181
|
+
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
|
|
182
|
+
# Also trigger drafter model's post processing if MTP is enabled
|
|
183
|
+
if (
|
|
184
|
+
getattr(self.model_runner, "drafter", None) is not None
|
|
185
|
+
and getattr(self.model_runner.drafter, "model", None) is not None
|
|
186
|
+
):
|
|
187
|
+
process_weights_after_loading(
|
|
188
|
+
self.model_runner.drafter.model, self.model_config, self.device
|
|
189
|
+
)
|
|
190
|
+
|
|
169
191
|
update_weights_from_ipc(
|
|
170
192
|
self._zmq_ctx,
|
|
171
193
|
zmq_handles[self._device_uuid],
|
|
172
194
|
device_id=self.device.index,
|
|
173
|
-
run=
|
|
174
|
-
post_hook=
|
|
175
|
-
self.model_runner.model, self.model_config, self.device
|
|
176
|
-
),
|
|
195
|
+
run=_load_weights,
|
|
196
|
+
post_hook=_post_hook,
|
|
177
197
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
|
|
2
|
+
checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
|
|
3
|
+
checkpoint_engine/_version.py,sha256=3nDaC5e0d_scBB1bUEKPlItbvbY0PmXNNyyOTNFNWNI,704
|
|
4
|
+
checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
|
|
5
|
+
checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
|
|
6
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
7
|
+
checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
|
|
8
|
+
checkpoint_engine/pin_memory.py,sha256=b7nABKJV2bSIsOfX2YTHzUk1OkOze6AQjCaOIFaQnbA,16708
|
|
9
|
+
checkpoint_engine/ps.py,sha256=wBsHu2qWy5oRBrvLc7aEOroG_j58UJoWT6lFH4ylMRk,41092
|
|
10
|
+
checkpoint_engine/worker.py,sha256=CDWbxwvMpid19yriuwAsyZLUZtqfkh9Lybn8KpiuKCw,7781
|
|
11
|
+
checkpoint_engine-0.3.4.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
12
|
+
checkpoint_engine-0.3.4.dist-info/METADATA,sha256=P23Txz8z5WvM3km3EHFtKBEc5299c5UZcd0UTABN-u8,11559
|
|
13
|
+
checkpoint_engine-0.3.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
14
|
+
checkpoint_engine-0.3.4.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
15
|
+
checkpoint_engine-0.3.4.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
|
|
2
|
-
checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
|
|
3
|
-
checkpoint_engine/_version.py,sha256=e8NqPtZ8fggRgk3GPrqZ_U_BDV8aSULw1u_Gn9NNbnk,704
|
|
4
|
-
checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
|
|
5
|
-
checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
|
|
6
|
-
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
7
|
-
checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
|
|
8
|
-
checkpoint_engine/pin_memory.py,sha256=9XgE3Tn4XrEjXvA-XG70OgErDmlBU-cUVDP8ysB_9us,16237
|
|
9
|
-
checkpoint_engine/ps.py,sha256=IJiA2zvZucFzFvnaLCYJMK7FHl2M2Z-g1tlDeoeZ-Rs,40689
|
|
10
|
-
checkpoint_engine/worker.py,sha256=ghj9d2u8hY_U2uiOZWIN2CqRNZH6PrzujT22fHUFBWI,6879
|
|
11
|
-
checkpoint_engine-0.3.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
12
|
-
checkpoint_engine-0.3.2.dist-info/METADATA,sha256=a2BEqlP0yca80Djg9WZD3IWj0DLPv9hfk6j1pgnZiR0,11559
|
|
13
|
-
checkpoint_engine-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
14
|
-
checkpoint_engine-0.3.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
15
|
-
checkpoint_engine-0.3.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|