checkpoint-engine 0.3.1rc0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.1rc0'
32
- __version_tuple__ = version_tuple = (0, 3, 1, 'rc0')
31
+ __version__ = version = '0.3.2'
32
+ __version_tuple__ = version_tuple = (0, 3, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -191,6 +191,8 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
191
191
 
192
192
 
193
193
  def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
194
+ device_index = torch.cuda.current_device()
195
+
194
196
  def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
195
197
  """
196
198
  safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
@@ -204,6 +206,7 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
204
206
  Pin the memory of tensor in-place.
205
207
  See: https://github.com/pytorch/pytorch/issues/32167
206
208
  """
209
+ torch.cuda.set_device(device_index)
207
210
  cudart = torch.cuda.cudart()
208
211
  r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
209
212
  assert r == 0, f"pin memory error, error code: {r}"
checkpoint_engine/ps.py CHANGED
@@ -731,6 +731,7 @@ class ParameterServer:
731
731
  assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
732
732
  assert dist.is_initialized(), "process group is not initialized"
733
733
 
734
+ p2p_update = False
734
735
  # if both ranks is None or [], it will use fully broadcast to update to all ranks
735
736
  if not ranks:
736
737
  logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
@@ -739,6 +740,7 @@ class ParameterServer:
739
740
  assert self._p2p_store is not None, "p2p store is not initialized"
740
741
  assert ranks, "ranks should be set"
741
742
 
743
+ p2p_update = True
742
744
  need_update = self._rank in ranks
743
745
  logger.info(
744
746
  f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
@@ -764,11 +766,6 @@ class ParameterServer:
764
766
  if disable_h2d_buffer
765
767
  else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
766
768
  )
767
- # p2p store need to register h2d_buffer to let other ranks read
768
- if ranks:
769
- h2d_buffer_name = "__h2d_buffer__"
770
- if h2d_buffer is not None and self._p2p_store is not None:
771
- self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
772
769
  receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
773
770
  for receiver_rank, owner_rank, bucket in buckets:
774
771
  if receiver_rank != self._rank:
@@ -778,6 +775,12 @@ class ParameterServer:
778
775
  buffer = torch.empty(
779
776
  bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
780
777
  )
778
+ if p2p_update:
779
+ # p2p store need to register buffer to let other ranks read
780
+ p2p_ipc_buffer_name = "__ipc_buffer__"
781
+ self._p2p_store.register_named_tensors(
782
+ {p2p_ipc_buffer_name: buffer if disable_h2d_buffer else h2d_buffer}
783
+ )
781
784
  handle = reduce_tensor(buffer)
782
785
 
783
786
  buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
@@ -823,7 +826,14 @@ class ParameterServer:
823
826
  buffer_b: torch.Tensor = buffer[start : start + bucket.size]
824
827
  if receiver_rank == self._rank:
825
828
  if disable_h2d_buffer:
826
- self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
829
+ if p2p_update:
830
+ assert bucket == receiver_rank_buckets[i][1]
831
+ self._copy_to_buffer(
832
+ checkpoint_name,
833
+ bucket,
834
+ buffer_b,
835
+ receiver_rank_buckets[i][0] if p2p_update else None,
836
+ )
827
837
  else:
828
838
  buffer_b.data.copy_(h2d_buffer[: bucket.size])
829
839
  dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
@@ -850,8 +860,8 @@ class ParameterServer:
850
860
  req_thread.join()
851
861
  dist.barrier(group=ranks_group)
852
862
  socket.close()
853
- if ranks and h2d_buffer is not None:
854
- self._p2p_store.unregister_named_tensors([h2d_buffer_name])
863
+ if p2p_update:
864
+ self._p2p_store.unregister_named_tensors([p2p_ipc_buffer_name])
855
865
 
856
866
  self.device_manager.device_module.empty_cache()
857
867
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.3.1rc0
3
+ Version: 0.3.2
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -0,0 +1,15 @@
1
+ checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
2
+ checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
3
+ checkpoint_engine/_version.py,sha256=e8NqPtZ8fggRgk3GPrqZ_U_BDV8aSULw1u_Gn9NNbnk,704
4
+ checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
5
+ checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
6
+ checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
7
+ checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
8
+ checkpoint_engine/pin_memory.py,sha256=9XgE3Tn4XrEjXvA-XG70OgErDmlBU-cUVDP8ysB_9us,16237
9
+ checkpoint_engine/ps.py,sha256=IJiA2zvZucFzFvnaLCYJMK7FHl2M2Z-g1tlDeoeZ-Rs,40689
10
+ checkpoint_engine/worker.py,sha256=ghj9d2u8hY_U2uiOZWIN2CqRNZH6PrzujT22fHUFBWI,6879
11
+ checkpoint_engine-0.3.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
12
+ checkpoint_engine-0.3.2.dist-info/METADATA,sha256=a2BEqlP0yca80Djg9WZD3IWj0DLPv9hfk6j1pgnZiR0,11559
13
+ checkpoint_engine-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
+ checkpoint_engine-0.3.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
15
+ checkpoint_engine-0.3.2.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
2
- checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
3
- checkpoint_engine/_version.py,sha256=dHrWkv1sAsMQC6tqqCYOmWtbcT9G_s2WFJDo0_AYy_0,714
4
- checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
5
- checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
6
- checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
7
- checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
8
- checkpoint_engine/pin_memory.py,sha256=gpoe_z5XxbWkCvFLaXXpyUUFetBXUjsOrxBSX-ksZTw,16141
9
- checkpoint_engine/ps.py,sha256=0d68Sqb_y3H6b5H37exMbghDJ294VKaGqoWkcKE-Ao8,40316
10
- checkpoint_engine/worker.py,sha256=ghj9d2u8hY_U2uiOZWIN2CqRNZH6PrzujT22fHUFBWI,6879
11
- checkpoint_engine-0.3.1rc0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
12
- checkpoint_engine-0.3.1rc0.dist-info/METADATA,sha256=VCgsnIGn1CcO9-ILevego92QDldqyGn-frzs4weGIwQ,11562
13
- checkpoint_engine-0.3.1rc0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
14
- checkpoint_engine-0.3.1rc0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
15
- checkpoint_engine-0.3.1rc0.dist-info/RECORD,,