checkpoint-engine 0.3.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/__init__.py +4 -0
- checkpoint_engine/_version.py +34 -0
- checkpoint_engine/device_utils.py +86 -0
- checkpoint_engine/ps.py +1576 -0
- checkpoint_engine/worker.py +168 -0
- checkpoint_engine-0.3.0rc0.dist-info/METADATA +236 -0
- checkpoint_engine-0.3.0rc0.dist-info/RECORD +10 -0
- checkpoint_engine-0.3.0rc0.dist-info/WHEEL +5 -0
- checkpoint_engine-0.3.0rc0.dist-info/licenses/LICENCE +21 -0
- checkpoint_engine-0.3.0rc0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import traceback
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import zmq
|
|
8
|
+
|
|
9
|
+
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
13
|
+
func, args = handle
|
|
14
|
+
list_args = list(args)
|
|
15
|
+
if device_id is not None:
|
|
16
|
+
# the key is to change device id to the current device id
|
|
17
|
+
# in case two processes have different CUDA_VISIBLE_DEVICES
|
|
18
|
+
list_args[6] = device_id
|
|
19
|
+
buffer = func(*list_args)
|
|
20
|
+
return buffer
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FlattenedTensorMetadata(TypedDict):
|
|
24
|
+
name: str
|
|
25
|
+
shape: torch.Size
|
|
26
|
+
dtype: torch.dtype
|
|
27
|
+
# specify the start offset of this tensor in shared ipc_buffer tensor
|
|
28
|
+
offset: int
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _extract_weights(
|
|
32
|
+
payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
|
|
33
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
34
|
+
assert buffer is not None
|
|
35
|
+
weights: list[tuple[str, torch.Tensor]] = []
|
|
36
|
+
for item in payload:
|
|
37
|
+
shape = item["shape"]
|
|
38
|
+
if isinstance(shape, list | tuple):
|
|
39
|
+
shape = torch.Size(shape)
|
|
40
|
+
assert isinstance(shape, torch.Size)
|
|
41
|
+
dtype, offset = item["dtype"], item["offset"]
|
|
42
|
+
size = dtype.itemsize * shape.numel()
|
|
43
|
+
tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
|
|
44
|
+
weights.append((item["name"], tensor))
|
|
45
|
+
return weights
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def update_weights_from_ipc(
|
|
49
|
+
zmq_ctx: zmq.Context,
|
|
50
|
+
zmq_handle: str,
|
|
51
|
+
device_id: int,
|
|
52
|
+
*,
|
|
53
|
+
run: Callable[[list[tuple[str, torch.Tensor]]], None],
|
|
54
|
+
post_hook: Callable[[], None] | None = None,
|
|
55
|
+
):
|
|
56
|
+
socket = zmq_ctx.socket(zmq.REP)
|
|
57
|
+
socket.connect(zmq_handle)
|
|
58
|
+
buffer: torch.Tensor | None = None
|
|
59
|
+
device_manager = DeviceManager()
|
|
60
|
+
try:
|
|
61
|
+
ipc_handle: tuple[Callable, tuple] = socket.recv_pyobj()
|
|
62
|
+
assert isinstance(ipc_handle, tuple)
|
|
63
|
+
buffer = _rebuild_ipc(ipc_handle, device_id)
|
|
64
|
+
assert buffer.dtype == torch.uint8
|
|
65
|
+
socket.send(b"")
|
|
66
|
+
except Exception as e:
|
|
67
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
68
|
+
socket.send_string(msg)
|
|
69
|
+
socket.recv() # wait for ack
|
|
70
|
+
raise
|
|
71
|
+
try:
|
|
72
|
+
while True:
|
|
73
|
+
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
|
|
74
|
+
if payload is None: # done signal
|
|
75
|
+
if post_hook is not None:
|
|
76
|
+
post_hook()
|
|
77
|
+
device_manager.device_module.synchronize()
|
|
78
|
+
socket.send(b"")
|
|
79
|
+
break
|
|
80
|
+
if isinstance(payload, list): # still updating weights
|
|
81
|
+
try:
|
|
82
|
+
run(_extract_weights(payload, buffer))
|
|
83
|
+
device_manager.device_module.synchronize()
|
|
84
|
+
socket.send(b"")
|
|
85
|
+
except Exception as e: # noqa: BLE001
|
|
86
|
+
# Send exception back to Parameter Server.
|
|
87
|
+
# Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
|
|
88
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
89
|
+
socket.send_string(msg)
|
|
90
|
+
elif isinstance(
|
|
91
|
+
payload, Exception
|
|
92
|
+
): # error occurred, got force quit signal from Parameter Server
|
|
93
|
+
raise payload
|
|
94
|
+
else:
|
|
95
|
+
raise TypeError(f"Unexpected payload type: {type(payload)}")
|
|
96
|
+
|
|
97
|
+
finally:
|
|
98
|
+
socket.close()
|
|
99
|
+
del buffer
|
|
100
|
+
gc.collect()
|
|
101
|
+
device_manager.device_module.empty_cache()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class VllmColocateWorkerExtension:
|
|
105
|
+
"""
|
|
106
|
+
Worker extension for vLLM to update weights from checkpoint-engine.
|
|
107
|
+
|
|
108
|
+
This class provides a worker extension mechanism that allows vLLM workers to receive
|
|
109
|
+
and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
|
|
110
|
+
The methods in this worker extension will be injected into the vLLM worker class and
|
|
111
|
+
are callable from the `collective_rpc` API, enabling seamless weight updates for both
|
|
112
|
+
vLLM V0 and V1 versions.
|
|
113
|
+
|
|
114
|
+
Note:
|
|
115
|
+
This class is defined in a separate module. The fully qualified name
|
|
116
|
+
`checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
|
|
117
|
+
`worker_extension_cls` argument when initializing the vLLM worker.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
|
|
121
|
+
"""
|
|
122
|
+
Update model weights from checkpoint-engine via IPC communication.
|
|
123
|
+
|
|
124
|
+
This method establishes a ZMQ connection to the checkpoint-engine and receives
|
|
125
|
+
weight updates through a shared memory buffer. The update process includes:
|
|
126
|
+
1. Receiving IPC handles to reconstruct shared memory tensors
|
|
127
|
+
2. Extracting flattened metadata describing tensor weights in the shared memory tensor
|
|
128
|
+
3. Loading weights into the model
|
|
129
|
+
4. Post-processing weights after loading
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
|
|
133
|
+
The device UUID is platform-specific:
|
|
134
|
+
- For CUDA: UUID from `current_platform.get_device_uuid()`
|
|
135
|
+
- For NPU: Format "NPU-{generated_uuid}"
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If the device type is not supported (not CUDA or NPU).
|
|
139
|
+
AssertionError: If the device is not properly initialized.
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
This method is called by vLLM's collective RPC mechanism. The ZMQ context
|
|
143
|
+
is lazily initialized on first call and reused for subsequent updates.
|
|
144
|
+
"""
|
|
145
|
+
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
|
146
|
+
from vllm.platforms import current_platform
|
|
147
|
+
|
|
148
|
+
# vllm-ascend not init device
|
|
149
|
+
if current_platform.device_type == "npu" and self.device is None:
|
|
150
|
+
self.device = torch.device(f"npu:{self.local_rank}")
|
|
151
|
+
assert self.device is not None
|
|
152
|
+
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
|
153
|
+
self._zmq_ctx = zmq.Context()
|
|
154
|
+
if current_platform.device_type == "cuda":
|
|
155
|
+
device_uuid = current_platform.get_device_uuid(self.device.index)
|
|
156
|
+
elif current_platform.device_type == "npu":
|
|
157
|
+
device_uuid = f"NPU-{npu_generate_uuid()}"
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError(f"Unsupported device type: {current_platform.device_type}")
|
|
160
|
+
update_weights_from_ipc(
|
|
161
|
+
self._zmq_ctx,
|
|
162
|
+
zmq_handles[device_uuid],
|
|
163
|
+
device_id=self.device.index,
|
|
164
|
+
run=self.model_runner.model.load_weights,
|
|
165
|
+
post_hook=lambda: process_weights_after_loading(
|
|
166
|
+
self.model_runner.model, self.model_config, self.device
|
|
167
|
+
),
|
|
168
|
+
)
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: checkpoint-engine
|
|
3
|
+
Version: 0.3.0rc0
|
|
4
|
+
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
|
+
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
|
+
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
7
|
+
Project-URL: Documentation, https://github.com/MoonshotAI/checkpoint-engine
|
|
8
|
+
Project-URL: Issue Tracker, https://github.com/MoonshotAI/checkpoint-engine/issues
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENCE
|
|
12
|
+
Requires-Dist: torch>=2.5.0
|
|
13
|
+
Requires-Dist: fastapi
|
|
14
|
+
Requires-Dist: pydantic>=2.0.0
|
|
15
|
+
Requires-Dist: safetensors
|
|
16
|
+
Requires-Dist: pyzmq
|
|
17
|
+
Requires-Dist: uvicorn
|
|
18
|
+
Requires-Dist: loguru
|
|
19
|
+
Requires-Dist: numpy
|
|
20
|
+
Requires-Dist: httpx
|
|
21
|
+
Provides-Extra: p2p
|
|
22
|
+
Requires-Dist: mooncake-transfer-engine>=0.3.5; extra == "p2p"
|
|
23
|
+
Dynamic: license-file
|
|
24
|
+
|
|
25
|
+
# Checkpoint Engine
|
|
26
|
+
Checkpoint-engine is a simple middleware to update model weights in LLM inference engines -- a critical step in reinforcement learning.
|
|
27
|
+
We provide an efficient and lightweight implementation for inplace weight update:
|
|
28
|
+
updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion parameters) across thousands of GPUs takes about 20s.
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
<div align="center">
|
|
32
|
+
<picture>
|
|
33
|
+
<img src="figures/checkpoint-engine.png" width="80%" alt="ckpt-engine">
|
|
34
|
+
</picture>
|
|
35
|
+
</div>
|
|
36
|
+
|
|
37
|
+
## Architecture
|
|
38
|
+
|
|
39
|
+
The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
|
|
40
|
+
|
|
41
|
+
- **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
|
|
42
|
+
- **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
|
|
43
|
+
|
|
44
|
+
### Optimized Weight Broadcast
|
|
45
|
+
In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
|
|
46
|
+
We arrange the data transfer into 3 stages:
|
|
47
|
+
1. H2D: moving weights to GPU memory. These weights may come from disk or the training engine.
|
|
48
|
+
2. broadcast: broadcast among checkpoint engine workers; the data results in a CUDA IPC buffer shared with inference engine.
|
|
49
|
+
3. reload: inference engine decides what subset of weights to copy from the broadcasted data.
|
|
50
|
+
|
|
51
|
+
Checkpoint-engine orchestrates the entire transfer process. It first gathers necessary metadata to create a plan, including deciding the proper bucket size for data transfer.
|
|
52
|
+
It then executes the transfer, where it controls the inference engine through a ZeroMQ socket. To maximize performance, it organizes the data transfers into a pipeline with overlapped communication and copy, illustrated below. The details can be found in [Kimi-K2 Technical Report](https://arxiv.org/abs/2507.20534).
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
<div align="center">
|
|
56
|
+
<picture>
|
|
57
|
+
<img src="figures/pipeline.png" width="80%" alt="pipeline">
|
|
58
|
+
</picture>
|
|
59
|
+
</div>
|
|
60
|
+
|
|
61
|
+
Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
|
|
62
|
+
|
|
63
|
+
### Optimized P2P Bucket Assignment
|
|
64
|
+
In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
|
|
65
|
+
To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
|
|
66
|
+
The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
|
|
67
|
+
See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
|
|
68
|
+
|
|
69
|
+
## Benchmark
|
|
70
|
+
|
|
71
|
+
| Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
|
|
72
|
+
| :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
|
|
73
|
+
| GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
|
|
74
|
+
| Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
|
|
75
|
+
| DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
|
|
76
|
+
| Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
|
|
77
|
+
| DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
|
|
78
|
+
| Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
|
|
79
|
+
|
|
80
|
+
All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
|
|
81
|
+
|
|
82
|
+
* FP8 test needs additional vLLM patches, see [FP8 quantization](#fp8-quantization).
|
|
83
|
+
* Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
|
|
84
|
+
* Since update duration is related to IPC bucket size, we provide the bucket size in the table.
|
|
85
|
+
* The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
|
|
86
|
+
* We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
|
|
87
|
+
|
|
88
|
+
## Installation
|
|
89
|
+
|
|
90
|
+
Use the fastest broadcast implementation
|
|
91
|
+
|
|
92
|
+
```Bash
|
|
93
|
+
pip install checkpoint-engine
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
Use the flexible P2P implementation, notice this will install `mooncake-transfer-engine` to support RDMA transfer between different ranks.
|
|
97
|
+
|
|
98
|
+
```Bash
|
|
99
|
+
pip install 'checkpoint-engine[p2p]'
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
## Getting Started
|
|
103
|
+
|
|
104
|
+
Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
|
|
105
|
+
|
|
106
|
+
```Bash
|
|
107
|
+
mkdir -p /opt/vLLM && cd /opt/vLLM
|
|
108
|
+
uv venv --python 3.12 --seed
|
|
109
|
+
source .venv/bin/activate
|
|
110
|
+
uv pip install vllm==0.10.2
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
Install checkpoint-engine
|
|
114
|
+
|
|
115
|
+
```Bash
|
|
116
|
+
uv pip install 'checkpoint-engine[p2p]'
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
We use `Qwen/Qwen3-235B-A22B-Instruct-2507` (BF16) as the test model
|
|
120
|
+
|
|
121
|
+
```Bash
|
|
122
|
+
hf download Qwen/Qwen3-235B-A22B-Instruct-2507 --local-dir /opt/models/Qwen/Qwen3-235B-A22B-Instruct-2507/
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
Start vLLM in dev mode and set `--load-format dummy`. Notice that we also set `--worker-extension-cls=checkpoint_engine.worker.VllmColocateWorkerExtension`
|
|
126
|
+
|
|
127
|
+
```Bash
|
|
128
|
+
VLLM_SERVER_DEV_MODE=1 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 19730 --trust-remote-code \
|
|
129
|
+
--tensor-parallel-size=8 --max-model-len 4096 --load-format dummy \
|
|
130
|
+
--served-model-name checkpoint-engine-demo --model /opt/models/Qwen/Qwen3-235B-A22B-Instruct-2507/ \
|
|
131
|
+
--worker-extension-cls checkpoint_engine.worker.VllmColocateWorkerExtension
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
Meanwhile, use this command to update weights by checkpoint-engine. No need to wait for vLLM to get ready.
|
|
135
|
+
|
|
136
|
+
```Bash
|
|
137
|
+
torchrun --nproc-per-node 8 examples/update.py --update-method all --checkpoint-path /opt/models/Qwen/Qwen3-235B-A22B-Instruct-2507/
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
### Reuse weights from existing instances
|
|
141
|
+
|
|
142
|
+
New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
|
|
143
|
+
|
|
144
|
+
First, start the existing instances with `--save-metas-file global_metas.pkl` to save global metas to a file and use `--sleep-time 300` to make sure they stay alive.
|
|
145
|
+
|
|
146
|
+
```Bash
|
|
147
|
+
torchrun --nproc-per-node 8 examples/update.py --checkpoint-path $MODEL_PATH \
|
|
148
|
+
--sleep-time 300 --save-metas-file global_metas.pkl
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
After a checkpoint is registered, new instances can obtain a copy of the checkpoint by setting `--load-metas-file global_metas.pkl`.
|
|
152
|
+
|
|
153
|
+
```Bash
|
|
154
|
+
torchrun --nproc-per-node 8 examples/update.py --load-metas-file global_metas.pkl
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
### FP8 quantization
|
|
158
|
+
|
|
159
|
+
FP8 quantization currently do not natively work in vLLM when updating weights.
|
|
160
|
+
We provide a simple patch in [`patches/vllm_fp8.patch`](./patches/vllm_fp8.patch) to handle the correct weight update.
|
|
161
|
+
Notice this patch is only tested in DeepSeek-V3.1 and Kimi-K2. Other models may meet some compatible issues.
|
|
162
|
+
|
|
163
|
+
A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM project and waiting to discuss and review.
|
|
164
|
+
|
|
165
|
+
### Test
|
|
166
|
+
|
|
167
|
+
Run a simple correctness test for checkpoint_engine
|
|
168
|
+
|
|
169
|
+
```bash
|
|
170
|
+
pytest tests/test_update.py
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
`test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
|
|
174
|
+
|
|
175
|
+
Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
|
|
176
|
+
|
|
177
|
+
```bash
|
|
178
|
+
pytest tests/ -m "not gpu"
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
### Environment Variables
|
|
182
|
+
- `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
|
|
183
|
+
- `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
|
|
184
|
+
- `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
|
|
185
|
+
|
|
186
|
+
## SGLang Integration
|
|
187
|
+
|
|
188
|
+
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
189
|
+
|
|
190
|
+
### Quick Start
|
|
191
|
+
|
|
192
|
+
**1. Install checkpoint-engine:**
|
|
193
|
+
```bash
|
|
194
|
+
pip install 'checkpoint-engine[p2p]'
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
**2. Launch SGLang server:**
|
|
198
|
+
```bash
|
|
199
|
+
python -m sglang.launch_server \
|
|
200
|
+
--model-path $MODEL_PATH \
|
|
201
|
+
--tp 8 \
|
|
202
|
+
--load-format dummy \
|
|
203
|
+
--wait-for-initial-weights
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
**3. Run checkpoint engine:**
|
|
207
|
+
```bash
|
|
208
|
+
python -m sglang.srt.checkpoint_engine.update \
|
|
209
|
+
--update-method broadcast \
|
|
210
|
+
--checkpoint-path $MODEL_PATH \
|
|
211
|
+
--inference-parallel-size 8
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
### Multi-Node Setup
|
|
215
|
+
|
|
216
|
+
For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
|
|
217
|
+
|
|
218
|
+
### Key Options
|
|
219
|
+
|
|
220
|
+
**SGLang Server:**
|
|
221
|
+
- `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
|
|
222
|
+
- `--load-format dummy`: Enable overlapping initialization tasks
|
|
223
|
+
|
|
224
|
+
**Checkpoint Engine:**
|
|
225
|
+
- `--update-method`: Choose `broadcast`, `p2p`, or `all`
|
|
226
|
+
- `--inference-parallel-size`: Number of parallel processes
|
|
227
|
+
- `--checkpoint-path`: Model checkpoint directory
|
|
228
|
+
|
|
229
|
+
## Limitations and Future Work
|
|
230
|
+
|
|
231
|
+
- This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
|
|
232
|
+
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
|
|
233
|
+
|
|
234
|
+
## Acknowledgments
|
|
235
|
+
|
|
236
|
+
This open source project uses the same vLLM interface in https://github.com/vllm-project/vllm/pull/24295 . Thanks for the comments and insights from [youkaichao](https://github.com/youkaichao).
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=v0iyeXv9HxMc4JmYu_bJTIGKXRQVfpijACyjq2P_sk0,714
|
|
3
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
4
|
+
checkpoint_engine/ps.py,sha256=eIvg_eI7HMedacoQQer62NRnGDjANtxsHVxgM93ccXQ,66977
|
|
5
|
+
checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
|
|
6
|
+
checkpoint_engine-0.3.0rc0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
7
|
+
checkpoint_engine-0.3.0rc0.dist-info/METADATA,sha256=iVd2qPdNyTPPX3XIEiuM0ASk8As72zSGfFIYicpZG3E,11562
|
|
8
|
+
checkpoint_engine-0.3.0rc0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
checkpoint_engine-0.3.0rc0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
10
|
+
checkpoint_engine-0.3.0rc0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Moonshot AI
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
checkpoint_engine
|