checkpoint-engine 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/_version.py +3 -3
  3. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/pin_memory.py +9 -1
  4. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/ps.py +12 -2
  5. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/worker.py +28 -8
  6. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  7. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/.github/workflows/cpu-tests.yml +0 -0
  8. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/.github/workflows/pre-commit.yaml +0 -0
  9. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/.github/workflows/python-publish.yml +0 -0
  10. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/.gitignore +0 -0
  11. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/.pre-commit-config.yaml +0 -0
  12. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/LICENCE +0 -0
  13. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/README.md +0 -0
  14. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/__init__.py +0 -0
  15. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/__main__.py +0 -0
  16. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/api.py +0 -0
  17. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/data_types.py +0 -0
  18. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/device_utils.py +0 -0
  19. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine/p2p_store.py +0 -0
  20. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/SOURCES.txt +0 -0
  21. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  22. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/requires.txt +0 -0
  23. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/checkpoint_engine.egg-info/top_level.txt +0 -0
  24. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/docs/npu_start.md +0 -0
  25. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/examples/update.py +0 -0
  26. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/figures/checkpoint-engine.png +0 -0
  27. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/figures/overlap-update-and-copy.png +0 -0
  28. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/figures/pipeline.png +0 -0
  29. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/patches/vllm_fp8.patch +0 -0
  30. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/pyproject.toml +0 -0
  31. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/setup.cfg +0 -0
  32. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/tests/test_assign_receiver_ranks.py +0 -0
  33. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/tests/test_inplace_unpin.py +0 -0
  34. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/tests/test_rdma_parser.py +0 -0
  35. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/tests/test_reuse_pin_memory.py +0 -0
  36. {checkpoint_engine-0.3.2 → checkpoint_engine-0.3.4}/tests/test_update.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.2'
32
- __version_tuple__ = version_tuple = (0, 3, 2)
31
+ __version__ = version = '0.3.4'
32
+ __version_tuple__ = version_tuple = (0, 3, 4)
33
33
 
34
- __commit_id__ = commit_id = 'g4a73109a3'
34
+ __commit_id__ = commit_id = 'g15446dd22'
@@ -209,7 +209,9 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
209
209
  torch.cuda.set_device(device_index)
210
210
  cudart = torch.cuda.cudart()
211
211
  r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
212
- assert r == 0, f"pin memory error, error code: {r}"
212
+ if r != 0:
213
+ error_msg = cudart.cudaGetErrorString(r)
214
+ raise RuntimeError(f"pin memory error, error code: {r}, error message: {error_msg}")
213
215
 
214
216
  # TODO: should only support /dev/shm? but we found files in disk also work?
215
217
  size = os.stat(file_path).st_size
@@ -254,6 +256,12 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
254
256
  # Remove the file after successfully loading. This will avoid doubling the memory usage.
255
257
  # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
256
258
  os.remove(file_path)
259
+ if not metas:
260
+ # TODO: should we still return this buffer?
261
+ assert buffer.nbytes == 0, f"buffer nbytes {buffer.nbytes} should be 0"
262
+ logger.warning(f"[rank{rank}] no metas found in {file_path}, skip pin memory")
263
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=[], manually_pinned=False)
264
+
257
265
  _pin(buffer)
258
266
  logger.info(
259
267
  f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
@@ -391,7 +391,11 @@ class ParameterServer:
391
391
  )
392
392
  cudart = torch.cuda.cudart()
393
393
  r = cudart.cudaHostUnregister(t.data_ptr())
394
- assert r == 0, f"unpin memory error, error code: {r}"
394
+ if r != 0:
395
+ error_msg = cudart.cudaGetErrorString(r)
396
+ raise RuntimeError(
397
+ f"unpin memory error, error code: {r}, error message: {error_msg}"
398
+ )
395
399
 
396
400
  # if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
397
401
  try:
@@ -407,7 +411,13 @@ class ParameterServer:
407
411
  del self._memory_pool[checkpoint_name]
408
412
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
409
413
  # this works by using torch>=2.5.0
410
- torch._C._host_emptyCache()
414
+ if self.device_manager.device_type == "cuda":
415
+ torch._C._host_emptyCache()
416
+ else:
417
+ # torch._C._host_emptyCache() is not supported on NPU, so we call gc.collect() to empty host cache.
418
+ import gc
419
+
420
+ gc.collect()
411
421
 
412
422
  def gather_metas(self, checkpoint_name: str):
413
423
  """
@@ -10,6 +10,9 @@ import zmq
10
10
  from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
11
11
 
12
12
 
13
+ _WEIGHTS_TYPE = list[tuple[str, torch.Tensor]]
14
+
15
+
13
16
  def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
14
17
  func, args = handle
15
18
  list_args = list(args)
@@ -29,11 +32,9 @@ class FlattenedTensorMetadata(TypedDict):
29
32
  offset: int
30
33
 
31
34
 
32
- def _extract_weights(
33
- payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
34
- ) -> list[tuple[str, torch.Tensor]]:
35
+ def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> _WEIGHTS_TYPE:
35
36
  assert buffer is not None
36
- weights: list[tuple[str, torch.Tensor]] = []
37
+ weights: _WEIGHTS_TYPE = []
37
38
  for item in payload:
38
39
  shape = item["shape"]
39
40
  if isinstance(shape, list | tuple):
@@ -166,12 +167,31 @@ class VllmColocateWorkerExtension:
166
167
  self.device = torch.device(f"npu:{self.local_rank}")
167
168
  assert self.device is not None
168
169
 
170
+ def _load_weights(weights: _WEIGHTS_TYPE):
171
+ # Load main model weights
172
+ self.model_runner.model.load_weights(weights)
173
+ # Load drafter model weights if MTP/speculative decoding is enabled
174
+ if (
175
+ getattr(self.model_runner, "drafter", None) is not None
176
+ and getattr(self.model_runner.drafter, "model", None) is not None
177
+ ):
178
+ self.model_runner.drafter.model.load_weights(weights=weights)
179
+
180
+ def _post_hook():
181
+ process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
182
+ # Also trigger drafter model's post processing if MTP is enabled
183
+ if (
184
+ getattr(self.model_runner, "drafter", None) is not None
185
+ and getattr(self.model_runner.drafter, "model", None) is not None
186
+ ):
187
+ process_weights_after_loading(
188
+ self.model_runner.drafter.model, self.model_config, self.device
189
+ )
190
+
169
191
  update_weights_from_ipc(
170
192
  self._zmq_ctx,
171
193
  zmq_handles[self._device_uuid],
172
194
  device_id=self.device.index,
173
- run=self.model_runner.model.load_weights,
174
- post_hook=lambda: process_weights_after_loading(
175
- self.model_runner.model, self.model_config, self.device
176
- ),
195
+ run=_load_weights,
196
+ post_hook=_post_hook,
177
197
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine