checkpoint-engine 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/worker.py +28 -8
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.3.4.dist-info}/METADATA +1 -1
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.3.4.dist-info}/RECORD +7 -7
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.3.4.dist-info}/WHEEL +1 -1
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.3.4.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.3.dist-info → checkpoint_engine-0.3.4.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.4'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 4)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
checkpoint_engine/worker.py
CHANGED
|
@@ -10,6 +10,9 @@ import zmq
|
|
|
10
10
|
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
_WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
|
|
14
|
+
|
|
15
|
+
|
|
13
16
|
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
14
17
|
func, args = handle
|
|
15
18
|
list_args = list(args)
|
|
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
|
|
|
29
32
|
offset: int
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
def _extract_weights(
|
|
33
|
-
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
|
|
34
|
-
) -> list[tuple[str, torch.Tensor]]:
|
|
35
|
+
def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
|
|
35
36
|
assert buffer is not None
|
|
36
|
-
weights:
|
|
37
|
+
weights: _WEIGHTS_TYPE = []
|
|
37
38
|
for item in payload:
|
|
38
39
|
shape = item["shape"]
|
|
39
40
|
if isinstance(shape, list | tuple):
|
|
@@ -166,12 +167,31 @@ class VllmColocateWorkerExtension:
|
|
|
166
167
|
self.device = torch.device(f"npu:{self.local_rank}")
|
|
167
168
|
assert self.device is not None
|
|
168
169
|
|
|
170
|
+
def _load_weights(weights: _WEIGHTS_TYPE):
|
|
171
|
+
# Load main model weights
|
|
172
|
+
self.model_runner.model.load_weights(weights)
|
|
173
|
+
# Load drafter model weights if MTP/speculative decoding is enabled
|
|
174
|
+
if (
|
|
175
|
+
getattr(self.model_runner, "drafter", None) is not None
|
|
176
|
+
and getattr(self.model_runner.drafter, "model", None) is not None
|
|
177
|
+
):
|
|
178
|
+
self.model_runner.drafter.model.load_weights(weights=weights)
|
|
179
|
+
|
|
180
|
+
def _post_hook():
|
|
181
|
+
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
|
|
182
|
+
# Also trigger drafter model's post processing if MTP is enabled
|
|
183
|
+
if (
|
|
184
|
+
getattr(self.model_runner, "drafter", None) is not None
|
|
185
|
+
and getattr(self.model_runner.drafter, "model", None) is not None
|
|
186
|
+
):
|
|
187
|
+
process_weights_after_loading(
|
|
188
|
+
self.model_runner.drafter.model, self.model_config, self.device
|
|
189
|
+
)
|
|
190
|
+
|
|
169
191
|
update_weights_from_ipc(
|
|
170
192
|
self._zmq_ctx,
|
|
171
193
|
zmq_handles[self._device_uuid],
|
|
172
194
|
device_id=self.device.index,
|
|
173
|
-
run=
|
|
174
|
-
post_hook=
|
|
175
|
-
self.model_runner.model, self.model_config, self.device
|
|
176
|
-
),
|
|
195
|
+
run=_load_weights,
|
|
196
|
+
post_hook=_post_hook,
|
|
177
197
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.4
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
|
|
2
2
|
checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
|
|
3
|
-
checkpoint_engine/_version.py,sha256=
|
|
3
|
+
checkpoint_engine/_version.py,sha256=3nDaC5e0d_scBB1bUEKPlItbvbY0PmXNNyyOTNFNWNI,704
|
|
4
4
|
checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
|
|
5
5
|
checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
|
|
6
6
|
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
7
7
|
checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
|
|
8
8
|
checkpoint_engine/pin_memory.py,sha256=b7nABKJV2bSIsOfX2YTHzUk1OkOze6AQjCaOIFaQnbA,16708
|
|
9
9
|
checkpoint_engine/ps.py,sha256=wBsHu2qWy5oRBrvLc7aEOroG_j58UJoWT6lFH4ylMRk,41092
|
|
10
|
-
checkpoint_engine/worker.py,sha256=
|
|
11
|
-
checkpoint_engine-0.3.
|
|
12
|
-
checkpoint_engine-0.3.
|
|
13
|
-
checkpoint_engine-0.3.
|
|
14
|
-
checkpoint_engine-0.3.
|
|
15
|
-
checkpoint_engine-0.3.
|
|
10
|
+
checkpoint_engine/worker.py,sha256=CDWbxwvMpid19yriuwAsyZLUZtqfkh9Lybn8KpiuKCw,7781
|
|
11
|
+
checkpoint_engine-0.3.4.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
12
|
+
checkpoint_engine-0.3.4.dist-info/METADATA,sha256=P23Txz8z5WvM3km3EHFtKBEc5299c5UZcd0UTABN-u8,11559
|
|
13
|
+
checkpoint_engine-0.3.4.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
14
|
+
checkpoint_engine-0.3.4.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
15
|
+
checkpoint_engine-0.3.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|