checkpoint-engine 0.1.3__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.3'
32
- __version_tuple__ = version_tuple = (0, 1, 3)
31
+ __version__ = version = '0.2.0'
32
+ __version_tuple__ = version_tuple = (0, 2, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
checkpoint_engine/ps.py CHANGED
@@ -25,6 +25,8 @@ from torch.multiprocessing.reductions import reduce_tensor
25
25
 
26
26
 
27
27
  if TYPE_CHECKING:
28
+ from typing import TypeVar
29
+
28
30
  from typing_extensions import TypedDict
29
31
 
30
32
  class FileMeta(TypedDict):
@@ -34,6 +36,8 @@ if TYPE_CHECKING:
34
36
  type: type
35
37
  tp_concat_dim: int
36
38
 
39
+ T = TypeVar("T")
40
+
37
41
 
38
42
  def _dt_validate(value: Any) -> torch.dtype:
39
43
  if isinstance(value, str):
@@ -117,6 +121,7 @@ class MemoryBuffer(BaseModel):
117
121
  class MemoryBufferMetaList(BaseModel):
118
122
  p2p_store_addr: str | None
119
123
  memory_buffer_metas_list: list[MemoryBufferMetas]
124
+ rdma_device: str
120
125
 
121
126
 
122
127
  class DataToGather(MemoryBufferMetaList):
@@ -303,14 +308,7 @@ def _get_rdma_devices() -> list[str]:
303
308
  return devices_str.split(",")
304
309
  # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
305
310
  hca = os.getenv("NCCL_IB_HCA", None)
306
- if hca:
307
- hca_list = hca.split(",")
308
- if len(hca_list) > 1:
309
- # if NCCL_IB_HCA has multiple values, just return
310
- return hca_list
311
- else:
312
- hca = hca_list[0]
313
- return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
311
+ return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
314
312
 
315
313
 
316
314
  def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
@@ -328,6 +326,75 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
328
326
  return devices[local_rank // (gpu_count // len(devices))]
329
327
 
330
328
 
329
+ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
330
+ """
331
+ The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
332
+ The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
333
+
334
+ The list is comma-separated; port numbers are NOT supported yet.
335
+ An optional prefix '^' indicates the list is an exclude list.
336
+ A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
337
+ Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
338
+
339
+ Examples:
340
+ - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
341
+ - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
342
+ - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
343
+ - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
344
+ """
345
+ max_hcas = 32
346
+ if not value or value.strip() == "":
347
+ return available_devices[:max_hcas]
348
+
349
+ value = value.strip()
350
+ result = []
351
+ is_exclude = value.startswith("^")
352
+ if is_exclude:
353
+ value = value.removeprefix("^")
354
+ is_exact_match = value.startswith("=")
355
+ if is_exact_match:
356
+ value = value.removeprefix("=")
357
+
358
+ device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
359
+
360
+ result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
361
+ if is_exclude:
362
+ result = [dev for dev in available_devices if dev not in result]
363
+ if len(result) > max_hcas:
364
+ result = result[:max_hcas]
365
+
366
+ logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
367
+
368
+ return result
369
+
370
+
371
+ def _resolve_device_specs(
372
+ device_specs: list[str], is_exact_match: bool, available_devices: list[str]
373
+ ) -> list[str]:
374
+ devices = set()
375
+ for spec in device_specs:
376
+ parts = spec.split(":", 1)
377
+ device_name = parts[0].strip()
378
+ # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
379
+ # port = parts[1].strip() if len(parts) > 1 else None
380
+ base_devices = (
381
+ [device_name]
382
+ if device_name in available_devices
383
+ else []
384
+ if is_exact_match
385
+ else [dev for dev in available_devices if dev.startswith(device_name)]
386
+ )
387
+
388
+ if not base_devices:
389
+ logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
390
+ continue
391
+
392
+ for base_dev in base_devices:
393
+ devices.add(base_dev)
394
+
395
+ return sorted(devices)
396
+
397
+
331
398
  def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
332
399
  class TPMeta(BaseModel):
333
400
  concat_dim: int
@@ -490,8 +557,12 @@ def request_inference_to_update(
490
557
 
491
558
 
492
559
  def _gen_h2d_buckets(
493
- global_metas: dict[int, MemoryBufferMetaList], bucket_size: int
494
- ) -> list[tuple[int, H2DBucket]]:
560
+ global_metas: dict[int, MemoryBufferMetaList],
561
+ bucket_size: int,
562
+ local_topo: dict[str, set[int]],
563
+ remote_topo: dict[str, set[int]],
564
+ ranks: list[int] | None = None,
565
+ ) -> list[tuple[int, int, H2DBucket]]:
495
566
  buckets: list[tuple[int, H2DBucket]] = []
496
567
 
497
568
  for owner_rank, items in global_metas.items():
@@ -514,7 +585,73 @@ def _gen_h2d_buckets(
514
585
  assert buckets[-1][1].size > 0, (
515
586
  f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
516
587
  )
517
- return buckets
588
+ ranks_set = set(ranks) if ranks else set()
589
+ actual_local_topo = (
590
+ {k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo
591
+ )
592
+ # if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
593
+ if not ranks:
594
+ return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
595
+ else:
596
+ return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
597
+
598
+
599
+ def _assign_receiver_ranks(
600
+ buckets: list[tuple[int, "T"]],
601
+ local_topo: dict[str, set[int]],
602
+ remote_topo: dict[str, set[int]],
603
+ ) -> list[tuple[int, int, "T"]]:
604
+ """
605
+ (owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
606
+
607
+ Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
608
+ GPU-rdma_device topology will be considered to make full use of the bandwidth.
609
+ """
610
+ if not buckets:
611
+ logger.warning("bucket list is empty, no need to assign receiver ranks")
612
+ return []
613
+ rank_to_rdma_device = {
614
+ rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
615
+ }
616
+
617
+ # group buckets by owner RDMA devices
618
+ buckets_by_rdma_device = defaultdict(list)
619
+ for owner_rank, bucket in buckets:
620
+ owner_rdma_device = rank_to_rdma_device[owner_rank]
621
+ buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
622
+
623
+ buckets_matrix = list(buckets_by_rdma_device.values())
624
+ assert buckets_matrix, "buckets_matrix should not be empty"
625
+
626
+ # Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
627
+ num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
628
+ receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
629
+
630
+ flattened_buckets = [
631
+ buckets_matrix[row][col]
632
+ for col in range(
633
+ max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0
634
+ )
635
+ for row in range(len(buckets_matrix))
636
+ if col < len(buckets_matrix[row])
637
+ ]
638
+
639
+ buckets_with_receiver = []
640
+ assigned_cnt = 0
641
+ while assigned_cnt < len(flattened_buckets):
642
+ occupied_devices = set()
643
+ for receiver_rank in receiver_list:
644
+ if assigned_cnt >= len(flattened_buckets):
645
+ break
646
+ owner_rank, bucket = flattened_buckets[assigned_cnt]
647
+ rdma_device = rank_to_rdma_device[owner_rank]
648
+ if rdma_device in occupied_devices:
649
+ break
650
+ buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
651
+ occupied_devices.add(rdma_device)
652
+ assigned_cnt += 1
653
+
654
+ return buckets_with_receiver
518
655
 
519
656
 
520
657
  def _get_master_port(master_port: int | None = None) -> int:
@@ -525,6 +662,20 @@ def _get_master_port(master_port: int | None = None) -> int:
525
662
  return master_port
526
663
 
527
664
 
665
+ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
666
+ """
667
+ map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
668
+ which are generated in self.init_process_group_for_ranks
669
+ """
670
+ bcast_rank_map: dict[int, int] = {}
671
+ if not ranks:
672
+ bcast_rank_map = {r: r for r in range(world_size)}
673
+ else:
674
+ for i, r in enumerate(ranks):
675
+ bcast_rank_map[r] = i
676
+ return bcast_rank_map
677
+
678
+
528
679
  class P2PStore:
529
680
  def __init__(self):
530
681
  from mooncake.engine import TransferEngine
@@ -532,14 +683,14 @@ class P2PStore:
532
683
  self.rank = int(os.getenv("RANK"))
533
684
  gpu_count = torch.cuda.device_count()
534
685
  local_rank = self.rank % gpu_count
535
- device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
686
+ self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
536
687
  self.ip = _get_ip()
537
688
 
538
689
  # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
539
690
  retry_count = 8
540
691
  for i in range(retry_count):
541
692
  self.engine = TransferEngine()
542
- ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", device)
693
+ ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", self.device)
543
694
  if ret == 0:
544
695
  break
545
696
  # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
@@ -553,7 +704,7 @@ class P2PStore:
553
704
  self.port = self.engine.get_rpc_port()
554
705
  self.named_tensors: dict[str, torch.Tensor] = {}
555
706
  logger.info(
556
- f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
707
+ f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
557
708
  )
558
709
 
559
710
  @property
@@ -615,6 +766,8 @@ class ParameterServer:
615
766
  self._auto_pg = auto_pg
616
767
  self._all_hosts = []
617
768
  self._global_device_uuids: list[str] = []
769
+ self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
770
+ self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
618
771
  self._mem_fraction = mem_fraction or 0.9
619
772
 
620
773
  assert self._rank is not None and self._rank >= 0, self._rank
@@ -643,6 +796,7 @@ class ParameterServer:
643
796
  device_index = self._local_rank
644
797
  torch.cuda.set_device(device_index)
645
798
  self._device_uuid = _get_physical_gpu_id(device_index)
799
+ self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
646
800
 
647
801
  def _logger_rank0(self, msg: str):
648
802
  if self._local_rank == 0:
@@ -653,6 +807,13 @@ class ParameterServer:
653
807
 
654
808
  def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
655
809
  self._current_global_parameter_metas = metas
810
+ self._remote_rdma_devices = defaultdict(set)
811
+ for i, meta in self._current_global_parameter_metas.items():
812
+ assert meta.rdma_device is not None, "meta.rdma_device should not be None"
813
+ assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None"
814
+ self._remote_rdma_devices[
815
+ meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
816
+ ].add(i)
656
817
 
657
818
  def register_checkpoint(
658
819
  self,
@@ -726,11 +887,11 @@ class ParameterServer:
726
887
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
727
888
  host_ip=_get_ip(),
728
889
  device_uuid=self._device_uuid,
890
+ rdma_device=self._rdma_device or "",
729
891
  )
730
892
 
731
893
  dist.all_gather_object(metas_lst, metas)
732
894
 
733
- self._current_global_parameter_metas = {}
734
895
  num_parameters = 0
735
896
  all_hosts: list[str] = []
736
897
  global_device_uuids: list[str] = []
@@ -741,12 +902,24 @@ class ParameterServer:
741
902
  if not self._global_device_uuids:
742
903
  global_device_uuids.append(metas_buckets.device_uuid)
743
904
  if metas_buckets.memory_buffer_metas_list:
744
- self._current_global_parameter_metas[i] = metas_buckets
905
+ self._current_global_parameter_metas[i] = MemoryBufferMetaList(
906
+ memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
907
+ p2p_store_addr=metas_buckets.p2p_store_addr,
908
+ rdma_device=metas_buckets.rdma_device,
909
+ )
745
910
  num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
911
+ self._local_rdma_devices[
912
+ metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0]
913
+ if metas_buckets.p2p_store_addr
914
+ else metas_buckets.host_ip
915
+ ].add(i)
746
916
  if not self._all_hosts:
747
917
  self._all_hosts = all_hosts
748
918
  if not self._global_device_uuids:
749
919
  self._global_device_uuids = global_device_uuids
920
+ # Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
921
+ # Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
922
+ self._remote_rdma_devices = self._local_rdma_devices.copy()
750
923
  logger.info(
751
924
  f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
752
925
  )
@@ -801,6 +974,7 @@ class ParameterServer:
801
974
  If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
802
975
  which is useful in disaggregated architecture.
803
976
  """
977
+ assert req_func is not None, "req_func is required"
804
978
  try:
805
979
  # if both ranks is None or [], it will use fully broadcast to update to all ranks
806
980
  if not ranks:
@@ -808,17 +982,15 @@ class ParameterServer:
808
982
  self.init_process_group()
809
983
  self._update_per_bucket(checkpoint_name, req_func)
810
984
  else:
811
- if not self._auto_pg and self._rank not in ranks:
812
- return
813
985
  if self._auto_pg:
814
986
  if dist.is_initialized():
815
987
  dist.destroy_process_group()
816
988
  # HACK: wait 2s to ensure destroy is finished
817
989
  time.sleep(2)
818
- if self._rank not in ranks:
819
- return
820
990
  self.init_process_group_for_ranks(ranks)
821
- self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
991
+ if self._rank not in ranks:
992
+ return
993
+ self._update_per_bucket(checkpoint_name, req_func, ranks)
822
994
  if self._auto_pg:
823
995
  dist.destroy_process_group()
824
996
 
@@ -963,71 +1135,6 @@ class ParameterServer:
963
1135
  backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
964
1136
  )
965
1137
 
966
- def _update_per_bucket_p2p(
967
- self,
968
- checkpoint_name: str,
969
- req_func: Callable[[list[tuple[str, str]]], None],
970
- ranks: list[int],
971
- ):
972
- assert self._p2p_store is not None, "p2p store is not initialized"
973
- assert ranks, "ranks should be set"
974
- if len(self._current_global_parameter_metas) == 0:
975
- raise ValueError("parameter metas is empty")
976
- assert dist.is_initialized(), (
977
- "process group is not initialized when update model per bucket p2p"
978
- )
979
-
980
- need_update = self._rank in ranks
981
- logger.info(
982
- f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
983
- f"gpu_count {self._gpu_count}, world_size {self._world_size}"
984
- )
985
-
986
- if not need_update:
987
- return
988
-
989
- # first execute a barrier to avoid subsequent cuda oom
990
- dist.barrier()
991
-
992
- bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
993
- buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
994
- ipc_buffer_name = "__ipc_buffer___"
995
- self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
996
- logger.info(
997
- f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
998
- )
999
- handle = reduce_tensor(buffer)
1000
-
1001
- buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
1002
- socket, socket_paths = self._bind_zmq_socket()
1003
- req_thread = threading.Thread(
1004
- target=req_func,
1005
- args=(socket_paths,),
1006
- )
1007
- req_thread.start()
1008
- socket.send_pyobj(handle)
1009
- for gidx, (owner_rank, bucket) in enumerate(buckets):
1010
- self._logger_rank0(
1011
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1012
- )
1013
- _buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size]
1014
- if dist.get_rank() == 0:
1015
- self._copy_to_buffer(checkpoint_name, bucket, _buffer, owner_rank)
1016
- # broadcast the collected data to all ranks
1017
- dist.broadcast(_buffer, src=0)
1018
- socket.recv()
1019
- dist.barrier()
1020
- socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1021
-
1022
- socket.recv()
1023
- socket.send_pyobj(None)
1024
- socket.recv()
1025
- req_thread.join()
1026
- dist.barrier()
1027
- socket.close()
1028
- self._p2p_store.unregister_named_tensors([ipc_buffer_name])
1029
- torch.cuda.empty_cache()
1030
-
1031
1138
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
1032
1139
  addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
1033
1140
  metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
@@ -1057,38 +1164,63 @@ class ParameterServer:
1057
1164
  self,
1058
1165
  checkpoint_name: str,
1059
1166
  req_func: Callable[[list[tuple[str, str]]], None],
1167
+ ranks: list[int] | None = None,
1060
1168
  ):
1061
- if len(self._current_global_parameter_metas) == 0:
1062
- raise ValueError("parameter metas is empty")
1063
-
1169
+ assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1064
1170
  assert dist.is_initialized(), "process group is not initialized"
1171
+ # if both ranks is None or [], it will use fully broadcast to update to all ranks
1172
+ if not ranks:
1173
+ logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
1174
+ # if ranks is set, it will use p2p to update to the ranks
1175
+ else:
1176
+ assert self._p2p_store is not None, "p2p store is not initialized"
1177
+ assert ranks, "ranks should be set"
1065
1178
 
1066
- logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
1179
+ need_update = self._rank in ranks
1180
+ logger.info(
1181
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
1182
+ f"gpu_count {self._gpu_count}, world_size {self._world_size}"
1183
+ )
1184
+
1185
+ if not need_update:
1186
+ return
1187
+ # first execute a barrier to avoid subsequent cuda oom
1188
+ dist.barrier()
1067
1189
 
1068
1190
  bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1069
- buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
1191
+ buckets = _gen_h2d_buckets(
1192
+ self._current_global_parameter_metas,
1193
+ bucket_size,
1194
+ self._local_rdma_devices,
1195
+ self._remote_rdma_devices,
1196
+ ranks,
1197
+ )
1070
1198
 
1071
1199
  h2d_buffer: torch.Tensor | None = (
1072
1200
  None
1073
1201
  if disable_h2d_buffer
1074
1202
  else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1075
1203
  )
1076
-
1077
- owner_rank_buckets: list[H2DBucket] = []
1078
- for owner_rank, bucket in buckets:
1079
- if owner_rank != self._rank:
1204
+ # p2p store need to register h2d_buffer to let other ranks read
1205
+ if ranks:
1206
+ h2d_buffer_name = "__h2d_buffer__"
1207
+ if h2d_buffer is not None and self._p2p_store is not None:
1208
+ self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
1209
+ receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
1210
+ for receiver_rank, owner_rank, bucket in buckets:
1211
+ if receiver_rank != self._rank:
1080
1212
  continue
1081
- owner_rank_buckets.append(bucket)
1213
+ receiver_rank_buckets.append((owner_rank, bucket))
1082
1214
 
1083
1215
  buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
1084
1216
  handle = reduce_tensor(buffer)
1085
1217
 
1086
- buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list)
1218
+ buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
1087
1219
  max_len = 0
1088
- for owner_rank, bucket in buckets:
1089
- buckets_by_owner_rank[owner_rank].append(bucket)
1090
- if len(buckets_by_owner_rank[owner_rank]) > max_len:
1091
- max_len = len(buckets_by_owner_rank[owner_rank])
1220
+ for receiver_rank, _, bucket in buckets:
1221
+ buckets_by_receiver_rank[receiver_rank].append(bucket)
1222
+ if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
1223
+ max_len = len(buckets_by_receiver_rank[receiver_rank])
1092
1224
 
1093
1225
  socket, socket_paths = self._bind_zmq_socket()
1094
1226
  req_thread = threading.Thread(
@@ -1099,11 +1231,16 @@ class ParameterServer:
1099
1231
  socket.send_pyobj(handle)
1100
1232
 
1101
1233
  gidx = 0
1234
+ bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1102
1235
  for i in range(max_len):
1103
- if i < len(owner_rank_buckets) and not disable_h2d_buffer:
1104
- self._copy_to_buffer(checkpoint_name, owner_rank_buckets[i], h2d_buffer)
1105
-
1106
- for owner_rank, _buckets in buckets_by_owner_rank.items():
1236
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1237
+ self._copy_to_buffer(
1238
+ checkpoint_name,
1239
+ receiver_rank_buckets[i][1],
1240
+ h2d_buffer,
1241
+ receiver_rank_buckets[i][0] if ranks else None,
1242
+ )
1243
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1107
1244
  if i >= len(_buckets):
1108
1245
  continue
1109
1246
  bucket = _buckets[i]
@@ -1112,18 +1249,19 @@ class ParameterServer:
1112
1249
  torch.cuda.memory_reserved() / 1024 / 1024,
1113
1250
  )
1114
1251
  self._logger_rank0(
1115
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1252
+ f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1116
1253
  f"Current CUDA allocated {alloc:.2f} MB, "
1117
1254
  f"reserved {reserved:.2f} MB."
1118
1255
  )
1119
1256
  start = gidx % 2 * bucket_size
1120
1257
  buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1121
- if owner_rank == self._rank:
1258
+ if receiver_rank == self._rank:
1122
1259
  if disable_h2d_buffer:
1123
1260
  self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1124
1261
  else:
1125
1262
  buffer_b.data.copy_(h2d_buffer[: bucket.size])
1126
- dist.broadcast(buffer_b, src=owner_rank)
1263
+ brank = bcast_rank_map[receiver_rank]
1264
+ dist.broadcast(buffer_b, src=brank)
1127
1265
  socket.recv()
1128
1266
  dist.barrier()
1129
1267
  socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
@@ -1135,6 +1273,9 @@ class ParameterServer:
1135
1273
  req_thread.join()
1136
1274
  dist.barrier()
1137
1275
  socket.close()
1276
+ if ranks and h2d_buffer is not None:
1277
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1278
+
1138
1279
  torch.cuda.empty_cache()
1139
1280
 
1140
1281
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.3
3
+ Version: 0.2.0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -38,8 +38,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
38
38
 
39
39
  The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
40
40
 
41
- - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
42
- - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
41
+ - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
42
+ - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
43
43
 
44
44
  ### Optimized Weight Broadcast
45
45
  In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
@@ -60,16 +60,22 @@ It then executes the transfer, where it controls the inference engine through a
60
60
 
61
61
  Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
62
62
 
63
+ ### Optimized P2P Bucket Assignment
64
+ In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
65
+ To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
66
+ The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
67
+ See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
68
+
63
69
  ## Benchmark
64
70
 
65
71
  | Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
66
72
  | :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
67
- | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.17s | 3.94s (1.42GiB) | 8.83s (4.77GiB) |
68
- | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.46s | 6.75s (2.69GiB) | 16.47s (4.05GiB) |
69
- | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.44s | 12.22s (2.38GiB) | 25.77s (3.61GiB) |
70
- | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.81s | 15.45s (2.93GiB) | 36.24s (4.46GiB) |
71
- | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 1.40s | 13.88s (2.54GiB) | 33.30s (3.86 GiB) |
72
- | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.88s | 21.50s (2.99GiB) | 34.49s (4.57 GiB) |
73
+ | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
74
+ | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
75
+ | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
76
+ | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
77
+ | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
78
+ | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
73
79
 
74
80
  All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
75
81
 
@@ -77,6 +83,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and
77
83
  * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
78
84
  * Since update duration is related to IPC bucket size, we provide the bucket size in the table.
79
85
  * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
86
+ * We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
80
87
 
81
88
  ## Installation
82
89
 
@@ -92,7 +99,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
92
99
  pip install 'checkpoint-engine[p2p]'
93
100
  ```
94
101
 
95
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
102
+ If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
96
103
 
97
104
  ## Getting Started
98
105
 
@@ -165,11 +172,11 @@ Run a simple correctness test for checkpoint_engine
165
172
  torchrun --nproc-per-node 8 tests/test_update.py
166
173
  ```
167
174
 
175
+ Other unit tests can be done with pytest.
168
176
  ## Limitations and Future Work
169
177
 
170
178
  - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
171
179
  - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
172
- - The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
173
180
 
174
181
  ## Acknowledgments
175
182
 
@@ -0,0 +1,9 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=Dg8AmJomLVpjKL6prJylOONZAPRtB86LOce7dorQS_A,704
3
+ checkpoint_engine/ps.py,sha256=OpGocqJv0TfGgVC1cPKARfz6qehfCLMzQ5KpDQNxb0o,55291
4
+ checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
5
+ checkpoint_engine-0.2.0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
6
+ checkpoint_engine-0.2.0.dist-info/METADATA,sha256=tbAq45YlRvRAfQHDB0XV8w4ZP0zmVJ3RMTAx_wTm154,9896
7
+ checkpoint_engine-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ checkpoint_engine-0.2.0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
+ checkpoint_engine-0.2.0.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
3
- checkpoint_engine/ps.py,sha256=9dXRXi0QDPoRYrgGKAYvdmDFBXejgusjR0ltbii5_B0,49134
4
- checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
5
- checkpoint_engine-0.1.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
6
- checkpoint_engine-0.1.3.dist-info/METADATA,sha256=y96dMjEOKWaO_PA0h5BX8G3Ku7Tt1jCU09uIf8iYgic,9322
7
- checkpoint_engine-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- checkpoint_engine-0.1.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
- checkpoint_engine-0.1.3.dist-info/RECORD,,