checkpoint-engine 0.2.2__tar.gz → 0.3.0rc0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/PKG-INFO +1 -1
  2. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/_version.py +3 -3
  3. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/ps.py +63 -85
  4. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/PKG-INFO +1 -1
  5. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_update.py +7 -1
  6. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.github/workflows/cpu-tests.yml +0 -0
  7. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.github/workflows/pre-commit.yaml +0 -0
  8. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.github/workflows/python-publish.yml +0 -0
  9. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.gitignore +0 -0
  10. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.pre-commit-config.yaml +0 -0
  11. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/LICENCE +0 -0
  12. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/README.md +0 -0
  13. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/__init__.py +0 -0
  14. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/device_utils.py +0 -0
  15. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/worker.py +0 -0
  16. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/SOURCES.txt +0 -0
  17. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  18. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/requires.txt +0 -0
  19. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/top_level.txt +0 -0
  20. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/docs/npu_start.md +0 -0
  21. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/examples/update.py +0 -0
  22. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/figures/checkpoint-engine.png +0 -0
  23. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/figures/overlap-update-and-copy.png +0 -0
  24. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/figures/pipeline.png +0 -0
  25. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/patches/vllm_fp8.patch +0 -0
  26. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/pyproject.toml +0 -0
  27. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/setup.cfg +0 -0
  28. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_assign_receiver_ranks.py +0 -0
  29. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_pin_memory.py +0 -0
  30. {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_rdma_parser.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.2
3
+ Version: 0.3.0rc0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.2'
32
- __version_tuple__ = version_tuple = (0, 2, 2)
31
+ __version__ = version = '0.3.0rc0'
32
+ __version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
33
33
 
34
- __commit_id__ = commit_id = 'g089d18598'
34
+ __commit_id__ = commit_id = 'gbaf6f6196'
@@ -786,20 +786,6 @@ def _get_master_port(master_port: int | None = None) -> int:
786
786
  return master_port
787
787
 
788
788
 
789
- def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
790
- """
791
- map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
792
- which are generated in self.init_process_group_for_ranks
793
- """
794
- bcast_rank_map: dict[int, int] = {}
795
- if not ranks:
796
- bcast_rank_map = {r: r for r in range(world_size)}
797
- else:
798
- for i, r in enumerate(ranks):
799
- bcast_rank_map[r] = i
800
- return bcast_rank_map
801
-
802
-
803
789
  class P2PStore:
804
790
  def __init__(self, device_manager: DeviceManager):
805
791
  from mooncake.engine import TransferEngine
@@ -1164,12 +1150,36 @@ class ParameterServer:
1164
1150
  )
1165
1151
  logger.info(f"[rank{self._rank}] init process group successfully.")
1166
1152
 
1153
+ def store_based_barrier(
1154
+ self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
1155
+ ) -> None:
1156
+ """
1157
+ Perform a store-based barrier synchronization across all ranks.
1158
+
1159
+ This barrier uses a TCP store directly rather than a process group,
1160
+ allowing all ranks to synchronize regardless of which process group
1161
+ they belong to.
1162
+
1163
+ Args:
1164
+ store: The TCPStore instance to use for synchronization.
1165
+ """
1166
+ dist.distributed_c10d._store_based_barrier(
1167
+ rank=self._rank,
1168
+ store=store,
1169
+ group_name="parameter_server_barrier",
1170
+ rendezvous_count=self._world_size,
1171
+ timeout=timeout,
1172
+ )
1173
+
1167
1174
  def update(
1168
1175
  self,
1169
1176
  checkpoint_name: str,
1170
1177
  req_func: Callable[[list[tuple[str, str]]], None],
1171
1178
  *,
1179
+ timeout: timedelta = timedelta(minutes=10),
1172
1180
  ranks: list[int] | None = None,
1181
+ master_addr: str | None = None,
1182
+ master_port: int | None = None,
1173
1183
  ) -> None:
1174
1184
  """
1175
1185
  Update the checkpoint to inference engine. This function should be called after gather_metas.
@@ -1181,34 +1191,45 @@ class ParameterServer:
1181
1191
  which is the fastest way to update weights, especially in colocated architecture.
1182
1192
  If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
1183
1193
  which is useful in disaggregated architecture.
1194
+ master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
1195
+ master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
1196
+ timeout: The timeout of the barrier operation.
1184
1197
  """
