checkpoint-engine 0.1.2__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. checkpoint_engine-0.2.0/.github/workflows/cpu-tests.yml +30 -0
  2. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/PKG-INFO +18 -11
  3. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/README.md +17 -10
  4. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine/_version.py +3 -3
  5. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine/ps.py +272 -116
  6. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine.egg-info/PKG-INFO +18 -11
  7. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine.egg-info/SOURCES.txt +3 -0
  8. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/pyproject.toml +5 -0
  9. checkpoint_engine-0.2.0/tests/test_assign_receiver_ranks.py +68 -0
  10. checkpoint_engine-0.2.0/tests/test_rdma_parser.py +197 -0
  11. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/.github/workflows/pre-commit.yaml +0 -0
  12. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/.github/workflows/python-publish.yml +0 -0
  13. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/.gitignore +0 -0
  14. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/.pre-commit-config.yaml +0 -0
  15. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/LICENCE +0 -0
  16. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine/__init__.py +0 -0
  17. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine/worker.py +0 -0
  18. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  19. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine.egg-info/requires.txt +0 -0
  20. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/checkpoint_engine.egg-info/top_level.txt +0 -0
  21. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/examples/update.py +0 -0
  22. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/figures/checkpoint-engine.png +0 -0
  23. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/figures/overlap-update-and-copy.png +0 -0
  24. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/figures/pipeline.png +0 -0
  25. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/patches/vllm_fp8.patch +0 -0
  26. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/setup.cfg +0 -0
  27. {checkpoint_engine-0.1.2 → checkpoint_engine-0.2.0}/tests/test_update.py +0 -0
@@ -0,0 +1,30 @@
1
+ name: CPU Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ types: [opened, synchronize, reopened]
8
+
9
+
10
+ permissions:
11
+ contents: read
12
+
13
+ jobs:
14
+ build:
15
+ runs-on: ubuntu-latest
16
+ steps:
17
+ - name: Checkout code
18
+ uses: actions/checkout@v4
19
+ - name: Set up Python
20
+ uses: actions/setup-python@v3
21
+ with:
22
+ python-version: "3.10"
23
+ - name: Install dependencies
24
+ run: |
25
+ python -m pip install --upgrade pip
26
+ pip install pytest
27
+ pip install .[p2p]
28
+ - name: Do CPU tests with pytest
29
+ run: |
30
+ pytest -v -m "not gpu" tests/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -38,8 +38,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
38
38
 
39
39
  The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
40
40
 
41
- - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
42
- - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
41
+ - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
42
+ - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
43
43
 
44
44
  ### Optimized Weight Broadcast
45
45
  In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
@@ -60,16 +60,22 @@ It then executes the transfer, where it controls the inference engine through a
60
60
 
61
61
  Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
62
62
 
63
+ ### Optimized P2P Bucket Assignment
64
+ In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
65
+ To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
66
+ The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
67
+ See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
68
+
63
69
  ## Benchmark
64
70
 
65
71
  | Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
66
72
  | :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
67
- | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.17s | 3.94s (1.42GiB) | 8.83s (4.77GiB) |
68
- | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.46s | 6.75s (2.69GiB) | 16.47s (4.05GiB) |
69
- | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.44s | 12.22s (2.38GiB) | 25.77s (3.61GiB) |
70
- | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.81s | 15.45s (2.93GiB) | 36.24s (4.46GiB) |
71
- | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 1.40s | 13.88s (2.54GiB) | 33.30s (3.86 GiB) |
72
- | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.88s | 21.50s (2.99GiB) | 34.49s (4.57 GiB) |
73
+ | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
74
+ | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
75
+ | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
76
+ | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
77
+ | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
78
+ | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
73
79
 
74
80
  All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
75
81
 
@@ -77,6 +83,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and
77
83
  * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
78
84
  * Since update duration is related to IPC bucket size, we provide the bucket size in the table.
79
85
  * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
86
+ * We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
80
87
 
81
88
  ## Installation
82
89
 
@@ -92,7 +99,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
92
99
  pip install 'checkpoint-engine[p2p]'
93
100
  ```
94
101
 
95
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
102
+ If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
96
103
 
97
104
  ## Getting Started
98
105
 
@@ -165,11 +172,11 @@ Run a simple correctness test for checkpoint_engine
165
172
  torchrun --nproc-per-node 8 tests/test_update.py
166
173
  ```
167
174
 
175
+ Other unit tests can be done with pytest.
168
176
  ## Limitations and Future Work
169
177
 
170
178
  - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
171
179
  - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
172
- - The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
173
180
 
174
181
  ## Acknowledgments
175
182
 
@@ -14,8 +14,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
14
14
 
15
15
  The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
16
16
 
17
- - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
18
- - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
17
+ - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
18
+ - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
19
19
 
20
20
  ### Optimized Weight Broadcast
21
21
  In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
@@ -36,16 +36,22 @@ It then executes the transfer, where it controls the inference engine through a
36
36
 
37
37
  Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
38
38
 
39
+ ### Optimized P2P Bucket Assignment
40
+ In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
41
+ To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
42
+ The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
43
+ See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
44
+
39
45
  ## Benchmark
40
46
 
41
47
  | Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
42
48
  | :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
43
- | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.17s | 3.94s (1.42GiB) | 8.83s (4.77GiB) |
44
- | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.46s | 6.75s (2.69GiB) | 16.47s (4.05GiB) |
45
- | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.44s | 12.22s (2.38GiB) | 25.77s (3.61GiB) |
46
- | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.81s | 15.45s (2.93GiB) | 36.24s (4.46GiB) |
47
- | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 1.40s | 13.88s (2.54GiB) | 33.30s (3.86 GiB) |
48
- | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.88s | 21.50s (2.99GiB) | 34.49s (4.57 GiB) |
49
+ | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
50
+ | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
51
+ | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
52
+ | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
53
+ | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
54
+ | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
49
55
 
50
56
  All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
51
57
 
@@ -53,6 +59,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and
53
59
  * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
54
60
  * Since update duration is related to IPC bucket size, we provide the bucket size in the table.
