checkpoint-engine 0.3.0rc0__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/PKG-INFO +1 -1
- checkpoint_engine-0.3.1/checkpoint_engine/__init__.py +40 -0
- checkpoint_engine-0.3.1/checkpoint_engine/__main__.py +28 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine/_version.py +3 -3
- checkpoint_engine-0.3.1/checkpoint_engine/api.py +95 -0
- checkpoint_engine-0.3.1/checkpoint_engine/data_types.py +111 -0
- checkpoint_engine-0.3.1/checkpoint_engine/p2p_store.py +210 -0
- checkpoint_engine-0.3.1/checkpoint_engine/pin_memory.py +390 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine/ps.py +85 -798
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine/worker.py +18 -9
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine.egg-info/PKG-INFO +1 -1
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine.egg-info/SOURCES.txt +7 -1
- checkpoint_engine-0.3.1/tests/test_inplace_unpin.py +81 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/tests/test_rdma_parser.py +8 -4
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/tests/test_update.py +1 -2
- checkpoint_engine-0.3.0rc0/checkpoint_engine/__init__.py +0 -4
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/.github/workflows/cpu-tests.yml +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/.github/workflows/pre-commit.yaml +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/.github/workflows/python-publish.yml +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/.gitignore +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/.pre-commit-config.yaml +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/LICENCE +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/README.md +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine/device_utils.py +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine.egg-info/requires.txt +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/checkpoint_engine.egg-info/top_level.txt +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/docs/npu_start.md +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/examples/update.py +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/figures/checkpoint-engine.png +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/figures/overlap-update-and-copy.png +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/figures/pipeline.png +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/patches/vllm_fp8.patch +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/pyproject.toml +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/setup.cfg +0 -0
- {checkpoint_engine-0.3.0rc0 → checkpoint_engine-0.3.1}/tests/test_assign_receiver_ranks.py +0 -0
- /checkpoint_engine-0.3.0rc0/tests/test_pin_memory.py → /checkpoint_engine-0.3.1/tests/test_reuse_pin_memory.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from ._version import __version__
|
|
3
|
+
except ImportError:
|
|
4
|
+
__version__ = "dev"
|
|
5
|
+
|
|
6
|
+
from .api import request_inference_to_update
|
|
7
|
+
from .data_types import (
|
|
8
|
+
BucketRange,
|
|
9
|
+
DataToGather,
|
|
10
|
+
H2DBucket,
|
|
11
|
+
MemoryBuffer,
|
|
12
|
+
MemoryBufferMetaList,
|
|
13
|
+
MemoryBufferMetas,
|
|
14
|
+
ParameterMeta,
|
|
15
|
+
)
|
|
16
|
+
from .device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
17
|
+
from .p2p_store import P2PStore
|
|
18
|
+
from .ps import ParameterServer
|
|
19
|
+
from .worker import FlattenedTensorMetadata, VllmColocateWorkerExtension, update_weights_from_ipc
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"BucketRange",
|
|
24
|
+
"DataToGather",
|
|
25
|
+
"DeviceManager",
|
|
26
|
+
"FlattenedTensorMetadata",
|
|
27
|
+
"H2DBucket",
|
|
28
|
+
"MemoryBuffer",
|
|
29
|
+
"MemoryBufferMetaList",
|
|
30
|
+
"MemoryBufferMetas",
|
|
31
|
+
"P2PStore",
|
|
32
|
+
"ParameterMeta",
|
|
33
|
+
"ParameterServer",
|
|
34
|
+
"VllmColocateWorkerExtension",
|
|
35
|
+
"__version__",
|
|
36
|
+
"get_ip",
|
|
37
|
+
"npu_generate_uuid",
|
|
38
|
+
"request_inference_to_update",
|
|
39
|
+
"update_weights_from_ipc",
|
|
40
|
+
]
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from checkpoint_engine.api import _init_api
|
|
7
|
+
from checkpoint_engine.ps import ParameterServer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@logger.catch(reraise=True)
|
|
11
|
+
def run_from_cli():
|
|
12
|
+
import uvicorn
|
|
13
|
+
|
|
14
|
+
parser = argparse.ArgumentParser(description="Parameter Server")
|
|
15
|
+
parser.add_argument("--uds", type=str)
|
|
16
|
+
|
|
17
|
+
args = parser.parse_args()
|
|
18
|
+
logger.info(
|
|
19
|
+
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
assert args.uds and len(args.uds) > 0, args.uds
|
|
23
|
+
ps = ParameterServer(auto_pg=True)
|
|
24
|
+
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
if __name__ == "__main__":
|
|
28
|
+
run_from_cli()
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.3.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 3,
|
|
31
|
+
__version__ = version = '0.3.1'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 3, 1)
|
|
33
33
|
|
|
34
|
-
__commit_id__ = commit_id = '
|
|
34
|
+
__commit_id__ = commit_id = 'g09c543af4'
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import fastapi
|
|
5
|
+
import httpx
|
|
6
|
+
from fastapi import Request
|
|
7
|
+
from fastapi.responses import JSONResponse, Response
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from checkpoint_engine.ps import ParameterServer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def request_inference_to_update(
|
|
15
|
+
url: str,
|
|
16
|
+
socket_paths: dict[str, str],
|
|
17
|
+
timeout: float = 300.0,
|
|
18
|
+
uds: str | None = None,
|
|
19
|
+
):
|
|
20
|
+
"""Send an inference update request to inference server via HTTP or Unix socket.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
|
|
24
|
+
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
|
|
25
|
+
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
|
|
26
|
+
uds (str, optional): Path to a Unix domain socket. If provided, the request
|
|
27
|
+
will be sent via the Unix socket instead of HTTP. Defaults to None.
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
httpx.HTTPStatusError: If the response contains an HTTP error status.
|
|
31
|
+
httpx.RequestError: If there was an issue while making the request.
|
|
32
|
+
"""
|
|
33
|
+
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
|
|
34
|
+
url,
|
|
35
|
+
json={
|
|
36
|
+
"method": "update_weights_from_ipc",
|
|
37
|
+
"args": [socket_paths],
|
|
38
|
+
"timeout": timeout,
|
|
39
|
+
},
|
|
40
|
+
timeout=timeout,
|
|
41
|
+
)
|
|
42
|
+
resp.raise_for_status()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _init_api(ps: ParameterServer) -> Any:
|
|
46
|
+
app = fastapi.FastAPI()
|
|
47
|
+
|
|
48
|
+
class RegisterRequest(BaseModel):
|
|
49
|
+
files: list[str]
|
|
50
|
+
|
|
51
|
+
class UpdateRequest(BaseModel):
|
|
52
|
+
ranks: list[int] = []
|
|
53
|
+
update_url: str | None = None
|
|
54
|
+
inference_group_ranks: list[int] = []
|
|
55
|
+
timeout: float = 300.0
|
|
56
|
+
uds: str | None = None
|
|
57
|
+
|
|
58
|
+
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
59
|
+
try:
|
|
60
|
+
func()
|
|
61
|
+
except Exception as e: # noqa: BLE001
|
|
62
|
+
logger.exception(f"wrap exception {func} failed")
|
|
63
|
+
return JSONResponse(content=str(e), status_code=500)
|
|
64
|
+
return Response(status_code=200)
|
|
65
|
+
|
|
66
|
+
@app.post("/v1/checkpoints/{checkpoint_name}/files")
|
|
67
|
+
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
|
|
68
|
+
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
|
|
69
|
+
|
|
70
|
+
@app.delete("/v1/checkpoints/{checkpoint_name}")
|
|
71
|
+
async def unregister_checkpoint(checkpoint_name: str) -> Response:
|
|
72
|
+
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
|
|
73
|
+
|
|
74
|
+
@app.get("/v1/healthz")
|
|
75
|
+
async def healthz() -> Response:
|
|
76
|
+
return Response(status_code=200)
|
|
77
|
+
|
|
78
|
+
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
|
|
79
|
+
async def gather_metas(checkpoint_name: str) -> Response:
|
|
80
|
+
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
|
|
81
|
+
|
|
82
|
+
@app.post("/v1/checkpoints/{checkpoint_name}/update")
|
|
83
|
+
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
|
|
84
|
+
def update_func(socket_paths: list[tuple[str, str]]):
|
|
85
|
+
if req.update_url is None:
|
|
86
|
+
return
|
|
87
|
+
if req.inference_group_ranks:
|
|
88
|
+
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
|
|
89
|
+
request_inference_to_update(
|
|
90
|
+
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
|
|
94
|
+
|
|
95
|
+
return app
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Annotated, Any, NamedTuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from typing import TypeVar
|
|
9
|
+
|
|
10
|
+
from typing_extensions import TypedDict
|
|
11
|
+
|
|
12
|
+
class FileMeta(TypedDict):
|
|
13
|
+
key: str # parameter name
|
|
14
|
+
dtype: torch.dtype
|
|
15
|
+
shape: torch.Size
|
|
16
|
+
type: type
|
|
17
|
+
tp_concat_dim: int
|
|
18
|
+
|
|
19
|
+
T = TypeVar("T")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _dt_validate(value: Any) -> torch.dtype:
|
|
23
|
+
if isinstance(value, str):
|
|
24
|
+
if not value.startswith("torch."):
|
|
25
|
+
raise ValueError(f"dtype {value} should start with torch.")
|
|
26
|
+
try:
|
|
27
|
+
value = getattr(torch, value.split(".")[1])
|
|
28
|
+
except AttributeError as e:
|
|
29
|
+
raise ValueError(f"unknown dtype: {value}") from e
|
|
30
|
+
if not isinstance(value, torch.dtype):
|
|
31
|
+
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
|
|
32
|
+
return value
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_TorchDtype = Annotated[
|
|
36
|
+
torch.dtype,
|
|
37
|
+
PlainValidator(_dt_validate),
|
|
38
|
+
PlainSerializer(lambda x: str(x), return_type=str),
|
|
39
|
+
WithJsonSchema({"type": "string"}, mode="serialization"),
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _size_validate(value: Any) -> torch.Size:
|
|
44
|
+
if isinstance(value, list | tuple):
|
|
45
|
+
return torch.Size(value)
|
|
46
|
+
if not isinstance(value, torch.Size):
|
|
47
|
+
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
|
|
48
|
+
return value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
_TorchSize = Annotated[
|
|
52
|
+
torch.Size,
|
|
53
|
+
PlainValidator(_size_validate),
|
|
54
|
+
PlainSerializer(lambda x: tuple(x), return_type=tuple),
|
|
55
|
+
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _tensor_validate(value: Any) -> torch.Tensor:
|
|
60
|
+
if isinstance(value, torch.Tensor):
|
|
61
|
+
return value
|
|
62
|
+
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
_TorchTensor = Annotated[
|
|
66
|
+
torch.Tensor,
|
|
67
|
+
PlainValidator(_tensor_validate),
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ParameterMeta(BaseModel):
|
|
72
|
+
name: str
|
|
73
|
+
dtype: _TorchDtype
|
|
74
|
+
shape: _TorchSize
|
|
75
|
+
aligned_size: int
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class BucketRange(NamedTuple):
|
|
79
|
+
idx: int # bucket_idx of MemoryBucket in memory_pool
|
|
80
|
+
offset: int
|
|
81
|
+
size: int
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class H2DBucket(BaseModel):
|
|
85
|
+
size: int
|
|
86
|
+
ranges: list[BucketRange]
|
|
87
|
+
items: list[ParameterMeta]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class MemoryBufferMetas(BaseModel):
|
|
91
|
+
metas: list[ParameterMeta]
|
|
92
|
+
ptr: int
|
|
93
|
+
size: int
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class MemoryBuffer(BaseModel):
|
|
97
|
+
buffer: _TorchTensor
|
|
98
|
+
size: int
|
|
99
|
+
metas: list[ParameterMeta]
|
|
100
|
+
manually_pinned: bool = False
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class MemoryBufferMetaList(BaseModel):
|
|
104
|
+
p2p_store_addr: str | None
|
|
105
|
+
memory_buffer_metas_list: list[MemoryBufferMetas]
|
|
106
|
+
rdma_device: str
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class DataToGather(MemoryBufferMetaList):
|
|
110
|
+
host_ip: str
|
|
111
|
+
device_uuid: str
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
import time
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
from checkpoint_engine.device_utils import DeviceManager, get_ip
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _ibv_get_device_list() -> list[str]:
|
|
13
|
+
lib = ctypes.CDLL("libibverbs.so.1")
|
|
14
|
+
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
15
|
+
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
|
|
16
|
+
|
|
17
|
+
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
|
18
|
+
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
|
|
19
|
+
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
|
|
20
|
+
|
|
21
|
+
num = ctypes.c_int()
|
|
22
|
+
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
|
|
23
|
+
if not dev_array or num.value <= 0:
|
|
24
|
+
return []
|
|
25
|
+
|
|
26
|
+
devices = []
|
|
27
|
+
for i in range(num.value):
|
|
28
|
+
dev_ptr = dev_array[i] # struct ibv_device *
|
|
29
|
+
name = lib.ibv_get_device_name(dev_ptr) # const char *
|
|
30
|
+
devices.append(name.decode())
|
|
31
|
+
lib.ibv_free_device_list(dev_array)
|
|
32
|
+
return devices
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_rdma_devices() -> list[str]:
|
|
36
|
+
"""
|
|
37
|
+
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
|
|
38
|
+
"""
|
|
39
|
+
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
|
|
40
|
+
if devices_str:
|
|
41
|
+
return devices_str.split(",")
|
|
42
|
+
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
43
|
+
hca = os.getenv("NCCL_IB_HCA", None)
|
|
44
|
+
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
48
|
+
"""
|
|
49
|
+
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
|
|
50
|
+
"""
|
|
51
|
+
if not devices:
|
|
52
|
+
raise RuntimeError("no rdma devices found")
|
|
53
|
+
try:
|
|
54
|
+
assert len(devices) <= gpu_count, (
|
|
55
|
+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
56
|
+
)
|
|
57
|
+
assert gpu_count % len(devices) == 0, (
|
|
58
|
+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
59
|
+
)
|
|
60
|
+
return devices[local_rank // (gpu_count // len(devices))]
|
|
61
|
+
except AssertionError:
|
|
62
|
+
logger.error(
|
|
63
|
+
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
|
|
64
|
+
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
|
|
65
|
+
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
|
|
66
|
+
)
|
|
67
|
+
raise
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
71
|
+
"""
|
|
72
|
+
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
|
|
73
|
+
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
|
|
74
|
+
|
|
75
|
+
The list is comma-separated; port numbers are NOT supported yet.
|
|
76
|
+
An optional prefix '^' indicates the list is an exclude list.
|
|
77
|
+
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
|
|
78
|
+
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
|
|
79
|
+
|
|
80
|
+
Examples:
|
|
81
|
+
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
|
|
82
|
+
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
|
|
83
|
+
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
|
|
84
|
+
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
|
|
85
|
+
"""
|
|
86
|
+
max_hcas = 32
|
|
87
|
+
if not value or value.strip() == "":
|
|
88
|
+
return available_devices[:max_hcas]
|
|
89
|
+
|
|
90
|
+
value = value.strip()
|
|
91
|
+
result = []
|
|
92
|
+
is_exclude = value.startswith("^")
|
|
93
|
+
if is_exclude:
|
|
94
|
+
value = value.removeprefix("^")
|
|
95
|
+
is_exact_match = value.startswith("=")
|
|
96
|
+
if is_exact_match:
|
|
97
|
+
value = value.removeprefix("=")
|
|
98
|
+
|
|
99
|
+
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
|
|
100
|
+
|
|
101
|
+
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
|
|
102
|
+
if is_exclude:
|
|
103
|
+
result = [dev for dev in available_devices if dev not in result]
|
|
104
|
+
if len(result) > max_hcas:
|
|
105
|
+
result = result[:max_hcas]
|
|
106
|
+
|
|
107
|
+
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
|
|
108
|
+
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _resolve_device_specs(
|
|
113
|
+
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
|
|
114
|
+
) -> list[str]:
|
|
115
|
+
devices = set()
|
|
116
|
+
for spec in device_specs:
|
|
117
|
+
parts = spec.split(":", 1)
|
|
118
|
+
device_name = parts[0].strip()
|
|
119
|
+
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
|
|
120
|
+
# port = parts[1].strip() if len(parts) > 1 else None
|
|
121
|
+
base_devices = (
|
|
122
|
+
[device_name]
|
|
123
|
+
if device_name in available_devices
|
|
124
|
+
else []
|
|
125
|
+
if is_exact_match
|
|
126
|
+
else [dev for dev in available_devices if dev.startswith(device_name)]
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if not base_devices:
|
|
130
|
+
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
for base_dev in base_devices:
|
|
134
|
+
devices.add(base_dev)
|
|
135
|
+
|
|
136
|
+
return sorted(devices)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class P2PStore:
|
|
140
|
+
def __init__(self, device_manager: DeviceManager):
|
|
141
|
+
from mooncake.engine import TransferEngine
|
|
142
|
+
|
|
143
|
+
self.rank = int(os.environ["RANK"]) # ENV RANK is required
|
|
144
|
+
gpu_count = device_manager.device_module.device_count()
|
|
145
|
+
local_rank = self.rank % gpu_count
|
|
146
|
+
device_type = device_manager.device_type
|
|
147
|
+
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
|
|
148
|
+
self.device = ""
|
|
149
|
+
else:
|
|
150
|
+
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
151
|
+
self.ip = get_ip()
|
|
152
|
+
|
|
153
|
+
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
154
|
+
retry_count = 8
|
|
155
|
+
for i in range(retry_count):
|
|
156
|
+
self.engine = TransferEngine()
|
|
157
|
+
ret = self.engine.initialize(
|
|
158
|
+
self.ip,
|
|
159
|
+
"P2PHANDSHAKE",
|
|
160
|
+
"ascend_direct" if device_type == "npu" else "rdma",
|
|
161
|
+
self.device,
|
|
162
|
+
)
|
|
163
|
+
if ret == 0:
|
|
164
|
+
break
|
|
165
|
+
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
166
|
+
sleep_ms = random.randint(500, 2000)
|
|
167
|
+
logger.warning(
|
|
168
|
+
f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
|
|
169
|
+
)
|
|
170
|
+
time.sleep(sleep_ms / 1000)
|
|
171
|
+
else:
|
|
172
|
+
raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
|
|
173
|
+
self.port = self.engine.get_rpc_port()
|
|
174
|
+
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
175
|
+
logger.info(
|
|
176
|
+
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
@property
|
|
180
|
+
def addr(self) -> str:
|
|
181
|
+
return f"{self.ip}:{self.port}"
|
|
182
|
+
|
|
183
|
+
def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
|
|
184
|
+
buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
|
|
185
|
+
capacities = [tensor.nbytes for tensor in named_tensors.values()]
|
|
186
|
+
self.named_tensors.update(named_tensors)
|
|
187
|
+
for i, name in enumerate(named_tensors.keys()):
|
|
188
|
+
logger.info(
|
|
189
|
+
f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
|
|
190
|
+
)
|
|
191
|
+
assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
|
|
192
|
+
|
|
193
|
+
def unregister_named_tensors(self, names: list[str]) -> int:
|
|
194
|
+
buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
|
|
195
|
+
assert self.engine.batch_unregister_memory(buffer_addresses) == 0
|
|
196
|
+
num_unregistered = 0
|
|
197
|
+
for i, name in enumerate(names):
|
|
198
|
+
del self.named_tensors[name]
|
|
199
|
+
logger.info(
|
|
200
|
+
f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
|
|
201
|
+
)
|
|
202
|
+
num_unregistered += 1
|
|
203
|
+
return num_unregistered
|
|
204
|
+
|
|
205
|
+
def batch_transfer_sync_read(
|
|
206
|
+
self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
|
|
207
|
+
):
|
|
208
|
+
assert (
|
|
209
|
+
self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
|
|
210
|
+
)
|