checkpoint-engine 0.2.3__tar.gz → 0.3.0rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/_version.py +3 -3
  3. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/ps.py +119 -92
  4. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  5. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/SOURCES.txt +2 -1
  6. checkpoint_engine-0.3.0rc1/tests/test_inplace_unpin.py +81 -0
  7. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/tests/test_update.py +7 -3
  8. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.github/workflows/cpu-tests.yml +0 -0
  9. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.github/workflows/pre-commit.yaml +0 -0
  10. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.github/workflows/python-publish.yml +0 -0
  11. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.gitignore +0 -0
  12. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.pre-commit-config.yaml +0 -0
  13. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/LICENCE +0 -0
  14. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/README.md +0 -0
  15. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/__init__.py +0 -0
  16. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/device_utils.py +0 -0
  17. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/worker.py +0 -0
  18. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  19. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/requires.txt +0 -0
  20. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/top_level.txt +0 -0
  21. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/docs/npu_start.md +0 -0
  22. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/examples/update.py +0 -0
  23. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/figures/checkpoint-engine.png +0 -0
  24. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/figures/overlap-update-and-copy.png +0 -0
  25. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/figures/pipeline.png +0 -0
  26. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/patches/vllm_fp8.patch +0 -0
  27. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/pyproject.toml +0 -0
  28. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/setup.cfg +0 -0
  29. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/tests/test_assign_receiver_ranks.py +0 -0
  30. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/tests/test_rdma_parser.py +0 -0
  31. /checkpoint_engine-0.2.3/tests/test_pin_memory.py → /checkpoint_engine-0.3.0rc1/tests/test_reuse_pin_memory.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.3
3
+ Version: 0.3.0rc1
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.3'
32
- __version_tuple__ = version_tuple = (0, 2, 3)
31
+ __version__ = version = '0.3.0rc1'
32
+ __version_tuple__ = version_tuple = (0, 3, 0, 'rc1')
33
33
 
34
- __commit_id__ = commit_id = 'g0a6244951'
34
+ __commit_id__ = commit_id = 'g88370e267'
@@ -118,6 +118,7 @@ class MemoryBuffer(BaseModel):
118
118
  buffer: _TorchTensor
119
119
  size: int
120
120
  metas: list[ParameterMeta]
121
+ manually_pinned: bool = False
121
122
 
122
123
 
123
124
  class MemoryBufferMetaList(BaseModel):
@@ -520,7 +521,7 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
520
521
  logger.info(
521
522
  f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
523
  )
523
- return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
524
525
 
525
526
  memory_buffers: list[MemoryBuffer] = []
526
527
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
@@ -792,20 +793,6 @@ def _get_master_port(master_port: int | None = None) -> int:
792
793
  return master_port
793
794
 
794
795
 
795
- def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
796
- """
797
- map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
798
- which are generated in self.init_process_group_for_ranks
799
- """
800
- bcast_rank_map: dict[int, int] = {}
801
- if not ranks:
802
- bcast_rank_map = {r: r for r in range(world_size)}
803
- else:
804
- for i, r in enumerate(ranks):
805
- bcast_rank_map[r] = i
806
- return bcast_rank_map
807
-
808
-
809
796
  class P2PStore:
810
797
  def __init__(self, device_manager: DeviceManager):
811
798
  from mooncake.engine import TransferEngine
@@ -888,7 +875,7 @@ class ParameterServer:
888
875
  *,
889
876
  rank: int | None = None,
890
877
  world_size: int | None = None,
891
- auto_pg: bool = False,
878
+ auto_pg: bool = True,
892
879
  gpu_count: int | None = None,
893
880
  mem_fraction: float | None = None,
894
881
  ):
@@ -897,7 +884,7 @@ class ParameterServer:
897
884
 
898
885
  Args:
899
886
  auto_pg: Whether to automatically initialize the process group.
900
- Notice that if auto_pg is True, will destroy the process group after update.
887
+ Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
901
888
  mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
902
889
  """
903
890
  self._rank = rank or int(os.environ.get("RANK", None))
@@ -979,12 +966,12 @@ class ParameterServer:
979
966
  files: list[str] | None = None,
980
967
  named_tensors: dict[str, torch.Tensor] | None = None,
981
968
  use_shared_memory_pool: bool = False,
982
- use_inplace_pin_memory: bool = False,
969
+ use_inplace_pin_memory: bool = True,
983
970
  ) -> None:
984
971
  """
985
972
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
986
973
  Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