55
61
  * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
62
+ * We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
56
63
 
57
64
  ## Installation
58
65
 
@@ -68,7 +75,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
68
75
  pip install 'checkpoint-engine[p2p]'
69
76
  ```
70
77
 
71
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
78
+ If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
72
79
 
73
80
  ## Getting Started
74
81
 
@@ -141,11 +148,11 @@ Run a simple correctness test for checkpoint_engine
141
148
  torchrun --nproc-per-node 8 tests/test_update.py
142
149
  ```
143
150
 
151
+ Other unit tests can be done with pytest.
144
152
  ## Limitations and Future Work
145
153
 
146
154
  - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
147
155
  - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
148
- - The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
149
156
 
150
157
  ## Acknowledgments
151
158
 
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.2'
32
- __version_tuple__ = version_tuple = (0, 1, 2)
31
+ __version__ = version = '0.2.0'
32
+ __version_tuple__ = version_tuple = (0, 2, 0)
33
33
 
34
- __commit_id__ = commit_id = 'g716c0dad9'
34
+ __commit_id__ = commit_id = 'ga29178282'
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  import argparse
4
2
  import concurrent.futures
5
3
  import ctypes
@@ -10,6 +8,7 @@ import socket
10
8
  import threading
11
9
  import time
12
10
  from collections import defaultdict
11
+ from collections.abc import Callable
13
12
  from datetime import timedelta
14
13
  from functools import lru_cache
15
14
  from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
@@ -26,7 +25,7 @@ from torch.multiprocessing.reductions import reduce_tensor
26
25
 
27
26
 
28
27
  if TYPE_CHECKING:
29
- from collections.abc import Callable
28
+ from typing import TypeVar
30
29
 
31
30
  from typing_extensions import TypedDict
32
31
 
@@ -37,6 +36,8 @@ if TYPE_CHECKING:
37
36
  type: type
38
37
  tp_concat_dim: int
39
38
 
39
+ T = TypeVar("T")
40
+
40
41
 
41
42
  def _dt_validate(value: Any) -> torch.dtype:
42
43
  if isinstance(value, str):
@@ -120,6 +121,7 @@ class MemoryBuffer(BaseModel):
120
121
  class MemoryBufferMetaList(BaseModel):
121
122
  p2p_store_addr: str | None
122
123
  memory_buffer_metas_list: list[MemoryBufferMetas]
124
+ rdma_device: str
123
125
 
124
126
 
125
127
  class DataToGather(MemoryBufferMetaList):
@@ -151,8 +153,8 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
151
153
  return ret
152
154
 
153
155
 
154
- def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
155
- def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
156
+ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
157
+ def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
156
158
  ret = {}
157
159
  with safe_open(fn, framework="pt") as f:
158
160
  for name in f.keys(): # noqa: SIM118
@@ -168,7 +170,7 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
168
170
  return ret
169
171
 
170
172
  # deprecated, will be removed in the future
171
- def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
173
+ def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
172
174
  """load *.np file and return memmap and related tensor meta"""
173
175
 
174
176
  def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
@@ -306,14 +308,7 @@ def _get_rdma_devices() -> list[str]:
306
308
  return devices_str.split(",")
307
309
  # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
308
310
  hca = os.getenv("NCCL_IB_HCA", None)
309
- if hca:
310
- hca_list = hca.split(",")
311
- if len(hca_list) > 1:
312
- # if NCCL_IB_HCA has multiple values, just return
313
- return hca_list
314
- else:
315
- hca = hca_list[0]
316
- return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
311
+ return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
317
312
 
318
313
 
319
314
  def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
