checkpoint-engine 0.3.1rc0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/pin_memory.py +3 -0
- checkpoint_engine/ps.py +18 -8
- {checkpoint_engine-0.3.1rc0.dist-info → checkpoint_engine-0.3.2.dist-info}/METADATA +1 -1
- checkpoint_engine-0.3.2.dist-info/RECORD +15 -0
- checkpoint_engine-0.3.1rc0.dist-info/RECORD +0 -15
- {checkpoint_engine-0.3.1rc0.dist-info → checkpoint_engine-0.3.2.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.3.1rc0.dist-info → checkpoint_engine-0.3.2.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.1rc0.dist-info → checkpoint_engine-0.3.2.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 2)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
checkpoint_engine/pin_memory.py
CHANGED
|
@@ -191,6 +191,8 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
191
191
|
|
|
192
192
|
|
|
193
193
|
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
194
|
+
device_index = torch.cuda.current_device()
|
|
195
|
+
|
|
194
196
|
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
195
197
|
"""
|
|
196
198
|
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
@@ -204,6 +206,7 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
|
|
|
204
206
|
Pin the memory of tensor in-place.
|
|
205
207
|
See: https://github.com/pytorch/pytorch/issues/32167
|
|
206
208
|
"""
|
|
209
|
+
torch.cuda.set_device(device_index)
|
|
207
210
|
cudart = torch.cuda.cudart()
|
|
208
211
|
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
209
212
|
assert r == 0, f"pin memory error, error code: {r}"
|
checkpoint_engine/ps.py
CHANGED
|
@@ -731,6 +731,7 @@ class ParameterServer:
|
|
|
731
731
|
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
732
732
|
assert dist.is_initialized(), "process group is not initialized"
|
|
733
733
|
|
|
734
|
+
p2p_update = False
|
|
734
735
|
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
735
736
|
if not ranks:
|
|
736
737
|
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
@@ -739,6 +740,7 @@ class ParameterServer:
|
|
|
739
740
|
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
740
741
|
assert ranks, "ranks should be set"
|
|
741
742
|
|
|
743
|
+
p2p_update = True
|
|
742
744
|
need_update = self._rank in ranks
|
|
743
745
|
logger.info(
|
|
744
746
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
|
|
@@ -764,11 +766,6 @@ class ParameterServer:
|
|
|
764
766
|
if disable_h2d_buffer
|
|
765
767
|
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
|
|
766
768
|
)
|
|
767
|
-
# p2p store need to register h2d_buffer to let other ranks read
|
|
768
|
-
if ranks:
|
|
769
|
-
h2d_buffer_name = "__h2d_buffer__"
|
|
770
|
-
if h2d_buffer is not None and self._p2p_store is not None:
|
|
771
|
-
self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
|
|
772
769
|
receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
|
|
773
770
|
for receiver_rank, owner_rank, bucket in buckets:
|
|
774
771
|
if receiver_rank != self._rank:
|
|
@@ -778,6 +775,12 @@ class ParameterServer:
|
|
|
778
775
|
buffer = torch.empty(
|
|
779
776
|
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
|
|
780
777
|
)
|
|
778
|
+
if p2p_update:
|
|
779
|
+
# p2p store need to register buffer to let other ranks read
|
|
780
|
+
p2p_ipc_buffer_name = "__ipc_buffer__"
|
|
781
|
+
self._p2p_store.register_named_tensors(
|
|
782
|
+
{p2p_ipc_buffer_name: buffer if disable_h2d_buffer else h2d_buffer}
|
|
783
|
+
)
|
|
781
784
|
handle = reduce_tensor(buffer)
|
|
782
785
|
|
|
783
786
|
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
|
|
@@ -823,7 +826,14 @@ class ParameterServer:
|
|
|
823
826
|
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
|
|
824
827
|
if receiver_rank == self._rank:
|
|
825
828
|
if disable_h2d_buffer:
|
|
826
|
-
|
|
829
|
+
if p2p_update:
|
|
830
|
+
assert bucket == receiver_rank_buckets[i][1]
|
|
831
|
+
self._copy_to_buffer(
|
|
832
|
+
checkpoint_name,
|
|
833
|
+
bucket,
|
|
834
|
+
buffer_b,
|
|
835
|
+
receiver_rank_buckets[i][0] if p2p_update else None,
|
|
836
|
+
)
|
|
827
837
|
else:
|
|
828
838
|
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
829
839
|
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
|
|
@@ -850,8 +860,8 @@ class ParameterServer:
|
|
|
850
860
|
req_thread.join()
|
|
851
861
|
dist.barrier(group=ranks_group)
|
|
852
862
|
socket.close()
|
|
853
|
-
if
|
|
854
|
-
self._p2p_store.unregister_named_tensors([
|
|
863
|
+
if p2p_update:
|
|
864
|
+
self._p2p_store.unregister_named_tensors([p2p_ipc_buffer_name])
|
|
855
865
|
|
|
856
866
|
self.device_manager.device_module.empty_cache()
|
|
857
867
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
|
|
2
|
+
checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
|
|
3
|
+
checkpoint_engine/_version.py,sha256=e8NqPtZ8fggRgk3GPrqZ_U_BDV8aSULw1u_Gn9NNbnk,704
|
|
4
|
+
checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
|
|
5
|
+
checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
|
|
6
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
7
|
+
checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
|
|
8
|
+
checkpoint_engine/pin_memory.py,sha256=9XgE3Tn4XrEjXvA-XG70OgErDmlBU-cUVDP8ysB_9us,16237
|
|
9
|
+
checkpoint_engine/ps.py,sha256=IJiA2zvZucFzFvnaLCYJMK7FHl2M2Z-g1tlDeoeZ-Rs,40689
|
|
10
|
+
checkpoint_engine/worker.py,sha256=ghj9d2u8hY_U2uiOZWIN2CqRNZH6PrzujT22fHUFBWI,6879
|
|
11
|
+
checkpoint_engine-0.3.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
12
|
+
checkpoint_engine-0.3.2.dist-info/METADATA,sha256=a2BEqlP0yca80Djg9WZD3IWj0DLPv9hfk6j1pgnZiR0,11559
|
|
13
|
+
checkpoint_engine-0.3.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
14
|
+
checkpoint_engine-0.3.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
15
|
+
checkpoint_engine-0.3.2.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=OeWxe9mxl2sZ6cW-blSTg6JbFlOMpGbBghLZtxGOqXk,942
|
|
2
|
-
checkpoint_engine/__main__.py,sha256=yzQlApuYo6eIOqtqM018RosyxNzXzB5a-stxUvsh-dg,709
|
|
3
|
-
checkpoint_engine/_version.py,sha256=dHrWkv1sAsMQC6tqqCYOmWtbcT9G_s2WFJDo0_AYy_0,714
|
|
4
|
-
checkpoint_engine/api.py,sha256=JDiQ4i3Gb6GoaBhlp8lNuUPaVURoFFdeGJY9ZDDGvPc,3518
|
|
5
|
-
checkpoint_engine/data_types.py,sha256=O9uAXjwB20iwrOHfEEQd8Y9CmaFspNJ9ks9noHqwQKk,2716
|
|
6
|
-
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
7
|
-
checkpoint_engine/p2p_store.py,sha256=abiCDVmRISPt9QFfavHB9Jo7ZpBbSjUS1NevGuB-AVA,8721
|
|
8
|
-
checkpoint_engine/pin_memory.py,sha256=gpoe_z5XxbWkCvFLaXXpyUUFetBXUjsOrxBSX-ksZTw,16141
|
|
9
|
-
checkpoint_engine/ps.py,sha256=0d68Sqb_y3H6b5H37exMbghDJ294VKaGqoWkcKE-Ao8,40316
|
|
10
|
-
checkpoint_engine/worker.py,sha256=ghj9d2u8hY_U2uiOZWIN2CqRNZH6PrzujT22fHUFBWI,6879
|
|
11
|
-
checkpoint_engine-0.3.1rc0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
12
|
-
checkpoint_engine-0.3.1rc0.dist-info/METADATA,sha256=VCgsnIGn1CcO9-ILevego92QDldqyGn-frzs4weGIwQ,11562
|
|
13
|
-
checkpoint_engine-0.3.1rc0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
14
|
-
checkpoint_engine-0.3.1rc0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
15
|
-
checkpoint_engine-0.3.1rc0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|