checkpoint-engine 0.2.3__tar.gz → 0.3.0rc1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/PKG-INFO +1 -1
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/_version.py +3 -3
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/ps.py +119 -92
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/PKG-INFO +1 -1
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/SOURCES.txt +2 -1
- checkpoint_engine-0.3.0rc1/tests/test_inplace_unpin.py +81 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/tests/test_update.py +7 -3
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.github/workflows/cpu-tests.yml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.github/workflows/pre-commit.yaml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.github/workflows/python-publish.yml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.gitignore +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/.pre-commit-config.yaml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/LICENCE +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/README.md +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/__init__.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/device_utils.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine/worker.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/requires.txt +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/top_level.txt +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/docs/npu_start.md +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/examples/update.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/figures/checkpoint-engine.png +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/figures/overlap-update-and-copy.png +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/figures/pipeline.png +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/patches/vllm_fp8.patch +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/pyproject.toml +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/setup.cfg +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/tests/test_assign_receiver_ranks.py +0 -0
- {checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/tests/test_rdma_parser.py +0 -0
- /checkpoint_engine-0.2.3/tests/test_pin_memory.py → /checkpoint_engine-0.3.0rc1/tests/test_reuse_pin_memory.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0rc1
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.3.0rc1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 0, 'rc1')
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'g88370e267'
|
|
@@ -118,6 +118,7 @@ class MemoryBuffer(BaseModel):
|
|
|
118
118
|
buffer: _TorchTensor
|
|
119
119
|
size: int
|
|
120
120
|
metas: list[ParameterMeta]
|
|
121
|
+
manually_pinned: bool = False
|
|
121
122
|
|
|
122
123
|
|
|
123
124
|
class MemoryBufferMetaList(BaseModel):
|
|
@@ -520,7 +521,7 @@ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[Memor
|
|
|
520
521
|
logger.info(
|
|
521
522
|
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
522
523
|
)
|
|
523
|
-
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
|
|
524
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
|
|
524
525
|
|
|
525
526
|
memory_buffers: list[MemoryBuffer] = []
|
|
526
527
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
@@ -792,20 +793,6 @@ def _get_master_port(master_port: int | None = None) -> int:
|
|
|
792
793
|
return master_port
|
|
793
794
|
|
|
794
795
|
|
|
795
|
-
def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
|
|
796
|
-
"""
|
|
797
|
-
map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
|
|
798
|
-
which are generated in self.init_process_group_for_ranks
|
|
799
|
-
"""
|
|
800
|
-
bcast_rank_map: dict[int, int] = {}
|
|
801
|
-
if not ranks:
|
|
802
|
-
bcast_rank_map = {r: r for r in range(world_size)}
|
|
803
|
-
else:
|
|
804
|
-
for i, r in enumerate(ranks):
|
|
805
|
-
bcast_rank_map[r] = i
|
|
806
|
-
return bcast_rank_map
|
|
807
|
-
|
|
808
|
-
|
|
809
796
|
class P2PStore:
|
|
810
797
|
def __init__(self, device_manager: DeviceManager):
|
|
811
798
|
from mooncake.engine import TransferEngine
|
|
@@ -888,7 +875,7 @@ class ParameterServer:
|
|
|
888
875
|
*,
|
|
889
876
|
rank: int | None = None,
|
|
890
877
|
world_size: int | None = None,
|
|
891
|
-
auto_pg: bool =
|
|
878
|
+
auto_pg: bool = True,
|
|
892
879
|
gpu_count: int | None = None,
|
|
893
880
|
mem_fraction: float | None = None,
|
|
894
881
|
):
|
|
@@ -897,7 +884,7 @@ class ParameterServer:
|
|
|
897
884
|
|
|
898
885
|
Args:
|
|
899
886
|
auto_pg: Whether to automatically initialize the process group.
|
|
900
|
-
Notice that if auto_pg is True, will destroy the process group after update.
|
|
887
|
+
Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
|
|
901
888
|
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
902
889
|
"""
|
|
903
890
|
self._rank = rank or int(os.environ.get("RANK", None))
|
|
@@ -979,12 +966,12 @@ class ParameterServer:
|
|
|
979
966
|
files: list[str] | None = None,
|
|
980
967
|
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
981
968
|
use_shared_memory_pool: bool = False,
|
|
982
|
-
use_inplace_pin_memory: bool =
|
|
969
|
+
use_inplace_pin_memory: bool = True,
|
|
983
970
|
) -> None:
|
|
984
971
|
"""
|
|
985
972
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
986
973
|
Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
987
|
-
Please make sure to copy the files to disks if you need to keep them.
|
|
974
|
+
Please make sure to copy the files to disks if you need to keep them. NPU does not support inplace pin memory.
|
|
988
975
|
|
|
989
976
|
Args:
|
|
990
977
|
checkpoint_name: The name of the checkpoint.
|
|
@@ -995,9 +982,14 @@ class ParameterServer:
|
|
|
995
982
|
cannot accommodate checkpoints with different memory requirements.
|
|
996
983
|
To free the actual memory of the shared pool or to modify its shape,
|
|
997
984
|
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
998
|
-
use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files.
|
|
999
|
-
|
|
985
|
+
use_inplace_pin_memory: If True (default), allows inplace pin memory for /dev/shm/ safetensors files.
|
|
986
|
+
This option is ignored when ``use_shared_memory_pool`` is True.
|
|
1000
987
|
"""
|
|
988
|
+
if self.device_manager.device_type != "cuda" and use_inplace_pin_memory:
|
|
989
|
+
logger.warning(
|
|
990
|
+
f"[rank{self._rank}] Only cuda devices support in-place pin memory, set use_inplace_pin_memory to False"
|
|
991
|
+
)
|
|
992
|
+
use_inplace_pin_memory = False
|
|
1001
993
|
try:
|
|
1002
994
|
if use_shared_memory_pool:
|
|
1003
995
|
logger.info(
|
|
@@ -1016,6 +1008,7 @@ class ParameterServer:
|
|
|
1016
1008
|
named_tensors=named_tensors or {},
|
|
1017
1009
|
rank=self._rank,
|
|
1018
1010
|
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
|
|
1011
|
+
inplace_pin=False, # inplace pin memory is not compatible with shared memory pool
|
|
1019
1012
|
)
|
|
1020
1013
|
self._current_shared_memory_pool_user = checkpoint_name
|
|
1021
1014
|
if self._p2p_store is not None and _is_first_time:
|
|
@@ -1074,6 +1067,46 @@ class ParameterServer:
|
|
|
1074
1067
|
del self._memory_pool[self.shared_memory_pool_name]
|
|
1075
1068
|
self._memory_pool[self.shared_memory_pool_name] = []
|
|
1076
1069
|
else:
|
|
1070
|
+
|
|
1071
|
+
def _unpin(t: torch.Tensor):
|
|
1072
|
+
"""
|
|
1073
|
+
Un-pin the pinned memory.
|
|
1074
|
+
"""
|
|
1075
|
+
p_flags = ctypes.c_uint()
|
|
1076
|
+
try:
|
|
1077
|
+
libc = ctypes.CDLL(None) # get all symbols from the current process
|
|
1078
|
+
cuda_host_get_flags = libc.cudaHostGetFlags
|
|
1079
|
+
cuda_host_get_flags.argtypes = [ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p]
|
|
1080
|
+
cuda_host_get_flags.restype = ctypes.c_int
|
|
1081
|
+
except AttributeError:
|
|
1082
|
+
logger.error("cudaHostGetFlags not found in libc, cannot unpin memory manually")
|
|
1083
|
+
raise
|
|
1084
|
+
r = cuda_host_get_flags(ctypes.byref(p_flags), ctypes.c_void_p(t.data_ptr()))
|
|
1085
|
+
assert r == 0, f"get pin flags error, error code: {r}"
|
|
1086
|
+
# p_flags value meaning from cuda/include/driver_types.h
|
|
1087
|
+
# cudaHostRegisterDefault 0x00 /**< Default host memory registration flag */
|
|
1088
|
+
# cudaHostRegisterPortable 0x01 /**< Pinned memory accessible by all CUDA contexts */
|
|
1089
|
+
# cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */
|
|
1090
|
+
# cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */
|
|
1091
|
+
# cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */
|
|
1092
|
+
assert p_flags.value == 0x02, (
|
|
1093
|
+
f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
|
|
1094
|
+
)
|
|
1095
|
+
cudart = torch.cuda.cudart()
|
|
1096
|
+
r = cudart.cudaHostUnregister(t.data_ptr())
|
|
1097
|
+
assert r == 0, f"unpin memory error, error code: {r}"
|
|
1098
|
+
|
|
1099
|
+
# if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
|
|
1100
|
+
try:
|
|
1101
|
+
for memory_buffer in self._memory_pool.get(checkpoint_name, []):
|
|
1102
|
+
if memory_buffer.manually_pinned:
|
|
1103
|
+
_unpin(memory_buffer.buffer)
|
|
1104
|
+
except Exception as e:
|
|
1105
|
+
logger.error(
|
|
1106
|
+
f"[rank{self._rank}] fail to unpin memory for checkpoint {checkpoint_name}: {e}"
|
|
1107
|
+
)
|
|
1108
|
+
raise
|
|
1109
|
+
# we won't delete the memory pool if unpinning fails.
|
|
1077
1110
|
del self._memory_pool[checkpoint_name]
|
|
1078
1111
|
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
1079
1112
|
# this works by using torch>=2.5.0
|
|
@@ -1176,15 +1209,41 @@ class ParameterServer:
|
|
|
1176
1209
|
)
|
|
1177
1210
|
logger.info(f"[rank{self._rank}] init process group successfully.")
|
|
1178
1211
|
|
|
1212
|
+
def store_based_barrier(
|
|
1213
|
+
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
|
|
1214
|
+
) -> None:
|
|
1215
|
+
"""
|
|
1216
|
+
Perform a store-based barrier synchronization across all ranks.
|
|
1217
|
+
|
|
1218
|
+
This barrier uses a TCP store directly rather than a process group,
|
|
1219
|
+
allowing all ranks to synchronize regardless of which process group
|
|
1220
|
+
they belong to.
|
|
1221
|
+
|
|
1222
|
+
Args:
|
|
1223
|
+
store: The TCPStore instance to use for synchronization.
|
|
1224
|
+
"""
|
|
1225
|
+
dist.distributed_c10d._store_based_barrier(
|
|
1226
|
+
rank=self._rank,
|
|
1227
|
+
store=store,
|
|
1228
|
+
group_name="parameter_server_barrier",
|
|
1229
|
+
rendezvous_count=self._world_size,
|
|
1230
|
+
timeout=timeout,
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1179
1233
|
def update(
|
|
1180
1234
|
self,
|
|
1181
1235
|
checkpoint_name: str,
|
|
1182
1236
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1183
1237
|
*,
|
|
1238
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
1184
1239
|
ranks: list[int] | None = None,
|
|
1240
|
+
master_addr: str | None = None,
|
|
1241
|
+
master_port: int | None = None,
|
|
1185
1242
|
) -> None:
|
|
1186
1243
|
"""
|
|
1187
1244
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
1245
|
+
Warning: if _auto_pg is False when initializing ParameterServer, please make sure ALL ranks in the WORLD_SIZE call `update` function,
|
|
1246
|
+
otherwise, it will hang.
|
|
1188
1247
|
|
|
1189
1248
|
Args:
|
|
1190
1249
|
checkpoint_name: The name of the checkpoint.
|
|
@@ -1193,34 +1252,45 @@ class ParameterServer:
|
|
|
1193
1252
|
which is the fastest way to update weights, especially in colocated architecture.
|
|
1194
1253
|
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
|
|
1195
1254
|
which is useful in disaggregated architecture.
|
|
1255
|
+
master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
|
|
1256
|
+
master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
|
|
1257
|
+
timeout: The timeout of the barrier operation.
|
|
1196
1258
|
"""
|
|
1197
1259
|
assert req_func is not None, "req_func is required"
|
|
1260
|
+
ranks_group = None
|
|
1198
1261
|
try:
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1262
|
+
master_addr = os.getenv("MASTER_ADDR") or master_addr
|
|
1263
|
+
assert master_addr, "master_addr is required"
|
|
1264
|
+
if self._auto_pg:
|
|
1265
|
+
if not dist.is_initialized():
|
|
1266
|
+
self.init_process_group(
|
|
1267
|
+
timeout=timeout, master_addr=master_addr, master_port=master_port
|
|
1268
|
+
)
|
|
1269
|
+
manager_store = dist.distributed_c10d._get_default_store()
|
|
1204
1270
|
else:
|
|
1205
|
-
if
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
self.
|
|
1211
|
-
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
1271
|
+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
|
|
1272
|
+
# If master_port is provided, use master_port+1 for barrier store
|
|
1273
|
+
manager_store = dist.TCPStore(
|
|
1274
|
+
master_addr,
|
|
1275
|
+
_get_master_port(master_port) + 1,
|
|
1276
|
+
self._world_size,
|
|
1277
|
+
timeout=timeout,
|
|
1278
|
+
is_master=self._rank == 0,
|
|
1279
|
+
)
|
|
1280
|
+
# if ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1281
|
+
ranks_group = dist.new_group(ranks) if ranks else None
|
|
1282
|
+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
|
|
1283
|
+
self.store_based_barrier(manager_store)
|
|
1215
1284
|
except Exception as e:
|
|
1216
1285
|
logger.exception(
|
|
1217
1286
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
1218
1287
|
)
|
|
1219
1288
|
raise
|
|
1220
1289
|
finally:
|
|
1221
|
-
if
|
|
1290
|
+
if ranks_group:
|
|
1291
|
+
dist.destroy_process_group(ranks_group)
|
|
1292
|
+
if self._auto_pg and dist.is_initialized():
|
|
1222
1293
|
dist.destroy_process_group()
|
|
1223
|
-
|
|
1224
1294
|
self.device_manager.device_module.empty_cache()
|
|
1225
1295
|
logger.info(
|
|
1226
1296
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
@@ -1238,7 +1308,9 @@ class ParameterServer:
|
|
|
1238
1308
|
self._zmq_addr_counter += 1
|
|
1239
1309
|
return socket, socket_paths
|
|
1240
1310
|
|
|
1241
|
-
def _detect_bucket_size(
|
|
1311
|
+
def _detect_bucket_size(
|
|
1312
|
+
self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
|
|
1313
|
+
) -> tuple[int, bool]:
|
|
1242
1314
|
GiB = 1 << 30 # noqa: N806
|
|
1243
1315
|
# auto detect bucket size
|
|
1244
1316
|
tensor = torch.tensor(
|
|
@@ -1254,7 +1326,7 @@ class ParameterServer:
|
|
|
1254
1326
|
dtype=torch.int64,
|
|
1255
1327
|
device=self.device_manager.device_type,
|
|
1256
1328
|
)
|
|
1257
|
-
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
1329
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
|
|
1258
1330
|
tensor = tensor.cpu()
|
|
1259
1331
|
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
|
|
1260
1332
|
max_tensor_bytes = 0
|
|
@@ -1317,51 +1389,6 @@ class ParameterServer:
|
|
|
1317
1389
|
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
|
|
1318
1390
|
self.device_manager.device_module.synchronize()
|
|
1319
1391
|
|
|
1320
|
-
def init_process_group_for_ranks(
|
|
1321
|
-
self,
|
|
1322
|
-
ranks: list[int],
|
|
1323
|
-
*,
|
|
1324
|
-
master_port: int | None = None,
|
|
1325
|
-
timeout: timedelta = timedelta(minutes=10),
|
|
1326
|
-
):
|
|
1327
|
-
"""
|
|
1328
|
-
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
1329
|
-
|
|
1330
|
-
Args:
|
|
1331
|
-
ranks: The ranks to initialize the process group. ranks should be a subset of all ranks.
|
|
1332
|
-
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
|
|
1333
|
-
timeout: The timeout of the process group.
|
|
1334
|
-
"""
|
|
1335
|
-
assert not dist.is_initialized()
|
|
1336
|
-
assert ranks, "ranks should be set"
|
|
1337
|
-
if self._rank not in ranks:
|
|
1338
|
-
return
|
|
1339
|
-
assert self._all_hosts, "all_hosts should be set"
|
|
1340
|
-
assert len(self._all_hosts) == self._world_size // self._gpu_count, (
|
|
1341
|
-
f"world_size {self._world_size} should be equal to all_hosts {len(self._all_hosts)}"
|
|
1342
|
-
)
|
|
1343
|
-
rank = ranks.index(self._rank)
|
|
1344
|
-
master_addr = self._all_hosts[ranks[0] // self._gpu_count]
|
|
1345
|
-
master_port = _get_master_port(master_port)
|
|
1346
|
-
logger.info(
|
|
1347
|
-
f"[rank{self._rank}] start to init process group as virtual_rank {rank}, "
|
|
1348
|
-
f"master_addr {master_addr}, master_port {master_port}, world_size {len(ranks)}, "
|
|
1349
|
-
)
|
|
1350
|
-
# only initialize process group and store for ranks, other nodes are not initialized
|
|
1351
|
-
# and will not participate in this update. Since they have registered memory addresses
|
|
1352
|
-
# to p2p_store at the beginning, update ranks can directly get the memory addresses
|
|
1353
|
-
# from other nodes and put the weights into the buffer.
|
|
1354
|
-
store = dist.TCPStore(
|
|
1355
|
-
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
1356
|
-
)
|
|
1357
|
-
dist.init_process_group(
|
|
1358
|
-
backend=self.device_manager.backend,
|
|
1359
|
-
world_size=len(ranks),
|
|
1360
|
-
rank=rank,
|
|
1361
|
-
timeout=timeout,
|
|
1362
|
-
store=store,
|
|
1363
|
-
)
|
|
1364
|
-
|
|
1365
1392
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
1366
1393
|
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
|
|
1367
1394
|
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
|
|
@@ -1401,10 +1428,12 @@ class ParameterServer:
|
|
|
1401
1428
|
self,
|
|
1402
1429
|
checkpoint_name: str,
|
|
1403
1430
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1431
|
+
ranks_group: dist.ProcessGroup | None,
|
|
1404
1432
|
ranks: list[int] | None = None,
|
|
1405
1433
|
):
|
|
1406
1434
|
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
1407
1435
|
assert dist.is_initialized(), "process group is not initialized"
|
|
1436
|
+
|
|
1408
1437
|
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1409
1438
|
if not ranks:
|
|
1410
1439
|
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
@@ -1422,9 +1451,9 @@ class ParameterServer:
|
|
|
1422
1451
|
if not need_update:
|
|
1423
1452
|
return
|
|
1424
1453
|
# first execute a barrier to avoid subsequent device oom
|
|
1425
|
-
dist.barrier()
|
|
1454
|
+
dist.barrier(group=ranks_group)
|
|
1426
1455
|
|
|
1427
|
-
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
1456
|
+
bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
|
|
1428
1457
|
buckets = _gen_h2d_buckets(
|
|
1429
1458
|
self._current_global_parameter_metas,
|
|
1430
1459
|
bucket_size,
|
|
@@ -1471,7 +1500,6 @@ class ParameterServer:
|
|
|
1471
1500
|
|
|
1472
1501
|
gidx = 0
|
|
1473
1502
|
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
|
|
1474
|
-
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
|
|
1475
1503
|
try:
|
|
1476
1504
|
for i in range(max_len):
|
|
1477
1505
|
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
@@ -1501,8 +1529,7 @@ class ParameterServer:
|
|
|
1501
1529
|
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1502
1530
|
else:
|
|
1503
1531
|
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1504
|
-
|
|
1505
|
-
dist.broadcast(buffer_b, src=brank)
|
|
1532
|
+
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
|
|
1506
1533
|
resp = socket.recv()
|
|
1507
1534
|
if resp != b"":
|
|
1508
1535
|
msg = resp.decode("utf-8")
|
|
@@ -1510,7 +1537,7 @@ class ParameterServer:
|
|
|
1510
1537
|
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
1511
1538
|
)
|
|
1512
1539
|
ret_code.fill_(1)
|
|
1513
|
-
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
|
1540
|
+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
|
|
1514
1541
|
self.device_manager.device_module.synchronize()
|
|
1515
1542
|
if ret_code.item() != 0:
|
|
1516
1543
|
# quit early if any rank failed
|
|
@@ -1524,7 +1551,7 @@ class ParameterServer:
|
|
|
1524
1551
|
socket.recv()
|
|
1525
1552
|
finally:
|
|
1526
1553
|
req_thread.join()
|
|
1527
|
-
dist.barrier()
|
|
1554
|
+
dist.barrier(group=ranks_group)
|
|
1528
1555
|
socket.close()
|
|
1529
1556
|
if ranks and h2d_buffer is not None:
|
|
1530
1557
|
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0rc1
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
{checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/SOURCES.txt
RENAMED
|
@@ -23,6 +23,7 @@ figures/overlap-update-and-copy.png
|
|
|
23
23
|
figures/pipeline.png
|
|
24
24
|
patches/vllm_fp8.patch
|
|
25
25
|
tests/test_assign_receiver_ranks.py
|
|
26
|
-
tests/
|
|
26
|
+
tests/test_inplace_unpin.py
|
|
27
27
|
tests/test_rdma_parser.py
|
|
28
|
+
tests/test_reuse_pin_memory.py
|
|
28
29
|
tests/test_update.py
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
import time
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
from test_update import device_manager, gen_test_tensors, get_world_size
|
|
8
|
+
|
|
9
|
+
from checkpoint_engine.ps import ParameterServer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_files() -> list[str]:
|
|
16
|
+
rank = int(os.getenv("RANK"))
|
|
17
|
+
named_tensors = dict(gen_test_tensors(rank))
|
|
18
|
+
import safetensors.torch
|
|
19
|
+
|
|
20
|
+
files = []
|
|
21
|
+
os.makedirs(dev_shm_dir, exist_ok=True)
|
|
22
|
+
tensors_in_dev_shm = named_tensors
|
|
23
|
+
time.sleep(1)
|
|
24
|
+
dev_shm_files = [
|
|
25
|
+
os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
|
|
26
|
+
for _ in range(get_world_size())
|
|
27
|
+
]
|
|
28
|
+
safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
|
|
29
|
+
time.sleep(1)
|
|
30
|
+
files.append(dev_shm_files[rank])
|
|
31
|
+
return files
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def run_pin_and_unpin(num_runs: int):
|
|
35
|
+
ps = ParameterServer(auto_pg=True)
|
|
36
|
+
checkpoint_name = "test_with_files"
|
|
37
|
+
for _ in range(num_runs):
|
|
38
|
+
files = get_files()
|
|
39
|
+
ps.register_checkpoint(checkpoint_name, files=files)
|
|
40
|
+
ps.gather_metas(checkpoint_name)
|
|
41
|
+
dist.barrier()
|
|
42
|
+
ps.unregister_checkpoint(checkpoint_name)
|
|
43
|
+
if ps._rank == 0:
|
|
44
|
+
import shutil
|
|
45
|
+
|
|
46
|
+
shutil.rmtree(dev_shm_dir)
|
|
47
|
+
|
|
48
|
+
dist.destroy_process_group()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@pytest.mark.gpu
|
|
52
|
+
def test_unpin_files():
|
|
53
|
+
world_size = device_manager.device_module.device_count()
|
|
54
|
+
assert world_size >= 2, "This test requires at least 2 GPUs."
|
|
55
|
+
master_addr = "localhost"
|
|
56
|
+
master_port = 25400
|
|
57
|
+
cmd = [
|
|
58
|
+
"torchrun",
|
|
59
|
+
"--nproc_per_node",
|
|
60
|
+
str(world_size),
|
|
61
|
+
"--master_addr",
|
|
62
|
+
master_addr,
|
|
63
|
+
"--master_port",
|
|
64
|
+
str(master_port),
|
|
65
|
+
__file__,
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
result = subprocess.run( # noqa: S603
|
|
69
|
+
cmd,
|
|
70
|
+
capture_output=False,
|
|
71
|
+
text=True,
|
|
72
|
+
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
73
|
+
shell=False,
|
|
74
|
+
check=False,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
assert result.returncode == 0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
run_pin_and_unpin(3)
|
|
@@ -185,7 +185,6 @@ def run_with_files(
|
|
|
185
185
|
os.makedirs(dev_shm_dir, exist_ok=True)
|
|
186
186
|
os.makedirs(disk_dir, exist_ok=True)
|
|
187
187
|
tensors_items = list(named_tensors.items())
|
|
188
|
-
tensors_in_dev_shm = named_tensors
|
|
189
188
|
tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
|
|
190
189
|
tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
|
|
191
190
|
tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
|
|
@@ -218,7 +217,6 @@ def run_with_files(
|
|
|
218
217
|
if rank == 0:
|
|
219
218
|
import shutil
|
|
220
219
|
|
|
221
|
-
# this test should be run under use_inplace_pin_memory=False. Otherwise, the files in /dev/shm/ will be deleted.
|
|
222
220
|
shutil.rmtree(dev_shm_dir)
|
|
223
221
|
shutil.rmtree(disk_dir)
|
|
224
222
|
assert proc.exitcode == 0
|
|
@@ -238,7 +236,13 @@ def run_with_files(
|
|
|
238
236
|
],
|
|
239
237
|
),
|
|
240
238
|
("test_with_remote_error", [[]]),
|
|
241
|
-
|
|
239
|
+
(
|
|
240
|
+
"test_no_error",
|
|
241
|
+
[
|
|
242
|
+
list(random.sample(range(get_world_size()), k=num_ranks))
|
|
243
|
+
for num_ranks in range(get_world_size() + 1)
|
|
244
|
+
],
|
|
245
|
+
),
|
|
242
246
|
],
|
|
243
247
|
)
|
|
244
248
|
def test_update(test_name: str, rank_list: list[list[int]] | None):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/requires.txt
RENAMED
|
File without changes
|
{checkpoint_engine-0.2.3 → checkpoint_engine-0.3.0rc1}/checkpoint_engine.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|