987
- Please make sure to copy the files to disks if you need to keep them.
974
+ Please make sure to copy the files to disks if you need to keep them. NPU does not support inplace pin memory.
988
975
 
989
976
  Args:
990
977
  checkpoint_name: The name of the checkpoint.
@@ -995,9 +982,14 @@ class ParameterServer:
995
982
  cannot accommodate checkpoints with different memory requirements.
996
983
  To free the actual memory of the shared pool or to modify its shape,
997
984
  please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
998
- use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
999
- Currently, this feature is experimental and may crash.
985
+ use_inplace_pin_memory: If True (default), allows inplace pin memory for /dev/shm/ safetensors files.
986
+ This option is ignored when ``use_shared_memory_pool`` is True.
1000
987
  """
988
+ if self.device_manager.device_type != "cuda" and use_inplace_pin_memory:
989
+ logger.warning(
990
+ f"[rank{self._rank}] Only cuda devices support in-place pin memory, set use_inplace_pin_memory to False"
991
+ )
992
+ use_inplace_pin_memory = False
1001
993
  try:
1002
994
  if use_shared_memory_pool:
1003
995
  logger.info(
@@ -1016,6 +1008,7 @@ class ParameterServer:
1016
1008
  named_tensors=named_tensors or {},
1017
1009
  rank=self._rank,
1018
1010
  shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
1011
+ inplace_pin=False, # inplace pin memory is not compatible with shared memory pool
1019
1012
  )
1020
1013
  self._current_shared_memory_pool_user = checkpoint_name
1021
1014
  if self._p2p_store is not None and _is_first_time:
@@ -1074,6 +1067,46 @@ class ParameterServer:
1074
1067
  del self._memory_pool[self.shared_memory_pool_name]
1075
1068
  self._memory_pool[self.shared_memory_pool_name] = []
1076
1069
  else:
1070
+
1071
+ def _unpin(t: torch.Tensor):
1072
+ """
1073
+ Un-pin the pinned memory.
1074
+ """
1075
+ p_flags = ctypes.c_uint()
1076
+ try:
1077
+ libc = ctypes.CDLL(None) # get all symbols from the current process
1078
+ cuda_host_get_flags = libc.cudaHostGetFlags
1079
+ cuda_host_get_flags.argtypes = [ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p]
1080
+ cuda_host_get_flags.restype = ctypes.c_int
1081
+ except AttributeError:
1082
+ logger.error("cudaHostGetFlags not found in libc, cannot unpin memory manually")
1083
+ raise
1084
+ r = cuda_host_get_flags(ctypes.byref(p_flags), ctypes.c_void_p(t.data_ptr()))
1085
+ assert r == 0, f"get pin flags error, error code: {r}"
1086
+ # p_flags value meaning from cuda/include/driver_types.h
1087
+ # cudaHostRegisterDefault 0x00 /**< Default host memory registration flag */
1088
+ # cudaHostRegisterPortable 0x01 /**< Pinned memory accessible by all CUDA contexts */
1089
+ # cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */
1090
+ # cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */
1091
+ # cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */
1092
+ assert p_flags.value == 0x02, (
1093
+ f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
1094
+ )
1095
+ cudart = torch.cuda.cudart()
1096
+ r = cudart.cudaHostUnregister(t.data_ptr())
1097
+ assert r == 0, f"unpin memory error, error code: {r}"
1098
+
1099
+ # if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
1100
+ try:
1101
+ for memory_buffer in self._memory_pool.get(checkpoint_name, []):
1102
+ if memory_buffer.manually_pinned:
1103
+ _unpin(memory_buffer.buffer)
1104
+ except Exception as e:
1105
+ logger.error(
1106
+ f"[rank{self._rank}] fail to unpin memory for checkpoint {checkpoint_name}: {e}"
1107
+ )
1108
+ raise
1109
+ # we won't delete the memory pool if unpinning fails.
1077
1110
  del self._memory_pool[checkpoint_name]
1078
1111
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
1079
1112
  # this works by using torch>=2.5.0
@@ -1176,15 +1209,41 @@ class ParameterServer:
1176
1209
  )
1177
1210
  logger.info(f"[rank{self._rank}] init process group successfully.")
1178
1211
 
1212
+ def store_based_barrier(
1213
+ self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
1214
+ ) -> None:
1215
+ """
1216
+ Perform a store-based barrier synchronization across all ranks.
1217
+
1218
+ This barrier uses a TCP store directly rather than a process group,
1219
+ allowing all ranks to synchronize regardless of which process group
1220
+ they belong to.
1221
+
1222
+ Args:
1223
+ store: The TCPStore instance to use for synchronization.
1224
+ """
1225
+ dist.distributed_c10d._store_based_barrier(
1226
+ rank=self._rank,
1227
+ store=store,
1228
+ group_name="parameter_server_barrier",
1229
+ rendezvous_count=self._world_size,
1230
+ timeout=timeout,
1231
+ )
1232
+
1179
1233
  def update(
1180
1234
  self,
1181
1235
  checkpoint_name: str,
1182
1236
  req_func: Callable[[list[tuple[str, str]]], None],
1183
1237
  *,
1238
+ timeout: timedelta = timedelta(minutes=10),
1184
1239
  ranks: list[int] | None = None,
1240
+ master_addr: str | None = None,
1241
+ master_port: int | None = None,
1185
1242
  ) -> None:
1186
1243
  """