@@ -331,6 +326,75 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
331
326
  return devices[local_rank // (gpu_count // len(devices))]
332
327
 
333
328
 
329
+ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
330
+ """
331
+ The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
332
+ The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
333
+
334
+ The list is comma-separated; port numbers are NOT supported yet.
335
+ An optional prefix '^' indicates the list is an exclude list.
336
+ A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
337
+ Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
338
+
339
+ Examples:
340
+ - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
341
+ - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
342
+ - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
343
+ - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
344
+ """
345
+ max_hcas = 32
346
+ if not value or value.strip() == "":
347
+ return available_devices[:max_hcas]
348
+
349
+ value = value.strip()
350
+ result = []
351
+ is_exclude = value.startswith("^")
352
+ if is_exclude:
353
+ value = value.removeprefix("^")
354
+ is_exact_match = value.startswith("=")
355
+ if is_exact_match:
356
+ value = value.removeprefix("=")
357
+
358
+ device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
359
+
360
+ result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
361
+ if is_exclude:
362
+ result = [dev for dev in available_devices if dev not in result]
363
+ if len(result) > max_hcas:
364
+ result = result[:max_hcas]
365
+
366
+ logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
367
+
368
+ return result
369
+
370
+
371
+ def _resolve_device_specs(
372
+ device_specs: list[str], is_exact_match: bool, available_devices: list[str]
373
+ ) -> list[str]:
374
+ devices = set()
375
+ for spec in device_specs:
376
+ parts = spec.split(":", 1)
377
+ device_name = parts[0].strip()
378
+ # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
379
+ # port = parts[1].strip() if len(parts) > 1 else None
380
+ base_devices = (
381
+ [device_name]
382
+ if device_name in available_devices
383
+ else []
384
+ if is_exact_match
385
+ else [dev for dev in available_devices if dev.startswith(device_name)]
386
+ )
387
+
388
+ if not base_devices:
389
+ logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
390
+ continue
391
+
392
+ for base_dev in base_devices:
393
+ devices.add(base_dev)
394
+
395
+ return sorted(devices)
396
+
397
+
334
398
  def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
335
399
  class TPMeta(BaseModel):
336
400
  concat_dim: int
@@ -493,8 +557,12 @@ def request_inference_to_update(
493
557
 
494
558
 
495
559
  def _gen_h2d_buckets(
496
- global_metas: dict[int, MemoryBufferMetaList], bucket_size: int
497
- ) -> list[tuple[int, H2DBucket]]:
560
+ global_metas: dict[int, MemoryBufferMetaList],
561
+ bucket_size: int,
562
+ local_topo: dict[str, set[int]],
563
+ remote_topo: dict[str, set[int]],
564
+ ranks: list[int] | None = None,
565
+ ) -> list[tuple[int, int, H2DBucket]]:
498
566
  buckets: list[tuple[int, H2DBucket]] = []
499
567
 
500
568
  for owner_rank, items in global_metas.items():
@@ -517,7 +585,73 @@ def _gen_h2d_buckets(
517
585
  assert buckets[-1][1].size > 0, (
518
586
  f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
519
587
  )
520
- return buckets
588
+ ranks_set = set(ranks) if ranks else set()
589
+ actual_local_topo = (
590
+ {k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo
591
+ )
592
+ # if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
593
+ if not ranks:
594
+ return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
595
+ else:
596
+ return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
597
+
598
+
599
+ def _assign_receiver_ranks(
600
+ buckets: list[tuple[int, "T"]],
601
+ local_topo: dict[str, set[int]],
602
+ remote_topo: dict[str, set[int]],
603
+ ) -> list[tuple[int, int, "T"]]:
604
+ """
605
+ (owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
606
+
607
+ Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
608
+ GPU-rdma_device topology will be considered to make full use of the bandwidth.
609
+ """
610
+ if not buckets:
611
+ logger.warning("bucket list is empty, no need to assign receiver ranks")
612
+ return []
613
+ rank_to_rdma_device = {
614
+ rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
615
+ }
616
+
617
+ # group buckets by owner RDMA devices
618
+ buckets_by_rdma_device = defaultdict(list)
619
+ for owner_rank, bucket in buckets:
620
+ owner_rdma_device = rank_to_rdma_device[owner_rank]
621
+ buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
622
+
623
+ buckets_matrix = list(buckets_by_rdma_device.values())
624
+ assert buckets_matrix, "buckets_matrix should not be empty"
625
+
626
+ # Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
627
+ num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
628
+ receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
629
+
630
+ flattened_buckets = [
631
+ buckets_matrix[row][col]
632
+ for col in range(
633
+ max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0
634
+ )
635
+ for row in range(len(buckets_matrix))
636
+ if col < len(buckets_matrix[row])
637
+ ]
638
+
639
+ buckets_with_receiver = []
640
+ assigned_cnt = 0
641
+ while assigned_cnt < len(flattened_buckets):
642
+ occupied_devices = set()
643
+ for receiver_rank in receiver_list:
644
+ if assigned_cnt >= len(flattened_buckets):
645
+ break
646
+ owner_rank, bucket = flattened_buckets[assigned_cnt]
647
+ rdma_device = rank_to_rdma_device[owner_rank]
648
+ if rdma_device in occupied_devices:
649
+ break
650
+ buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
651
+ occupied_devices.add(rdma_device)
652
+ assigned_cnt += 1
653
+
654
+ return buckets_with_receiver
521
655
 
522
656
 
523
657
  def _get_master_port(master_port: int | None = None) -> int:
@@ -528,6 +662,20 @@ def _get_master_port(master_port: int | None = None) -> int:
528
662
  return master_port
529
663
 
530
664
 
665
+ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, int]:
666
+ """
667
+ map the real ranks (receiver_rank) to the bcast ranks (0 ~ len(ranks) - 1),
668
+ which are generated in self.init_process_group_for_ranks
669
+ """
670
+ bcast_rank_map: dict[int, int] = {}
671
+ if not ranks:
672
+ bcast_rank_map = {r: r for r in range(world_size)}
673
+ else:
674
+ for i, r in enumerate(ranks):
675
+ bcast_rank_map[r] = i
676
+ return bcast_rank_map
677
+
678
+
531
679
  class P2PStore:
532
680
  def __init__(self):
533
681
  from mooncake.engine import TransferEngine
@@ -535,14 +683,14 @@ class P2PStore:
535
683
  self.rank = int(os.getenv("RANK"))
536
684
  gpu_count = torch.cuda.device_count()
537
685
  local_rank = self.rank % gpu_count
538
- device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
686
+ self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
539
687
  self.ip = _get_ip()
540
688
 
541
689
  # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
542
690
  retry_count = 8
543
691
  for i in range(retry_count):
544
692
  self.engine = TransferEngine()
545
- ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", device)
693
+ ret = self.engine.initialize(self.ip, "P2PHANDSHAKE", "rdma", self.device)
546
694
  if ret == 0:
547
695
  break
548
696
  # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
@@ -556,7 +704,7 @@ class P2PStore:
556
704
  self.port = self.engine.get_rpc_port()
557
705
  self.named_tensors: dict[str, torch.Tensor] = {}
558
706
  logger.info(
559
- f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
707
+ f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
560
708
  )
561
709
 
562
710
  @property
@@ -595,7 +743,13 @@ class P2PStore:
595
743
 
596
744
  class ParameterServer:
597
745
  def __init__(
598
- self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
746
+ self,
747
+ *,
748
+ rank: int | None = None,
749
+ world_size: int | None = None,
750
+ auto_pg: bool = False,
751
+ gpu_count: int | None = None,
752
+ mem_fraction: float | None = None,
599
753
  ):
600
754
  """
601
755
  Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
@@ -603,17 +757,29 @@ class ParameterServer:
603
757
  Args:
604
758
  auto_pg: Whether to automatically initialize the process group.
605
759
  Notice that if auto_pg is True, will destroy the process group after update.
760
+ mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
606
761
  """
607
762
  self._rank = rank or int(os.environ.get("RANK", None))
608
763
  self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
609
- self._gpu_count = torch.cuda.device_count()
764
+ self._gpu_count = gpu_count or torch.cuda.device_count()
610
765
  self._local_rank = self._rank % self._gpu_count
611
766
  self._auto_pg = auto_pg
612
767
  self._all_hosts = []
613
768
  self._global_device_uuids: list[str] = []
769
+ self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
770
+ self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
771
+ self._mem_fraction = mem_fraction or 0.9
614
772
 
615
773
  assert self._rank is not None and self._rank >= 0, self._rank
616
774
  assert self._world_size and self._world_size > 0, self._world_size
775
+ assert (
776
+ self._gpu_count is not None
777
+ and self._gpu_count > 0
778
+ and self._gpu_count <= torch.cuda.device_count()
779
+ ), self._gpu_count
780
+ assert (
781
+ self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
782
+ ), self._mem_fraction
617
783
 
618
784
  self._zmq_ctx = zmq.Context()
619
785
  self._zmq_addr_counter = 0
@@ -630,6 +796,7 @@ class ParameterServer:
630
796
  device_index = self._local_rank
631
797
  torch.cuda.set_device(device_index)
632
798
  self._device_uuid = _get_physical_gpu_id(device_index)
799
+ self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
633
800
 
634
801
  def _logger_rank0(self, msg: str):
635
802
  if self._local_rank == 0:
@@ -640,6 +807,13 @@ class ParameterServer:
640
807
 
641
808
  def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
642
809
  self._current_global_parameter_metas = metas
810
+ self._remote_rdma_devices = defaultdict(set)
811
+ for i, meta in self._current_global_parameter_metas.items():
812
+ assert meta.rdma_device is not None, "meta.rdma_device should not be None"
813
+ assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None"
814
+ self._remote_rdma_devices[
815
+ meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
816
+ ].add(i)
643
817
 
644
818
  def register_checkpoint(
645
819
  self,
@@ -713,11 +887,11 @@ class ParameterServer:
713
887
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
714
888
  host_ip=_get_ip(),
715
889
  device_uuid=self._device_uuid,
890
+ rdma_device=self._rdma_device or "",
716
891
  )
717
892
 
718
893
  dist.all_gather_object(metas_lst, metas)
719
894
 
720
- self._current_global_parameter_metas = {}
721
895
  num_parameters = 0
722
896
  all_hosts: list[str] = []
723
897
  global_device_uuids: list[str] = []
@@ -728,12 +902,24 @@ class ParameterServer:
728
902
  if not self._global_device_uuids:
729
903
  global_device_uuids.append(metas_buckets.device_uuid)
730
904
  if metas_buckets.memory_buffer_metas_list:
731
- self._current_global_parameter_metas[i] = metas_buckets
905
+ self._current_global_parameter_metas[i] = MemoryBufferMetaList(
906
+ memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
907
+ p2p_store_addr=metas_buckets.p2p_store_addr,
908
+ rdma_device=metas_buckets.rdma_device,
909
+ )
732
910
  num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
911
+ self._local_rdma_devices[
912
+ metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0]
913
+ if metas_buckets.p2p_store_addr
914
+ else metas_buckets.host_ip
915
+ ].add(i)
733
916
  if not self._all_hosts:
734
917
  self._all_hosts = all_hosts
735
918
  if not self._global_device_uuids:
736
919
  self._global_device_uuids = global_device_uuids
920
+ # Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
921
+ # Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
922
+ self._remote_rdma_devices = self._local_rdma_devices.copy()
737
923
  logger.info(
738
924
  f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
739
925
  )
@@ -788,6 +974,7 @@ class ParameterServer:
788
974
  If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
789
975
  which is useful in disaggregated architecture.
790
976
  """
977
+ assert req_func is not None, "req_func is required"
791
978
  try:
792
979
  # if both ranks is None or [], it will use fully broadcast to update to all ranks
793
980
  if not ranks:
@@ -795,15 +982,15 @@ class ParameterServer:
795
982
  self.init_process_group()
796
983
  self._update_per_bucket(checkpoint_name, req_func)
797
984
  else:
798
- if self._rank not in ranks:
799
- return
800
985
  if self._auto_pg:
801
986
  if dist.is_initialized():
802
987
  dist.destroy_process_group()
803
988
  # HACK: wait 2s to ensure destroy is finished
804
989
  time.sleep(2)
805
990
  self.init_process_group_for_ranks(ranks)
806
- self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
991
+ if self._rank not in ranks:
992
+ return
993
+ self._update_per_bucket(checkpoint_name, req_func, ranks)
807
994
  if self._auto_pg:
808
995
  dist.destroy_process_group()
809
996
 
@@ -835,8 +1022,8 @@ class ParameterServer:
835
1022
  # auto detect bucket size
836
1023
  tensor = torch.tensor(
837
1024
  [
838
- # 90% of current cuda free memory bytes
839
- int(float(torch.cuda.mem_get_info()[0]) * 0.9),
1025
+ # proportion of current cuda free memory bytes
1026
+ int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
840
1027
  # we use negative value to reuse allreduce min operation
841
1028
  # for getting the max value of zmq_addr_counter in all ranks
842
1029
  -self._zmq_addr_counter,
@@ -948,71 +1135,6 @@ class ParameterServer:
948
1135
  backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
949
1136
  )
950
1137
 
951
- def _update_per_bucket_p2p(
952
- self,
953
- checkpoint_name: str,
954
- req_func: Callable[[list[tuple[str, str]]], None],
955
- ranks: list[int],
956
- ):
957
- assert self._p2p_store is not None, "p2p store is not initialized"
958
- assert ranks, "ranks should be set"
959
- if len(self._current_global_parameter_metas) == 0:
960
- raise ValueError("parameter metas is empty")
961
- assert dist.is_initialized(), (
962
- "process group is not initialized when update model per bucket p2p"
963
- )
964
-
965
- need_update = self._rank in ranks
966
- logger.info(
967
- f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
968
- f"gpu_count {self._gpu_count}, world_size {self._world_size}"
969
- )
970
-
971
- if not need_update:
972
- return
973
-
974
- # first execute a barrier to avoid subsequent cuda oom
975
- dist.barrier()
976
-
977
- bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
978
- buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
979
- ipc_buffer_name = "__ipc_buffer___"
980
- self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
981
- logger.info(
982
- f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
983
- )
984
- handle = reduce_tensor(buffer)
985
-
986
- buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
987
- socket, socket_paths = self._bind_zmq_socket()
988
- req_thread = threading.Thread(
989
- target=req_func,
990
- args=(socket_paths,),
991
- )
992
- req_thread.start()
993
- socket.send_pyobj(handle)
994
- for gidx, (owner_rank, bucket) in enumerate(buckets):
995
- self._logger_rank0(
996
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
997
- )
998
- _buffer = buffer[gidx % 2 * bucket_size : gidx % 2 * bucket_size + bucket.size]
999
- if dist.get_rank() == 0:
1000
- self._copy_to_buffer(checkpoint_name, bucket, _buffer, owner_rank)
1001
- # broadcast the collected data to all ranks
1002
- dist.broadcast(_buffer, src=0)
1003
- socket.recv()
1004
- dist.barrier()
1005
- socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
1006
-
1007
- socket.recv()
1008
- socket.send_pyobj(None)
1009
- socket.recv()
1010
- req_thread.join()
1011
- dist.barrier()
1012
- socket.close()
1013
- self._p2p_store.unregister_named_tensors([ipc_buffer_name])
1014
- torch.cuda.empty_cache()
1015
-
1016
1138
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
1017
1139
  addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
1018
1140
  metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
@@ -1042,38 +1164,63 @@ class ParameterServer:
1042
1164
  self,
1043
1165
  checkpoint_name: str,
1044
1166
  req_func: Callable[[list[tuple[str, str]]], None],
1167
+ ranks: list[int] | None = None,
1045
1168
  ):
1046
- if len(self._current_global_parameter_metas) == 0:
1047
- raise ValueError("parameter metas is empty")
1048
-
1169
+ assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
1049
1170
  assert dist.is_initialized(), "process group is not initialized"
1171
+ # if both ranks is None or [], it will use fully broadcast to update to all ranks
1172
+ if not ranks:
1173
+ logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
1174
+ # if ranks is set, it will use p2p to update to the ranks
1175
+ else:
1176
+ assert self._p2p_store is not None, "p2p store is not initialized"
1177
+ assert ranks, "ranks should be set"
1178
+
1179
+ need_update = self._rank in ranks
1180
+ logger.info(
1181
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
1182
+ f"gpu_count {self._gpu_count}, world_size {self._world_size}"
1183
+ )
1050
1184
 
1051
- logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
1185
+ if not need_update:
1186
+ return
1187
+ # first execute a barrier to avoid subsequent cuda oom
1188
+ dist.barrier()
1052
1189
 
1053
1190
  bucket_size, disable_h2d_buffer = self._detect_bucket_size()
1054
- buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
1191
+ buckets = _gen_h2d_buckets(
1192
+ self._current_global_parameter_metas,
1193
+ bucket_size,
1194
+ self._local_rdma_devices,
1195
+ self._remote_rdma_devices,
1196
+ ranks,
1197
+ )
1055
1198
 
1056
1199
  h2d_buffer: torch.Tensor | None = (
1057
1200
  None
1058
1201
  if disable_h2d_buffer
1059
1202
  else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1060
1203
  )
1061
-
1062
- owner_rank_buckets: list[H2DBucket] = []
1063
- for owner_rank, bucket in buckets:
1064
- if owner_rank != self._rank:
1204
+ # p2p store need to register h2d_buffer to let other ranks read
1205
+ if ranks:
1206
+ h2d_buffer_name = "__h2d_buffer__"
1207
+ if h2d_buffer is not None and self._p2p_store is not None:
1208
+ self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
1209
+ receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
1210
+ for receiver_rank, owner_rank, bucket in buckets:
1211
+ if receiver_rank != self._rank:
1065
1212
  continue
1066
- owner_rank_buckets.append(bucket)
1213
+ receiver_rank_buckets.append((owner_rank, bucket))
1067
1214
 
1068
1215
  buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
1069
1216
  handle = reduce_tensor(buffer)
1070
1217
 
1071
- buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list)
1218
+ buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
1072
1219
  max_len = 0
1073
- for owner_rank, bucket in buckets:
1074
- buckets_by_owner_rank[owner_rank].append(bucket)
1075
- if len(buckets_by_owner_rank[owner_rank]) > max_len:
1076
- max_len = len(buckets_by_owner_rank[owner_rank])
1220
+ for receiver_rank, _, bucket in buckets:
1221
+ buckets_by_receiver_rank[receiver_rank].append(bucket)
1222
+ if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
1223
+ max_len = len(buckets_by_receiver_rank[receiver_rank])
1077
1224
 
1078
1225
  socket, socket_paths = self._bind_zmq_socket()
1079
1226
  req_thread = threading.Thread(
@@ -1084,11 +1231,16 @@ class ParameterServer:
1084
1231
  socket.send_pyobj(handle)
1085
1232
 
1086
1233
  gidx = 0
1234
+ bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
1087
1235
  for i in range(max_len):
1088
- if i < len(owner_rank_buckets) and not disable_h2d_buffer:
1089
- self._copy_to_buffer(checkpoint_name, owner_rank_buckets[i], h2d_buffer)
1090
-
1091
- for owner_rank, _buckets in buckets_by_owner_rank.items():
1236
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
1237
+ self._copy_to_buffer(
1238
+ checkpoint_name,
1239
+ receiver_rank_buckets[i][1],
1240
+ h2d_buffer,
1241
+ receiver_rank_buckets[i][0] if ranks else None,
1242
+ )
1243
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
1092
1244
  if i >= len(_buckets):
1093
1245
  continue
1094
1246
  bucket = _buckets[i]
@@ -1097,18 +1249,19 @@ class ParameterServer:
1097
1249
  torch.cuda.memory_reserved() / 1024 / 1024,
1098
1250
  )
1099
1251
  self._logger_rank0(
1100
- f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1252
+ f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
1101
1253
  f"Current CUDA allocated {alloc:.2f} MB, "
1102
1254
  f"reserved {reserved:.2f} MB."
1103
1255
  )
1104
1256
  start = gidx % 2 * bucket_size
1105
1257
  buffer_b: torch.Tensor = buffer[start : start + bucket.size]
1106
- if owner_rank == self._rank:
1258
+ if receiver_rank == self._rank:
1107
1259
  if disable_h2d_buffer:
1108
1260
  self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
1109
1261
  else:
1110
1262
  buffer_b.data.copy_(h2d_buffer[: bucket.size])
1111
- dist.broadcast(buffer_b, src=owner_rank)
1263
+ brank = bcast_rank_map[receiver_rank]
1264
+ dist.broadcast(buffer_b, src=brank)
1112
1265
  socket.recv()
1113
1266
  dist.barrier()
1114
1267
  socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
@@ -1120,6 +1273,9 @@ class ParameterServer:
1120
1273
  req_thread.join()
1121
1274
  dist.barrier()
1122
1275
  socket.close()
1276
+ if ranks and h2d_buffer is not None:
1277
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
1278
+
1123
1279
  torch.cuda.empty_cache()
1124
1280
 
1125
1281
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.2
3
+ Version: 0.2.0
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -38,8 +38,8 @@ updating our [Kimi-K2](https://github.com/MoonshotAI/Kimi-K2) model (1 Trillion
38
38
 
39
39
  The core weight update logic is in `ParameterServer` class, a service colocated with inference engines. It provides two implementations of weight update: Broadcast and P2P.
40
40
 
41
- - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket`.
42
- - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
41
+ - **Broadcast**: Used when a large number of inference instances need to update weights in synchronous. This is the fastest implementation and should be used as the default update method. See `_update_per_bucket` with `ranks == None or []`.
42
+ - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket` with `ranks` specified.
43
43
 
44
44
  ### Optimized Weight Broadcast
45
45
  In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
@@ -60,16 +60,22 @@ It then executes the transfer, where it controls the inference engine through a
60
60
 
61
61
  Pipelining naturally requires more GPU memory. When memory is not enough, checkpoint-engine will fallback to serial execution.
62
62
 
63
+ ### Optimized P2P Bucket Assignment
64
+ In the *P2P* implementation, checkpoint-engine needs to send weights from existing instances to new instances.
65
+ To minimize the overall transfer time, checkpoint-engine optimizes the bucket assignment for each sender-receiver pair.
66
+ The optimization goal is to make full use of the available network bandwidth for each sender and receiver.
67
+ See [issue #25](https://github.com/MoonshotAI/checkpoint-engine/issues/25)
68
+
63
69
  ## Benchmark
64
70
 
65
71
  | Model | Device Info | GatherMetas | Update (Broadcast) | Update (P2P) |
66
72
  | :----------------------------------- | :----------- | :---------- |:-------------------| :---------------------- |
67
- | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.17s | 3.94s (1.42GiB) | 8.83s (4.77GiB) |
68
- | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.46s | 6.75s (2.69GiB) | 16.47s (4.05GiB) |
69
- | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.44s | 12.22s (2.38GiB) | 25.77s (3.61GiB) |
70
- | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.81s | 15.45s (2.93GiB) | 36.24s (4.46GiB) |
71
- | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 1.40s | 13.88s (2.54GiB) | 33.30s (3.86 GiB) |
72
- | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.88s | 21.50s (2.99GiB) | 34.49s (4.57 GiB) |
73
+ | GLM-4.5-Air (BF16) | 8xH800 TP8 | 0.12s | 3.47s (3.02GiB) | 4.12s (3.02GiB) |
74
+ | Qwen3-235B-A22B-Instruct-2507 (BF16) | 8xH800 TP8 | 0.33s | 6.22s (2.67GiB) | 7.10s (2.68GiB) |
75
+ | DeepSeek-V3.1 (FP8) | 16xH20 TP16 | 1.17s | 10.19s (5.39GiB) | 11.80s (5.41GiB) |
76
+ | Kimi-K2-Instruct (FP8) | 16xH20 TP16 | 1.33s | 14.36s (5.89GiB) | 17.49s (5.91GiB) |
77
+ | DeepSeek-V3.1 (FP8) | 256xH20 TP16 | 0.80s | 11.33s (8.00GiB) | 11.81s (8.00GiB) |
78
+ | Kimi-K2-Instruct (FP8) | 256xH20 TP16 | 1.22s | 16.04s (8.00GiB) | 16.75s (8.00GiB) |
73
79
 
74
80
  All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
75
81
 
@@ -77,6 +83,7 @@ All results above are tested by [`examples/update.py`](./examples/update.py) and
77
83
  * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
78
84
  * Since update duration is related to IPC bucket size, we provide the bucket size in the table.
79
85
  * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
86
+ * We bind each GPU to its corresponding NUMA node to ensure stable H2D transfer speeds.
80
87
 
81
88
  ## Installation
82
89
 
@@ -92,7 +99,7 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
92
99
  pip install 'checkpoint-engine[p2p]'
93
100
  ```
94
101
 
95
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
102
+ If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
96
103
 
97
104
  ## Getting Started
98
105
 
@@ -165,11 +172,11 @@ Run a simple correctness test for checkpoint_engine
165
172
  torchrun --nproc-per-node 8 tests/test_update.py
166
173
  ```
167
174
 
175
+ Other unit tests can be done with pytest.
168
176
  ## Limitations and Future Work
169
177
 
170
178
  - This project is currently only tested with vLLM. But it is easy to integrate with other frameworks like SGLang.
171
179
  - The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
172
- - The P2P update method is currently not the optimal implementation since it will receive data only in rank 0 and broadcast to others synchronizely. This is a potential optimization in the future.
173
180
 
174
181
  ## Acknowledgments
175
182
 
@@ -3,6 +3,7 @@
3
3
  LICENCE
4
4
  README.md
5
5
  pyproject.toml
6
+ .github/workflows/cpu-tests.yml
6
7
  .github/workflows/pre-commit.yaml
7
8
  .github/workflows/python-publish.yml
8
9
  checkpoint_engine/__init__.py
@@ -19,4 +20,6 @@ figures/checkpoint-engine.png
19
20
  figures/overlap-update-and-copy.png
20
21
  figures/pipeline.png
21
22
  patches/vllm_fp8.patch
23
+ tests/test_assign_receiver_ranks.py
24
+ tests/test_rdma_parser.py
22
25
  tests/test_update.py
@@ -158,3 +158,8 @@ inline-quotes = "double"
158
158
 
159
159
  [tool.ruff.lint.flake8-tidy-imports]
160
160
  ban-relative-imports = "all"
161
+
162
+ [tool.pytest.ini_options]
163
+ markers = [
164
+ "gpu: marks tests as GPU test (deselect with '-m \"not gpu\"')",
165
+ ]
@@ -0,0 +1,68 @@
1
+ import pytest
2
+
3
+ from checkpoint_engine.ps import _assign_receiver_ranks
4
+
5
+
6
+ @pytest.mark.parametrize(
7
+ "buckets,local_topo,remote_topo,expected_results",
8
+ [
9
+ (
10
+ [(i % 8, f"bucket{i}") for i in range(80)],
11
+ {f"rdma{i}": {i} for i in range(8)},
12
+ {f"rdma{i}": {i} for i in range(8)},
13
+ [(i % 8, i % 8, f"bucket{i}") for i in range(80)],
14
+ ),
15
+ (
16
+ [(i % 8, f"bucket{i}") for i in range(80)],
17
+ {f"rdma{i}": {i} for i in range(8)},
18
+ {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)},
19
+ [((i // 2 % 4), i % 8, f"bucket{i}") for i in range(80)],
20
+ ),
21
+ (
22
+ [(i % 8, f"bucket{i}") for i in range(80)],
23
+ {f"rdma{i}": {i, i + 1, i + 2, i + 3} for i in range(0, 8, 4)},
24
+ {f"rdma{i}": {i} for i in range(8)},
25
+ [((i % 2) * 4, i % 8, f"bucket{i}") for i in range(80)],
26
+ ),
27
+ (
28
+ [(i % 8, f"bucket{i}") for i in range(13)],
29
+ {f"rdma{i}": {i} for i in range(8)},
30
+ {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)},
31
+ [((i // 2 % 4), i % 8, f"bucket{i}") for i in range(13)],
32
+ ),
33
+ (
34
+ [(i % 8, f"bucket{i}") for i in range(13)],
35
+ {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)},
36
+ {f"rdma{i}": {i} for i in range(8)},
37
+ [((i % 4) * 2, i % 8, f"bucket{i}") for i in range(13)],
38
+ ),
39
+ (
40
+ [(i % 8, f"bucket{i}") for i in range(13)],
41
+ {f"rdma{i}": {i} for i in range(3)},
42
+ {f"rdma{i}": {i, i + 1} for i in range(0, 8, 2)},
43
+ [
44
+ (0, 0, "bucket0"),
45
+ (1, 1, "bucket1"),
46
+ (1, 2, "bucket2"),
47
+ (2, 3, "bucket3"),
48
+ (2, 4, "bucket4"),
49
+ (0, 5, "bucket5"),
50
+ (0, 6, "bucket6"),
51
+ (1, 7, "bucket7"),
52
+ (2, 0, "bucket8"),
53
+ (2, 1, "bucket9"),
54
+ (0, 2, "bucket10"),
55
+ (0, 3, "bucket11"),
56
+ (1, 4, "bucket12"),
57
+ ],
58
+ ),
59
+ ],
60
+ )
61
+ def test_basic_functionality(
62
+ buckets: list[tuple[int, str]],
63
+ local_topo: dict[str, int],
64
+ remote_topo: dict[str, int],
65
+ expected_results: list[tuple[int, int, str]],
66
+ ):
67
+ assert len(expected_results) == len(buckets)
68
+ assert set(expected_results) == set(_assign_receiver_ranks(buckets, local_topo, remote_topo))
@@ -0,0 +1,197 @@
1
+ import os
2
+ from unittest.mock import patch
3
+
4
+ import pytest
5
+
6
+ from checkpoint_engine.ps import (
7
+ _get_my_rdma_device,
8
+ _get_rdma_devices,
9
+ _ibv_get_device_list,
10
+ _parse_NCCL_IB_HCA,
11
+ )
12
+
13
+
14
+ @pytest.fixture
15
+ def mock_available_devices() -> list[str]:
16
+ """Provide mock available device list"""
17
+ return ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]
18
+
19
+
20
+ def test_detect_ibv_list():
21
+ """Test detection of _ibv_get_device_list function"""
22
+ # Skip this test if no real infiniband devices exist
23
+ if not os.path.exists("/sys/class/infiniband"):
24
+ pytest.skip("No infiniband devices found on system")
25
+
26
+ real_ibv_list = sorted(os.listdir("/sys/class/infiniband"))
27
+ if real_ibv_list:
28
+ devices = _ibv_get_device_list()
29
+ assert isinstance(devices, list)
30
+
31
+
32
+ def test_parse_max_hcas_limit():
33
+ """Test maximum HCA quantity limit"""
34
+ # Create mock data with more than 32 devices
35
+ many_devices = [f"device_{i}" for i in range(50)]
36
+ result = _parse_NCCL_IB_HCA("", many_devices)
37
+ assert len(result) == 32
38
+ assert result == many_devices[:32]
39
+
40
+
41
+ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]):
42
+ """Test _get_rdma_devices with no environment variables"""
43
+ with (
44
+ patch.dict(os.environ, clear=True),
45
+ patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices),
46
+ ):
47
+ devices = _get_rdma_devices()
48
+ assert sorted(devices) == sorted(mock_available_devices)
49
+
50
+
51
+ @pytest.mark.parametrize(
52
+ "input_value,expected",
53
+ [
54
+ pytest.param("", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="empty string"),
55
+ pytest.param(" \t\n ", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="whitespace"),
56
+ pytest.param("None", [], id="None string"),
57
+ pytest.param("^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret"),
58
+ pytest.param("^=", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="caret-equals"),
59
+ pytest.param("=^", [], id="equals-caret"),
60
+ pytest.param("^^", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"], id="double-caret"),
61
+ pytest.param("=", [], id="equals"),
62
+ pytest.param("==", [], id="double-equals"),
63
+ ],
64
+ )
65
+ def test_parse_basic_cases(
66
+ input_value: str, expected: list[str], mock_available_devices: list[str]
67
+ ):
68
+ """Test basic parsing cases: empty string, whitespace, None"""
69
+ result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
70
+ assert result == expected
71
+
72
+
73
+ @pytest.mark.parametrize(
74
+ "input_value,expected",
75
+ [
76
+ # prefix
77
+ ("mlx5_0", ["mlx5_0"]),
78
+ ("mlx5", ["mlx5_0", "mlx5_1"]),
79
+ # exact match
80
+ ("=mlx5_0", ["mlx5_0"]),
81
+ ("=mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
82
+ # ignore ports, whitespace and duplicated commas
83
+ ("mlx5_0:1,mlx5_1:2", ["mlx5_0", "mlx5_1"]),
84
+ ("mlx5_0:1,mlx5_1", ["mlx5_0", "mlx5_1"]),
85
+ (" mlx5_0 , mlx5_1 ", ["mlx5_0", "mlx5_1"]),
86
+ ("mlx5_0,,mlx5_1", ["mlx5_0", "mlx5_1"]),
87
+ # exclusion
88
+ ("^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]),
89
+ ("^mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]),
90
+ ("^mlx5", ["mlx4_0", "mlx4_1"]),
91
+ ("^=mlx5_0,mlx5_1", ["mlx4_0", "mlx4_1"]),
92
+ ("^=mlx4", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
93
+ ],
94
+ )
95
+ def test_parse_various_patterns(
96
+ input_value: str, expected: list[str], mock_available_devices: list[str]
97
+ ):
98
+ """Test various parsing patterns"""
99
+ result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
100
+ assert result == expected
101
+
102
+
103
+ @pytest.mark.parametrize(
104
+ "input_value,expected_result,expected_warning",
105
+ [
106
+ ("=mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=True."),
107
+ ("mlx5_100", [], "No RDMA device match device_name='mlx5_100' where is_exact_match=False."),
108
+ (
109
+ "^mlx5_100",
110
+ ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"],
111
+ "No RDMA device match device_name='mlx5_100' where is_exact_match=False.",
112
+ ),
113
+ ("mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=False."),
114
+ ("=mlx6", [], "No RDMA device match device_name='mlx6' where is_exact_match=True."),
115
+ ],
116
+ )
117
+ def test_parse_exact_match_with_nonexistent_device(
118
+ input_value: str,
119
+ expected_result: list[str],
120
+ expected_warning: str,
121
+ mock_available_devices: list[str],
122
+ ):
123
+ """Test exact matching with non-existent device"""
124
+ with patch("checkpoint_engine.ps.logger") as mock_logger:
125
+ result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
126
+ assert result == expected_result
127
+ mock_logger.warning.assert_called_once_with(expected_warning)
128
+
129
+
130
+ @pytest.mark.parametrize(
131
+ "env_var_name,env_var_value,expected_devices",
132
+ [
133
+ ("PS_P2P_STORE_RDMA_DEVICES", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
134
+ ("NCCL_IB_HCA", "mlx5", ["mlx5_0", "mlx5_1"]),
135
+ ("NCCL_IB_HCA", "mlx5_0,mlx5_1", ["mlx5_0", "mlx5_1"]),
136
+ ("NCCL_IB_HCA", "^mlx5_0", ["mlx5_1", "mlx4_0", "mlx4_1"]),
137
+ ("NCCL_IB_HCA", "mlx6", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
138
+ ("NCCL_IB_HCA", "", ["mlx5_0", "mlx5_1", "mlx4_0", "mlx4_1"]),
139
+ ],
140
+ )
141
+ def test_get_rdma_devices_with_env_vars(
142
+ env_var_name: str,
143
+ env_var_value: str,
144
+ expected_devices: list[str],
145
+ mock_available_devices: list[str],
146
+ ):
147
+ """Test _get_rdma_devices with various environment variables"""
148
+ env_dict = {env_var_name: env_var_value}
149
+ with (
150
+ patch.dict(os.environ, env_dict),
151
+ patch("checkpoint_engine.ps._ibv_get_device_list", return_value=mock_available_devices),
152
+ ):
153
+ devices = _get_rdma_devices()
154
+ assert sorted(devices) == sorted(expected_devices)
155
+
156
+
157
+ @pytest.mark.parametrize(
158
+ "local_rank,gpu_count,expected_device",
159
+ [
160
+ (0, 4, "mlx5_0"),
161
+ (3, 4, "mlx5_3"),
162
+ (4, 8, "mlx5_2"),
163
+ (7, 8, "mlx5_3"),
164
+ ],
165
+ )
166
+ def test_get_my_rdma_device_basic(local_rank: int, gpu_count: int, expected_device: str):
167
+ """Test _get_my_rdma_device with basic allocation"""
168
+ # Use fewer devices to match the GPU count constraint
169
+ devices = ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3"]
170
+ device = _get_my_rdma_device(local_rank, gpu_count, devices)
171
+ assert device == expected_device
172
+
173
+
174
+ @pytest.mark.parametrize(
175
+ "local_rank,gpu_count,devices,error",
176
+ [
177
+ (
178
+ 0,
179
+ 4,
180
+ ["mlx5_0", "mlx5_1", "mlx5_2", "mlx5_3", "mlx5_4"],
181
+ AssertionError,
182
+ ), # Too many devices
183
+ (
184
+ 0,
185
+ 8,
186
+ ["mlx5_0", "mlx5_1", "mlx5_2"],
187
+ AssertionError,
188
+ ), # GPU count not divisible by device count
189
+ (0, 8, [], RuntimeError), # No devices
190
+ ],
191
+ )
192
+ def test_get_my_rdma_device_invalid_config(
193
+ local_rank: int, gpu_count: int, devices: list[str], error: type
194
+ ):
195
+ """Test _get_my_rdma_device with invalid configuration"""
196
+ with pytest.raises(error):
197
+ _get_my_rdma_device(local_rank, gpu_count, devices)