checkpoint-engine 0.2.3__tar.gz → 0.3.0rc0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/PKG-INFO +1 -1
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/_version.py +3 -3
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/ps.py +71 -105
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/PKG-INFO +1 -1
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_update.py +8 -3
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.github/workflows/cpu-tests.yml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.github/workflows/pre-commit.yaml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.github/workflows/python-publish.yml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.gitignore +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/.pre-commit-config.yaml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/LICENCE +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/README.md +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/__init__.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/device_utils.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine/worker.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/SOURCES.txt +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/requires.txt +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/top_level.txt +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/docs/npu_start.md +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/examples/update.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/figures/checkpoint-engine.png +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/figures/overlap-update-and-copy.png +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/figures/pipeline.png +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/patches/vllm_fp8.patch +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/pyproject.toml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/setup.cfg +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_assign_receiver_ranks.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_pin_memory.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/tests/test_rdma_parser.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0rc0
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.3.0rc0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'gbaf6f6196'
|
|
@@ -622,7 +622,6 @@ def _register_checkpoint(
|
|
|
622
622
|
named_tensors: dict[str, torch.Tensor],
|
|
623
623
|
rank: int | None = None,
|
|
624
624
|
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
625
|
-
inplace_pin: bool = False,
|
|
626
625
|
) -> list[MemoryBuffer]:
|
|
627
626
|
logger.info(
|
|
628
627
|
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
@@ -630,17 +629,12 @@ def _register_checkpoint(
|
|
|
630
629
|
if not files and not named_tensors:
|
|
631
630
|
return []
|
|
632
631
|
memory_buffers: list[MemoryBuffer] = []
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
]
|
|
640
|
-
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
641
|
-
else:
|
|
642
|
-
files_to_normal_pin = files
|
|
643
|
-
files_to_inplace_pin = []
|
|
632
|
+
files_to_inplace_pin = [
|
|
633
|
+
file
|
|
634
|
+
for file in files
|
|
635
|
+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
636
|
+
]
|
|
637
|
+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
644
638
|
if files_to_normal_pin or named_tensors:
|
|
645
639
|
memory_buffers.extend(
|
|
646
640
|
_normal_pin_memory(
|
|
@@ -792,20 +786,6 @@ def _get_master_port(master_port: int | None = None) -> int:
|
|
|
792
786
|
return master_port
|
|
793
787
|
|
|
794
788
|
|
|
795
|
-
def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
|
|
796
|
-
"""
|
|
797
|
-
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
|
|
798
|
-
which are generated in self.init_process_group_for_ranks
|
|
799
|
-
"""
|
|
800
|
-
bcast_rank_map: dict[int, int] = {}
|
|
801
|
-
if not ranks:
|
|
802
|
-
bcast_rank_map = {r: r for r in range(world_size)}
|
|
803
|
-
else:
|
|
804
|
-
for i, r in enumerate(ranks):
|
|
805
|
-
bcast_rank_map[r] = i
|
|
806
|
-
return bcast_rank_map
|
|
807
|
-
|
|
808
|
-
|
|
809
789
|
class P2PStore:
|
|
810
790
|
def __init__(self, device_manager: DeviceManager):
|
|
811
791
|
from mooncake.engine import TransferEngine
|
|
@@ -979,11 +959,10 @@ class ParameterServer:
|
|
|
979
959
|
files: list[str] | None = None,
|
|
980
960
|
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
981
961
|
use_shared_memory_pool: bool = False,
|
|
982
|
-
use_inplace_pin_memory: bool = False,
|
|
983
962
|
) -> None:
|
|
984
963
|
"""
|
|
985
964
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
986
|
-
Warning:
|
|
965
|
+
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
987
966
|
Please make sure to copy the files to disks if you need to keep them.
|
|
988
967
|
|
|
989
968
|
Args:
|
|
@@ -995,8 +974,6 @@ class ParameterServer:
|
|
|
995
974
|
cannot accommodate checkpoints with different memory requirements.
|
|
996
975
|
To free the actual memory of the shared pool or to modify its shape,
|
|
997
976
|
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
998
|
-
use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
|
|
999
|
-
Currently, this feature is experimental and may crash.
|
|
1000
977
|
"""
|
|
1001
978
|
try:
|
|
1002
979
|
if use_shared_memory_pool:
|
|
@@ -1025,10 +1002,7 @@ class ParameterServer:
|
|
|
1025
1002
|
f"checkpoint {checkpoint_name} already registered"
|
|
1026
1003
|
)
|
|
1027
1004
|
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
1028
|
-
files=files or [],
|
|
1029
|
-
named_tensors=named_tensors or {},
|
|
1030
|
-
rank=self._rank,
|
|
1031
|
-
inplace_pin=use_inplace_pin_memory,
|
|
1005
|
+
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
|
|
1032
1006
|
)
|
|
1033
1007
|
if self._p2p_store is not None:
|
|
1034
1008
|
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
@@ -1176,12 +1150,36 @@ class ParameterServer:
|
|
|
1176
1150
|
)
|
|
1177
1151
|
logger.info(f"[rank{self._rank}] init process group successfully.")
|
|
1178
1152
|
|
|
1153
|
+
def store_based_barrier(
|
|
1154
|
+
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
|
|
1155
|
+
) -> None:
|
|
1156
|
+
"""
|
|
1157
|
+
Perform a store-based barrier synchronization across all ranks.
|
|
1158
|
+
|
|
1159
|
+
This barrier uses a TCP store directly rather than a process group,
|
|
1160
|
+
allowing all ranks to synchronize regardless of which process group
|
|
1161
|
+
they belong to.
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
store: The TCPStore instance to use for synchronization.
|
|
1165
|
+
"""
|
|
1166
|
+
dist.distributed_c10d._store_based_barrier(
|
|
1167
|
+
rank=self._rank,
|
|
1168
|
+
store=store,
|
|
1169
|
+
group_name="parameter_server_barrier",
|
|
1170
|
+
rendezvous_count=self._world_size,
|
|
1171
|
+
timeout=timeout,
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1179
1174
|
def update(
|
|
1180
1175
|
self,
|
|
1181
1176
|
checkpoint_name: str,
|
|
1182
1177
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1183
1178
|
*,
|
|
1179
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
1184
1180
|
ranks: list[int] | None = None,
|
|
1181
|
+
master_addr: str | None = None,
|
|
1182
|
+
master_port: int | None = None,
|
|
1185
1183
|
) -> None:
|
|
1186
1184
|
"""
|
|
1187
1185
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
@@ -1193,34 +1191,45 @@ class ParameterServer:
|
|
|
1193
1191
|
which is the fastest way to update weights, especially in colocated architecture.
|
|
1194
1192
|
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
|
|
1195
1193
|
which is useful in disaggregated architecture.
|
|
1194
|
+
master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
|
|
1195
|
+
master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
|
|
1196
|
+
timeout: The timeout of the barrier operation.
|
|
1196
1197
|
"""
|
|
1197
1198
|
assert req_func is not None, "req_func is required"
|
|
1199
|
+
ranks_group = None
|
|
1198
1200
|
try:
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1201
|
+
master_addr = os.getenv("MASTER_ADDR") or master_addr
|
|
1202
|
+
assert master_addr, "master_addr is required"
|
|
1203
|
+
if self._auto_pg:
|
|
1204
|
+
if not dist.is_initialized():
|
|
1205
|
+
self.init_process_group(
|
|
1206
|
+
timeout=timeout, master_addr=master_addr, master_port=master_port
|
|
1207
|
+
)
|
|
1208
|
+
manager_store = dist.distributed_c10d._get_default_store()
|
|
1204
1209
|
else:
|
|
1205
|
-
if
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
self.
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1210
|
+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
|
|
1211
|
+
# If master_port is provided, use master_port+1 for barrier store
|
|
1212
|
+
manager_store = dist.TCPStore(
|
|
1213
|
+
master_addr,
|
|
1214
|
+
_get_master_port(master_port) + 1,
|
|
1215
|
+
self._world_size,
|
|
1216
|
+
timeout=timeout,
|
|
1217
|
+
is_master=self._rank == 0,
|
|
1218
|
+
)
|
|
1219
|
+
# if ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1220
|
+
ranks_group = dist.new_group(ranks if ranks else None)
|
|
1221
|
+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
|
|
1222
|
+
self.store_based_barrier(manager_store)
|
|
1215
1223
|
except Exception as e:
|
|
1216
1224
|
logger.exception(
|
|
1217
1225
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
1218
1226
|
)
|
|
1219
1227
|
raise
|
|
1220
1228
|
finally:
|
|
1221
|
-
if
|
|
1229
|
+
if ranks_group:
|
|
1230
|
+
dist.destroy_process_group(ranks_group)
|
|
1231
|
+
if self._auto_pg and dist.is_initialized():
|
|
1222
1232
|
dist.destroy_process_group()
|
|
1223
|
-
|
|
1224
1233
|
self.device_manager.device_module.empty_cache()
|
|
1225
1234
|
logger.info(
|
|
1226
1235
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
@@ -1238,7 +1247,9 @@ class ParameterServer:
|
|
|
1238
1247
|
self._zmq_addr_counter += 1
|
|
1239
1248
|
return socket, socket_paths
|
|
1240
1249
|
|
|
1241
|
-
def _detect_bucket_size(
|
|
1250
|
+
def _detect_bucket_size(
|
|
1251
|
+
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
|
|
1252
|
+
) -> tuple[int, bool]:
|
|
1242
1253
|
GiB = 1 << 30 # noqa: N806
|
|
1243
1254
|
# auto detect bucket size
|
|
1244
1255
|
tensor = torch.tensor(
|
|
@@ -1254,7 +1265,7 @@ class ParameterServer:
|
|
|
1254
1265
|
dtype=torch.int64,
|
|
1255
1266
|
device=self.device_manager.device_type,
|
|
1256
1267
|
)
|
|
1257
|
-
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
1268
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
|
|
1258
1269
|
tensor = tensor.cpu()
|
|
1259
1270
|
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
|
|
1260
1271
|
max_tensor_bytes = 0
|
|
@@ -1317,51 +1328,6 @@ class ParameterServer:
|
|
|
1317
1328
|
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
|
|
1318
1329
|
self.device_manager.device_module.synchronize()
|
|
1319
1330
|
|
|
1320
|
-
def init_process_group_for_ranks(
|
|
1321
|
-
self,
|
|
1322
|
-
ranks: list[int],
|
|
1323
|
-
*,
|
|
1324
|
-
master_port: int | None = None,
|
|
1325
|
-
timeout: timedelta = timedelta(minutes=10),
|
|
1326
|
-
):
|
|
1327
|
-
"""
|
|
1328
|
-
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
1329
|
-
|
|
1330
|
-
Args:
|
|
1331
|
-
ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
|
|
1332
|
-
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
|
|
1333
|
-
timeout: The timeout of the process group.
|
|
1334
|
-
"""
|
|
1335
|
-
assert not dist.is_initialized()
|
|
1336
|
-
assert ranks, "ranks should be set"
|
|
1337
|
-
if self._rank not in ranks:
|
|
1338
|
-
return
|
|
1339
|
-
assert self._all_hosts, "all_hosts should be set"
|
|
1340
|
-
assert len(self._all_hosts) == self._world_size // self._gpu_count, (
|
|
1341
|
-
f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
|
|
1342
|
-
)
|
|
1343
|
-
rank = ranks.index(self._rank)
|
|
1344
|
-
master_addr = self._all_hosts[ranks[0] // self._gpu_count]
|
|
1345
|
-
master_port = _get_master_port(master_port)
|
|
1346
|
-
logger.info(
|
|
1347
|
-
f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
|
|
1348
|
-
f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
|
|
1349
|
-
)
|
|
1350
|
-
# only initialize process group and store for ranks, other nodes are not initialized
|
|
1351
|
-
# and will not participate in this update. Since they have registered memory addresses
|
|
1352
|
-
# to p2p_store at the beginning, update ranks can directly get the memory addresses
|
|
1353
|
-
# from other nodes and put the weights into the buffer.
|
|
1354
|
-
store = dist.TCPStore(
|
|
1355
|
-
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
1356
|
-
)
|
|
1357
|
-
dist.init_process_group(
|
|
1358
|
-
backend=self.device_manager.backend,
|
|
1359
|
-
world_size=len(ranks),
|
|
1360
|
-
rank=rank,
|
|
1361
|
-
timeout=timeout,
|
|
1362
|
-
store=store,
|
|
1363
|
-
)
|
|
1364
|
-
|
|
1365
1331
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
1366
1332
|
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
|
|
1367
1333
|
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
|
|
@@ -1401,10 +1367,12 @@ class ParameterServer:
|
|
|
1401
1367
|
self,
|
|
1402
1368
|
checkpoint_name: str,
|
|
1403
1369
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1370
|
+
ranks_group: dist.ProcessGroup,
|
|
1404
1371
|
ranks: list[int] | None = None,
|
|
1405
1372
|
):
|
|
1406
1373
|
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
1407
1374
|
assert dist.is_initialized(), "process group is not initialized"
|
|
1375
|
+
|
|
1408
1376
|
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1409
1377
|
if not ranks:
|
|
1410
1378
|
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
@@ -1422,9 +1390,9 @@ class ParameterServer:
|
|
|
1422
1390
|
if not need_update:
|
|
1423
1391
|
return
|
|
1424
1392
|
# first execute a barrier to avoid subsequent device oom
|
|
1425
|
-
dist.barrier()
|
|
1393
|
+
dist.barrier(group=ranks_group)
|
|
1426
1394
|
|
|
1427
|
-
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
1395
|
+
bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
|
|
1428
1396
|
buckets = _gen_h2d_buckets(
|
|
1429
1397
|
self._current_global_parameter_metas,
|
|
1430
1398
|
bucket_size,
|
|
@@ -1471,7 +1439,6 @@ class ParameterServer:
|
|
|
1471
1439
|
|
|
1472
1440
|
gidx = 0
|
|
1473
1441
|
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
|
|
1474
|
-
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
|
|
1475
1442
|
try:
|
|
1476
1443
|
for i in range(max_len):
|
|
1477
1444
|
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
@@ -1501,8 +1468,7 @@ class ParameterServer:
|
|
|
1501
1468
|
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1502
1469
|
else:
|
|
1503
1470
|
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1504
|
-
|
|
1505
|
-
dist.broadcast(buffer_b, src=brank)
|
|
1471
|
+
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
|
|
1506
1472
|
resp = socket.recv()
|
|
1507
1473
|
if resp != b"":
|
|
1508
1474
|
msg = resp.decode("utf-8")
|
|
@@ -1510,7 +1476,7 @@ class ParameterServer:
|
|
|
1510
1476
|
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
1511
1477
|
)
|
|
1512
1478
|
ret_code.fill_(1)
|
|
1513
|
-
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
|
1479
|
+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
|
|
1514
1480
|
self.device_manager.device_module.synchronize()
|
|
1515
1481
|
if ret_code.item() != 0:
|
|
1516
1482
|
# quit early if any rank failed
|
|
@@ -1524,7 +1490,7 @@ class ParameterServer:
|
|
|
1524
1490
|
socket.recv()
|
|
1525
1491
|
finally:
|
|
1526
1492
|
req_thread.join()
|
|
1527
|
-
dist.barrier()
|
|
1493
|
+
dist.barrier(group=ranks_group)
|
|
1528
1494
|
socket.close()
|
|
1529
1495
|
if ranks and h2d_buffer is not None:
|
|
1530
1496
|
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0rc0
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -218,8 +218,7 @@ def run_with_files(
|
|
|
218
218
|
if rank == 0:
|
|
219
219
|
import shutil
|
|
220
220
|
|
|
221
|
-
|
|
222
|
-
shutil.rmtree(dev_shm_dir)
|
|
221
|
+
os.removedirs(dev_shm_dir)
|
|
223
222
|
shutil.rmtree(disk_dir)
|
|
224
223
|
assert proc.exitcode == 0
|
|
225
224
|
|
|
@@ -238,7 +237,13 @@ def run_with_files(
|
|
|
238
237
|
],
|
|
239
238
|
),
|
|
240
239
|
("test_with_remote_error", [[]]),
|
|
241
|
-
|
|
240
|
+
(
|
|
241
|
+
"test_no_error",
|
|
242
|
+
[
|
|
243
|
+
list(random.sample(range(get_world_size()), k=num_ranks))
|
|
244
|
+
for num_ranks in range(get_world_size() + 1)
|
|
245
|
+
],
|
|
246
|
+
),
|
|
242
247
|
],
|
|
243
248
|
)
|
|
244
249
|
def test_update(test_name: str, rank_list: list[list[int]] | None):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/requires.txt
RENAMED
|
File without changes
|
{checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc0}/checkpoint_engine.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|