checkpoint-engine 0.3.3__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/_version.py +3 -3
  3. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/worker.py +28 -8
  4. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  5. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/.github/workflows/cpu-tests.yml +0 -0
  6. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/.github/workflows/pre-commit.yaml +0 -0
  7. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/.github/workflows/python-publish.yml +0 -0
  8. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/.gitignore +0 -0
  9. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/.pre-commit-config.yaml +0 -0
  10. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/LICENCE +0 -0
  11. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/README.md +0 -0
  12. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/__init__.py +0 -0
  13. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/__main__.py +0 -0
  14. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/api.py +0 -0
  15. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/data_types.py +0 -0
  16. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/device_utils.py +0 -0
  17. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/p2p_store.py +0 -0
  18. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/pin_memory.py +0 -0
  19. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine/ps.py +0 -0
  20. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/SOURCES.txt +0 -0
  21. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  22. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/requires.txt +0 -0
  23. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/top_level.txt +0 -0
  24. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/docs/npu_start.md +0 -0
  25. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/examples/update.py +0 -0
  26. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/figures/checkpoint-engine.png +0 -0
  27. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/figures/overlap-update-and-copy.png +0 -0
  28. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/figures/pipeline.png +0 -0
  29. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/patches/vllm_fp8.patch +0 -0
  30. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/pyproject.toml +0 -0
  31. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/setup.cfg +0 -0
  32. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/tests/test_assign_receiver_ranks.py +0 -0
  33. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/tests/test_inplace_unpin.py +0 -0
  34. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/tests/test_rdma_parser.py +0 -0
  35. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/tests/test_reuse_pin_memory.py +0 -0
  36. {checkpoint_engine-0.3.3 → checkpoint_engine-0.3.4}/tests/test_update.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.3'
32
- __version_tuple__ = version_tuple = (0, 3, 3)
31
+ __version__ = version = '0.3.4'
32
+ __version_tuple__ = version_tuple = (0, 3, 4)
33
33
 
34
- __commit_id__ = commit_id = 'gf6910d646'
34
+ __commit_id__ = commit_id = 'g15446dd22'
@@ -10,6 +10,9 @@ import zmq
10
10
  from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
11
11
 
12
12
 
13
+ _WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
14
+
15
+
13
16
  def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
14
17
  func, args = handle
15
18
  list_args = list(args)
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
29
32
  offset: int
30
33
 
31
34
 
32
- def _extract_weights(
33
- payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
34
- ) -> list[tuple[str, torch.Tensor]]:
35
+ def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
35
36
  assert buffer is not None
36
- weights: list[tuple[str, torch.Tensor]] = []
37
+ weights: _WEIGHTS_TYPE = []
37
38
  for item in payload:
38
39
  shape = item["shape"]
39
40
  if isinstance(shape, list | tuple):
@@ -166,12 +167,31 @@ class VllmColocateWorkerExtension:
166
167
  self.device = torch.device(f"npu:{self.local_rank}")
167
168
  assert self.device is not None
168
169
 
170
+ def _load_weights(weights: _WEIGHTS_TYPE):
171
+ # Load main model weights
172
+ self.model_runner.model.load_weights(weights)
173
+ # Load drafter model weights if MTP/speculative decoding is enabled
174
+ if (
175
+ getattr(self.model_runner, "drafter", None) is not None
176
+ and getattr(self.model_runner.drafter, "model", None) is not None
177
+ ):
178
+ self.model_runner.drafter.model.load_weights(weights=weights)
179
+
180
+ def _post_hook():
181
+ process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
182
+ # Also trigger drafter model's post processing if MTP is enabled
183
+ if (
184
+ getattr(self.model_runner, "drafter", None) is not None
185
+ and getattr(self.model_runner.drafter, "model", None) is not None
186
+ ):
187
+ process_weights_after_loading(
188
+ self.model_runner.drafter.model, self.model_config, self.device
189
+ )
190
+
169
191
  update_weights_from_ipc(
170
192
  self._zmq_ctx,
171
193
  zmq_handles[self._device_uuid],
172
194
  device_id=self.device.index,
173
- run=self.model_runner.model.load_weights,
174
- post_hook=lambda: process_weights_after_loading(
175
- self.model_runner.model, self.model_config, self.device
176
- ),
195
+ run=_load_weights,
196
+ post_hook=_post_hook,
177
197
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine