checkpoint-engine 0.2.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/PKG-INFO +54 -4
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/README.md +53 -3
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine/_version.py +3 -3
- checkpoint_engine-0.2.1/checkpoint_engine/device_utils.py +86 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine/ps.py +138 -103
- checkpoint_engine-0.2.1/checkpoint_engine/worker.py +165 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine.egg-info/PKG-INFO +54 -4
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine.egg-info/SOURCES.txt +2 -0
- checkpoint_engine-0.2.1/docs/npu_start.md +91 -0
- checkpoint_engine-0.2.1/tests/test_update.py +234 -0
- checkpoint_engine-0.2.0/checkpoint_engine/worker.py +0 -109
- checkpoint_engine-0.2.0/tests/test_update.py +0 -90
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/.github/workflows/cpu-tests.yml +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/.github/workflows/pre-commit.yaml +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/.github/workflows/python-publish.yml +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/.gitignore +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/.pre-commit-config.yaml +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/LICENCE +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine/__init__.py +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine.egg-info/requires.txt +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/checkpoint_engine.egg-info/top_level.txt +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/examples/update.py +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/figures/checkpoint-engine.png +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/figures/overlap-update-and-copy.png +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/figures/pipeline.png +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/patches/vllm_fp8.patch +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/pyproject.toml +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/setup.cfg +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/tests/test_assign_receiver_ranks.py +0 -0
- {checkpoint_engine-0.2.0 → checkpoint_engine-0.2.1}/tests/test_rdma_parser.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -169,13 +169,63 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
|
|
|
169
169
|
Run a simple correctness test for checkpoint_engine
|
|
170
170
|
|
|
171
171
|
```bash
|
|
172
|
-
|
|
172
|
+
pytest tests/test_update.py
|
|
173
173
|
```
|
|
174
174
|
|
|
175
|
-
|
|
175
|
+
`test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
|
|
176
|
+
|
|
177
|
+
Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
|
|
178
|
+
|
|
179
|
+
```bash
|
|
180
|
+
pytest tests/ -m "not gpu"
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
## SGLang Integration
|
|
184
|
+
|
|
185
|
+
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
186
|
+
|
|
187
|
+
### Quick Start
|
|
188
|
+
|
|
189
|
+
**1. Install checkpoint-engine:**
|
|
190
|
+
```bash
|
|
191
|
+
pip install 'checkpoint-engine[p2p]'
|
|
192
|
+
```
|
|
193
|
+
|
|
194
|
+
**2. Launch SGLang server:**
|
|
195
|
+
```bash
|
|
196
|
+
python -m sglang.launch_server \
|
|
197
|
+
--model-path $MODEL_PATH \
|
|
198
|
+
--tp 8 \
|
|
199
|
+
--load-format dummy \
|
|
200
|
+
--wait-for-initial-weights
|
|
201
|
+
```
|
|
202
|
+
|
|
203
|
+
**3. Run checkpoint engine:**
|
|
204
|
+
```bash
|
|
205
|
+
python -m sglang.srt.checkpoint_engine.update \
|
|
206
|
+
--update-method broadcast \
|
|
207
|
+
--checkpoint-path $MODEL_PATH \
|
|
208
|
+
--inference-parallel-size 8
|
|
209
|
+
```
|
|
210
|
+
|
|
211
|
+
### Multi-Node Setup
|
|
212
|
+
|
|
213
|
+
For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
|
|
214
|
+
|
|
215
|
+
### Key Options
|
|
216
|
+
|
|
217
|
+
**SGLang Server:**
|
|
218
|
+
- `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
|
|
219
|
+
- `--load-format dummy`: Enable overlapping initialization tasks
|
|
220
|
+
|
|
221
|
+
**Checkpoint Engine:**
|
|
222
|
+
- `--update-method`: Choose `broadcast`, `p2p`, or `all`
|
|
223
|
+
- `--inference-parallel-size`: Number of parallel processes
|
|
224
|
+
- `--checkpoint-path`: Model checkpoint directory
|
|
225
|
+
|
|
176
226
|
## Limitations and Future Work
|
|
177
227
|
|
|
178
|
-
- This project is currently
|
|
228
|
+
- This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
|
|
179
229
|
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
|
|
180
230
|
|
|
181
231
|
## Acknowledgments
|
|
@@ -145,13 +145,63 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
|
|
|
145
145
|
Run a simple correctness test for checkpoint_engine
|
|
146
146
|
|
|
147
147
|
```bash
|
|
148
|
-
|
|
148
|
+
pytest tests/test_update.py
|
|
149
149
|
```
|
|
150
150
|
|
|
151
|
-
|
|
151
|
+
`test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
|
|
152
|
+
|
|
153
|
+
Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
|
|
154
|
+
|
|
155
|
+
```bash
|
|
156
|
+
pytest tests/ -m "not gpu"
|
|
157
|
+
```
|
|
158
|
+
|
|
159
|
+
## SGLang Integration
|
|
160
|
+
|
|
161
|
+
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
162
|
+
|
|
163
|
+
### Quick Start
|
|
164
|
+
|
|
165
|
+
**1. Install checkpoint-engine:**
|
|
166
|
+
```bash
|
|
167
|
+
pip install 'checkpoint-engine[p2p]'
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
**2. Launch SGLang server:**
|
|
171
|
+
```bash
|
|
172
|
+
python -m sglang.launch_server \
|
|
173
|
+
--model-path $MODEL_PATH \
|
|
174
|
+
--tp 8 \
|
|
175
|
+
--load-format dummy \
|
|
176
|
+
--wait-for-initial-weights
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
**3. Run checkpoint engine:**
|
|
180
|
+
```bash
|
|
181
|
+
python -m sglang.srt.checkpoint_engine.update \
|
|
182
|
+
--update-method broadcast \
|
|
183
|
+
--checkpoint-path $MODEL_PATH \
|
|
184
|
+
--inference-parallel-size 8
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
### Multi-Node Setup
|
|
188
|
+
|
|
189
|
+
For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
|
|
190
|
+
|
|
191
|
+
### Key Options
|
|
192
|
+
|
|
193
|
+
**SGLang Server:**
|
|
194
|
+
- `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
|
|
195
|
+
- `--load-format dummy`: Enable overlapping initialization tasks
|
|
196
|
+
|
|
197
|
+
**Checkpoint Engine:**
|
|
198
|
+
- `--update-method`: Choose `broadcast`, `p2p`, or `all`
|
|
199
|
+
- `--inference-parallel-size`: Number of parallel processes
|
|
200
|
+
- `--checkpoint-path`: Model checkpoint directory
|
|
201
|
+
|
|
152
202
|
## Limitations and Future Work
|
|
153
203
|
|
|
154
|
-
- This project is currently
|
|
204
|
+
- This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
|
|
155
205
|
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
|
|
156
206
|
|
|
157
207
|
## Acknowledgments
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 1)
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'g279a908a9'
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import socket
|
|
4
|
+
import subprocess
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=1)
|
|
12
|
+
def get_ip() -> str:
|
|
13
|
+
try:
|
|
14
|
+
# try to get ip from network interface
|
|
15
|
+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
16
|
+
s.connect(("8.8.8.8", 80))
|
|
17
|
+
return s.getsockname()[0]
|
|
18
|
+
except Exception as e: # noqa: BLE001
|
|
19
|
+
# fallback to get ip from hostname
|
|
20
|
+
logger.warning(
|
|
21
|
+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
22
|
+
)
|
|
23
|
+
return socket.gethostbyname(socket.gethostname())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def npu_generate_uuid() -> str:
|
|
27
|
+
str_pid = str(os.getpid())
|
|
28
|
+
npu_num = 8
|
|
29
|
+
try:
|
|
30
|
+
for npu_id in range(npu_num):
|
|
31
|
+
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
|
|
32
|
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
|
|
33
|
+
str_result = str(result.stdout)
|
|
34
|
+
if str_pid in str_result:
|
|
35
|
+
# In A3 server, one NPU has two chips.
|
|
36
|
+
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
|
|
37
|
+
chip_count = int(match_chip_count.group(1))
|
|
38
|
+
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
|
|
39
|
+
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
|
|
40
|
+
chip_id = int(match_chip_id.group(1))
|
|
41
|
+
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
|
|
42
|
+
raise ValueError("The current process is not running on the npu device")
|
|
43
|
+
except subprocess.CalledProcessError as e:
|
|
44
|
+
raise ValueError("The current process is not running on the npu device") from e
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DeviceManager:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.device_type = self._detect_device_type()
|
|
50
|
+
self._setup_device_module()
|
|
51
|
+
|
|
52
|
+
def _is_torch_npu_available(self) -> bool:
|
|
53
|
+
try:
|
|
54
|
+
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
|
|
55
|
+
return torch.npu.is_available()
|
|
56
|
+
else:
|
|
57
|
+
return False
|
|
58
|
+
except ImportError:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def _detect_device_type(self) -> str:
|
|
62
|
+
if self._is_torch_npu_available():
|
|
63
|
+
return "npu"
|
|
64
|
+
elif torch.cuda.is_available():
|
|
65
|
+
return "cuda"
|
|
66
|
+
else:
|
|
67
|
+
raise TypeError("The current device type is not supported")
|
|
68
|
+
|
|
69
|
+
def _setup_device_module(self):
|
|
70
|
+
if self.device_type == "npu":
|
|
71
|
+
import torch_npu
|
|
72
|
+
|
|
73
|
+
self.device_module = torch_npu.npu
|
|
74
|
+
elif self.device_type == "cuda":
|
|
75
|
+
self.device_module = torch.cuda
|
|
76
|
+
else:
|
|
77
|
+
raise TypeError("The current device type is not supported")
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def backend(self) -> str:
|
|
81
|
+
if self.device_type == "npu":
|
|
82
|
+
return "hccl"
|
|
83
|
+
elif self.device_type == "cuda":
|
|
84
|
+
return "nccl"
|
|
85
|
+
else:
|
|
86
|
+
raise TypeError("The current device type is not supported")
|
|
@@ -4,13 +4,11 @@ import ctypes
|
|
|
4
4
|
import os
|
|
5
5
|
import pickle
|
|
6
6
|
import random
|
|
7
|
-
import socket
|
|
8
7
|
import threading
|
|
9
8
|
import time
|
|
10
9
|
from collections import defaultdict
|
|
11
10
|
from collections.abc import Callable
|
|
12
11
|
from datetime import timedelta
|
|
13
|
-
from functools import lru_cache
|
|
14
12
|
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
|
|
15
13
|
|
|
16
14
|
import httpx
|
|
@@ -23,6 +21,8 @@ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
|
23
21
|
from safetensors.torch import safe_open
|
|
24
22
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
25
23
|
|
|
24
|
+
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
25
|
+
|
|
26
26
|
|
|
27
27
|
if TYPE_CHECKING:
|
|
28
28
|
from typing import TypeVar
|
|
@@ -254,28 +254,16 @@ def _concat_tp_weights(
|
|
|
254
254
|
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
255
255
|
|
|
256
256
|
|
|
257
|
-
def _get_physical_gpu_id(device_index: int | None = None) -> str:
|
|
257
|
+
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
|
|
258
258
|
try:
|
|
259
|
-
|
|
259
|
+
if device_manager.device_type == "npu":
|
|
260
|
+
return f"NPU-{npu_generate_uuid()}"
|
|
261
|
+
else:
|
|
262
|
+
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
|
|
260
263
|
except AssertionError as e:
|
|
261
264
|
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
262
265
|
|
|
263
266
|
|
|
264
|
-
@lru_cache(maxsize=1)
|
|
265
|
-
def _get_ip() -> str:
|
|
266
|
-
try:
|
|
267
|
-
# try to get ip from network interface
|
|
268
|
-
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
269
|
-
s.connect(("8.8.8.8", 80))
|
|
270
|
-
return s.getsockname()[0]
|
|
271
|
-
except Exception as e: # noqa: BLE001
|
|
272
|
-
# fallback to get ip from hostname
|
|
273
|
-
logger.warning(
|
|
274
|
-
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
275
|
-
)
|
|
276
|
-
return socket.gethostbyname(socket.gethostname())
|
|
277
|
-
|
|
278
|
-
|
|
279
267
|
def _ibv_get_device_list() -> list[str]:
|
|
280
268
|
lib = ctypes.CDLL("libibverbs.so.1")
|
|
281
269
|
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
@@ -317,13 +305,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
|
|
|
317
305
|
"""
|
|
318
306
|
if not devices:
|
|
319
307
|
raise RuntimeError("no rdma devices found")
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
308
|
+
try:
|
|
309
|
+
assert len(devices) <= gpu_count, (
|
|
310
|
+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
311
|
+
)
|
|
312
|
+
assert gpu_count % len(devices) == 0, (
|
|
313
|
+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
314
|
+
)
|
|
315
|
+
return devices[local_rank // (gpu_count // len(devices))]
|
|
316
|
+
except AssertionError:
|
|
317
|
+
logger.error(
|
|
318
|
+
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
|
|
319
|
+
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
|
|
320
|
+
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
|
|
321
|
+
)
|
|
322
|
+
raise
|
|
327
323
|
|
|
328
324
|
|
|
329
325
|
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
@@ -677,20 +673,29 @@ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, i
|
|
|
677
673
|
|
|
678
674
|
|
|
679
675
|
class P2PStore:
|
|
680
|
-
def __init__(self):
|
|
676
|
+
def __init__(self, device_manager: DeviceManager):
|
|
681
677
|
from mooncake.engine import TransferEngine
|
|
682
678
|
|
|
683
679
|
self.rank = int(os.getenv("RANK"))
|
|
684
|
-
gpu_count =
|
|
680
|
+
gpu_count = device_manager.device_module.device_count()
|
|
685
681
|
local_rank = self.rank % gpu_count
|
|
686
|
-
|
|
687
|
-
|
|
682
|
+
device_type = device_manager.device_type
|
|
683
|
+
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
|
|
684
|
+
self.device = ""
|
|
685
|
+
else:
|
|
686
|
+
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
687
|
+
self.ip = get_ip()
|
|
688
688
|
|
|
689
689
|
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
690
690
|
retry_count = 8
|
|
691
691
|
for i in range(retry_count):
|
|
692
692
|
self.engine = TransferEngine()
|
|
693
|
-
ret = self.engine.initialize(
|
|
693
|
+
ret = self.engine.initialize(
|
|
694
|
+
self.ip,
|
|
695
|
+
"P2PHANDSHAKE",
|
|
696
|
+
"ascend_direct" if device_type == "npu" else "rdma",
|
|
697
|
+
self.device,
|
|
698
|
+
)
|
|
694
699
|
if ret == 0:
|
|
695
700
|
break
|
|
696
701
|
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
@@ -757,11 +762,12 @@ class ParameterServer:
|
|
|
757
762
|
Args:
|
|
758
763
|
auto_pg: Whether to automatically initialize the process group.
|
|
759
764
|
Notice that if auto_pg is True, will destroy the process group after update.
|
|
760
|
-
mem_fraction: The proportion (as a fraction) of the current free
|
|
765
|
+
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
761
766
|
"""
|
|
762
767
|
self._rank = rank or int(os.environ.get("RANK", None))
|
|
763
768
|
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
|
|
764
|
-
self.
|
|
769
|
+
self.device_manager = DeviceManager()
|
|
770
|
+
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
|
|
765
771
|
self._local_rank = self._rank % self._gpu_count
|
|
766
772
|
self._auto_pg = auto_pg
|
|
767
773
|
self._all_hosts = []
|
|
@@ -775,7 +781,7 @@ class ParameterServer:
|
|
|
775
781
|
assert (
|
|
776
782
|
self._gpu_count is not None
|
|
777
783
|
and self._gpu_count > 0
|
|
778
|
-
and self._gpu_count <=
|
|
784
|
+
and self._gpu_count <= self.device_manager.device_module.device_count()
|
|
779
785
|
), self._gpu_count
|
|
780
786
|
assert (
|
|
781
787
|
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
|
|
@@ -787,15 +793,16 @@ class ParameterServer:
|
|
|
787
793
|
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
788
794
|
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
789
795
|
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
|
|
796
|
+
# NPU transfer engine initialization requires prior set_device.
|
|
797
|
+
device_index = self._local_rank
|
|
798
|
+
self.device_manager.device_module.set_device(device_index)
|
|
790
799
|
try:
|
|
791
|
-
self._p2p_store = P2PStore()
|
|
800
|
+
self._p2p_store = P2PStore(self.device_manager)
|
|
792
801
|
except ImportError as e:
|
|
793
802
|
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
|
|
794
803
|
self._p2p_store = None
|
|
795
804
|
|
|
796
|
-
|
|
797
|
-
torch.cuda.set_device(device_index)
|
|
798
|
-
self._device_uuid = _get_physical_gpu_id(device_index)
|
|
805
|
+
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
|
|
799
806
|
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
800
807
|
|
|
801
808
|
def _logger_rank0(self, msg: str):
|
|
@@ -885,13 +892,15 @@ class ParameterServer:
|
|
|
885
892
|
for x in self._memory_pool.get(checkpoint_name, [])
|
|
886
893
|
],
|
|
887
894
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
888
|
-
host_ip=
|
|
895
|
+
host_ip=get_ip(),
|
|
889
896
|
device_uuid=self._device_uuid,
|
|
890
897
|
rdma_device=self._rdma_device or "",
|
|
891
898
|
)
|
|
892
899
|
|
|
893
900
|
dist.all_gather_object(metas_lst, metas)
|
|
894
901
|
|
|
902
|
+
self._current_global_parameter_metas = {}
|
|
903
|
+
|
|
895
904
|
num_parameters = 0
|
|
896
905
|
all_hosts: list[str] = []
|
|
897
906
|
global_device_uuids: list[str] = []
|
|
@@ -948,7 +957,7 @@ class ParameterServer:
|
|
|
948
957
|
is_master=self._rank == 0,
|
|
949
958
|
)
|
|
950
959
|
dist.init_process_group(
|
|
951
|
-
backend=
|
|
960
|
+
backend=self.device_manager.backend,
|
|
952
961
|
world_size=self._world_size,
|
|
953
962
|
rank=self._rank,
|
|
954
963
|
timeout=timeout,
|
|
@@ -991,21 +1000,22 @@ class ParameterServer:
|
|
|
991
1000
|
if self._rank not in ranks:
|
|
992
1001
|
return
|
|
993
1002
|
self._update_per_bucket(checkpoint_name, req_func, ranks)
|
|
994
|
-
if self._auto_pg:
|
|
995
|
-
dist.destroy_process_group()
|
|
996
|
-
|
|
997
|
-
torch.cuda.empty_cache()
|
|
998
1003
|
|
|
999
|
-
logger.info(
|
|
1000
|
-
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
1001
|
-
f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
|
|
1002
|
-
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
|
|
1003
|
-
)
|
|
1004
1004
|
except Exception as e:
|
|
1005
1005
|
logger.exception(
|
|
1006
1006
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
1007
1007
|
)
|
|
1008
1008
|
raise
|
|
1009
|
+
finally:
|
|
1010
|
+
if self._auto_pg and (not ranks or self._rank in ranks):
|
|
1011
|
+
dist.destroy_process_group()
|
|
1012
|
+
|
|
1013
|
+
self.device_manager.device_module.empty_cache()
|
|
1014
|
+
logger.info(
|
|
1015
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
1016
|
+
f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
|
|
1017
|
+
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
|
|
1018
|
+
)
|
|
1009
1019
|
|
|
1010
1020
|
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
|
|
1011
1021
|
def zmq_handle(device_uuid: str) -> str:
|
|
@@ -1022,14 +1032,16 @@ class ParameterServer:
|
|
|
1022
1032
|
# auto detect bucket size
|
|
1023
1033
|
tensor = torch.tensor(
|
|
1024
1034
|
[
|
|
1025
|
-
# proportion of current
|
|
1026
|
-
int(
|
|
1035
|
+
# proportion of current device free memory bytes
|
|
1036
|
+
int(
|
|
1037
|
+
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
|
|
1038
|
+
),
|
|
1027
1039
|
# we use negative value to reuse allreduce min operation
|
|
1028
1040
|
# for getting the max value of zmq_addr_counter in all ranks
|
|
1029
1041
|
-self._zmq_addr_counter,
|
|
1030
1042
|
],
|
|
1031
1043
|
dtype=torch.int64,
|
|
1032
|
-
device=
|
|
1044
|
+
device=self.device_manager.device_type,
|
|
1033
1045
|
)
|
|
1034
1046
|
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
1035
1047
|
tensor = tensor.cpu()
|
|
@@ -1092,7 +1104,7 @@ class ParameterServer:
|
|
|
1092
1104
|
assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
|
|
1093
1105
|
if owner_rank is not None:
|
|
1094
1106
|
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
|
|
1095
|
-
|
|
1107
|
+
self.device_manager.device_module.synchronize()
|
|
1096
1108
|
|
|
1097
1109
|
def init_process_group_for_ranks(
|
|
1098
1110
|
self,
|
|
@@ -1132,7 +1144,11 @@ class ParameterServer:
|
|
|
1132
1144
|
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
1133
1145
|
)
|
|
1134
1146
|
dist.init_process_group(
|
|
1135
|
-
backend=
|
|
1147
|
+
backend=self.device_manager.backend,
|
|
1148
|
+
world_size=len(ranks),
|
|
1149
|
+
rank=rank,
|
|
1150
|
+
timeout=timeout,
|
|
1151
|
+
store=store,
|
|
1136
1152
|
)
|
|
1137
1153
|
|
|
1138
1154
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
@@ -1184,7 +1200,7 @@ class ParameterServer:
|
|
|
1184
1200
|
|
|
1185
1201
|
if not need_update:
|
|
1186
1202
|
return
|
|
1187
|
-
# first execute a barrier to avoid subsequent
|
|
1203
|
+
# first execute a barrier to avoid subsequent device oom
|
|
1188
1204
|
dist.barrier()
|
|
1189
1205
|
|
|
1190
1206
|
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
@@ -1199,7 +1215,7 @@ class ParameterServer:
|
|
|
1199
1215
|
h2d_buffer: torch.Tensor | None = (
|
|
1200
1216
|
None
|
|
1201
1217
|
if disable_h2d_buffer
|
|
1202
|
-
else torch.empty(bucket_size, dtype=torch.uint8, device=
|
|
1218
|
+
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
|
|
1203
1219
|
)
|
|
1204
1220
|
# p2p store need to register h2d_buffer to let other ranks read
|
|
1205
1221
|
if ranks:
|
|
@@ -1212,7 +1228,9 @@ class ParameterServer:
|
|
|
1212
1228
|
continue
|
|
1213
1229
|
receiver_rank_buckets.append((owner_rank, bucket))
|
|
1214
1230
|
|
|
1215
|
-
buffer = torch.empty(
|
|
1231
|
+
buffer = torch.empty(
|
|
1232
|
+
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
|
|
1233
|
+
)
|
|
1216
1234
|
handle = reduce_tensor(buffer)
|
|
1217
1235
|
|
|
1218
1236
|
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
|
|
@@ -1231,52 +1249,66 @@ class ParameterServer:
|
|
|
1231
1249
|
socket.send_pyobj(handle)
|
|
1232
1250
|
|
|
1233
1251
|
gidx = 0
|
|
1252
|
+
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
|
|
1234
1253
|
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
if
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1254
|
+
try:
|
|
1255
|
+
for i in range(max_len):
|
|
1256
|
+
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
1257
|
+
self._copy_to_buffer(
|
|
1258
|
+
checkpoint_name,
|
|
1259
|
+
receiver_rank_buckets[i][1],
|
|
1260
|
+
h2d_buffer,
|
|
1261
|
+
receiver_rank_buckets[i][0] if ranks else None,
|
|
1262
|
+
)
|
|
1263
|
+
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
|
|
1264
|
+
if i >= len(_buckets):
|
|
1265
|
+
continue
|
|
1266
|
+
bucket = _buckets[i]
|
|
1267
|
+
alloc, reserved = (
|
|
1268
|
+
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
|
|
1269
|
+
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
|
|
1270
|
+
)
|
|
1271
|
+
self._logger_rank0(
|
|
1272
|
+
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
1273
|
+
f"Current device allocated {alloc:.2f} MB, "
|
|
1274
|
+
f"reserved {reserved:.2f} MB."
|
|
1275
|
+
)
|
|
1276
|
+
start = gidx % 2 * bucket_size
|
|
1277
|
+
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
|
|
1278
|
+
if receiver_rank == self._rank:
|
|
1279
|
+
if disable_h2d_buffer:
|
|
1280
|
+
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1281
|
+
else:
|
|
1282
|
+
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1283
|
+
brank = bcast_rank_map[receiver_rank]
|
|
1284
|
+
dist.broadcast(buffer_b, src=brank)
|
|
1285
|
+
resp = socket.recv()
|
|
1286
|
+
if resp != b"":
|
|
1287
|
+
exception_obj = pickle.loads(resp)
|
|
1288
|
+
logger.error(
|
|
1289
|
+
f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
|
|
1290
|
+
)
|
|
1291
|
+
ret_code.fill_(1)
|
|
1292
|
+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
|
1293
|
+
self.device_manager.device_module.synchronize()
|
|
1294
|
+
if ret_code.item() != 0:
|
|
1295
|
+
# quit early if any rank failed
|
|
1296
|
+
socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
|
|
1297
|
+
raise RuntimeError("Failed to update weights due to remote errors")
|
|
1298
|
+
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
1299
|
+
gidx += 1
|
|
1300
|
+
|
|
1301
|
+
socket.recv()
|
|
1302
|
+
socket.send_pyobj(None)
|
|
1303
|
+
socket.recv()
|
|
1304
|
+
finally:
|
|
1305
|
+
req_thread.join()
|
|
1306
|
+
dist.barrier()
|
|
1307
|
+
socket.close()
|
|
1308
|
+
if ranks and h2d_buffer is not None:
|
|
1309
|
+
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
1310
|
+
|
|
1311
|
+
self.device_manager.device_module.empty_cache()
|
|
1280
1312
|
|
|
1281
1313
|
|
|
1282
1314
|
def _init_api(ps: ParameterServer) -> Any:
|
|
@@ -1294,6 +1326,7 @@ def _init_api(ps: ParameterServer) -> Any:
|
|
|
1294
1326
|
update_url: str | None = None
|
|
1295
1327
|
inference_group_ranks: list[int] = []
|
|
1296
1328
|
timeout: float = 300.0
|
|
1329
|
+
uds: str | None = None
|
|
1297
1330
|
|
|
1298
1331
|
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
1299
1332
|
try:
|
|
@@ -1326,7 +1359,9 @@ def _init_api(ps: ParameterServer) -> Any:
|
|
|
1326
1359
|
return
|
|
1327
1360
|
if req.inference_group_ranks:
|
|
1328
1361
|
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
|
|
1329
|
-
request_inference_to_update(
|
|
1362
|
+
request_inference_to_update(
|
|
1363
|
+
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
|
|
1364
|
+
)
|
|
1330
1365
|
|
|
1331
1366
|
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
|
|
1332
1367
|
|