checkpoint-engine 0.2.3__tar.gz → 0.3.0rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/_version.py +3 -3
  3. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/ps.py +71 -105
  4. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  5. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_update.py +8 -3
  6. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.github/workflows/cpu-tests.yml +0 -0
  7. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.github/workflows/pre-commit.yaml +0 -0
  8. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.github/workflows/python-publish.yml +0 -0
  9. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.gitignore +0 -0
  10. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.pre-commit-config.yaml +0 -0
  11. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/LICENCE +0 -0
  12. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/README.md +0 -0
  13. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/__init__.py +0 -0
  14. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/device_utils.py +0 -0
  15. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/worker.py +0 -0
  16. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/SOURCES.txt +0 -0
  17. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  18. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/requires.txt +0 -0
  19. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/top_level.txt +0 -0
  20. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/docs/npu_start.md +0 -0
  21. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/examples/update.py +0 -0
  22. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/figures/checkpoint-engine.png +0 -0
  23. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/figures/overlap-update-and-copy.png +0 -0
  24. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/figures/pipeline.png +0 -0
  25. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/patches/vllm_fp8.patch +0 -0
  26. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/pyproject.toml +0 -0
  27. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/setup.cfg +0 -0
  28. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_assign_receiver_ranks.py +0 -0
  29. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_pin_memory.py +0 -0
  30. {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_rdma_parser.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.3
3
+ Version: 0.3.0rc0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.3'
32
- __version_tuple__ = version_tuple = (0, 2, 3)
31
+ __version__ = version = '0.3.0rc0'
32
+ __version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
33
33
 
34
- __commit_id__ = commit_id = 'g0a6244951'
34
+ __commit_id__ = commit_id = 'gbaf6f6196'
@@ -622,7 +622,6 @@ def _register_checkpoint(
622
622
  named_tensors: dict[str, torch.Tensor],
623
623
  rank: int | None = None,
624
624
  shared_pin_memory: list[MemoryBuffer] | None = None,
625
- inplace_pin: bool = False,
626
625
  ) -> list[MemoryBuffer]:
627
626
  logger.info(
628
627
  f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
@@ -630,17 +629,12 @@ def _register_checkpoint(
630
629
  if not files and not named_tensors:
631
630
  return []
632
631
  memory_buffers: list[MemoryBuffer] = []
633
- if inplace_pin:
634
- logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
635
- files_to_inplace_pin = [
636
- file
637
- for file in files
638
- if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
639
- ]
640
- files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
641
- else:
642
- files_to_normal_pin = files
643
- files_to_inplace_pin = []
632
+ files_to_inplace_pin = [
633
+ file
634
+ for file in files
635
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636
+ ]
637
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
644
638
  if files_to_normal_pin or named_tensors:
645
639
  memory_buffers.extend(
646
640
  _normal_pin_memory(
@@ -792,20 +786,6 @@ def _get_master_port(master_port: int | None = None) -> int:
792
786
  return master_port
793
787
 
794
788
 
795
- def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
796
- """
797
- map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
798
- which are generated in self.init_process_group_for_ranks
799
- """
800
- bcast_rank_map: dict[int, int] = {}
801
- if not ranks:
802
- bcast_rank_map = {r: r for r in range(world_size)}
803
- else:
804
- for i, r in enumerate(ranks):
805
- bcast_rank_map[r] = i
806
- return bcast_rank_map
807
-
808
-
809
789
  class P2PStore:
810
790
  def __init__(self, device_manager: DeviceManager):
811
791
  from mooncake.engine import TransferEngine
@@ -979,11 +959,10 @@ class ParameterServer:
979
959
  files: list[str] | None = None,
980
960
  named_tensors: dict[str, torch.Tensor] | None = None,
981
961
  use_shared_memory_pool: bool = False,
982
- use_inplace_pin_memory: bool = False,
983
962
  ) -> None:
984
963
  """
985
964
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
986
- Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
965
+ Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
987
966
  Please make sure to copy the files to disks if you need to keep them.
988
967
 
989
968
  Args:
@@ -995,8 +974,6 @@ class ParameterServer:
995
974
  cannot accommodate checkpoints with different memory requirements.
996
975
  To free the actual memory of the shared pool or to modify its shape,
997
976
  please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
998
- use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
999
- Currently, this feature is experimental and may crash.
1000
977
  """
1001
978
  try:
1002
979
  if use_shared_memory_pool:
@@ -1025,10 +1002,7 @@ class ParameterServer:
1025
1002
  f"checkpoint {checkpoint_name} already registered"
1026
1003
  )
1027
1004
  self._memory_pool[checkpoint_name] = _register_checkpoint(
1028
- files=files or [],
1029
- named_tensors=named_tensors or {},
1030
- rank=self._rank,
1031
- inplace_pin=use_inplace_pin_memory,
1005
+ files=files or [], named_tensors=named_tensors or {}, rank=self._rank
1032
1006
  )
1033
1007
  if self._p2p_store is not None:
1034
1008
  self._register_parameters_to_p2p_store(checkpoint_name)
@@ -1176,12 +1150,36 @@ class ParameterServer:
1176
1150
  )
1177
1151
  logger.info(f"[rank{self._rank}] init process group successfully.")
1178
1152
 
1153
+ def store_based_barrier(
1154
+ self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
1155
+ ) -> None:
1156
+ """
1157
+ Perform a store-based barrier synchronization across all ranks.
1158
+
1159
+ This barrier uses a TCP store directly rather than a process group,
1160
+ allowing all ranks to synchronize regardless of which process group
1161
+ they belong to.
1162
+
1163
+ Args:
1164
+ store: The TCPStore instance to use for synchronization.
1165
+ """
1166
+ dist.distributed_c10d._store_based_barrier(
1167
+ rank=self._rank,
1168
+ store=store,
1169
+ group_name="parameter_server_barrier",
1170
+ rendezvous_count=self._world_size,
1171
+ timeout=timeout,
1172
+ )
1173
+
1179
1174
  def update(
1180
1175
  self,
1181
1176
  checkpoint_name: str,
1182
1177
  req_func: Callable[[list[tuple[str, str]]], None],
1183
1178
  *,
1179
+ timeout: timedelta = timedelta(minutes=10),
1184
1180
  ranks: list[int] | None = None,
1181
+ master_addr: str | None = None,
1182
+ master_port: int | None = None,
1185
1183
  ) -> None:
1186
1184
  """
1187
1185
  Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -1193,34 +1191,45 @@ class ParameterServer:
1193
1191
  which is the fastest way to update weights, especially in colocated architecture.
1194
1192
  If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
1195
1193
  which is useful in disaggregated architecture.
1194
+ master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1195
+ master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1196
+ timeout: The timeout of the barrier operation.
1196
1197
  """
1197
1198
  assert req_func is not None, "req_func is required"
1199
+ ranks_group = None
1198
1200
  try:
1199
- # if both ranks is None or [], it will use fully broadcast to update to all ranks
1200
- if not ranks:
1201
- if self._auto_pg and not dist.is_initialized():
1202
- self.init_process_group()
1203
- self._update_per_bucket(checkpoint_name, req_func)
1201
+ master_addr = os.getenv("MASTER_ADDR") or master_addr
1202
+ assert master_addr, "master_addr is required"
1203
+ if self._auto_pg:
1204
+ if not dist.is_initialized():
1205
+ self.init_process_group(
1206
+ timeout=timeout, master_addr=master_addr, master_port=master_port
1207
+ )
1208
+ manager_store = dist.distributed_c10d._get_default_store()
1204
1209
  else:
1205
- if self._auto_pg:
1206
- if dist.is_initialized():
1207
- dist.destroy_process_group()
1208
- # HACK: wait 2s to ensure destroy is finished
1209
- time.sleep(2)
1210
- self.init_process_group_for_ranks(ranks)
1211
- if self._rank not in ranks:
1212
- return
1213
- self._update_per_bucket(checkpoint_name, req_func, ranks)
1214
-
1210
+ # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1211
+ # If master_port is provided, use master_port+1 for barrier store
1212
+ manager_store = dist.TCPStore(
1213
+ master_addr,
1214
+ _get_master_port(master_port) + 1,
1215
+ self._world_size,
1216
+ timeout=timeout,
1217
+ is_master=self._rank == 0,
1218
+ )
1219
+ # if ranks is None or [], it will use fully broadcast to update to all ranks
1220
+ ranks_group = dist.new_group(ranks if ranks else None)
1221
+ self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1222
+ self.store_based_barrier(manager_store)
1215
1223
  except Exception as e:
1216
1224
  logger.exception(
1217
1225
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
1218
1226
  )
1219
1227
  raise
1220
1228
  finally:
1221
- if self._auto_pg and (not ranks or self._rank in ranks):
1229
+ if ranks_group:
1230
+ dist.destroy_process_group(ranks_group)
1231
+ if self._auto_pg and dist.is_initialized():
1222
1232
  dist.destroy_process_group()
1223
-
1224
1233
  self.device_manager.device_module.empty_cache()
1225
1234
  logger.info(
1226
1235
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
@@ -1238,7 +1247,9 @@ class ParameterServer:
1238
1247
  self._zmq_addr_counter += 1
1239
1248
  return socket, socket_paths
1240
1249
 
1241
- def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
1250
+ def _detect_bucket_size(
1251
+ self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
1252
+ ) -> tuple[int, bool]:
1242
1253
  GiB = 1 << 30 # noqa: N806
1243
1254
  # auto detect bucket size
1244
1255
  tensor = torch.tensor(
@@ -1254,7 +1265,7 @@ class ParameterServer:
1254
1265
  dtype=torch.int64,
1255
1266
  device=self.device_manager.device_type,
1256
1267
  )
1257
- dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
1268
+ dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
1258
1269
  tensor = tensor.cpu()
1259
1270
  free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
1260
1271
  max_tensor_bytes = 0
@@ -1317,51 +1328,6 @@ class ParameterServer:
1317
1328
  self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
1318
1329
  self.device_manager.device_module.synchronize()
1319
1330
 
1320
- def init_process_group_for_ranks(
1321
- self,
1322
- ranks: list[int],
1323
- *,
1324
- master_port: int | None = None,
1325
- timeout: timedelta = timedelta(minutes=10),
1326
- ):
1327
- """
1328
- Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1329
-
1330
- Args:
1331
- ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
1332
- master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1333
- timeout: The timeout of the process group.
1334
- """
1335
- assert not dist.is_initialized()
1336
- assert ranks, "ranks should be set"
1337
- if self._rank not in ranks:
1338
- return
1339
- assert self._all_hosts, "all_hosts should be set"
1340
- assert len(self._all_hosts) == self._world_size // self._gpu_count, (
1341
- f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
1342
- )
1343
- rank = ranks.index(self._rank)
1344
- master_addr = self._all_hosts[ranks[0] // self._gpu_count]
1345
- master_port = _get_master_port(master_port)
1346
- logger.info(
1347
- f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
1348
- f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
1349
- )
1350
- # only initialize process group and store for ranks, other nodes are not initialized
1351
- # and will not participate in this update. Since they have registered memory addresses
1352
- # to p2p_store at the beginning, update ranks can directly get the memory addresses
1353
- # from other nodes and put the weights into the buffer.
1354
- store = dist.TCPStore(
1355
- master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
1356
- )
1357
- dist.init_process_group(
1358
- backend=self.device_manager.backend,
1359
- world_size=len(ranks),
1360
- rank=rank,
1361
- timeout=timeout,
1362
- store=store,
1363
- )
1364
-
1365
1331
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
1366
1332
  addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
1367
1333
  metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
@@ -1401,10 +1367,12 @@ class ParameterServer:
1401
1367
  self,
1402
1368
  checkpoint_name: str,
1403
1369
  req_func: Callable[[list[tuple[str, str]]], None],
1370
+ ranks_group: dist.ProcessGroup,
1404
1371
  ranks: list[int] | None = None,
1405
1372
  ):
1406
1373
  assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1407
1374
  assert dist.is_initialized(), "process group is not initialized"
1375
+
1408
1376
  # if both ranks is None or [], it will use fully broadcast to update to all ranks
1409
1377
  if not ranks:
1410
1378
  logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
@@ -1422,9 +1390,9 @@ class ParameterServer:
1422
1390
  if not need_update:
1423
1391
  return
1424
1392
  # first execute a barrier to avoid subsequent device oom
1425
- dist.barrier()
1393
+ dist.barrier(group=ranks_group)
1426
1394
 
1427
- bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1395
+ bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
1428
1396
  buckets = _gen_h2d_buckets(
1429
1397
  self._current_global_parameter_metas,
1430
1398
  bucket_size,
@@ -1471,7 +1439,6 @@ class ParameterServer:
1471
1439
 
1472
1440
  gidx = 0
1473
1441
  ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1474
- bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1475
1442
  try:
1476
1443
  for i in range(max_len):
1477
1444
  if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
@@ -1501,8 +1468,7 @@ class ParameterServer:
1501
1468
  self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1502
1469
  else:
1503
1470
  buffer_b.data.copy_(h2d_buffer[: bucket.size])
1504
- brank = bcast_rank_map[receiver_rank]
1505
- dist.broadcast(buffer_b, src=brank)
1471
+ dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
1506
1472
  resp = socket.recv()
1507
1473
  if resp != b"":
1508
1474
  msg = resp.decode("utf-8")
@@ -1510,7 +1476,7 @@ class ParameterServer:
1510
1476
  f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1511
1477
  )
1512
1478
  ret_code.fill_(1)
1513
- dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1479
+ dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
1514
1480
  self.device_manager.device_module.synchronize()
1515
1481
  if ret_code.item() != 0:
1516
1482
  # quit early if any rank failed
@@ -1524,7 +1490,7 @@ class ParameterServer:
1524
1490
  socket.recv()
1525
1491
  finally:
1526
1492
  req_thread.join()
1527
- dist.barrier()
1493
+ dist.barrier(group=ranks_group)
1528
1494
  socket.close()
1529
1495
  if ranks and h2d_buffer is not None:
1530
1496
  self._p2p_store.unregister_named_tensors([h2d_buffer_name])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.3
3
+ Version: 0.3.0rc0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -218,8 +218,7 @@ def run_with_files(
218
218
  if rank == 0:
219
219
  import shutil
220
220
 
221
- # this test should be run under use_inplace_pin_memory=False. Otherwise, the files in /dev/shm/ will be deleted.
222
- shutil.rmtree(dev_shm_dir)
221
+ os.removedirs(dev_shm_dir)
223
222
  shutil.rmtree(disk_dir)
224
223
  assert proc.exitcode == 0
225
224
 
@@ -238,7 +237,13 @@ def run_with_files(
238
237
  ],
239
238
  ),
240
239
  ("test_with_remote_error", [[]]),
241
- # ("long_test_no_error", [list(random.sample(range(get_world_size()), k=num_ranks)) for num_ranks in range(get_world_size() + 1)]),
240
+ (
241
+ "test_no_error",
242
+ [
243
+ list(random.sample(range(get_world_size()), k=num_ranks))
244
+ for num_ranks in range(get_world_size() + 1)
245
+ ],
246
+ ),
242
247
  ],
243
248
  )
244
249
  def test_update(test_name: str, rank_list: list[list[int]] | None):