checkpoint-engine 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine/_version.py +3 -3
  3. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine/ps.py +27 -12
  4. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  5. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/.github/workflows/pre-commit.yaml +0 -0
  6. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/.github/workflows/python-publish.yml +0 -0
  7. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/.gitignore +0 -0
  8. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/.pre-commit-config.yaml +0 -0
  9. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/LICENCE +0 -0
  10. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/README.md +0 -0
  11. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine/__init__.py +0 -0
  12. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine/worker.py +0 -0
  13. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine.egg-info/SOURCES.txt +0 -0
  14. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  15. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine.egg-info/requires.txt +0 -0
  16. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/checkpoint_engine.egg-info/top_level.txt +0 -0
  17. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/examples/update.py +0 -0
  18. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/figures/checkpoint-engine.png +0 -0
  19. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/figures/overlap-update-and-copy.png +0 -0
  20. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/figures/pipeline.png +0 -0
  21. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/patches/vllm_fp8.patch +0 -0
  22. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/pyproject.toml +0 -0
  23. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/setup.cfg +0 -0
  24. {checkpoint_engine-0.1.2 → checkpoint_engine-0.1.3}/tests/test_update.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.2'
32
- __version_tuple__ = version_tuple = (0, 1, 2)
31
+ __version__ = version = '0.1.3'
32
+ __version_tuple__ = version_tuple = (0, 1, 3)
33
33
 
34
- __commit_id__ = commit_id = 'g716c0dad9'
34
+ __commit_id__ = commit_id = 'g8a60e65ba'
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  import argparse
4
2
  import concurrent.futures
5
3
  import ctypes
@@ -10,6 +8,7 @@ import socket
10
8
  import threading
11
9
  import time
12
10
  from collections import defaultdict
11
+ from collections.abc import Callable
13
12
  from datetime import timedelta
14
13
  from functools import lru_cache
15
14
  from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
@@ -26,8 +25,6 @@ from torch.multiprocessing.reductions import reduce_tensor
26
25
 
27
26
 
28
27
  if TYPE_CHECKING:
29
- from collections.abc import Callable
30
-
31
28
  from typing_extensions import TypedDict
32
29
 
33
30
  class FileMeta(TypedDict):
@@ -151,8 +148,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
151
148
  return ret
152
149
 
153
150
 
154
- def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
155
- def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
151
+ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
152
+ def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
156
153
  ret = {}
157
154
  with safe_open(fn, framework="pt") as f:
158
155
  for name in f.keys(): # noqa: SIM118
@@ -168,7 +165,7 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
168
165
  return ret
169
166
 
170
167
  # deprecated, will be removed in the future
171
- def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
168
+ def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
172
169
  """load *.np file and return memmap and related tensor meta"""
173
170
 
174
171
  def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
@@ -595,7 +592,13 @@ class P2PStore:
595
592
 
596
593
  class ParameterServer:
597
594
  def __init__(
598
- self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
595
+ self,
596
+ *,
597
+ rank: int | None = None,
598
+ world_size: int | None = None,
599
+ auto_pg: bool = False,
600
+ gpu_count: int | None = None,
601
+ mem_fraction: float | None = None,
599
602
  ):
600
603
  """
601
604
  Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -603,17 +606,27 @@ class ParameterServer:
603
606
  Args:
604
607
  auto_pg: Whether to automatically initialize the process group.
605
608
  Notice that if auto_pg is True, will destroy the process group after update.
609
+ mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
606
610
  """
607
611
  self._rank = rank or int(os.environ.get("RANK", None))
608
612
  self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
609
- self._gpu_count = torch.cuda.device_count()
613
+ self._gpu_count = gpu_count or torch.cuda.device_count()
610
614
  self._local_rank = self._rank % self._gpu_count
611
615
  self._auto_pg = auto_pg
612
616
  self._all_hosts = []
613
617
  self._global_device_uuids: list[str] = []
618
+ self._mem_fraction = mem_fraction or 0.9
614
619
 
615
620
  assert self._rank is not None and self._rank >= 0, self._rank
616
621
  assert self._world_size and self._world_size > 0, self._world_size
622
+ assert (
623
+ self._gpu_count is not None
624
+ and self._gpu_count > 0
625
+ and self._gpu_count <= torch.cuda.device_count()
626
+ ), self._gpu_count
627
+ assert (
628
+ self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
629
+ ), self._mem_fraction
617
630
 
618
631
  self._zmq_ctx = zmq.Context()
619
632
  self._zmq_addr_counter = 0
@@ -795,13 +808,15 @@ class ParameterServer:
795
808
  self.init_process_group()
796
809
  self._update_per_bucket(checkpoint_name, req_func)
797
810
  else:
798
- if self._rank not in ranks:
811
+ if not self._auto_pg and self._rank not in ranks:
799
812
  return
800
813
  if self._auto_pg:
801
814
  if dist.is_initialized():
802
815
  dist.destroy_process_group()
803
816
  # HACK: wait 2s to ensure destroy is finished
804
817
  time.sleep(2)
818
+ if self._rank not in ranks:
819
+ return
805
820
  self.init_process_group_for_ranks(ranks)
806
821
  self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
807
822
  if self._auto_pg:
@@ -835,8 +850,8 @@ class ParameterServer:
835
850
  # auto detect bucket size
836
851
  tensor = torch.tensor(
837
852
  [
838
- # 90% of current cuda free memory bytes
839
- int(float(torch.cuda.mem_get_info()[0]) * 0.9),
853
+ # proportion of current cuda free memory bytes
854
+ int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
840
855
  # we use negative value to reuse allreduce min operation
841
856
  # for getting the max value of zmq_addr_counter in all ranks
842
857
  -self._zmq_addr_counter,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine