checkpoint-engine 0.2.1__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/PKG-INFO +9 -6
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/README.md +8 -5
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/_version.py +3 -3
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/ps.py +259 -38
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/worker.py +5 -2
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/PKG-INFO +9 -6
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/SOURCES.txt +1 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/examples/update.py +5 -2
- checkpoint_engine-0.2.3/tests/test_pin_memory.py +77 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/tests/test_update.py +96 -2
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.github/workflows/cpu-tests.yml +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.github/workflows/pre-commit.yaml +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.github/workflows/python-publish.yml +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.gitignore +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.pre-commit-config.yaml +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/LICENCE +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/__init__.py +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/device_utils.py +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/requires.txt +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/top_level.txt +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/docs/npu_start.md +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/figures/checkpoint-engine.png +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/figures/overlap-update-and-copy.png +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/figures/pipeline.png +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/patches/vllm_fp8.patch +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/pyproject.toml +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/setup.cfg +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/tests/test_assign_receiver_ranks.py +0 -0
- {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/tests/test_rdma_parser.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
|
|
|
99
99
|
pip install 'checkpoint-engine[p2p]'
|
|
100
100
|
```
|
|
101
101
|
|
|
102
|
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
103
|
-
|
|
104
102
|
## Getting Started
|
|
105
103
|
|
|
106
|
-
Prepare an H800 or H20 machine with 8 GPUs with
|
|
104
|
+
Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
|
|
107
105
|
|
|
108
106
|
```Bash
|
|
109
|
-
|
|
107
|
+
mkdir -p /opt/vLLM && cd /opt/vLLM
|
|
110
108
|
uv venv --python 3.12 --seed
|
|
111
109
|
source .venv/bin/activate
|
|
112
|
-
|
|
110
|
+
uv pip install vllm==0.10.2
|
|
113
111
|
```
|
|
114
112
|
|
|
115
113
|
Install checkpoint-engine
|
|
@@ -180,6 +178,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
|
|
|
180
178
|
pytest tests/ -m "not gpu"
|
|
181
179
|
```
|
|
182
180
|
|
|
181
|
+
### Environment Variables
|
|
182
|
+
- `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
|
|
183
|
+
- `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
|
|
184
|
+
- `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
|
|
185
|
+
|
|
183
186
|
## SGLang Integration
|
|
184
187
|
|
|
185
188
|
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
@@ -75,17 +75,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
|
|
|
75
75
|
pip install 'checkpoint-engine[p2p]'
|
|
76
76
|
```
|
|
77
77
|
|
|
78
|
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
79
|
-
|
|
80
78
|
## Getting Started
|
|
81
79
|
|
|
82
|
-
Prepare an H800 or H20 machine with 8 GPUs with
|
|
80
|
+
Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
|
|
83
81
|
|
|
84
82
|
```Bash
|
|
85
|
-
|
|
83
|
+
mkdir -p /opt/vLLM && cd /opt/vLLM
|
|
86
84
|
uv venv --python 3.12 --seed
|
|
87
85
|
source .venv/bin/activate
|
|
88
|
-
|
|
86
|
+
uv pip install vllm==0.10.2
|
|
89
87
|
```
|
|
90
88
|
|
|
91
89
|
Install checkpoint-engine
|
|
@@ -156,6 +154,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
|
|
|
156
154
|
pytest tests/ -m "not gpu"
|
|
157
155
|
```
|
|
158
156
|
|
|
157
|
+
### Environment Variables
|
|
158
|
+
- `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
|
|
159
|
+
- `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
|
|
160
|
+
- `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
|
|
161
|
+
|
|
159
162
|
## SGLang Integration
|
|
160
163
|
|
|
161
164
|
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.3'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 3)
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'g0a6244951'
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import concurrent.futures
|
|
3
3
|
import ctypes
|
|
4
|
+
import json
|
|
4
5
|
import os
|
|
5
6
|
import pickle
|
|
6
7
|
import random
|
|
@@ -18,7 +19,7 @@ import torch.distributed as dist
|
|
|
18
19
|
import zmq
|
|
19
20
|
from loguru import logger
|
|
20
21
|
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
21
|
-
from safetensors.torch import safe_open
|
|
22
|
+
from safetensors.torch import _getdtype, safe_open
|
|
22
23
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
23
24
|
|
|
24
25
|
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
|
|
|
92
93
|
name: str
|
|
93
94
|
dtype: _TorchDtype
|
|
94
95
|
shape: _TorchSize
|
|
96
|
+
aligned_size: int
|
|
95
97
|
|
|
96
98
|
|
|
97
99
|
class BucketRange(NamedTuple):
|
|
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
|
140
142
|
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
141
143
|
ret = []
|
|
142
144
|
for meta in metas:
|
|
143
|
-
size =
|
|
145
|
+
size = meta.aligned_size
|
|
144
146
|
ret.append(
|
|
145
147
|
{
|
|
146
148
|
"name": meta.name,
|
|
@@ -422,6 +424,7 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
422
424
|
name=parameter_name,
|
|
423
425
|
shape=meta["shape"],
|
|
424
426
|
dtype=meta["dtype"],
|
|
427
|
+
aligned_size=_align_size(meta["dtype"], meta["shape"]),
|
|
425
428
|
)
|
|
426
429
|
tp_meta = tp_metas[parameter_name]
|
|
427
430
|
if tp_meta.concat_dim != -1:
|
|
@@ -431,7 +434,10 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
431
434
|
shape = list(parameter_metas[name].shape)
|
|
432
435
|
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
433
436
|
parameter_metas[name] = ParameterMeta(
|
|
434
|
-
name=name,
|
|
437
|
+
name=name,
|
|
438
|
+
shape=torch.Size(shape),
|
|
439
|
+
dtype=parameter_metas[name].dtype,
|
|
440
|
+
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
|
|
435
441
|
)
|
|
436
442
|
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
437
443
|
# TODO: here concat is serial, which may be slow
|
|
@@ -449,17 +455,85 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
449
455
|
return parameters
|
|
450
456
|
|
|
451
457
|
|
|
452
|
-
def
|
|
453
|
-
|
|
458
|
+
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
459
|
+
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
460
|
+
"""
|
|
461
|
+
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
462
|
+
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
|
|
463
|
+
The actual tensor data is in the remaining bytes and is naturally aligned.
|
|
464
|
+
We pin the remaining bytes as the buffer, making pinning faster.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def _pin(t: torch.Tensor):
|
|
468
|
+
"""
|
|
469
|
+
Pin the memory of tensor in-place.
|
|
470
|
+
See: https://github.com/pytorch/pytorch/issues/32167
|
|
471
|
+
"""
|
|
472
|
+
cudart = torch.cuda.cudart()
|
|
473
|
+
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
474
|
+
assert r == 0, f"pin memory error, error code: {r}"
|
|
475
|
+
|
|
476
|
+
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
477
|
+
size = os.stat(file_path).st_size
|
|
478
|
+
flag_size = 8
|
|
479
|
+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
|
|
480
|
+
assert t.nbytes > flag_size, (
|
|
481
|
+
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
|
|
482
|
+
)
|
|
483
|
+
start_pos = (
|
|
484
|
+
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
|
|
485
|
+
+ flag_size
|
|
486
|
+
)
|
|
487
|
+
header_tensor = t[flag_size:start_pos]
|
|
488
|
+
header = json.loads(header_tensor.numpy().tobytes())
|
|
489
|
+
if "__metadata__" in header:
|
|
490
|
+
header.pop("__metadata__")
|
|
491
|
+
|
|
492
|
+
metas: list[ParameterMeta] = []
|
|
493
|
+
offset = 0
|
|
494
|
+
try:
|
|
495
|
+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
|
|
496
|
+
start, end = meta["data_offsets"]
|
|
497
|
+
# safetensors format ensures offsets are aligned
|
|
498
|
+
assert offset == start, f"offset {offset} should be equal to start {start}"
|
|
499
|
+
metas.append(
|
|
500
|
+
ParameterMeta(
|
|
501
|
+
name=name,
|
|
502
|
+
dtype=_getdtype(meta["dtype"]),
|
|
503
|
+
shape=torch.Size(meta["shape"]),
|
|
504
|
+
aligned_size=end - start,
|
|
505
|
+
)
|
|
506
|
+
)
|
|
507
|
+
offset = end
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
|
|
510
|
+
raise
|
|
511
|
+
|
|
512
|
+
buffer = t[start_pos:]
|
|
513
|
+
assert offset == buffer.nbytes, (
|
|
514
|
+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
|
|
515
|
+
)
|
|
516
|
+
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
517
|
+
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
518
|
+
os.remove(file_path)
|
|
519
|
+
_pin(buffer)
|
|
520
|
+
logger.info(
|
|
521
|
+
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
522
|
+
)
|
|
523
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
|
|
524
|
+
|
|
525
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
526
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
527
|
+
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
|
|
528
|
+
return memory_buffers
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _normal_pin_memory(
|
|
454
532
|
files: list[str],
|
|
455
533
|
named_tensors: dict[str, torch.Tensor],
|
|
456
534
|
rank: int | None = None,
|
|
535
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
457
536
|
) -> list[MemoryBuffer]:
|
|
458
|
-
logger.info(
|
|
459
|
-
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
460
|
-
)
|
|
461
|
-
if not files and not named_tensors:
|
|
462
|
-
return []
|
|
463
537
|
parameters = _load_checkpoint(files)
|
|
464
538
|
if named_tensors:
|
|
465
539
|
parameters.update(named_tensors)
|
|
@@ -469,13 +543,16 @@ def _register_checkpoint(
|
|
|
469
543
|
size: int
|
|
470
544
|
metas: list[ParameterMeta]
|
|
471
545
|
|
|
472
|
-
buckets: list[MemoryBucket] = [
|
|
546
|
+
buckets: list[MemoryBucket] = []
|
|
547
|
+
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
473
548
|
for name, tensor in sorted(parameters.items()):
|
|
474
549
|
size = _align_size(tensor.dtype, tensor.shape)
|
|
475
550
|
if buckets[-1].size + size > bucket_size:
|
|
476
551
|
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
|
|
477
552
|
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
478
|
-
buckets[-1].metas.append(
|
|
553
|
+
buckets[-1].metas.append(
|
|
554
|
+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
|
|
555
|
+
)
|
|
479
556
|
buckets[-1].size += size
|
|
480
557
|
|
|
481
558
|
memory_buffers = [
|
|
@@ -483,16 +560,34 @@ def _register_checkpoint(
|
|
|
483
560
|
for bucket in buckets
|
|
484
561
|
]
|
|
485
562
|
|
|
486
|
-
def register_pin_memory(
|
|
487
|
-
|
|
488
|
-
|
|
563
|
+
def register_pin_memory(
|
|
564
|
+
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
|
|
565
|
+
) -> tuple[int, torch.Tensor]:
|
|
566
|
+
if shared_pin_memory:
|
|
567
|
+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
|
|
568
|
+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
|
|
569
|
+
assert idx < len(shared_pin_memory), (
|
|
570
|
+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
|
|
571
|
+
)
|
|
572
|
+
assert shared_pin_memory[idx].size == size, (
|
|
573
|
+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
|
|
574
|
+
)
|
|
575
|
+
return idx, shared_pin_memory[idx].buffer
|
|
576
|
+
else:
|
|
577
|
+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
578
|
+
return idx, buffer
|
|
489
579
|
|
|
490
580
|
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
|
|
491
581
|
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
492
582
|
|
|
493
583
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
494
584
|
futures = [
|
|
495
|
-
executor.submit(
|
|
585
|
+
executor.submit(
|
|
586
|
+
register_pin_memory,
|
|
587
|
+
idx,
|
|
588
|
+
bucket.size,
|
|
589
|
+
shared_pin_memory,
|
|
590
|
+
)
|
|
496
591
|
for idx, bucket in enumerate(buckets)
|
|
497
592
|
]
|
|
498
593
|
new_futures = []
|
|
@@ -518,6 +613,45 @@ def _register_checkpoint(
|
|
|
518
613
|
offset += size
|
|
519
614
|
for future in concurrent.futures.as_completed(new_futures):
|
|
520
615
|
future.result()
|
|
616
|
+
return memory_buffers
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _register_checkpoint(
|
|
620
|
+
*,
|
|
621
|
+
files: list[str],
|
|
622
|
+
named_tensors: dict[str, torch.Tensor],
|
|
623
|
+
rank: int | None = None,
|
|
624
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
625
|
+
inplace_pin: bool = False,
|
|
626
|
+
) -> list[MemoryBuffer]:
|
|
627
|
+
logger.info(
|
|
628
|
+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
629
|
+
)
|
|
630
|
+
if not files and not named_tensors:
|
|
631
|
+
return []
|
|
632
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
633
|
+
if inplace_pin:
|
|
634
|
+
logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
|
|
635
|
+
files_to_inplace_pin = [
|
|
636
|
+
file
|
|
637
|
+
for file in files
|
|
638
|
+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
639
|
+
]
|
|
640
|
+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
641
|
+
else:
|
|
642
|
+
files_to_normal_pin = files
|
|
643
|
+
files_to_inplace_pin = []
|
|
644
|
+
if files_to_normal_pin or named_tensors:
|
|
645
|
+
memory_buffers.extend(
|
|
646
|
+
_normal_pin_memory(
|
|
647
|
+
files=files_to_normal_pin,
|
|
648
|
+
named_tensors=named_tensors,
|
|
649
|
+
rank=rank,
|
|
650
|
+
shared_pin_memory=shared_pin_memory,
|
|
651
|
+
)
|
|
652
|
+
)
|
|
653
|
+
if files_to_inplace_pin:
|
|
654
|
+
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
|
|
521
655
|
return memory_buffers
|
|
522
656
|
|
|
523
657
|
|
|
@@ -566,7 +700,7 @@ def _gen_h2d_buckets(
|
|
|
566
700
|
for idx, metas in enumerate(items.memory_buffer_metas_list):
|
|
567
701
|
start_offset, offset = 0, 0
|
|
568
702
|
for meta in metas.metas:
|
|
569
|
-
s =
|
|
703
|
+
s = meta.aligned_size
|
|
570
704
|
if buckets[-1][1].size + s > bucket_size:
|
|
571
705
|
if offset - start_offset > 0:
|
|
572
706
|
buckets[-1][1].ranges.append(
|
|
@@ -747,6 +881,8 @@ class P2PStore:
|
|
|
747
881
|
|
|
748
882
|
|
|
749
883
|
class ParameterServer:
|
|
884
|
+
shared_memory_pool_name = "__shared_memory_pool__"
|
|
885
|
+
|
|
750
886
|
def __init__(
|
|
751
887
|
self,
|
|
752
888
|
*,
|
|
@@ -790,7 +926,10 @@ class ParameterServer:
|
|
|
790
926
|
self._zmq_ctx = zmq.Context()
|
|
791
927
|
self._zmq_addr_counter = 0
|
|
792
928
|
|
|
929
|
+
# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
|
|
930
|
+
self._current_shared_memory_pool_user: str = ""
|
|
793
931
|
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
932
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
794
933
|
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
795
934
|
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
|
|
796
935
|
# NPU transfer engine initialization requires prior set_device.
|
|
@@ -805,6 +944,17 @@ class ParameterServer:
|
|
|
805
944
|
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
|
|
806
945
|
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
807
946
|
|
|
947
|
+
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
|
|
948
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
949
|
+
assert self._memory_pool[self.shared_memory_pool_name], (
|
|
950
|
+
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
|
|
951
|
+
)
|
|
952
|
+
return self._memory_pool[self.shared_memory_pool_name]
|
|
953
|
+
elif checkpoint_name in self._memory_pool:
|
|
954
|
+
return self._memory_pool[checkpoint_name]
|
|
955
|
+
else:
|
|
956
|
+
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
|
|
957
|
+
|
|
808
958
|
def _logger_rank0(self, msg: str):
|
|
809
959
|
if self._local_rank == 0:
|
|
810
960
|
logger.info(msg)
|
|
@@ -828,46 +978,103 @@ class ParameterServer:
|
|
|
828
978
|
*,
|
|
829
979
|
files: list[str] | None = None,
|
|
830
980
|
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
981
|
+
use_shared_memory_pool: bool = False,
|
|
982
|
+
use_inplace_pin_memory: bool = False,
|
|
831
983
|
) -> None:
|
|
832
984
|
"""
|
|
833
985
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
986
|
+
Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
987
|
+
Please make sure to copy the files to disks if you need to keep them.
|
|
834
988
|
|
|
835
989
|
Args:
|
|
836
990
|
checkpoint_name: The name of the checkpoint.
|
|
837
991
|
files: The safetensors files to register.
|
|
838
992
|
named_tensors: The named tensors to register.
|
|
993
|
+
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
|
|
994
|
+
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
|
|
995
|
+
cannot accommodate checkpoints with different memory requirements.
|
|
996
|
+
To free the actual memory of the shared pool or to modify its shape,
|
|
997
|
+
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
998
|
+
use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
|
|
999
|
+
Currently, this feature is experimental and may crash.
|
|
839
1000
|
"""
|
|
840
1001
|
try:
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
1002
|
+
if use_shared_memory_pool:
|
|
1003
|
+
logger.info(
|
|
1004
|
+
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
|
|
1005
|
+
)
|
|
1006
|
+
assert self._current_shared_memory_pool_user == "", (
|
|
1007
|
+
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
|
|
1008
|
+
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
|
|
1009
|
+
f"This registration may cause unexpected conflicts."
|
|
1010
|
+
)
|
|
1011
|
+
# Since we set the uninitialized shared memory pool to empty list,
|
|
1012
|
+
# we can check whether this is the first time to use shared memory pool
|
|
1013
|
+
_is_first_time = not self._memory_pool[self.shared_memory_pool_name]
|
|
1014
|
+
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
|
|
1015
|
+
files=files or [],
|
|
1016
|
+
named_tensors=named_tensors or {},
|
|
1017
|
+
rank=self._rank,
|
|
1018
|
+
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
|
|
1019
|
+
)
|
|
1020
|
+
self._current_shared_memory_pool_user = checkpoint_name
|
|
1021
|
+
if self._p2p_store is not None and _is_first_time:
|
|
1022
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
1023
|
+
else:
|
|
1024
|
+
assert checkpoint_name not in self._memory_pool, (
|
|
1025
|
+
f"checkpoint {checkpoint_name} already registered"
|
|
1026
|
+
)
|
|
1027
|
+
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
1028
|
+
files=files or [],
|
|
1029
|
+
named_tensors=named_tensors or {},
|
|
1030
|
+
rank=self._rank,
|
|
1031
|
+
inplace_pin=use_inplace_pin_memory,
|
|
1032
|
+
)
|
|
1033
|
+
if self._p2p_store is not None:
|
|
1034
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
849
1035
|
except Exception:
|
|
850
1036
|
logger.exception(
|
|
851
1037
|
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
|
|
852
1038
|
)
|
|
853
|
-
if self._p2p_store is not None:
|
|
1039
|
+
if self._p2p_store is not None and not use_shared_memory_pool:
|
|
854
1040
|
self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
855
1041
|
self.unregister_checkpoint(checkpoint_name)
|
|
856
1042
|
raise
|
|
857
1043
|
|
|
858
|
-
def unregister_checkpoint(self, checkpoint_name: str):
|
|
1044
|
+
def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
|
|
859
1045
|
"""
|
|
860
1046
|
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
|
|
861
1047
|
from p2p store if p2p store is initialized.
|
|
1048
|
+
Args:
|
|
1049
|
+
checkpoint_name: The name of the checkpoint.
|
|
1050
|
+
force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
|
|
1051
|
+
If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
|
|
862
1052
|
"""
|
|
863
|
-
if
|
|
1053
|
+
if (
|
|
1054
|
+
checkpoint_name not in self._memory_pool
|
|
1055
|
+
and checkpoint_name != self._current_shared_memory_pool_user
|
|
1056
|
+
):
|
|
1057
|
+
logger.warning(
|
|
1058
|
+
f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
|
|
1059
|
+
)
|
|
1060
|
+
return
|
|
1061
|
+
|
|
1062
|
+
if checkpoint_name == self._current_shared_memory_pool_user and not force:
|
|
1063
|
+
self._current_shared_memory_pool_user = ""
|
|
864
1064
|
return
|
|
1065
|
+
|
|
865
1066
|
if self._p2p_store is not None:
|
|
866
1067
|
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
867
1068
|
logger.info(
|
|
868
1069
|
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
|
|
869
1070
|
)
|
|
870
|
-
|
|
1071
|
+
|
|
1072
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
1073
|
+
self._current_shared_memory_pool_user = ""
|
|
1074
|
+
del self._memory_pool[self.shared_memory_pool_name]
|
|
1075
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
1076
|
+
else:
|
|
1077
|
+
del self._memory_pool[checkpoint_name]
|
|
871
1078
|
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
872
1079
|
# this works by using torch>=2.5.0
|
|
873
1080
|
torch._C._host_emptyCache()
|
|
@@ -882,6 +1089,10 @@ class ParameterServer:
|
|
|
882
1089
|
self.init_process_group()
|
|
883
1090
|
assert dist.is_initialized(), "process group is not initialized"
|
|
884
1091
|
metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
|
|
1092
|
+
try:
|
|
1093
|
+
memory_pool = self._get_memory_pool(checkpoint_name)
|
|
1094
|
+
except RuntimeError:
|
|
1095
|
+
memory_pool = []
|
|
885
1096
|
metas = DataToGather(
|
|
886
1097
|
memory_buffer_metas_list=[
|
|
887
1098
|
MemoryBufferMetas(
|
|
@@ -889,7 +1100,7 @@ class ParameterServer:
|
|
|
889
1100
|
ptr=x.buffer.data_ptr(),
|
|
890
1101
|
size=x.size,
|
|
891
1102
|
)
|
|
892
|
-
for x in
|
|
1103
|
+
for x in memory_pool
|
|
893
1104
|
],
|
|
894
1105
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
895
1106
|
host_ip=get_ip(),
|
|
@@ -1050,7 +1261,7 @@ class ParameterServer:
|
|
|
1050
1261
|
for items in self._current_global_parameter_metas.values():
|
|
1051
1262
|
for metas_list in items.memory_buffer_metas_list:
|
|
1052
1263
|
for meta in metas_list.metas:
|
|
1053
|
-
max_tensor_bytes = max(max_tensor_bytes,
|
|
1264
|
+
max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
|
|
1054
1265
|
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
|
|
1055
1266
|
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
|
|
1056
1267
|
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
|
|
@@ -1095,7 +1306,7 @@ class ParameterServer:
|
|
|
1095
1306
|
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
|
|
1096
1307
|
lens.append(b.size)
|
|
1097
1308
|
else:
|
|
1098
|
-
pool = self.
|
|
1309
|
+
pool = self._get_memory_pool(checkpoint_name)[b.idx]
|
|
1099
1310
|
buffer[offset : offset + b.size].data.copy_(
|
|
1100
1311
|
pool.buffer[b.offset : b.offset + b.size],
|
|
1101
1312
|
non_blocking=True,
|
|
@@ -1158,22 +1369,32 @@ class ParameterServer:
|
|
|
1158
1369
|
|
|
1159
1370
|
def _register_parameters_to_p2p_store(self, checkpoint_name: str):
|
|
1160
1371
|
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1161
|
-
pool = self.
|
|
1372
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1162
1373
|
if len(pool) == 0:
|
|
1163
1374
|
return
|
|
1164
1375
|
named_tensors, tensor_ptrs = {}, []
|
|
1376
|
+
register_name = (
|
|
1377
|
+
checkpoint_name
|
|
1378
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1379
|
+
else self.shared_memory_pool_name
|
|
1380
|
+
)
|
|
1165
1381
|
for idx, memory_buffer in enumerate(pool):
|
|
1166
|
-
named_tensors[f"memory_pool_{
|
|
1382
|
+
named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
|
|
1167
1383
|
tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
|
|
1168
1384
|
self._p2p_store.register_named_tensors(named_tensors)
|
|
1169
1385
|
|
|
1170
1386
|
def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
|
|
1171
1387
|
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1172
|
-
pool = self.
|
|
1388
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1173
1389
|
if len(pool) == 0:
|
|
1174
1390
|
return 0
|
|
1391
|
+
unregister_name = (
|
|
1392
|
+
checkpoint_name
|
|
1393
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1394
|
+
else self.shared_memory_pool_name
|
|
1395
|
+
)
|
|
1175
1396
|
return self._p2p_store.unregister_named_tensors(
|
|
1176
|
-
[f"memory_pool_{
|
|
1397
|
+
[f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
|
|
1177
1398
|
)
|
|
1178
1399
|
|
|
1179
1400
|
def _update_per_bucket(
|
|
@@ -1284,9 +1505,9 @@ class ParameterServer:
|
|
|
1284
1505
|
dist.broadcast(buffer_b, src=brank)
|
|
1285
1506
|
resp = socket.recv()
|
|
1286
1507
|
if resp != b"":
|
|
1287
|
-
|
|
1508
|
+
msg = resp.decode("utf-8")
|
|
1288
1509
|
logger.error(
|
|
1289
|
-
f"[rank{self._rank}] receive error response
|
|
1510
|
+
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
1290
1511
|
)
|
|
1291
1512
|
ret_code.fill_(1)
|
|
1292
1513
|
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
import traceback
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import TypedDict
|
|
4
5
|
|
|
@@ -63,7 +64,8 @@ def update_weights_from_ipc(
|
|
|
63
64
|
assert buffer.dtype == torch.uint8
|
|
64
65
|
socket.send(b"")
|
|
65
66
|
except Exception as e:
|
|
66
|
-
|
|
67
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
68
|
+
socket.send_string(msg)
|
|
67
69
|
socket.recv() # wait for ack
|
|
68
70
|
raise
|
|
69
71
|
try:
|
|
@@ -83,7 +85,8 @@ def update_weights_from_ipc(
|
|
|
83
85
|
except Exception as e: # noqa: BLE001
|
|
84
86
|
# Send exception back to Parameter Server.
|
|
85
87
|
# Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
|
|
86
|
-
|
|
88
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
89
|
+
socket.send_string(msg)
|
|
87
90
|
elif isinstance(
|
|
88
91
|
payload, Exception
|
|
89
92
|
): # error occurred, got force quit signal from Parameter Server
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
|
|
|
99
99
|
pip install 'checkpoint-engine[p2p]'
|
|
100
100
|
```
|
|
101
101
|
|
|
102
|
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
103
|
-
|
|
104
102
|
## Getting Started
|
|
105
103
|
|
|
106
|
-
Prepare an H800 or H20 machine with 8 GPUs with
|
|
104
|
+
Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
|
|
107
105
|
|
|
108
106
|
```Bash
|
|
109
|
-
|
|
107
|
+
mkdir -p /opt/vLLM && cd /opt/vLLM
|
|
110
108
|
uv venv --python 3.12 --seed
|
|
111
109
|
source .venv/bin/activate
|
|
112
|
-
|
|
110
|
+
uv pip install vllm==0.10.2
|
|
113
111
|
```
|
|
114
112
|
|
|
115
113
|
Install checkpoint-engine
|
|
@@ -180,6 +178,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
|
|
|
180
178
|
pytest tests/ -m "not gpu"
|
|
181
179
|
```
|
|
182
180
|
|
|
181
|
+
### Environment Variables
|
|
182
|
+
- `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
|
|
183
|
+
- `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
|
|
184
|
+
- `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
|
|
185
|
+
|
|
183
186
|
## SGLang Integration
|
|
184
187
|
|
|
185
188
|
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
@@ -100,8 +100,9 @@ def update_weights(
|
|
|
100
100
|
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
|
|
101
101
|
uds: str | None = None,
|
|
102
102
|
):
|
|
103
|
-
ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
|
|
104
103
|
ps.init_process_group()
|
|
104
|
+
dist.barrier()
|
|
105
|
+
ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
|
|
105
106
|
check_vllm_ready(endpoint, inference_parallel_size, uds)
|
|
106
107
|
dist.barrier()
|
|
107
108
|
with timer("Gather metas"):
|
|
@@ -173,7 +174,9 @@ if __name__ == "__main__":
|
|
|
173
174
|
args.uds,
|
|
174
175
|
)
|
|
175
176
|
else:
|
|
176
|
-
if os.path.exists(
|
|
177
|
+
if os.path.exists(
|
|
178
|
+
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
|
|
179
|
+
) and not args.checkpoint_path.startswith("/dev/shm/"): # noqa: S108
|
|
177
180
|
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
|
|
178
181
|
checkpoint_files = []
|
|
179
182
|
else:
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from checkpoint_engine.ps import ParameterServer
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def generate_dummy_checkpoint() -> dict[str, torch.Tensor]:
|
|
10
|
+
"""
|
|
11
|
+
Generate dummy checkpoint data
|
|
12
|
+
"""
|
|
13
|
+
named_tensors = {
|
|
14
|
+
"layer1.weight": torch.randn(1024, 1024),
|
|
15
|
+
"layer1.bias": torch.randn(1024),
|
|
16
|
+
"layer2.weight": torch.randn(2048, 1024),
|
|
17
|
+
"layer2.bias": torch.randn(2048),
|
|
18
|
+
}
|
|
19
|
+
return named_tensors
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.mark.gpu
|
|
23
|
+
def test_register_pin_memory():
|
|
24
|
+
os.environ["RANK"] = "0"
|
|
25
|
+
os.environ["WORLD_SIZE"] = "1"
|
|
26
|
+
ps = ParameterServer()
|
|
27
|
+
checkpoint1 = generate_dummy_checkpoint()
|
|
28
|
+
checkpoint_shared1 = generate_dummy_checkpoint()
|
|
29
|
+
checkpoint2 = generate_dummy_checkpoint()
|
|
30
|
+
checkpoint_shared2 = generate_dummy_checkpoint()
|
|
31
|
+
checkpoint_shared3 = generate_dummy_checkpoint()
|
|
32
|
+
checkpoint_shared3["layer3.weight"] = torch.randn(4096, 2048)
|
|
33
|
+
checkpoint_shared3["layer3.bias"] = torch.randn(4096)
|
|
34
|
+
ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1)
|
|
35
|
+
ps.unregister_checkpoint("test_checkpoint1")
|
|
36
|
+
assert "test_checkpoint1" not in ps._memory_pool
|
|
37
|
+
ps.register_checkpoint(
|
|
38
|
+
"test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True
|
|
39
|
+
)
|
|
40
|
+
ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2)
|
|
41
|
+
assert "test_checkpoint_shared1" not in ps._memory_pool
|
|
42
|
+
assert "__shared_memory_pool__" in ps._memory_pool
|
|
43
|
+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
|
|
44
|
+
assert "test_checkpoint2" in ps._memory_pool
|
|
45
|
+
try:
|
|
46
|
+
ps.register_checkpoint(
|
|
47
|
+
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
|
|
48
|
+
) # this will fail
|
|
49
|
+
except AssertionError:
|
|
50
|
+
print("Caught expected AssertionError when registering second shared memory pool user")
|
|
51
|
+
assert "test_checkpoint_shared2" not in ps._memory_pool
|
|
52
|
+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
|
|
53
|
+
ps.unregister_checkpoint("test_checkpoint_shared1")
|
|
54
|
+
assert ps._current_shared_memory_pool_user == ""
|
|
55
|
+
assert "__shared_memory_pool__" in ps._memory_pool
|
|
56
|
+
ps.register_checkpoint(
|
|
57
|
+
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
|
|
58
|
+
)
|
|
59
|
+
assert "test_checkpoint_shared2" not in ps._memory_pool
|
|
60
|
+
assert "__shared_memory_pool__" in ps._memory_pool
|
|
61
|
+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2"
|
|
62
|
+
ps.unregister_checkpoint("test_checkpoint1") # this will trigger an warning
|
|
63
|
+
assert "test_checkpoint1" not in ps._memory_pool
|
|
64
|
+
ps.unregister_checkpoint("test_checkpoint2")
|
|
65
|
+
assert "test_checkpoint2" not in ps._memory_pool
|
|
66
|
+
ps.unregister_checkpoint("test_checkpoint_shared2", force=True)
|
|
67
|
+
assert ps._current_shared_memory_pool_user == ""
|
|
68
|
+
assert "__shared_memory_pool__" in ps._memory_pool
|
|
69
|
+
ps.register_checkpoint(
|
|
70
|
+
"test_checkpoint_shared3", named_tensors=checkpoint_shared3, use_shared_memory_pool=True
|
|
71
|
+
)
|
|
72
|
+
assert "test_checkpoint_shared3" not in ps._memory_pool
|
|
73
|
+
assert "__shared_memory_pool__" in ps._memory_pool
|
|
74
|
+
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared3"
|
|
75
|
+
ps.unregister_checkpoint("test_checkpoint_shared3")
|
|
76
|
+
assert ps._current_shared_memory_pool_user == ""
|
|
77
|
+
assert "__shared_memory_pool__" in ps._memory_pool
|
|
@@ -82,7 +82,7 @@ def checker_proc_with_error(
|
|
|
82
82
|
try:
|
|
83
83
|
trigger_error(socket_paths)
|
|
84
84
|
except RuntimeError as e:
|
|
85
|
-
assert str(e) == "
|
|
85
|
+
assert str(e) == "Some workers failed to update weights"
|
|
86
86
|
|
|
87
87
|
|
|
88
88
|
def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
|
|
@@ -96,7 +96,7 @@ def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Ten
|
|
|
96
96
|
for name, weight in weights:
|
|
97
97
|
if name not in named_tensors:
|
|
98
98
|
continue
|
|
99
|
-
assert (weight == named_tensors[name]).all()
|
|
99
|
+
assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!"
|
|
100
100
|
names_to_check[name] = True
|
|
101
101
|
|
|
102
102
|
def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]):
|
|
@@ -163,6 +163,67 @@ def run(
|
|
|
163
163
|
assert proc.exitcode == 0
|
|
164
164
|
|
|
165
165
|
|
|
166
|
+
def run_with_files(
|
|
167
|
+
checker_func: callable,
|
|
168
|
+
):
|
|
169
|
+
rank = int(os.getenv("RANK"))
|
|
170
|
+
ctx = get_context("spawn")
|
|
171
|
+
queue = ctx.Queue()
|
|
172
|
+
_device_uuid = _get_physical_gpu_id(device_manager, rank)
|
|
173
|
+
ps = ParameterServer(auto_pg=True)
|
|
174
|
+
_device_uuid = _get_physical_gpu_id(ps.device_manager, rank)
|
|
175
|
+
named_tensors = dict(gen_test_tensors(rank))
|
|
176
|
+
|
|
177
|
+
# Save 1/3 tensors to /dev/shm/ as .safetensors files
|
|
178
|
+
# Save 1/3 tensors to ./tmp (disk) as .safetensors files
|
|
179
|
+
# Keep 1/3 tensors in memory
|
|
180
|
+
import safetensors.torch
|
|
181
|
+
|
|
182
|
+
files = []
|
|
183
|
+
dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
|
|
184
|
+
disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108
|
|
185
|
+
os.makedirs(dev_shm_dir, exist_ok=True)
|
|
186
|
+
os.makedirs(disk_dir, exist_ok=True)
|
|
187
|
+
tensors_items = list(named_tensors.items())
|
|
188
|
+
tensors_in_dev_shm = named_tensors
|
|
189
|
+
tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
|
|
190
|
+
tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
|
|
191
|
+
tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
|
|
192
|
+
disk_files = [
|
|
193
|
+
os.path.join(disk_dir, f"rank{_rank}_checkpoint.safetensors")
|
|
194
|
+
for _rank in range(get_world_size())
|
|
195
|
+
]
|
|
196
|
+
safetensors.torch.save_file(tensors_in_disk, disk_files[rank])
|
|
197
|
+
time.sleep(1)
|
|
198
|
+
files.append(disk_files[rank])
|
|
199
|
+
dev_shm_files = [
|
|
200
|
+
os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
|
|
201
|
+
for _ in range(get_world_size())
|
|
202
|
+
]
|
|
203
|
+
safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
|
|
204
|
+
time.sleep(1)
|
|
205
|
+
files.append(dev_shm_files[rank])
|
|
206
|
+
|
|
207
|
+
checkpoint_name = "test_with_files"
|
|
208
|
+
proc = ctx.Process(target=checker_func, args=(rank, _device_uuid, named_tensors, queue))
|
|
209
|
+
proc.start()
|
|
210
|
+
ps.register_checkpoint(checkpoint_name, named_tensors=tensors_in_memory, files=files)
|
|
211
|
+
ps.gather_metas(checkpoint_name)
|
|
212
|
+
ps.update(checkpoint_name, queue.put, ranks=[])
|
|
213
|
+
# sleep 3s to wait process group is destroyed
|
|
214
|
+
time.sleep(3)
|
|
215
|
+
ps.unregister_checkpoint(checkpoint_name)
|
|
216
|
+
queue.put(None)
|
|
217
|
+
proc.join()
|
|
218
|
+
if rank == 0:
|
|
219
|
+
import shutil
|
|
220
|
+
|
|
221
|
+
# this test should be run under use_inplace_pin_memory=False. Otherwise, the files in /dev/shm/ will be deleted.
|
|
222
|
+
shutil.rmtree(dev_shm_dir)
|
|
223
|
+
shutil.rmtree(disk_dir)
|
|
224
|
+
assert proc.exitcode == 0
|
|
225
|
+
|
|
226
|
+
|
|
166
227
|
@pytest.mark.gpu
|
|
167
228
|
@pytest.mark.parametrize(
|
|
168
229
|
"test_name,rank_list",
|
|
@@ -211,6 +272,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
|
|
|
211
272
|
assert result.returncode == 0
|
|
212
273
|
|
|
213
274
|
|
|
275
|
+
@pytest.mark.gpu
|
|
276
|
+
def test_update_with_files(test_name: str = "test_with_files"):
|
|
277
|
+
world_size = device_manager.device_module.device_count()
|
|
278
|
+
assert world_size >= 2, "This test requires at least 2 GPUs."
|
|
279
|
+
master_addr = "localhost"
|
|
280
|
+
master_port = 25400
|
|
281
|
+
cmd = [
|
|
282
|
+
"torchrun",
|
|
283
|
+
"--nproc_per_node",
|
|
284
|
+
str(world_size),
|
|
285
|
+
"--master_addr",
|
|
286
|
+
master_addr,
|
|
287
|
+
"--master_port",
|
|
288
|
+
str(master_port),
|
|
289
|
+
__file__,
|
|
290
|
+
test_name,
|
|
291
|
+
"[]",
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
result = subprocess.run( # noqa: S603
|
|
295
|
+
cmd,
|
|
296
|
+
capture_output=False,
|
|
297
|
+
text=True,
|
|
298
|
+
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
|
299
|
+
shell=False,
|
|
300
|
+
check=False,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
assert result.returncode == 0
|
|
304
|
+
|
|
305
|
+
|
|
214
306
|
if __name__ == "__main__":
|
|
215
307
|
run_with_pytest = "PYTEST_CURRENT_TEST" in os.environ
|
|
216
308
|
if not run_with_pytest:
|
|
@@ -230,5 +322,7 @@ if __name__ == "__main__":
|
|
|
230
322
|
expected_exception=RuntimeError,
|
|
231
323
|
exception_msg="Failed to update weights due to remote errors",
|
|
232
324
|
)
|
|
325
|
+
elif test_type == "test_with_files":
|
|
326
|
+
run_with_files(checker_proc)
|
|
233
327
|
else:
|
|
234
328
|
raise ValueError(f"Unknown TEST_TYPE: {test_type}")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
{checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|