1185
1198
  assert req_func is not None, "req_func is required"
1199
+ ranks_group = None
1186
1200
  try:
1187
- # if both ranks is None or [], it will use fully broadcast to update to all ranks
1188
- if not ranks:
1189
- if self._auto_pg and not dist.is_initialized():
1190
- self.init_process_group()
1191
- self._update_per_bucket(checkpoint_name, req_func)
1201
+ master_addr = os.getenv("MASTER_ADDR") or master_addr
1202
+ assert master_addr, "master_addr is required"
1203
+ if self._auto_pg:
1204
+ if not dist.is_initialized():
1205
+ self.init_process_group(
1206
+ timeout=timeout, master_addr=master_addr, master_port=master_port
1207
+ )
1208
+ manager_store = dist.distributed_c10d._get_default_store()
1192
1209
  else:
1193
- if self._auto_pg:
1194
- if dist.is_initialized():
1195
- dist.destroy_process_group()
1196
- # HACK: wait 2s to ensure destroy is finished
1197
- time.sleep(2)
1198
- self.init_process_group_for_ranks(ranks)
1199
- if self._rank not in ranks:
1200
- return
1201
- self._update_per_bucket(checkpoint_name, req_func, ranks)
1202
-
1210
+ # HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
1211
+ # If master_port is provided, use master_port+1 for barrier store
1212
+ manager_store = dist.TCPStore(
1213
+ master_addr,
1214
+ _get_master_port(master_port) + 1,
1215
+ self._world_size,
1216
+ timeout=timeout,
1217
+ is_master=self._rank == 0,
1218
+ )
1219
+ # if ranks is None or [], it will use fully broadcast to update to all ranks
1220
+ ranks_group = dist.new_group(ranks if ranks else None)
1221
+ self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1222
+ self.store_based_barrier(manager_store)
1203
1223
  except Exception as e:
1204
1224
  logger.exception(
1205
1225
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
1206
1226
  )
1207
1227
  raise
1208
1228
  finally:
1209
- if self._auto_pg and (not ranks or self._rank in ranks):
1229
+ if ranks_group:
1230
+ dist.destroy_process_group(ranks_group)
1231
+ if self._auto_pg and dist.is_initialized():
1210
1232
  dist.destroy_process_group()
1211
-
1212
1233
  self.device_manager.device_module.empty_cache()
1213
1234
  logger.info(
1214
1235
  f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
@@ -1226,7 +1247,9 @@ class ParameterServer:
1226
1247
  self._zmq_addr_counter += 1
1227
1248
  return socket, socket_paths
1228
1249
 
1229
- def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
1250
+ def _detect_bucket_size(
1251
+ self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
1252
+ ) -> tuple[int, bool]:
1230
1253
  GiB = 1 << 30 # noqa: N806
1231
1254
  # auto detect bucket size
1232
1255
  tensor = torch.tensor(
@@ -1242,7 +1265,7 @@ class ParameterServer:
1242
1265
  dtype=torch.int64,
1243
1266
  device=self.device_manager.device_type,
1244
1267
  )
1245
- dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
1268
+ dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
1246
1269
  tensor = tensor.cpu()
1247
1270
  free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
1248
1271
  max_tensor_bytes = 0
@@ -1305,51 +1328,6 @@ class ParameterServer:
1305
1328
  self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
1306
1329
  self.device_manager.device_module.synchronize()
1307
1330
 
1308
- def init_process_group_for_ranks(
1309
- self,
1310
- ranks: list[int],
1311
- *,
1312
- master_port: int | None = None,
1313
- timeout: timedelta = timedelta(minutes=10),
1314
- ):
1315
- """
1316
- Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
1317
-
1318
- Args:
1319
- ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
1320
- master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
1321
- timeout: The timeout of the process group.
1322
- """
1323
- assert not dist.is_initialized()
1324
- assert ranks, "ranks should be set"
1325
- if self._rank not in ranks:
1326
- return
1327
- assert self._all_hosts, "all_hosts should be set"
1328
- assert len(self._all_hosts) == self._world_size // self._gpu_count, (
1329
- f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
1330
- )
1331
- rank = ranks.index(self._rank)
1332
- master_addr = self._all_hosts[ranks[0] // self._gpu_count]
1333
- master_port = _get_master_port(master_port)
1334
- logger.info(
1335
- f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
1336
- f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
1337
- )
1338
- # only initialize process group and store for ranks, other nodes are not initialized
1339
- # and will not participate in this update. Since they have registered memory addresses
1340
- # to p2p_store at the beginning, update ranks can directly get the memory addresses
1341
- # from other nodes and put the weights into the buffer.
1342
- store = dist.TCPStore(
1343
- master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
1344
- )
1345
- dist.init_process_group(
1346
- backend=self.device_manager.backend,
1347
- world_size=len(ranks),
1348
- rank=rank,
1349
- timeout=timeout,
1350
- store=store,
1351
- )
1352
-
1353
1331
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
1354
1332
  addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
1355
1333
  metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
@@ -1389,10 +1367,12 @@ class ParameterServer:
1389
1367
  self,
1390
1368
  checkpoint_name: str,
1391
1369
  req_func: Callable[[list[tuple[str, str]]], None],
1370
+ ranks_group: dist.ProcessGroup,
1392
1371
  ranks: list[int] | None = None,
1393
1372
  ):
