checkpoint-engine 0.2.2__tar.gz → 0.3.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/PKG-INFO +1 -1
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/_version.py +3 -3
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/ps.py +63 -85
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/PKG-INFO +1 -1
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_update.py +7 -1
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.github/workflows/cpu-tests.yml +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.github/workflows/pre-commit.yaml +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.github/workflows/python-publish.yml +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.gitignore +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/.pre-commit-config.yaml +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/LICENCE +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/README.md +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/__init__.py +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/device_utils.py +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/worker.py +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/SOURCES.txt +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/requires.txt +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/top_level.txt +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/docs/npu_start.md +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/examples/update.py +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/figures/checkpoint-engine.png +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/figures/overlap-update-and-copy.png +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/figures/pipeline.png +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/patches/vllm_fp8.patch +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/pyproject.toml +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/setup.cfg +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_assign_receiver_ranks.py +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_pin_memory.py +0 -0
- {checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/tests/test_rdma_parser.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0rc0
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.3.0rc0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'gbaf6f6196'
|
|
@@ -786,20 +786,6 @@ def _get_master_port(master_port: int | None = None) -> int:
|
|
|
786
786
|
return master_port
|
|
787
787
|
|
|
788
788
|
|
|
789
|
-
def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
|
|
790
|
-
"""
|
|
791
|
-
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
|
|
792
|
-
which are generated in self.init_process_group_for_ranks
|
|
793
|
-
"""
|
|
794
|
-
bcast_rank_map: dict[int, int] = {}
|
|
795
|
-
if not ranks:
|
|
796
|
-
bcast_rank_map = {r: r for r in range(world_size)}
|
|
797
|
-
else:
|
|
798
|
-
for i, r in enumerate(ranks):
|
|
799
|
-
bcast_rank_map[r] = i
|
|
800
|
-
return bcast_rank_map
|
|
801
|
-
|
|
802
|
-
|
|
803
789
|
class P2PStore:
|
|
804
790
|
def __init__(self, device_manager: DeviceManager):
|
|
805
791
|
from mooncake.engine import TransferEngine
|
|
@@ -1164,12 +1150,36 @@ class ParameterServer:
|
|
|
1164
1150
|
)
|
|
1165
1151
|
logger.info(f"[rank{self._rank}] init process group successfully.")
|
|
1166
1152
|
|
|
1153
|
+
def store_based_barrier(
|
|
1154
|
+
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
|
|
1155
|
+
) -> None:
|
|
1156
|
+
"""
|
|
1157
|
+
Perform a store-based barrier synchronization across all ranks.
|
|
1158
|
+
|
|
1159
|
+
This barrier uses a TCP store directly rather than a process group,
|
|
1160
|
+
allowing all ranks to synchronize regardless of which process group
|
|
1161
|
+
they belong to.
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
store: The TCPStore instance to use for synchronization.
|
|
1165
|
+
"""
|
|
1166
|
+
dist.distributed_c10d._store_based_barrier(
|
|
1167
|
+
rank=self._rank,
|
|
1168
|
+
store=store,
|
|
1169
|
+
group_name="parameter_server_barrier",
|
|
1170
|
+
rendezvous_count=self._world_size,
|
|
1171
|
+
timeout=timeout,
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1167
1174
|
def update(
|
|
1168
1175
|
self,
|
|
1169
1176
|
checkpoint_name: str,
|
|
1170
1177
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1171
1178
|
*,
|
|
1179
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
1172
1180
|
ranks: list[int] | None = None,
|
|
1181
|
+
master_addr: str | None = None,
|
|
1182
|
+
master_port: int | None = None,
|
|
1173
1183
|
) -> None:
|
|
1174
1184
|
"""
|
|
1175
1185
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
@@ -1181,34 +1191,45 @@ class ParameterServer:
|
|
|
1181
1191
|
which is the fastest way to update weights, especially in colocated architecture.
|
|
1182
1192
|
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
|
|
1183
1193
|
which is useful in disaggregated architecture.
|
|
1194
|
+
master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
|
|
1195
|
+
master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
|
|
1196
|
+
timeout: The timeout of the barrier operation.
|
|
1184
1197
|
"""
|
|
1185
1198
|
assert req_func is not None, "req_func is required"
|
|
1199
|
+
ranks_group = None
|
|
1186
1200
|
try:
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1201
|
+
master_addr = os.getenv("MASTER_ADDR") or master_addr
|
|
1202
|
+
assert master_addr, "master_addr is required"
|
|
1203
|
+
if self._auto_pg:
|
|
1204
|
+
if not dist.is_initialized():
|
|
1205
|
+
self.init_process_group(
|
|
1206
|
+
timeout=timeout, master_addr=master_addr, master_port=master_port
|
|
1207
|
+
)
|
|
1208
|
+
manager_store = dist.distributed_c10d._get_default_store()
|
|
1192
1209
|
else:
|
|
1193
|
-
if
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
self.
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1210
|
+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
|
|
1211
|
+
# If master_port is provided, use master_port+1 for barrier store
|
|
1212
|
+
manager_store = dist.TCPStore(
|
|
1213
|
+
master_addr,
|
|
1214
|
+
_get_master_port(master_port) + 1,
|
|
1215
|
+
self._world_size,
|
|
1216
|
+
timeout=timeout,
|
|
1217
|
+
is_master=self._rank == 0,
|
|
1218
|
+
)
|
|
1219
|
+
# if ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1220
|
+
ranks_group = dist.new_group(ranks if ranks else None)
|
|
1221
|
+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
|
|
1222
|
+
self.store_based_barrier(manager_store)
|
|
1203
1223
|
except Exception as e:
|
|
1204
1224
|
logger.exception(
|
|
1205
1225
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
1206
1226
|
)
|
|
1207
1227
|
raise
|
|
1208
1228
|
finally:
|
|
1209
|
-
if
|
|
1229
|
+
if ranks_group:
|
|
1230
|
+
dist.destroy_process_group(ranks_group)
|
|
1231
|
+
if self._auto_pg and dist.is_initialized():
|
|
1210
1232
|
dist.destroy_process_group()
|
|
1211
|
-
|
|
1212
1233
|
self.device_manager.device_module.empty_cache()
|
|
1213
1234
|
logger.info(
|
|
1214
1235
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
@@ -1226,7 +1247,9 @@ class ParameterServer:
|
|
|
1226
1247
|
self._zmq_addr_counter += 1
|
|
1227
1248
|
return socket, socket_paths
|
|
1228
1249
|
|
|
1229
|
-
def _detect_bucket_size(
|
|
1250
|
+
def _detect_bucket_size(
|
|
1251
|
+
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
|
|
1252
|
+
) -> tuple[int, bool]:
|
|
1230
1253
|
GiB = 1 << 30 # noqa: N806
|
|
1231
1254
|
# auto detect bucket size
|
|
1232
1255
|
tensor = torch.tensor(
|
|
@@ -1242,7 +1265,7 @@ class ParameterServer:
|
|
|
1242
1265
|
dtype=torch.int64,
|
|
1243
1266
|
device=self.device_manager.device_type,
|
|
1244
1267
|
)
|
|
1245
|
-
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
1268
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
|
|
1246
1269
|
tensor = tensor.cpu()
|
|
1247
1270
|
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
|
|
1248
1271
|
max_tensor_bytes = 0
|
|
@@ -1305,51 +1328,6 @@ class ParameterServer:
|
|
|
1305
1328
|
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
|
|
1306
1329
|
self.device_manager.device_module.synchronize()
|
|
1307
1330
|
|
|
1308
|
-
def init_process_group_for_ranks(
|
|
1309
|
-
self,
|
|
1310
|
-
ranks: list[int],
|
|
1311
|
-
*,
|
|
1312
|
-
master_port: int | None = None,
|
|
1313
|
-
timeout: timedelta = timedelta(minutes=10),
|
|
1314
|
-
):
|
|
1315
|
-
"""
|
|
1316
|
-
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
1317
|
-
|
|
1318
|
-
Args:
|
|
1319
|
-
ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
|
|
1320
|
-
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
|
|
1321
|
-
timeout: The timeout of the process group.
|
|
1322
|
-
"""
|
|
1323
|
-
assert not dist.is_initialized()
|
|
1324
|
-
assert ranks, "ranks should be set"
|
|
1325
|
-
if self._rank not in ranks:
|
|
1326
|
-
return
|
|
1327
|
-
assert self._all_hosts, "all_hosts should be set"
|
|
1328
|
-
assert len(self._all_hosts) == self._world_size // self._gpu_count, (
|
|
1329
|
-
f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
|
|
1330
|
-
)
|
|
1331
|
-
rank = ranks.index(self._rank)
|
|
1332
|
-
master_addr = self._all_hosts[ranks[0] // self._gpu_count]
|
|
1333
|
-
master_port = _get_master_port(master_port)
|
|
1334
|
-
logger.info(
|
|
1335
|
-
f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
|
|
1336
|
-
f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
|
|
1337
|
-
)
|
|
1338
|
-
# only initialize process group and store for ranks, other nodes are not initialized
|
|
1339
|
-
# and will not participate in this update. Since they have registered memory addresses
|
|
1340
|
-
# to p2p_store at the beginning, update ranks can directly get the memory addresses
|
|
1341
|
-
# from other nodes and put the weights into the buffer.
|
|
1342
|
-
store = dist.TCPStore(
|
|
1343
|
-
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
1344
|
-
)
|
|
1345
|
-
dist.init_process_group(
|
|
1346
|
-
backend=self.device_manager.backend,
|
|
1347
|
-
world_size=len(ranks),
|
|
1348
|
-
rank=rank,
|
|
1349
|
-
timeout=timeout,
|
|
1350
|
-
store=store,
|
|
1351
|
-
)
|
|
1352
|
-
|
|
1353
1331
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
1354
1332
|
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
|
|
1355
1333
|
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
|
|
@@ -1389,10 +1367,12 @@ class ParameterServer:
|
|
|
1389
1367
|
self,
|
|
1390
1368
|
checkpoint_name: str,
|
|
1391
1369
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1370
|
+
ranks_group: dist.ProcessGroup,
|
|
1392
1371
|
ranks: list[int] | None = None,
|
|
1393
1372
|
):
|
|
1394
1373
|
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
1395
1374
|
assert dist.is_initialized(), "process group is not initialized"
|
|
1375
|
+
|
|
1396
1376
|
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1397
1377
|
if not ranks:
|
|
1398
1378
|
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
@@ -1410,9 +1390,9 @@ class ParameterServer:
|
|
|
1410
1390
|
if not need_update:
|
|
1411
1391
|
return
|
|
1412
1392
|
# first execute a barrier to avoid subsequent device oom
|
|
1413
|
-
dist.barrier()
|
|
1393
|
+
dist.barrier(group=ranks_group)
|
|
1414
1394
|
|
|
1415
|
-
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
1395
|
+
bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
|
|
1416
1396
|
buckets = _gen_h2d_buckets(
|
|
1417
1397
|
self._current_global_parameter_metas,
|
|
1418
1398
|
bucket_size,
|
|
@@ -1459,7 +1439,6 @@ class ParameterServer:
|
|
|
1459
1439
|
|
|
1460
1440
|
gidx = 0
|
|
1461
1441
|
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
|
|
1462
|
-
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
|
|
1463
1442
|
try:
|
|
1464
1443
|
for i in range(max_len):
|
|
1465
1444
|
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
@@ -1489,8 +1468,7 @@ class ParameterServer:
|
|
|
1489
1468
|
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1490
1469
|
else:
|
|
1491
1470
|
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1492
|
-
|
|
1493
|
-
dist.broadcast(buffer_b, src=brank)
|
|
1471
|
+
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
|
|
1494
1472
|
resp = socket.recv()
|
|
1495
1473
|
if resp != b"":
|
|
1496
1474
|
msg = resp.decode("utf-8")
|
|
@@ -1498,7 +1476,7 @@ class ParameterServer:
|
|
|
1498
1476
|
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
1499
1477
|
)
|
|
1500
1478
|
ret_code.fill_(1)
|
|
1501
|
-
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
|
1479
|
+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
|
|
1502
1480
|
self.device_manager.device_module.synchronize()
|
|
1503
1481
|
if ret_code.item() != 0:
|
|
1504
1482
|
# quit early if any rank failed
|
|
@@ -1512,7 +1490,7 @@ class ParameterServer:
|
|
|
1512
1490
|
socket.recv()
|
|
1513
1491
|
finally:
|
|
1514
1492
|
req_thread.join()
|
|
1515
|
-
dist.barrier()
|
|
1493
|
+
dist.barrier(group=ranks_group)
|
|
1516
1494
|
socket.close()
|
|
1517
1495
|
if ranks and h2d_buffer is not None:
|
|
1518
1496
|
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0rc0
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -237,7 +237,13 @@ def run_with_files(
|
|
|
237
237
|
],
|
|
238
238
|
),
|
|
239
239
|
("test_with_remote_error", [[]]),
|
|
240
|
-
|
|
240
|
+
(
|
|
241
|
+
"test_no_error",
|
|
242
|
+
[
|
|
243
|
+
list(random.sample(range(get_world_size()), k=num_ranks))
|
|
244
|
+
for num_ranks in range(get_world_size() + 1)
|
|
245
|
+
],
|
|
246
|
+
),
|
|
241
247
|
],
|
|
242
248
|
)
|
|
243
249
|
def test_update(test_name: str, rank_list: list[list[int]] | None):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/requires.txt
RENAMED
|
File without changes
|
{checkpoint_engine-0.2.2 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|