checkpoint-engine 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/ps.py +247 -38
- checkpoint_engine/worker.py +5 -2
- {checkpoint_engine-0.2.1.dist-info → checkpoint_engine-0.2.2.dist-info}/METADATA +9 -6
- checkpoint_engine-0.2.2.dist-info/RECORD +10 -0
- checkpoint_engine-0.2.1.dist-info/RECORD +0 -10
- {checkpoint_engine-0.2.1.dist-info → checkpoint_engine-0.2.2.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.2.1.dist-info → checkpoint_engine-0.2.2.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.2.1.dist-info → checkpoint_engine-0.2.2.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 2)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
checkpoint_engine/ps.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import concurrent.futures
|
|
3
3
|
import ctypes
|
|
4
|
+
import json
|
|
4
5
|
import os
|
|
5
6
|
import pickle
|
|
6
7
|
import random
|
|
@@ -18,7 +19,7 @@ import torch.distributed as dist
|
|
|
18
19
|
import zmq
|
|
19
20
|
from loguru import logger
|
|
20
21
|
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
21
|
-
from safetensors.torch import safe_open
|
|
22
|
+
from safetensors.torch import _getdtype, safe_open
|
|
22
23
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
23
24
|
|
|
24
25
|
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
|
|
|
92
93
|
name: str
|
|
93
94
|
dtype: _TorchDtype
|
|
94
95
|
shape: _TorchSize
|
|
96
|
+
aligned_size: int
|
|
95
97
|
|
|
96
98
|
|
|
97
99
|
class BucketRange(NamedTuple):
|
|
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
|
140
142
|
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
141
143
|
ret = []
|
|
142
144
|
for meta in metas:
|
|
143
|
-
size =
|
|
145
|
+
size = meta.aligned_size
|
|
144
146
|
ret.append(
|
|
145
147
|
{
|
|
146
148
|
"name": meta.name,
|
|
@@ -422,6 +424,7 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
422
424
|
name=parameter_name,
|
|
423
425
|
shape=meta["shape"],
|
|
424
426
|
dtype=meta["dtype"],
|
|
427
|
+
aligned_size=_align_size(meta["dtype"], meta["shape"]),
|
|
425
428
|
)
|
|
426
429
|
tp_meta = tp_metas[parameter_name]
|
|
427
430
|
if tp_meta.concat_dim != -1:
|
|
@@ -431,7 +434,10 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
431
434
|
shape = list(parameter_metas[name].shape)
|
|
432
435
|
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
433
436
|
parameter_metas[name] = ParameterMeta(
|
|
434
|
-
name=name,
|
|
437
|
+
name=name,
|
|
438
|
+
shape=torch.Size(shape),
|
|
439
|
+
dtype=parameter_metas[name].dtype,
|
|
440
|
+
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
|
|
435
441
|
)
|
|
436
442
|
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
437
443
|
# TODO: here concat is serial, which may be slow
|
|
@@ -449,17 +455,85 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
449
455
|
return parameters
|
|
450
456
|
|
|
451
457
|
|
|
452
|
-
def
|
|
453
|
-
|
|
458
|
+
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
459
|
+
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
460
|
+
"""
|
|
461
|
+
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
462
|
+
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
|
|
463
|
+
The actual tensor data is in the remaining bytes and is naturally aligned.
|
|
464
|
+
We pin the remaining bytes as the buffer, making pinning faster.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def _pin(t: torch.Tensor):
|
|
468
|
+
"""
|
|
469
|
+
Pin the memory of tensor in-place.
|
|
470
|
+
See: https://github.com/pytorch/pytorch/issues/32167
|
|
471
|
+
"""
|
|
472
|
+
cudart = torch.cuda.cudart()
|
|
473
|
+
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
474
|
+
assert r == 0, f"pin memory error, error code: {r}"
|
|
475
|
+
|
|
476
|
+
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
477
|
+
size = os.stat(file_path).st_size
|
|
478
|
+
flag_size = 8
|
|
479
|
+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
|
|
480
|
+
assert t.nbytes > flag_size, (
|
|
481
|
+
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
|
|
482
|
+
)
|
|
483
|
+
start_pos = (
|
|
484
|
+
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
|
|
485
|
+
+ flag_size
|
|
486
|
+
)
|
|
487
|
+
header_tensor = t[flag_size:start_pos]
|
|
488
|
+
header = json.loads(header_tensor.numpy().tobytes())
|
|
489
|
+
if "__metadata__" in header:
|
|
490
|
+
header.pop("__metadata__")
|
|
491
|
+
|
|
492
|
+
metas: list[ParameterMeta] = []
|
|
493
|
+
offset = 0
|
|
494
|
+
try:
|
|
495
|
+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
|
|
496
|
+
start, end = meta["data_offsets"]
|
|
497
|
+
# safetensors format ensures offsets are aligned
|
|
498
|
+
assert offset == start, f"offset {offset} should be equal to start {start}"
|
|
499
|
+
metas.append(
|
|
500
|
+
ParameterMeta(
|
|
501
|
+
name=name,
|
|
502
|
+
dtype=_getdtype(meta["dtype"]),
|
|
503
|
+
shape=torch.Size(meta["shape"]),
|
|
504
|
+
aligned_size=end - start,
|
|
505
|
+
)
|
|
506
|
+
)
|
|
507
|
+
offset = end
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
|
|
510
|
+
raise
|
|
511
|
+
|
|
512
|
+
buffer = t[start_pos:]
|
|
513
|
+
assert offset == buffer.nbytes, (
|
|
514
|
+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
|
|
515
|
+
)
|
|
516
|
+
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
517
|
+
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
518
|
+
os.remove(file_path)
|
|
519
|
+
_pin(buffer)
|
|
520
|
+
logger.info(
|
|
521
|
+
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
522
|
+
)
|
|
523
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
|
|
524
|
+
|
|
525
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
526
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
527
|
+
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
|
|
528
|
+
return memory_buffers
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _normal_pin_memory(
|
|
454
532
|
files: list[str],
|
|
455
533
|
named_tensors: dict[str, torch.Tensor],
|
|
456
534
|
rank: int | None = None,
|
|
535
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
457
536
|
) -> list[MemoryBuffer]:
|
|
458
|
-
logger.info(
|
|
459
|
-
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
460
|
-
)
|
|
461
|
-
if not files and not named_tensors:
|
|
462
|
-
return []
|
|
463
537
|
parameters = _load_checkpoint(files)
|
|
464
538
|
if named_tensors:
|
|
465
539
|
parameters.update(named_tensors)
|
|
@@ -469,13 +543,16 @@ def _register_checkpoint(
|
|
|
469
543
|
size: int
|
|
470
544
|
metas: list[ParameterMeta]
|
|
471
545
|
|
|
472
|
-
buckets: list[MemoryBucket] = [
|
|
546
|
+
buckets: list[MemoryBucket] = []
|
|
547
|
+
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
473
548
|
for name, tensor in sorted(parameters.items()):
|
|
474
549
|
size = _align_size(tensor.dtype, tensor.shape)
|
|
475
550
|
if buckets[-1].size + size > bucket_size:
|
|
476
551
|
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
|
|
477
552
|
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
478
|
-
buckets[-1].metas.append(
|
|
553
|
+
buckets[-1].metas.append(
|
|
554
|
+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
|
|
555
|
+
)
|
|
479
556
|
buckets[-1].size += size
|
|
480
557
|
|
|
481
558
|
memory_buffers = [
|
|
@@ -483,16 +560,34 @@ def _register_checkpoint(
|
|
|
483
560
|
for bucket in buckets
|
|
484
561
|
]
|
|
485
562
|
|
|
486
|
-
def register_pin_memory(
|
|
487
|
-
|
|
488
|
-
|
|
563
|
+
def register_pin_memory(
|
|
564
|
+
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
|
|
565
|
+
) -> tuple[int, torch.Tensor]:
|
|
566
|
+
if shared_pin_memory:
|
|
567
|
+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
|
|
568
|
+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
|
|
569
|
+
assert idx < len(shared_pin_memory), (
|
|
570
|
+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
|
|
571
|
+
)
|
|
572
|
+
assert shared_pin_memory[idx].size == size, (
|
|
573
|
+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
|
|
574
|
+
)
|
|
575
|
+
return idx, shared_pin_memory[idx].buffer
|
|
576
|
+
else:
|
|
577
|
+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
578
|
+
return idx, buffer
|
|
489
579
|
|
|
490
580
|
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
|
|
491
581
|
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
492
582
|
|
|
493
583
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
494
584
|
futures = [
|
|
495
|
-
executor.submit(
|
|
585
|
+
executor.submit(
|
|
586
|
+
register_pin_memory,
|
|
587
|
+
idx,
|
|
588
|
+
bucket.size,
|
|
589
|
+
shared_pin_memory,
|
|
590
|
+
)
|
|
496
591
|
for idx, bucket in enumerate(buckets)
|
|
497
592
|
]
|
|
498
593
|
new_futures = []
|
|
@@ -518,6 +613,39 @@ def _register_checkpoint(
|
|
|
518
613
|
offset += size
|
|
519
614
|
for future in concurrent.futures.as_completed(new_futures):
|
|
520
615
|
future.result()
|
|
616
|
+
return memory_buffers
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _register_checkpoint(
|
|
620
|
+
*,
|
|
621
|
+
files: list[str],
|
|
622
|
+
named_tensors: dict[str, torch.Tensor],
|
|
623
|
+
rank: int | None = None,
|
|
624
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
625
|
+
) -> list[MemoryBuffer]:
|
|
626
|
+
logger.info(
|
|
627
|
+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
628
|
+
)
|
|
629
|
+
if not files and not named_tensors:
|
|
630
|
+
return []
|
|
631
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
632
|
+
files_to_inplace_pin = [
|
|
633
|
+
file
|
|
634
|
+
for file in files
|
|
635
|
+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
636
|
+
]
|
|
637
|
+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
638
|
+
if files_to_normal_pin or named_tensors:
|
|
639
|
+
memory_buffers.extend(
|
|
640
|
+
_normal_pin_memory(
|
|
641
|
+
files=files_to_normal_pin,
|
|
642
|
+
named_tensors=named_tensors,
|
|
643
|
+
rank=rank,
|
|
644
|
+
shared_pin_memory=shared_pin_memory,
|
|
645
|
+
)
|
|
646
|
+
)
|
|
647
|
+
if files_to_inplace_pin:
|
|
648
|
+
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
|
|
521
649
|
return memory_buffers
|
|
522
650
|
|
|
523
651
|
|
|
@@ -566,7 +694,7 @@ def _gen_h2d_buckets(
|
|
|
566
694
|
for idx, metas in enumerate(items.memory_buffer_metas_list):
|
|
567
695
|
start_offset, offset = 0, 0
|
|
568
696
|
for meta in metas.metas:
|
|
569
|
-
s =
|
|
697
|
+
s = meta.aligned_size
|
|
570
698
|
if buckets[-1][1].size + s > bucket_size:
|
|
571
699
|
if offset - start_offset > 0:
|
|
572
700
|
buckets[-1][1].ranges.append(
|
|
@@ -747,6 +875,8 @@ class P2PStore:
|
|
|
747
875
|
|
|
748
876
|
|
|
749
877
|
class ParameterServer:
|
|
878
|
+
shared_memory_pool_name = "__shared_memory_pool__"
|
|
879
|
+
|
|
750
880
|
def __init__(
|
|
751
881
|
self,
|
|
752
882
|
*,
|
|
@@ -790,7 +920,10 @@ class ParameterServer:
|
|
|
790
920
|
self._zmq_ctx = zmq.Context()
|
|
791
921
|
self._zmq_addr_counter = 0
|
|
792
922
|
|
|
923
|
+
# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
|
|
924
|
+
self._current_shared_memory_pool_user: str = ""
|
|
793
925
|
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
926
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
794
927
|
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
795
928
|
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
|
|
796
929
|
# NPU transfer engine initialization requires prior set_device.
|
|
@@ -805,6 +938,17 @@ class ParameterServer:
|
|
|
805
938
|
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
|
|
806
939
|
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
807
940
|
|
|
941
|
+
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
|
|
942
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
943
|
+
assert self._memory_pool[self.shared_memory_pool_name], (
|
|
944
|
+
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
|
|
945
|
+
)
|
|
946
|
+
return self._memory_pool[self.shared_memory_pool_name]
|
|
947
|
+
elif checkpoint_name in self._memory_pool:
|
|
948
|
+
return self._memory_pool[checkpoint_name]
|
|
949
|
+
else:
|
|
950
|
+
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
|
|
951
|
+
|
|
808
952
|
def _logger_rank0(self, msg: str):
|
|
809
953
|
if self._local_rank == 0:
|
|
810
954
|
logger.info(msg)
|
|
@@ -828,46 +972,97 @@ class ParameterServer:
|
|
|
828
972
|
*,
|
|
829
973
|
files: list[str] | None = None,
|
|
830
974
|
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
975
|
+
use_shared_memory_pool: bool = False,
|
|
831
976
|
) -> None:
|
|
832
977
|
"""
|
|
833
978
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
979
|
+
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
980
|
+
Please make sure to copy the files to disks if you need to keep them.
|
|
834
981
|
|
|
835
982
|
Args:
|
|
836
983
|
checkpoint_name: The name of the checkpoint.
|
|
837
984
|
files: The safetensors files to register.
|
|
838
985
|
named_tensors: The named tensors to register.
|
|
986
|
+
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
|
|
987
|
+
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
|
|
988
|
+
cannot accommodate checkpoints with different memory requirements.
|
|
989
|
+
To free the actual memory of the shared pool or to modify its shape,
|
|
990
|
+
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
839
991
|
"""
|
|
840
992
|
try:
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
993
|
+
if use_shared_memory_pool:
|
|
994
|
+
logger.info(
|
|
995
|
+
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
|
|
996
|
+
)
|
|
997
|
+
assert self._current_shared_memory_pool_user == "", (
|
|
998
|
+
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
|
|
999
|
+
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
|
|
1000
|
+
f"This registration may cause unexpected conflicts."
|
|
1001
|
+
)
|
|
1002
|
+
# Since we set the uninitialized shared memory pool to empty list,
|
|
1003
|
+
# we can check whether this is the first time to use shared memory pool
|
|
1004
|
+
_is_first_time = not self._memory_pool[self.shared_memory_pool_name]
|
|
1005
|
+
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
|
|
1006
|
+
files=files or [],
|
|
1007
|
+
named_tensors=named_tensors or {},
|
|
1008
|
+
rank=self._rank,
|
|
1009
|
+
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
|
|
1010
|
+
)
|
|
1011
|
+
self._current_shared_memory_pool_user = checkpoint_name
|
|
1012
|
+
if self._p2p_store is not None and _is_first_time:
|
|
1013
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
1014
|
+
else:
|
|
1015
|
+
assert checkpoint_name not in self._memory_pool, (
|
|
1016
|
+
f"checkpoint {checkpoint_name} already registered"
|
|
1017
|
+
)
|
|
1018
|
+
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
1019
|
+
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
|
|
1020
|
+
)
|
|
1021
|
+
if self._p2p_store is not None:
|
|
1022
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
849
1023
|
except Exception:
|
|
850
1024
|
logger.exception(
|
|
851
1025
|
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
|
|
852
1026
|
)
|
|
853
|
-
if self._p2p_store is not None:
|
|
1027
|
+
if self._p2p_store is not None and not use_shared_memory_pool:
|
|
854
1028
|
self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
855
1029
|
self.unregister_checkpoint(checkpoint_name)
|
|
856
1030
|
raise
|
|
857
1031
|
|
|
858
|
-
def unregister_checkpoint(self, checkpoint_name: str):
|
|
1032
|
+
def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
|
|
859
1033
|
"""
|
|
860
1034
|
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
|
|
861
1035
|
from p2p store if p2p store is initialized.
|
|
1036
|
+
Args:
|
|
1037
|
+
checkpoint_name: The name of the checkpoint.
|
|
1038
|
+
force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
|
|
1039
|
+
If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
|
|
862
1040
|
"""
|
|
863
|
-
if
|
|
1041
|
+
if (
|
|
1042
|
+
checkpoint_name not in self._memory_pool
|
|
1043
|
+
and checkpoint_name != self._current_shared_memory_pool_user
|
|
1044
|
+
):
|
|
1045
|
+
logger.warning(
|
|
1046
|
+
f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
|
|
1047
|
+
)
|
|
1048
|
+
return
|
|
1049
|
+
|
|
1050
|
+
if checkpoint_name == self._current_shared_memory_pool_user and not force:
|
|
1051
|
+
self._current_shared_memory_pool_user = ""
|
|
864
1052
|
return
|
|
1053
|
+
|
|
865
1054
|
if self._p2p_store is not None:
|
|
866
1055
|
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
867
1056
|
logger.info(
|
|
868
1057
|
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
|
|
869
1058
|
)
|
|
870
|
-
|
|
1059
|
+
|
|
1060
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
1061
|
+
self._current_shared_memory_pool_user = ""
|
|
1062
|
+
del self._memory_pool[self.shared_memory_pool_name]
|
|
1063
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
1064
|
+
else:
|
|
1065
|
+
del self._memory_pool[checkpoint_name]
|
|
871
1066
|
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
872
1067
|
# this works by using torch>=2.5.0
|
|
873
1068
|
torch._C._host_emptyCache()
|
|
@@ -882,6 +1077,10 @@ class ParameterServer:
|
|
|
882
1077
|
self.init_process_group()
|
|
883
1078
|
assert dist.is_initialized(), "process group is not initialized"
|
|
884
1079
|
metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
|
|
1080
|
+
try:
|
|
1081
|
+
memory_pool = self._get_memory_pool(checkpoint_name)
|
|
1082
|
+
except RuntimeError:
|
|
1083
|
+
memory_pool = []
|
|
885
1084
|
metas = DataToGather(
|
|
886
1085
|
memory_buffer_metas_list=[
|
|
887
1086
|
MemoryBufferMetas(
|
|
@@ -889,7 +1088,7 @@ class ParameterServer:
|
|
|
889
1088
|
ptr=x.buffer.data_ptr(),
|
|
890
1089
|
size=x.size,
|
|
891
1090
|
)
|
|
892
|
-
for x in
|
|
1091
|
+
for x in memory_pool
|
|
893
1092
|
],
|
|
894
1093
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
895
1094
|
host_ip=get_ip(),
|
|
@@ -1050,7 +1249,7 @@ class ParameterServer:
|
|
|
1050
1249
|
for items in self._current_global_parameter_metas.values():
|
|
1051
1250
|
for metas_list in items.memory_buffer_metas_list:
|
|
1052
1251
|
for meta in metas_list.metas:
|
|
1053
|
-
max_tensor_bytes = max(max_tensor_bytes,
|
|
1252
|
+
max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
|
|
1054
1253
|
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
|
|
1055
1254
|
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
|
|
1056
1255
|
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
|
|
@@ -1095,7 +1294,7 @@ class ParameterServer:
|
|
|
1095
1294
|
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
|
|
1096
1295
|
lens.append(b.size)
|
|
1097
1296
|
else:
|
|
1098
|
-
pool = self.
|
|
1297
|
+
pool = self._get_memory_pool(checkpoint_name)[b.idx]
|
|
1099
1298
|
buffer[offset : offset + b.size].data.copy_(
|
|
1100
1299
|
pool.buffer[b.offset : b.offset + b.size],
|
|
1101
1300
|
non_blocking=True,
|
|
@@ -1158,22 +1357,32 @@ class ParameterServer:
|
|
|
1158
1357
|
|
|
1159
1358
|
def _register_parameters_to_p2p_store(self, checkpoint_name: str):
|
|
1160
1359
|
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1161
|
-
pool = self.
|
|
1360
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1162
1361
|
if len(pool) == 0:
|
|
1163
1362
|
return
|
|
1164
1363
|
named_tensors, tensor_ptrs = {}, []
|
|
1364
|
+
register_name = (
|
|
1365
|
+
checkpoint_name
|
|
1366
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1367
|
+
else self.shared_memory_pool_name
|
|
1368
|
+
)
|
|
1165
1369
|
for idx, memory_buffer in enumerate(pool):
|
|
1166
|
-
named_tensors[f"memory_pool_{
|
|
1370
|
+
named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
|
|
1167
1371
|
tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
|
|
1168
1372
|
self._p2p_store.register_named_tensors(named_tensors)
|
|
1169
1373
|
|
|
1170
1374
|
def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
|
|
1171
1375
|
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1172
|
-
pool = self.
|
|
1376
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1173
1377
|
if len(pool) == 0:
|
|
1174
1378
|
return 0
|
|
1379
|
+
unregister_name = (
|
|
1380
|
+
checkpoint_name
|
|
1381
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1382
|
+
else self.shared_memory_pool_name
|
|
1383
|
+
)
|
|
1175
1384
|
return self._p2p_store.unregister_named_tensors(
|
|
1176
|
-
[f"memory_pool_{
|
|
1385
|
+
[f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
|
|
1177
1386
|
)
|
|
1178
1387
|
|
|
1179
1388
|
def _update_per_bucket(
|
|
@@ -1284,9 +1493,9 @@ class ParameterServer:
|
|
|
1284
1493
|
dist.broadcast(buffer_b, src=brank)
|
|
1285
1494
|
resp = socket.recv()
|
|
1286
1495
|
if resp != b"":
|
|
1287
|
-
|
|
1496
|
+
msg = resp.decode("utf-8")
|
|
1288
1497
|
logger.error(
|
|
1289
|
-
f"[rank{self._rank}] receive error response
|
|
1498
|
+
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
1290
1499
|
)
|
|
1291
1500
|
ret_code.fill_(1)
|
|
1292
1501
|
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
checkpoint_engine/worker.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
import traceback
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import TypedDict
|
|
4
5
|
|
|
@@ -63,7 +64,8 @@ def update_weights_from_ipc(
|
|
|
63
64
|
assert buffer.dtype == torch.uint8
|
|
64
65
|
socket.send(b"")
|
|
65
66
|
except Exception as e:
|
|
66
|
-
|
|
67
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
68
|
+
socket.send_string(msg)
|
|
67
69
|
socket.recv() # wait for ack
|
|
68
70
|
raise
|
|
69
71
|
try:
|
|
@@ -83,7 +85,8 @@ def update_weights_from_ipc(
|
|
|
83
85
|
except Exception as e: # noqa: BLE001
|
|
84
86
|
# Send exception back to Parameter Server.
|
|
85
87
|
# Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
|
|
86
|
-
|
|
88
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
89
|
+
socket.send_string(msg)
|
|
87
90
|
elif isinstance(
|
|
88
91
|
payload, Exception
|
|
89
92
|
): # error occurred, got force quit signal from Parameter Server
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
|
|
|
99
99
|
pip install 'checkpoint-engine[p2p]'
|
|
100
100
|
```
|
|
101
101
|
|
|
102
|
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
103
|
-
|
|
104
102
|
## Getting Started
|
|
105
103
|
|
|
106
|
-
Prepare an H800 or H20 machine with 8 GPUs with
|
|
104
|
+
Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
|
|
107
105
|
|
|
108
106
|
```Bash
|
|
109
|
-
|
|
107
|
+
mkdir -p /opt/vLLM && cd /opt/vLLM
|
|
110
108
|
uv venv --python 3.12 --seed
|
|
111
109
|
source .venv/bin/activate
|
|
112
|
-
|
|
110
|
+
uv pip install vllm==0.10.2
|
|
113
111
|
```
|
|
114
112
|
|
|
115
113
|
Install checkpoint-engine
|
|
@@ -180,6 +178,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
|
|
|
180
178
|
pytest tests/ -m "not gpu"
|
|
181
179
|
```
|
|
182
180
|
|
|
181
|
+
### Environment Variables
|
|
182
|
+
- `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
|
|
183
|
+
- `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
|
|
184
|
+
- `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
|
|
185
|
+
|
|
183
186
|
## SGLang Integration
|
|
184
187
|
|
|
185
188
|
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=o3ZTescp-19Z9cvBGq9dQnbppljgzdUYUf98Nov0spY,704
|
|
3
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
4
|
+
checkpoint_engine/ps.py,sha256=cu8Qp5daY1iL30iN69jXP4grlHoAKILblngcKQPA5Bg,67692
|
|
5
|
+
checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
|
|
6
|
+
checkpoint_engine-0.2.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
7
|
+
checkpoint_engine-0.2.2.dist-info/METADATA,sha256=_bBxy27d0GMc7KzuIBAdw-Lno3-UrVLUhH63YDbY1YA,11559
|
|
8
|
+
checkpoint_engine-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
checkpoint_engine-0.2.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
10
|
+
checkpoint_engine-0.2.2.dist-info/RECORD,,
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
-
checkpoint_engine/_version.py,sha256=vYqoJTG51NOUmYyL0xt8asRK8vUT4lGAdal_EZ59mvw,704
|
|
3
|
-
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
4
|
-
checkpoint_engine/ps.py,sha256=Mgfr_MXYYZ_6JKqD5kIGDBWCYNWAtIPfAppP_cFu604,57781
|
|
5
|
-
checkpoint_engine/worker.py,sha256=5TzDgTPew6Ts9sMOzecalLCR1p_ZwfeKPdzzr68kAQg,6564
|
|
6
|
-
checkpoint_engine-0.2.1.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
7
|
-
checkpoint_engine-0.2.1.dist-info/METADATA,sha256=7E7NhehWHpS6QVkLup-oCm350wdbiZX8CY3jmmJP0bU,11315
|
|
8
|
-
checkpoint_engine-0.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
-
checkpoint_engine-0.2.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
10
|
-
checkpoint_engine-0.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|