1394
1373
  assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1395
1374
  assert dist.is_initialized(), "process group is not initialized"
1375
+
1396
1376
  # if both ranks is None or [], it will use fully broadcast to update to all ranks
1397
1377
  if not ranks:
1398
1378
  logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
@@ -1410,9 +1390,9 @@ class ParameterServer:
1410
1390
  if not need_update:
1411
1391
  return
1412
1392
  # first execute a barrier to avoid subsequent device oom
1413
- dist.barrier()
1393
+ dist.barrier(group=ranks_group)
1414
1394
 
1415
- bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1395
+ bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
1416
1396
  buckets = _gen_h2d_buckets(
1417
1397
  self._current_global_parameter_metas,
1418
1398
  bucket_size,
@@ -1459,7 +1439,6 @@ class ParameterServer:
1459
1439
 
1460
1440
  gidx = 0
1461
1441
  ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
1462
- bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1463
1442
  try:
1464
1443
  for i in range(max_len):
1465
1444
  if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
@@ -1489,8 +1468,7 @@ class ParameterServer:
1489
1468
  self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1490
1469
  else:
1491
1470
  buffer_b.data.copy_(h2d_buffer[: bucket.size])
1492
- brank = bcast_rank_map[receiver_rank]
1493
- dist.broadcast(buffer_b, src=brank)
1471
+ dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
1494
1472
  resp = socket.recv()
1495
1473
  if resp != b"":
1496
1474
  msg = resp.decode("utf-8")
@@ -1498,7 +1476,7 @@ class ParameterServer:
1498
1476
  f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1499
1477
  )
1500
1478
  ret_code.fill_(1)
1501
- dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
1479
+ dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
1502
1480
  self.device_manager.device_module.synchronize()
1503
1481
  if ret_code.item() != 0:
1504
1482
  # quit early if any rank failed
@@ -1512,7 +1490,7 @@ class ParameterServer:
1512
1490
  socket.recv()
1513
1491
  finally:
1514
1492
  req_thread.join()
1515
- dist.barrier()
1493
+ dist.barrier(group=ranks_group)
1516
1494
  socket.close()
1517
1495
  if ranks and h2d_buffer is not None:
1518
1496
  self._p2p_store.unregister_named_tensors([h2d_buffer_name])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.2
3
+ Version: 0.3.0rc0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -237,7 +237,13 @@ def run_with_files(
237
237
  ],
238
238
  ),
239
239
  ("test_with_remote_error", [[]]),
240
- # ("long_test_no_error", [list(random.sample(range(get_world_size()), k=num_ranks)) for num_ranks in range(get_world_size() + 1)]),
240
+ (
241
+ "test_no_error",
242
+ [
243
+ list(random.sample(range(get_world_size()), k=num_ranks))
244
+ for num_ranks in range(get_world_size() + 1)
245
+ ],
246
+ ),
241
247
  ],
242
248
  )
243
249
  def test_update(test_name: str, rank_list: list[list[int]] | None):