1187
1244
  Update the checkpoint to inference engine. This function should be called after gather_metas.
1245
+ Warning: if _auto_pg is False when initializing ParameterServer, please make sure ALL ranks in the WORLD_SIZE call `update` function,
1246
+ otherwise, it will hang.
1188
1247
 
1189
1248
  Args:
1190
1249
  checkpoint_name: The name of the checkpoint.
@@ -1193,34 +1252,45 @@ class ParameterServer:
1193
1252
  which is the fastest way to update weights, especially in colocated architecture.
1194
1253
  If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
1195
1254
  which is useful in disaggregated architecture.
1255
+ master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1256
+ master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1257
+ timeout: The timeout of the barrier operation.
1196
1258
  """
1197
1259
  assert req_func is not None, "req_func is required"
1260
+ ranks_group = None
1198
1261
  try:
1199
- # if both ranks is None or [], it will use fully broadcast to update to all ranks
1200
- if not ranks:
1201
- if self._auto_pg and not dist.is_initialized():
1202
- self.init_process_group()
1203
- self._update_per_bucket(checkpoint_name, req_func)
1262
+ master_addr = os.getenv("MASTER_ADDR") or master_addr
1263
+ assert master_addr, "master_addr is required"
1264
+ if self._auto_pg:
1265
+ if not dist.is_initialized():
1266
+ self.init_process_group(
1267
+ timeout=timeout, master_addr=master_addr, master_port=master_port
1268
+ )
1269
+ manager_store = dist.distributed_c10d._get_default_store()
1204
1270
  else:
1205
- if self._auto_pg:
1206
- if dist.is_initialized():
1207
- dist.destroy_process_group()
1208
- # HACK: wait 2s to ensure destroy is finished
1209
- time.sleep(2)
1210
- self.init_process_group_for_ranks(ranks)
1211
- if self._rank not in ranks:
1212
- return
1213
- self._update_per_bucket(checkpoint_name, req_func, ranks)
1214
-
1271
+ # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1272
+ # If master_port is provided, use master_port+1 for barrier store
1273
+ manager_store = dist.TCPStore(
1274
+ master_addr,
1275
+ _get_master_port(master_port) + 1,
1276
+ self._world_size,
1277
+ timeout=timeout,
1278
+ is_master=self._rank == 0,
1279
+ )
1280
+ # if ranks is None or [], it will use fully broadcast to update to all ranks
1281
+ ranks_group = dist.new_group(ranks) if ranks else None
1282
+ self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1283
+ self.store_based_barrier(manager_store)
1215
1284
  except Exception as e:
1216
1285
  logger.exception(
1217
1286
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
1218
1287
  )
1219
1288
  raise
1220
1289
  finally:
1221
- if self._auto_pg and (not ranks or self._rank in ranks):
1290
+ if ranks_group:
1291
+ dist.destroy_process_group(ranks_group)
1292
+ if self._auto_pg and dist.is_initialized():
1222
1293
  dist.destroy_process_group()
1223
-
1224
1294
  self.device_manager.device_module.empty_cache()
1225
1295
  logger.info(
1226
1296
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
@@ -1238,7 +1308,9 @@ class ParameterServer:
1238
1308
  self._zmq_addr_counter += 1
1239
1309
  return socket, socket_paths
1240
1310
 
1241
- def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
1311
+ def _detect_bucket_size(
1312
+ self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
1313
+ ) -> tuple[int, bool]:
1242
1314
  GiB = 1 << 30 # noqa: N806
1243
1315
  # auto detect bucket size
1244
1316
  tensor = torch.tensor(
@@ -1254,7 +1326,7 @@ class ParameterServer:
1254
1326
  dtype=torch.int64,
1255
1327
  device=self.device_manager.device_type,
1256
1328
  )
1257
- dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
1329
+ dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
1258
1330
  tensor = tensor.cpu()
1259
1331
  free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
1260
1332
  max_tensor_bytes = 0
@@ -1317,51 +1389,6 @@ class ParameterServer:
1317
1389
  self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
1318
1390
  self.device_manager.device_module.synchronize()
1319
1391
 
1320
- def init_process_group_for_ranks(
1321
- self,
1322
- ranks: list[int],
1323
- *,
1324
- master_port: int | None = None,
1325
- timeout: timedelta = timedelta(minutes=10),
1326
- ):
1327
- """
1328
- Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1329
-
1330
- Args:
1331
- ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
1332
- master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1333
- timeout: The timeout of the process group.
1334
- """
1335
- assert not dist.is_initialized()
1336
- assert ranks, "ranks should be set"
1337
- if self._rank not in ranks:
1338
- return
1339
- assert self._all_hosts, "all_hosts should be set"
1340
- assert len(self._all_hosts) == self._world_size // self._gpu_count, (
1341
- f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
1342
- )
1343
- rank = ranks.index(self._rank)
1344
- master_addr = self._all_hosts[ranks[0] // self._gpu_count]
1345
- master_port = _get_master_port(master_port)
1346
- logger.info(
1347
- f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
1348
- f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
1349
- )
1350
- # only initialize process group and store for ranks, other nodes are not initialized
1351
- # and will not participate in this update. Since they have registered memory addresses
1352
- # to p2p_store at the beginning, update ranks can directly get the memory addresses
1353
- # from other nodes and put the weights into the buffer.
1354
- store = dist.TCPStore(
1355
- master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
1356
- )
1357
- dist.init_process_group(
1358
- backend=self.device_manager.backend,
1359
- world_size=len(ranks),
1360
- rank=rank,
1361
- timeout=timeout,
1362
- store=store,
1363
- )
1364
-
1365
1392
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
1366
1393
  addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
1367
1394
  metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
@@ -1401,10 +1428,12 @@ class ParameterServer:
1401
1428
  self,
1402
1429
  checkpoint_name: str,
1403
1430
  req_func: Callable[[list[tuple[str, str]]], None],
1431
+ ranks_group: dist.ProcessGroup | None,
1404
1432
  ranks: list[int] | None = None,
1405
1433
  ):
1406
1434
  assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1407
1435
  assert dist.is_initialized(), "process group is not initialized"
1436
+
1408
1437
  # if both ranks is None or [], it will use fully broadcast to update to all ranks
1409
1438
  if not ranks:
1410
1439
  logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
@@ -1422,9 +1451,9 @@ class ParameterServer:
1422
1451
  if not need_update:
1423
1452
  return
1424
1453
  # first execute a barrier to avoid subsequent device oom
1425
- dist.barrier()
1454
+ dist.barrier(group=ranks_group)
1426
1455
 
1427
- bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1456
+ bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
1428
1457
  buckets = _gen_h2d_buckets(
1429
1458
  self._current_global_parameter_metas,
1430
1459
  bucket_size,
@@ -1471,7 +1500,6 @@ class ParameterServer:
1471
1500
 
1472
1501
  gidx = 0
1473
1502
  ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1474
- bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1475
1503
  try:
1476
1504
  for i in range(max_len):
1477
1505
  if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
@@ -1501,8 +1529,7 @@ class ParameterServer:
1501
1529
  self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1502
1530
  else:
1503
1531
  buffer_b.data.copy_(h2d_buffer[: bucket.size])
1504
- brank = bcast_rank_map[receiver_rank]
1505
- dist.broadcast(buffer_b, src=brank)
1532
+ dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
1506
1533
  resp = socket.recv()
1507
1534
  if resp != b"":
1508
1535
  msg = resp.decode("utf-8")
@@ -1510,7 +1537,7 @@ class ParameterServer:
1510
1537
  f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1511
1538
  )
1512
1539
  ret_code.fill_(1)
1513
- dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1540
+ dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
1514
1541
  self.device_manager.device_module.synchronize()
1515
1542
  if ret_code.item() != 0:
1516
1543
  # quit early if any rank failed
@@ -1524,7 +1551,7 @@ class ParameterServer:
1524
1551
  socket.recv()
1525
1552
  finally:
1526
1553
  req_thread.join()
1527
- dist.barrier()
1554
+ dist.barrier(group=ranks_group)
1528
1555
  socket.close()
1529
1556
  if ranks and h2d_buffer is not None:
1530
1557
  self._p2p_store.unregister_named_tensors([h2d_buffer_name])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.3
3
+ Version: 0.3.0rc1
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -23,6 +23,7 @@ figures/overlap-update-and-copy.png
23
23
  figures/pipeline.png
24
24
  patches/vllm_fp8.patch
25
25
  tests/test_assign_receiver_ranks.py
26
- tests/test_pin_memory.py
26
+ tests/test_inplace_unpin.py
27
27
  tests/test_rdma_parser.py
28
+ tests/test_reuse_pin_memory.py
28
29
  tests/test_update.py
@@ -0,0 +1,81 @@
1
+ import os
2
+ import subprocess
3
+ import time
4
+
5
+ import pytest
6
+ import torch.distributed as dist
7
+ from test_update import device_manager, gen_test_tensors, get_world_size
8
+
9
+ from checkpoint_engine.ps import ParameterServer
10
+
11
+
12
+ dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
13
+
14
+
15
+ def get_files() -> list[str]:
16
+ rank = int(os.getenv("RANK"))
17
+ named_tensors = dict(gen_test_tensors(rank))
18
+ import safetensors.torch
19
+
20
+ files = []
21
+ os.makedirs(dev_shm_dir, exist_ok=True)
22
+ tensors_in_dev_shm = named_tensors
23
+ time.sleep(1)
24
+ dev_shm_files = [
25
+ os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
26
+ for _ in range(get_world_size())
27
+ ]
28
+ safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
29
+ time.sleep(1)
30
+ files.append(dev_shm_files[rank])
31
+ return files
32
+
33
+
34
+ def run_pin_and_unpin(num_runs: int):
35
+ ps = ParameterServer(auto_pg=True)
36
+ checkpoint_name = "test_with_files"
37
+ for _ in range(num_runs):
38
+ files = get_files()
39
+ ps.register_checkpoint(checkpoint_name, files=files)
40
+ ps.gather_metas(checkpoint_name)
41
+ dist.barrier()
42
+ ps.unregister_checkpoint(checkpoint_name)
43
+ if ps._rank == 0:
44
+ import shutil
45
+
46
+ shutil.rmtree(dev_shm_dir)
47
+
48
+ dist.destroy_process_group()
49
+
50
+
51
+ @pytest.mark.gpu
52
+ def test_unpin_files():
53
+ world_size = device_manager.device_module.device_count()
54
+ assert world_size >= 2, "This test requires at least 2 GPUs."
55
+ master_addr = "localhost"
56
+ master_port = 25400
57
+ cmd = [
58
+ "torchrun",
59
+ "--nproc_per_node",
60
+ str(world_size),
61
+ "--master_addr",
62
+ master_addr,
63
+ "--master_port",
64
+ str(master_port),
65
+ __file__,
66
+ ]
67
+
68
+ result = subprocess.run( # noqa: S603
69
+ cmd,
70
+ capture_output=False,
71
+ text=True,
72
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
73
+ shell=False,
74
+ check=False,
75
+ )
76
+
77
+ assert result.returncode == 0
78
+
79
+
80
+ if __name__ == "__main__":
81
+ run_pin_and_unpin(3)
@@ -185,7 +185,6 @@ def run_with_files(
185
185
  os.makedirs(dev_shm_dir, exist_ok=True)
186
186
  os.makedirs(disk_dir, exist_ok=True)
187
187
  tensors_items = list(named_tensors.items())
188
- tensors_in_dev_shm = named_tensors
189
188
  tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
190
189
  tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
191
190
  tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
@@ -218,7 +217,6 @@ def run_with_files(
218
217
  if rank == 0:
219
218
  import shutil
220
219
 
221
- # this test should be run under use_inplace_pin_memory=False. Otherwise, the files in /dev/shm/ will be deleted.
222
220
  shutil.rmtree(dev_shm_dir)
223
221
  shutil.rmtree(disk_dir)
224
222
  assert proc.exitcode == 0
@@ -238,7 +236,13 @@ def run_with_files(
238
236
  ],
239
237
  ),
240
238
  ("test_with_remote_error", [[]]),
241
- # ("long_test_no_error", [list(random.sample(range(get_world_size()), k=num_ranks)) for num_ranks in range(get_world_size() + 1)]),
239
+ (
240
+ "test_no_error",
241
+ [
242
+ list(random.sample(range(get_world_size()), k=num_ranks))
243
+ for num_ranks in range(get_world_size() + 1)
244
+ ],
245
+ ),
242
246
  ],
243
247
  )
244
248
  def test_update(test_name: str, rank_list: list[list[int]] | None):