checkpoint-engine 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/ps.py +272 -116
- {checkpoint_engine-0.1.2.dist-info → checkpoint_engine-0.2.0.dist-info}/METADATA +18 -11
- checkpoint_engine-0.2.0.dist-info/RECORD +9 -0
- checkpoint_engine-0.1.2.dist-info/RECORD +0 -9
- {checkpoint_engine-0.1.2.dist-info → checkpoint_engine-0.2.0.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.1.2.dist-info → checkpoint_engine-0.2.0.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.1.2.dist-info → checkpoint_engine-0.2.0.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.2.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
checkpoint_engine/ps.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
1
|
import argparse
|
|
4
2
|
import concurrent.futures
|
|
5
3
|
import ctypes
|
|
@@ -10,6 +8,7 @@ import socket
|
|
|
10
8
|
import threading
|
|
11
9
|
import time
|
|
12
10
|
from collections import defaultdict
|
|
11
|
+
from collections.abc import Callable
|
|
13
12
|
from datetime import timedelta
|
|
14
13
|
from functools import lru_cache
|
|
15
14
|
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
|
|
@@ -26,7 +25,7 @@ from torch.multiprocessing.reductions import reduce_tensor
|
|
|
26
25
|
|
|
27
26
|
|
|
28
27
|
if TYPE_CHECKING:
|
|
29
|
-
from
|
|
28
|
+
from typing import TypeVar
|
|
30
29
|
|
|
31
30
|
from typing_extensions import TypedDict
|
|
32
31
|
|
|
@@ -37,6 +36,8 @@ if TYPE_CHECKING:
|
|
|
37
36
|
type: type
|
|
38
37
|
tp_concat_dim: int
|
|
39
38
|
|
|
39
|
+
T = TypeVar("T")
|
|
40
|
+
|
|
40
41
|
|
|
41
42
|
def _dt_validate(value: Any) -> torch.dtype:
|
|
42
43
|
if isinstance(value, str):
|
|
@@ -120,6 +121,7 @@ class MemoryBuffer(BaseModel):
|
|
|
120
121
|
class MemoryBufferMetaList(BaseModel):
|
|
121
122
|
p2p_store_addr: str | None
|
|
122
123
|
memory_buffer_metas_list: list[MemoryBufferMetas]
|
|
124
|
+
rdma_device: str
|
|
123
125
|
|
|
124
126
|
|
|
125
127
|
class DataToGather(MemoryBufferMetaList):
|
|
@@ -151,8 +153,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
|
151
153
|
return ret
|
|
152
154
|
|
|
153
155
|
|
|
154
|
-
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
|
|
155
|
-
def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
156
|
+
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
|
|
157
|
+
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
156
158
|
ret = {}
|
|
157
159
|
with safe_open(fn, framework="pt") as f:
|
|
158
160
|
for name in f.keys(): # noqa: SIM118
|
|
@@ -168,7 +170,7 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
|
|
|
168
170
|
return ret
|
|
169
171
|
|
|
170
172
|
# deprecated, will be removed in the future
|
|
171
|
-
def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
|
|
173
|
+
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
172
174
|
"""load *.np file and return memmap and related tensor meta"""
|
|
173
175
|
|
|
174
176
|
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
|
|
@@ -306,14 +308,7 @@ def _get_rdma_devices() -> list[str]:
|
|
|
306
308
|
return devices_str.split(",")
|
|
307
309
|
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
308
310
|
hca = os.getenv("NCCL_IB_HCA", None)
|
|
309
|
-
|
|
310
|
-
hca_list = hca.split(",")
|
|
311
|
-
if len(hca_list) > 1:
|
|
312
|
-
# if NCCL_IB_HCA has multiple values, just return
|
|
313
|
-
return hca_list
|
|
314
|
-
else:
|
|
315
|
-
hca = hca_list[0]
|
|
316
|
-
return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
|
|
311
|
+
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
|
|
317
312
|
|
|
318
313
|
|
|
319
314
|
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
@@ -331,6 +326,75 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
|
|
|
331
326
|
return devices[local_rank // (gpu_count // len(devices))]
|
|
332
327
|
|
|
333
328
|
|
|
329
|
+
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
330
|
+
"""
|
|
331
|
+
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
|
|
332
|
+
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
|
|
333
|
+
|
|
334
|
+
The list is comma-separated; port numbers are NOT supported yet.
|
|
335
|
+
An optional prefix '^' indicates the list is an exclude list.
|
|
336
|
+
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
|
|
337
|
+
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
|
|
338
|
+
|
|
339
|
+
Examples:
|
|
340
|
+
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
|
|
341
|
+
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
|
|
342
|
+
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
|
|
343
|
+
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
|
|
344
|
+
"""
|
|
345
|
+
max_hcas = 32
|
|
346
|
+
if not value or value.strip() == "":
|
|
347
|
+
return available_devices[:max_hcas]
|
|
348
|
+
|
|
349
|
+
value = value.strip()
|
|
350
|
+
result = []
|
|
351
|
+
is_exclude = value.startswith("^")
|
|
352
|
+
if is_exclude:
|
|
353
|
+
value = value.removeprefix("^")
|
|
354
|
+
is_exact_match = value.startswith("=")
|
|
355
|
+
if is_exact_match:
|
|
356
|
+
value = value.removeprefix("=")
|
|
357
|
+
|
|
358
|
+
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
|
|
359
|
+
|
|
360
|
+
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
|
|
361
|
+
if is_exclude:
|
|
362
|
+
result = [dev for dev in available_devices if dev not in result]
|
|
363
|
+
if len(result) > max_hcas:
|
|
364
|
+
result = result[:max_hcas]
|
|
365
|
+
|
|
366
|
+
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
|
|
367
|
+
|
|
368
|
+
return result
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _resolve_device_specs(
|
|
372
|
+
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
|
|
373
|
+
) -> list[str]:
|
|
374
|
+
devices = set()
|
|
375
|
+
for spec in device_specs:
|
|
376
|
+
parts = spec.split(":", 1)
|
|
377
|
+
device_name = parts[0].strip()
|
|
378
|
+
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
|
|
379
|
+
# port = parts[1].strip() if len(parts) > 1 else None
|
|
380
|
+
base_devices = (
|
|
381
|
+
[device_name]
|
|
382
|
+
if device_name in available_devices
|
|
383
|
+
else []
|
|
384
|
+
if is_exact_match
|
|
385
|
+
else [dev for dev in available_devices if dev.startswith(device_name)]
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if not base_devices:
|
|
389
|
+
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
for base_dev in base_devices:
|
|
393
|
+
devices.add(base_dev)
|
|
394
|
+
|
|
395
|
+
return sorted(devices)
|
|
396
|
+
|
|
397
|
+
|
|
334
398
|
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
335
399
|
class TPMeta(BaseModel):
|
|
336
400
|
concat_dim: int
|
|
@@ -493,8 +557,12 @@ def request_inference_to_update(
|
|
|
493
557
|
|
|
494
558
|
|
|
495
559
|
def _gen_h2d_buckets(
|
|
496
|
-
global_metas: dict[int, MemoryBufferMetaList],
|
|
497
|
-
|
|
560
|
+
global_metas: dict[int, MemoryBufferMetaList],
|
|
561
|
+
bucket_size: int,
|
|
562
|
+
local_topo: dict[str, set[int]],
|
|
563
|
+
remote_topo: dict[str, set[int]],
|
|
564
|
+
ranks: list[int] | None = None,
|
|
565
|
+
) -> list[tuple[int, int, H2DBucket]]:
|
|
498
566
|
buckets: list[tuple[int, H2DBucket]] = []
|
|
499
567
|
|
|
500
568
|
for owner_rank, items in global_metas.items():
|
|
@@ -517,7 +585,73 @@ def _gen_h2d_buckets(
|
|
|
517
585
|
assert buckets[-1][1].size > 0, (
|
|
518
586
|
f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
|
|
519
587
|
)
|
|
520
|
-
|
|
588
|
+
ranks_set = set(ranks) if ranks else set()
|
|
589
|
+
actual_local_topo = (
|
|
590
|
+
{k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo
|
|
591
|
+
)
|
|
592
|
+
# if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
|
|
593
|
+
if not ranks:
|
|
594
|
+
return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
|
|
595
|
+
else:
|
|
596
|
+
return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def _assign_receiver_ranks(
|
|
600
|
+
buckets: list[tuple[int, "T"]],
|
|
601
|
+
local_topo: dict[str, set[int]],
|
|
602
|
+
remote_topo: dict[str, set[int]],
|
|
603
|
+
) -> list[tuple[int, int, "T"]]:
|
|
604
|
+
"""
|
|
605
|
+
(owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
|
|
606
|
+
|
|
607
|
+
Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
|
|
608
|
+
GPU-rdma_device topology will be considered to make full use of the bandwidth.
|
|
609
|
+
"""
|
|
610
|
+
if not buckets:
|
|
611
|
+
logger.warning("bucket list is empty, no need to assign receiver ranks")
|
|
612
|
+
return []
|
|
613
|
+
rank_to_rdma_device = {
|
|
614
|
+
rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
# group buckets by owner RDMA devices
|
|
618
|
+
buckets_by_rdma_device = defaultdict(list)
|
|
619
|
+
for owner_rank, bucket in buckets:
|
|
620
|
+
owner_rdma_device = rank_to_rdma_device[owner_rank]
|
|
621
|
+
buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
|
|
622
|
+
|
|
623
|
+
buckets_matrix = list(buckets_by_rdma_device.values())
|
|
624
|
+
assert buckets_matrix, "buckets_matrix should not be empty"
|
|
625
|
+
|
|
626
|
+
# Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
|
|
627
|
+
num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
|
|
628
|
+
receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
|
|
629
|
+
|
|
630
|
+
flattened_buckets = [
|
|
631
|
+
buckets_matrix[row][col]
|
|
632
|
+
for col in range(
|
|
633
|
+
max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0
|
|
634
|
+
)
|
|
635
|
+
for row in range(len(buckets_matrix))
|
|
636
|
+
if col < len(buckets_matrix[row])
|
|
637
|
+
]
|
|
638
|
+
|
|
639
|
+
buckets_with_receiver = []
|
|
640
|
+
assigned_cnt = 0
|
|
641
|
+
while assigned_cnt < len(flattened_buckets):
|
|
642
|
+
occupied_devices = set()
|
|
643
|
+
for receiver_rank in receiver_list:
|
|
644
|
+
if assigned_cnt >= len(flattened_buckets):
|
|
645
|
+
break
|
|
646
|
+
owner_rank, bucket = flattened_buckets[assigned_cnt]
|
|
647
|
+
rdma_device = rank_to_rdma_device[owner_rank]
|
|
648
|
+
if rdma_device in occupied_devices:
|
|
649
|
+
break
|
|
650
|
+
buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
|
|
651
|
+
occupied_devices.add(rdma_device)
|
|
652
|
+
assigned_cnt += 1
|
|
653
|
+
|
|
654
|
+
return buckets_with_receiver
|
|
521
655
|
|
|
522
656
|
|
|
523
657
|
def _get_master_port(master_port: int | None = None) -> int:
|
|
@@ -528,6 +662,20 @@ def _get_master_port(master_port: int | None = None) -> int:
|
|
|
528
662
|
return master_port
|
|
529
663
|
|
|
530
664
|
|
|
665
|
+
def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
|
|
666
|
+
"""
|
|
667
|
+
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
|
|
668
|
+
which are generated in self.init_process_group_for_ranks
|
|
669
|
+
"""
|
|
670
|
+
bcast_rank_map: dict[int, int] = {}
|
|
671
|
+
if not ranks:
|
|
672
|
+
bcast_rank_map = {r: r for r in range(world_size)}
|
|
673
|
+
else:
|
|
674
|
+
for i, r in enumerate(ranks):
|
|
675
|
+
bcast_rank_map[r] = i
|
|
676
|
+
return bcast_rank_map
|
|
677
|
+
|
|
678
|
+
|
|
531
679
|
class P2PStore:
|
|
532
680
|
def __init__(self):
|
|
533
681
|
from mooncake.engine import TransferEngine
|
|
@@ -535,14 +683,14 @@ class P2PStore:
|
|
|
535
683
|
self.rank = int(os.getenv("RANK"))
|
|
536
684
|
gpu_count = torch.cuda.device_count()
|
|
537
685
|
local_rank = self.rank % gpu_count
|
|
538
|
-
device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
686
|
+
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
539
687
|
self.ip = _get_ip()
|
|
540
688
|
|
|
541
689
|
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
542
690
|
retry_count = 8
|
|
543
691
|
for i in range(retry_count):
|
|
544
692
|
self.engine = TransferEngine()
|
|
545
|
-
ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", device)
|
|
693
|
+
ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", self.device)
|
|
546
694
|
if ret == 0:
|
|
547
695
|
break
|
|
548
696
|
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
@@ -556,7 +704,7 @@ class P2PStore:
|
|
|
556
704
|
self.port = self.engine.get_rpc_port()
|
|
557
705
|
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
558
706
|
logger.info(
|
|
559
|
-
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
|
|
707
|
+
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
|
|
560
708
|
)
|
|
561
709
|
|
|
562
710
|
@property
|
|
@@ -595,7 +743,13 @@ class P2PStore:
|
|
|
595
743
|
|
|
596
744
|
class ParameterServer:
|
|
597
745
|
def __init__(
|
|
598
|
-
self,
|
|
746
|
+
self,
|
|
747
|
+
*,
|
|
748
|
+
rank: int | None = None,
|
|
749
|
+
world_size: int | None = None,
|
|
750
|
+
auto_pg: bool = False,
|
|
751
|
+
gpu_count: int | None = None,
|
|
752
|
+
mem_fraction: float | None = None,
|
|
599
753
|
):
|
|
600
754
|
"""
|
|
601
755
|
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
|
|
@@ -603,17 +757,29 @@ class ParameterServer:
|
|
|
603
757
|
Args:
|
|
604
758
|
auto_pg: Whether to automatically initialize the process group.
|
|
605
759
|
Notice that if auto_pg is True, will destroy the process group after update.
|
|
760
|
+
mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
|
|
606
761
|
"""
|
|
607
762
|
self._rank = rank or int(os.environ.get("RANK", None))
|
|
608
763
|
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
|
|
609
|
-
self._gpu_count = torch.cuda.device_count()
|
|
764
|
+
self._gpu_count = gpu_count or torch.cuda.device_count()
|
|
610
765
|
self._local_rank = self._rank % self._gpu_count
|
|
611
766
|
self._auto_pg = auto_pg
|
|
612
767
|
self._all_hosts = []
|
|
613
768
|
self._global_device_uuids: list[str] = []
|
|
769
|
+
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
770
|
+
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
771
|
+
self._mem_fraction = mem_fraction or 0.9
|
|
614
772
|
|
|
615
773
|
assert self._rank is not None and self._rank >= 0, self._rank
|
|
616
774
|
assert self._world_size and self._world_size > 0, self._world_size
|
|
775
|
+
assert (
|
|
776
|
+
self._gpu_count is not None
|
|
777
|
+
and self._gpu_count > 0
|
|
778
|
+
and self._gpu_count <= torch.cuda.device_count()
|
|
779
|
+
), self._gpu_count
|
|
780
|
+
assert (
|
|
781
|
+
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
|
|
782
|
+
), self._mem_fraction
|
|
617
783
|
|
|
618
784
|
self._zmq_ctx = zmq.Context()
|
|
619
785
|
self._zmq_addr_counter = 0
|
|
@@ -630,6 +796,7 @@ class ParameterServer:
|
|
|
630
796
|
device_index = self._local_rank
|
|
631
797
|
torch.cuda.set_device(device_index)
|
|
632
798
|
self._device_uuid = _get_physical_gpu_id(device_index)
|
|
799
|
+
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
633
800
|
|
|
634
801
|
def _logger_rank0(self, msg: str):
|
|
635
802
|
if self._local_rank == 0:
|
|
@@ -640,6 +807,13 @@ class ParameterServer:
|
|
|
640
807
|
|
|
641
808
|
def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
|
|
642
809
|
self._current_global_parameter_metas = metas
|
|
810
|
+
self._remote_rdma_devices = defaultdict(set)
|
|
811
|
+
for i, meta in self._current_global_parameter_metas.items():
|
|
812
|
+
assert meta.rdma_device is not None, "meta.rdma_device should not be None"
|
|
813
|
+
assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None"
|
|
814
|
+
self._remote_rdma_devices[
|
|
815
|
+
meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
|
|
816
|
+
].add(i)
|
|
643
817
|
|
|
644
818
|
def register_checkpoint(
|
|
645
819
|
self,
|
|
@@ -713,11 +887,11 @@ class ParameterServer:
|
|
|
713
887
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
714
888
|
host_ip=_get_ip(),
|
|
715
889
|
device_uuid=self._device_uuid,
|
|
890
|
+
rdma_device=self._rdma_device or "",
|
|
716
891
|
)
|
|
717
892
|
|
|
718
893
|
dist.all_gather_object(metas_lst, metas)
|
|
719
894
|
|
|
720
|
-
self._current_global_parameter_metas = {}
|
|
721
895
|
num_parameters = 0
|
|
722
896
|
all_hosts: list[str] = []
|
|
723
897
|
global_device_uuids: list[str] = []
|
|
@@ -728,12 +902,24 @@ class ParameterServer:
|
|
|
728
902
|
if not self._global_device_uuids:
|
|
729
903
|
global_device_uuids.append(metas_buckets.device_uuid)
|
|
730
904
|
if metas_buckets.memory_buffer_metas_list:
|
|
731
|
-
self._current_global_parameter_metas[i] =
|
|
905
|
+
self._current_global_parameter_metas[i] = MemoryBufferMetaList(
|
|
906
|
+
memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
|
|
907
|
+
p2p_store_addr=metas_buckets.p2p_store_addr,
|
|
908
|
+
rdma_device=metas_buckets.rdma_device,
|
|
909
|
+
)
|
|
732
910
|
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
|
|
911
|
+
self._local_rdma_devices[
|
|
912
|
+
metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0]
|
|
913
|
+
if metas_buckets.p2p_store_addr
|
|
914
|
+
else metas_buckets.host_ip
|
|
915
|
+
].add(i)
|
|
733
916
|
if not self._all_hosts:
|
|
734
917
|
self._all_hosts = all_hosts
|
|
735
918
|
if not self._global_device_uuids:
|
|
736
919
|
self._global_device_uuids = global_device_uuids
|
|
920
|
+
# Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
|
|
921
|
+
# Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
|
|
922
|
+
self._remote_rdma_devices = self._local_rdma_devices.copy()
|
|
737
923
|
logger.info(
|
|
738
924
|
f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
|
|
739
925
|
)
|
|
@@ -788,6 +974,7 @@ class ParameterServer:
|
|
|
788
974
|
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
|
|
789
975
|
which is useful in disaggregated architecture.
|
|
790
976
|
"""
|
|
977
|
+
assert req_func is not None, "req_func is required"
|
|
791
978
|
try:
|
|
792
979
|
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
793
980
|
if not ranks:
|
|
@@ -795,15 +982,15 @@ class ParameterServer:
|
|
|
795
982
|
self.init_process_group()
|
|
796
983
|
self._update_per_bucket(checkpoint_name, req_func)
|
|
797
984
|
else:
|
|
798
|
-
if self._rank not in ranks:
|
|
799
|
-
return
|
|
800
985
|
if self._auto_pg:
|
|
801
986
|
if dist.is_initialized():
|
|
802
987
|
dist.destroy_process_group()
|
|
803
988
|
# HACK: wait 2s to ensure destroy is finished
|
|
804
989
|
time.sleep(2)
|
|
805
990
|
self.init_process_group_for_ranks(ranks)
|
|
806
|
-
self.
|
|
991
|
+
if self._rank not in ranks:
|
|
992
|
+
return
|
|
993
|
+
self._update_per_bucket(checkpoint_name, req_func, ranks)
|
|
807
994
|
if self._auto_pg:
|
|
808
995
|
dist.destroy_process_group()
|
|
809
996
|
|
|
@@ -835,8 +1022,8 @@ class ParameterServer:
|
|
|
835
1022
|
# auto detect bucket size
|
|
836
1023
|
tensor = torch.tensor(
|
|
837
1024
|
[
|
|
838
|
-
#
|
|
839
|
-
int(float(torch.cuda.mem_get_info()[0]) *
|
|
1025
|
+
# proportion of current cuda free memory bytes
|
|
1026
|
+
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
|
|
840
1027
|
# we use negative value to reuse allreduce min operation
|
|
841
1028
|
# for getting the max value of zmq_addr_counter in all ranks
|
|
842
1029
|
-self._zmq_addr_counter,
|
|
@@ -948,71 +1135,6 @@ class ParameterServer:
|
|
|
948
1135
|
backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
|
|
949
1136
|
)
|
|
950
1137
|
|
|
951
|
-
def _update_per_bucket_p2p(
|
|
952
|
-
self,
|
|
953
|
-
checkpoint_name: str,
|
|
954
|
-
req_func: Callable[[list[tuple[str, str]]], None],
|
|
955
|
-
ranks: list[int],
|
|
956
|
-
):
|
|
957
|
-
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
958
|
-
assert ranks, "ranks should be set"
|
|
959
|
-
if len(self._current_global_parameter_metas) == 0:
|
|
960
|
-
raise ValueError("parameter metas is empty")
|
|
961
|
-
assert dist.is_initialized(), (
|
|
962
|
-
"process group is not initialized when update model per bucket p2p"
|
|
963
|
-
)
|
|
964
|
-
|
|
965
|
-
need_update = self._rank in ranks
|
|
966
|
-
logger.info(
|
|
967
|
-
f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
|
|
968
|
-
f"gpu_count {self._gpu_count}, world_size {self._world_size}"
|
|
969
|
-
)
|
|
970
|
-
|
|
971
|
-
if not need_update:
|
|
972
|
-
return
|
|
973
|
-
|
|
974
|
-
# first execute a barrier to avoid subsequent cuda oom
|
|
975
|
-
dist.barrier()
|
|
976
|
-
|
|
977
|
-
bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
|
|
978
|
-
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
|
|
979
|
-
ipc_buffer_name = "__ipc_buffer___"
|
|
980
|
-
self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
|
|
981
|
-
logger.info(
|
|
982
|
-
f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
|
|
983
|
-
)
|
|
984
|
-
handle = reduce_tensor(buffer)
|
|
985
|
-
|
|
986
|
-
buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
|
|
987
|
-
socket, socket_paths = self._bind_zmq_socket()
|
|
988
|
-
req_thread = threading.Thread(
|
|
989
|
-
target=req_func,
|
|
990
|
-
args=(socket_paths,),
|
|
991
|
-
)
|
|
992
|
-
req_thread.start()
|
|
993
|
-
socket.send_pyobj(handle)
|
|
994
|
-
for gidx, (owner_rank, bucket) in enumerate(buckets):
|
|
995
|
-
self._logger_rank0(
|
|
996
|
-
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
997
|
-
)
|
|
998
|
-
_buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size]
|
|
999
|
-
if dist.get_rank() == 0:
|
|
1000
|
-
self._copy_to_buffer(checkpoint_name, bucket, _buffer, owner_rank)
|
|
1001
|
-
# broadcast the collected data to all ranks
|
|
1002
|
-
dist.broadcast(_buffer, src=0)
|
|
1003
|
-
socket.recv()
|
|
1004
|
-
dist.barrier()
|
|
1005
|
-
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
1006
|
-
|
|
1007
|
-
socket.recv()
|
|
1008
|
-
socket.send_pyobj(None)
|
|
1009
|
-
socket.recv()
|
|
1010
|
-
req_thread.join()
|
|
1011
|
-
dist.barrier()
|
|
1012
|
-
socket.close()
|
|
1013
|
-
self._p2p_store.unregister_named_tensors([ipc_buffer_name])
|
|
1014
|
-
torch.cuda.empty_cache()
|
|
1015
|
-
|
|
1016
1138
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
1017
1139
|
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
|
|
1018
1140
|
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
|
|
@@ -1042,38 +1164,63 @@ class ParameterServer:
|
|
|
1042
1164
|
self,
|
|
1043
1165
|
checkpoint_name: str,
|
|
1044
1166
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1167
|
+
ranks: list[int] | None = None,
|
|
1045
1168
|
):
|
|
1046
|
-
|
|
1047
|
-
raise ValueError("parameter metas is empty")
|
|
1048
|
-
|
|
1169
|
+
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
1049
1170
|
assert dist.is_initialized(), "process group is not initialized"
|
|
1171
|
+
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1172
|
+
if not ranks:
|
|
1173
|
+
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
1174
|
+
# if ranks is set, it will use p2p to update to the ranks
|
|
1175
|
+
else:
|
|
1176
|
+
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1177
|
+
assert ranks, "ranks should be set"
|
|
1178
|
+
|
|
1179
|
+
need_update = self._rank in ranks
|
|
1180
|
+
logger.info(
|
|
1181
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
|
|
1182
|
+
f"gpu_count {self._gpu_count}, world_size {self._world_size}"
|
|
1183
|
+
)
|
|
1050
1184
|
|
|
1051
|
-
|
|
1185
|
+
if not need_update:
|
|
1186
|
+
return
|
|
1187
|
+
# first execute a barrier to avoid subsequent cuda oom
|
|
1188
|
+
dist.barrier()
|
|
1052
1189
|
|
|
1053
1190
|
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
1054
|
-
buckets = _gen_h2d_buckets(
|
|
1191
|
+
buckets = _gen_h2d_buckets(
|
|
1192
|
+
self._current_global_parameter_metas,
|
|
1193
|
+
bucket_size,
|
|
1194
|
+
self._local_rdma_devices,
|
|
1195
|
+
self._remote_rdma_devices,
|
|
1196
|
+
ranks,
|
|
1197
|
+
)
|
|
1055
1198
|
|
|
1056
1199
|
h2d_buffer: torch.Tensor | None = (
|
|
1057
1200
|
None
|
|
1058
1201
|
if disable_h2d_buffer
|
|
1059
1202
|
else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
|
|
1060
1203
|
)
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
if
|
|
1204
|
+
# p2p store need to register h2d_buffer to let other ranks read
|
|
1205
|
+
if ranks:
|
|
1206
|
+
h2d_buffer_name = "__h2d_buffer__"
|
|
1207
|
+
if h2d_buffer is not None and self._p2p_store is not None:
|
|
1208
|
+
self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
|
|
1209
|
+
receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
|
|
1210
|
+
for receiver_rank, owner_rank, bucket in buckets:
|
|
1211
|
+
if receiver_rank != self._rank:
|
|
1065
1212
|
continue
|
|
1066
|
-
|
|
1213
|
+
receiver_rank_buckets.append((owner_rank, bucket))
|
|
1067
1214
|
|
|
1068
1215
|
buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
|
|
1069
1216
|
handle = reduce_tensor(buffer)
|
|
1070
1217
|
|
|
1071
|
-
|
|
1218
|
+
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
|
|
1072
1219
|
max_len = 0
|
|
1073
|
-
for
|
|
1074
|
-
|
|
1075
|
-
if len(
|
|
1076
|
-
max_len = len(
|
|
1220
|
+
for receiver_rank, _, bucket in buckets:
|
|
1221
|
+
buckets_by_receiver_rank[receiver_rank].append(bucket)
|
|
1222
|
+
if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
|
|
1223
|
+
max_len = len(buckets_by_receiver_rank[receiver_rank])
|
|
1077
1224
|
|
|
1078
1225
|
socket, socket_paths = self._bind_zmq_socket()
|
|
1079
1226
|
req_thread = threading.Thread(
|
|
@@ -1084,11 +1231,16 @@ class ParameterServer:
|
|
|
1084
1231
|
socket.send_pyobj(handle)
|
|
1085
1232
|
|
|
1086
1233
|
gidx = 0
|
|
1234
|
+
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
|
|
1087
1235
|
for i in range(max_len):
|
|
1088
|
-
if i < len(
|
|
1089
|
-
self._copy_to_buffer(
|
|
1090
|
-
|
|
1091
|
-
|
|
1236
|
+
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
1237
|
+
self._copy_to_buffer(
|
|
1238
|
+
checkpoint_name,
|
|
1239
|
+
receiver_rank_buckets[i][1],
|
|
1240
|
+
h2d_buffer,
|
|
1241
|
+
receiver_rank_buckets[i][0] if ranks else None,
|
|
1242
|
+
)
|
|
1243
|
+
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
|
|
1092
1244
|
if i >= len(_buckets):
|
|
1093
1245
|
continue
|
|
1094
1246
|
bucket = _buckets[i]
|
|
@@ -1097,18 +1249,19 @@ class ParameterServer:
|
|
|
1097
1249
|
torch.cuda.memory_reserved() / 1024 / 1024,
|
|
1098
1250
|
)
|
|
1099
1251
|
self._logger_rank0(
|
|
1100
|
-
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)}
|
|
1252
|
+
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
1101
1253
|
f"Current CUDA allocated {alloc:.2f} MB, "
|
|
1102
1254
|
f"reserved {reserved:.2f} MB."
|
|
1103
1255
|
)
|
|
1104
1256
|
start = gidx % 2 * bucket_size
|
|
1105
1257
|
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
|
|
1106
|
-
if
|
|
1258
|
+
if receiver_rank == self._rank:
|
|
1107
1259
|
if disable_h2d_buffer:
|
|
1108
1260
|
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1109
1261
|
else:
|
|
1110
1262
|
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1111
|
-
|
|
1263
|
+
brank = bcast_rank_map[receiver_rank]
|
|
1264
|
+
dist.broadcast(buffer_b, src=brank)
|
|
1112
1265
|
socket.recv()
|
|
1113
1266
|
dist.barrier()
|
|
1114
1267
|
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
@@ -1120,6 +1273,9 @@ class ParameterServer:
|
|
|
1120
1273
|
req_thread.join()
|
|
1121
1274
|
dist.barrier()
|
|
1122
1275
|
socket.close()
|
|
1276
|
+
if ranks and h2d_buffer is not None:
|
|
1277
|
+
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
1278
|
+
|
|
1123
1279
|
torch.cuda.empty_cache()
|
|
1124
1280
|
|
|
1125
1281
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -38,8 +38,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
|
|
|
38
38
|
|
|
39
39
|
The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
|
|
40
40
|
|
|
41
|
-
- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
|
|
42
|
-
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `
|
|
41
|
+
- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
|
|
42
|
+
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
|
|
43
43
|
|
|
44
44
|
### Optimized Weight Broadcast
|
|
45
45
|
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
|
|
@@ -60,16 +60,22 @@ It then executes the transfer, where it controls the inference engine through a
|
|
|
60
60
|
|
|
61
61
|
Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
|
|
62
62
|
|
|
63
|
+
### Optimized P2P Bucket Assignment
|
|
64
|
+
In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
|
|
65
|
+
To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
|
|
66
|
+
The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
|
|
67
|
+
See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
|
|
68
|
+
|
|
63
69
|
## Benchmark
|
|
64
70
|
|
|
65
71
|
| Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
|
|
66
72
|
| :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
|
|
67
|
-
| GLM-4.5-Air (BF16) | 8xH800 TP8
|
|
68
|
-
| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8
|
|
69
|
-
| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.
|
|
70
|
-
| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.
|
|
71
|
-
| DeepSeek-V3.1 (FP8) | 256xH20 TP16 |
|
|
72
|
-
| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.
|
|
73
|
+
| GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
|
|
74
|
+
| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
|
|
75
|
+
| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
|
|
76
|
+
| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
|
|
77
|
+
| DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
|
|
78
|
+
| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
|
|
73
79
|
|
|
74
80
|
All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
|
|
75
81
|
|
|
@@ -77,6 +83,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and
|
|
|
77
83
|
* Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
|
|
78
84
|
* Since update duration is related to IPC bucket size, we provide the bucket size in the table.
|
|
79
85
|
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
|
|
86
|
+
* We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
|
|
80
87
|
|
|
81
88
|
## Installation
|
|
82
89
|
|
|
@@ -92,7 +99,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
|
|
|
92
99
|
pip install 'checkpoint-engine[p2p]'
|
|
93
100
|
```
|
|
94
101
|
|
|
95
|
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
102
|
+
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
96
103
|
|
|
97
104
|
## Getting Started
|
|
98
105
|
|
|
@@ -165,11 +172,11 @@ Run a simple correctness test for checkpoint_engine
|
|
|
165
172
|
torchrun --nproc-per-node 8 tests/test_update.py
|
|
166
173
|
```
|
|
167
174
|
|
|
175
|
+
Other unit tests can be done with pytest.
|
|
168
176
|
## Limitations and Future Work
|
|
169
177
|
|
|
170
178
|
- This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
|
|
171
179
|
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
|
|
172
|
-
- The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
|
|
173
180
|
|
|
174
181
|
## Acknowledgments
|
|
175
182
|
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=Dg8AmJomLVpjKL6prJylOONZAPRtB86LOce7dorQS_A,704
|
|
3
|
+
checkpoint_engine/ps.py,sha256=OpGocqJv0TfGgVC1cPKARfz6qehfCLMzQ5KpDQNxb0o,55291
|
|
4
|
+
checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
|
|
5
|
+
checkpoint_engine-0.2.0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
6
|
+
checkpoint_engine-0.2.0.dist-info/METADATA,sha256=tbAq45YlRvRAfQHDB0XV8w4ZP0zmVJ3RMTAx_wTm154,9896
|
|
7
|
+
checkpoint_engine-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
+
checkpoint_engine-0.2.0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
+
checkpoint_engine-0.2.0.dist-info/RECORD,,
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
-
checkpoint_engine/_version.py,sha256=Ok5oAXdWgR9aghaFXTafTeDW6sYO3uVe6d2Nket57R4,704
|
|
3
|
-
checkpoint_engine/ps.py,sha256=ckM2vdLg3aeOKmM_vTcbIPKcT-r-E4s73yPaCESKdwg,48439
|
|
4
|
-
checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
|
|
5
|
-
checkpoint_engine-0.1.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
6
|
-
checkpoint_engine-0.1.2.dist-info/METADATA,sha256=9FUb4s1KMSzrDOOV-C18q3gqcVWD_qZ4t_DaLuai_M4,9322
|
|
7
|
-
checkpoint_engine-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
-
checkpoint_engine-0.1.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
-
checkpoint_engine-0.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|