checkpoint-engine 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,3 +2,39 @@ try:
2
2
  from ._version import __version__
3
3
  except ImportError:
4
4
  __version__ = "dev"
5
+
6
+ from .api import request_inference_to_update
7
+ from .data_types import (
8
+ BucketRange,
9
+ DataToGather,
10
+ H2DBucket,
11
+ MemoryBuffer,
12
+ MemoryBufferMetaList,
13
+ MemoryBufferMetas,
14
+ ParameterMeta,
15
+ )
16
+ from .device_utils import DeviceManager, get_ip, npu_generate_uuid
17
+ from .p2p_store import P2PStore
18
+ from .ps import ParameterServer
19
+ from .worker import FlattenedTensorMetadata, VllmColocateWorkerExtension, update_weights_from_ipc
20
+
21
+
22
+ __all__ = [
23
+ "BucketRange",
24
+ "DataToGather",
25
+ "DeviceManager",
26
+ "FlattenedTensorMetadata",
27
+ "H2DBucket",
28
+ "MemoryBuffer",
29
+ "MemoryBufferMetaList",
30
+ "MemoryBufferMetas",
31
+ "P2PStore",
32
+ "ParameterMeta",
33
+ "ParameterServer",
34
+ "VllmColocateWorkerExtension",
35
+ "__version__",
36
+ "get_ip",
37
+ "npu_generate_uuid",
38
+ "request_inference_to_update",
39
+ "update_weights_from_ipc",
40
+ ]
@@ -0,0 +1,28 @@
1
+ import argparse
2
+ import os
3
+
4
+ from loguru import logger
5
+
6
+ from checkpoint_engine.api import _init_api
7
+ from checkpoint_engine.ps import ParameterServer
8
+
9
+
10
+ @logger.catch(reraise=True)
11
+ def run_from_cli():
12
+ import uvicorn
13
+
14
+ parser = argparse.ArgumentParser(description="Parameter Server")
15
+ parser.add_argument("--uds", type=str)
16
+
17
+ args = parser.parse_args()
18
+ logger.info(
19
+ f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
20
+ )
21
+
22
+ assert args.uds and len(args.uds) > 0, args.uds
23
+ ps = ParameterServer(auto_pg=True)
24
+ uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ run_from_cli()
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.0rc0'
32
- __version_tuple__ = version_tuple = (0, 3, 0, 'rc0')
31
+ __version__ = version = '0.3.1'
32
+ __version_tuple__ = version_tuple = (0, 3, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,95 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ import fastapi
5
+ import httpx
6
+ from fastapi import Request
7
+ from fastapi.responses import JSONResponse, Response
8
+ from loguru import logger
9
+ from pydantic import BaseModel
10
+
11
+ from checkpoint_engine.ps import ParameterServer
12
+
13
+
14
+ def request_inference_to_update(
15
+ url: str,
16
+ socket_paths: dict[str, str],
17
+ timeout: float = 300.0,
18
+ uds: str | None = None,
19
+ ):
20
+ """Send an inference update request to inference server via HTTP or Unix socket.
21
+
22
+ Args:
23
+ url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
24
+ socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
25
+ timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
26
+ uds (str, optional): Path to a Unix domain socket. If provided, the request
27
+ will be sent via the Unix socket instead of HTTP. Defaults to None.
28
+
29
+ Raises:
30
+ httpx.HTTPStatusError: If the response contains an HTTP error status.
31
+ httpx.RequestError: If there was an issue while making the request.
32
+ """
33
+ resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
34
+ url,
35
+ json={
36
+ "method": "update_weights_from_ipc",
37
+ "args": [socket_paths],
38
+ "timeout": timeout,
39
+ },
40
+ timeout=timeout,
41
+ )
42
+ resp.raise_for_status()
43
+
44
+
45
+ def _init_api(ps: ParameterServer) -> Any:
46
+ app = fastapi.FastAPI()
47
+
48
+ class RegisterRequest(BaseModel):
49
+ files: list[str]
50
+
51
+ class UpdateRequest(BaseModel):
52
+ ranks: list[int] = []
53
+ update_url: str | None = None
54
+ inference_group_ranks: list[int] = []
55
+ timeout: float = 300.0
56
+ uds: str | None = None
57
+
58
+ def wrap_exception(func: Callable[[], None]) -> Response:
59
+ try:
60
+ func()
61
+ except Exception as e: # noqa: BLE001
62
+ logger.exception(f"wrap exception {func} failed")
63
+ return JSONResponse(content=str(e), status_code=500)
64
+ return Response(status_code=200)
65
+
66
+ @app.post("/v1/checkpoints/{checkpoint_name}/files")
67
+ async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
68
+ return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
69
+
70
+ @app.delete("/v1/checkpoints/{checkpoint_name}")
71
+ async def unregister_checkpoint(checkpoint_name: str) -> Response:
72
+ return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
73
+
74
+ @app.get("/v1/healthz")
75
+ async def healthz() -> Response:
76
+ return Response(status_code=200)
77
+
78
+ @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
79
+ async def gather_metas(checkpoint_name: str) -> Response:
80
+ return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
81
+
82
+ @app.post("/v1/checkpoints/{checkpoint_name}/update")
83
+ async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
84
+ def update_func(socket_paths: list[tuple[str, str]]):
85
+ if req.update_url is None:
86
+ return
87
+ if req.inference_group_ranks:
88
+ socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
89
+ request_inference_to_update(
90
+ req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
91
+ )
92
+
93
+ return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
94
+
95
+ return app
@@ -0,0 +1,111 @@
1
+ from typing import TYPE_CHECKING, Annotated, Any, NamedTuple
2
+
3
+ import torch
4
+ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
5
+
6
+
7
+ if TYPE_CHECKING:
8
+ from typing import TypeVar
9
+
10
+ from typing_extensions import TypedDict
11
+
12
+ class FileMeta(TypedDict):
13
+ key: str # parameter name
14
+ dtype: torch.dtype
15
+ shape: torch.Size
16
+ type: type
17
+ tp_concat_dim: int
18
+
19
+ T = TypeVar("T")
20
+
21
+
22
+ def _dt_validate(value: Any) -> torch.dtype:
23
+ if isinstance(value, str):
24
+ if not value.startswith("torch."):
25
+ raise ValueError(f"dtype {value} should start with torch.")
26
+ try:
27
+ value = getattr(torch, value.split(".")[1])
28
+ except AttributeError as e:
29
+ raise ValueError(f"unknown dtype: {value}") from e
30
+ if not isinstance(value, torch.dtype):
31
+ raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
32
+ return value
33
+
34
+
35
+ _TorchDtype = Annotated[
36
+ torch.dtype,
37
+ PlainValidator(_dt_validate),
38
+ PlainSerializer(lambda x: str(x), return_type=str),
39
+ WithJsonSchema({"type": "string"}, mode="serialization"),
40
+ ]
41
+
42
+
43
+ def _size_validate(value: Any) -> torch.Size:
44
+ if isinstance(value, list | tuple):
45
+ return torch.Size(value)
46
+ if not isinstance(value, torch.Size):
47
+ raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
48
+ return value
49
+
50
+
51
+ _TorchSize = Annotated[
52
+ torch.Size,
53
+ PlainValidator(_size_validate),
54
+ PlainSerializer(lambda x: tuple(x), return_type=tuple),
55
+ WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
56
+ ]
57
+
58
+
59
+ def _tensor_validate(value: Any) -> torch.Tensor:
60
+ if isinstance(value, torch.Tensor):
61
+ return value
62
+ raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
63
+
64
+
65
+ _TorchTensor = Annotated[
66
+ torch.Tensor,
67
+ PlainValidator(_tensor_validate),
68
+ ]
69
+
70
+
71
+ class ParameterMeta(BaseModel):
72
+ name: str
73
+ dtype: _TorchDtype
74
+ shape: _TorchSize
75
+ aligned_size: int
76
+
77
+
78
+ class BucketRange(NamedTuple):
79
+ idx: int # bucket_idx of MemoryBucket in memory_pool
80
+ offset: int
81
+ size: int
82
+
83
+
84
+ class H2DBucket(BaseModel):
85
+ size: int
86
+ ranges: list[BucketRange]
87
+ items: list[ParameterMeta]
88
+
89
+
90
+ class MemoryBufferMetas(BaseModel):
91
+ metas: list[ParameterMeta]
92
+ ptr: int
93
+ size: int
94
+
95
+
96
+ class MemoryBuffer(BaseModel):
97
+ buffer: _TorchTensor
98
+ size: int
99
+ metas: list[ParameterMeta]
100
+ manually_pinned: bool = False
101
+
102
+
103
+ class MemoryBufferMetaList(BaseModel):
104
+ p2p_store_addr: str | None
105
+ memory_buffer_metas_list: list[MemoryBufferMetas]
106
+ rdma_device: str
107
+
108
+
109
+ class DataToGather(MemoryBufferMetaList):
110
+ host_ip: str
111
+ device_uuid: str
@@ -0,0 +1,210 @@
1
+ import ctypes
2
+ import os
3
+ import random
4
+ import time
5
+
6
+ import torch
7
+ from loguru import logger
8
+
9
+ from checkpoint_engine.device_utils import DeviceManager, get_ip
10
+
11
+
12
+ def _ibv_get_device_list() -> list[str]:
13
+ lib = ctypes.CDLL("libibverbs.so.1")
14
+ lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
15
+ lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
16
+
17
+ lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
18
+ lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
19
+ lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
20
+
21
+ num = ctypes.c_int()
22
+ dev_array = lib.ibv_get_device_list(ctypes.byref(num))
23
+ if not dev_array or num.value <= 0:
24
+ return []
25
+
26
+ devices = []
27
+ for i in range(num.value):
28
+ dev_ptr = dev_array[i] # struct ibv_device *
29
+ name = lib.ibv_get_device_name(dev_ptr) # const char *
30
+ devices.append(name.decode())
31
+ lib.ibv_free_device_list(dev_array)
32
+ return devices
33
+
34
+
35
+ def _get_rdma_devices() -> list[str]:
36
+ """
37
+ use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
38
+ """
39
+ devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
40
+ if devices_str:
41
+ return devices_str.split(",")
42
+ # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
43
+ hca = os.getenv("NCCL_IB_HCA", None)
44
+ return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
45
+
46
+
47
+ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
48
+ """
49
+ implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
50
+ """
51
+ if not devices:
52
+ raise RuntimeError("no rdma devices found")
53
+ try:
54
+ assert len(devices) <= gpu_count, (
55
+ f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
56
+ )
57
+ assert gpu_count % len(devices) == 0, (
58
+ f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
59
+ )
60
+ return devices[local_rank // (gpu_count // len(devices))]
61
+ except AssertionError:
62
+ logger.error(
63
+ "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
64
+ "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
65
+ "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
66
+ )
67
+ raise
68
+
69
+
70
+ def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
71
+ """
72
+ The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
73
+ The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
74
+
75
+ The list is comma-separated; port numbers are NOT supported yet.
76
+ An optional prefix '^' indicates the list is an exclude list.
77
+ A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
78
+ Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
79
+
80
+ Examples:
81
+ - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
82
+ - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
83
+ - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
84
+ - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
85
+ """
86
+ max_hcas = 32
87
+ if not value or value.strip() == "":
88
+ return available_devices[:max_hcas]
89
+
90
+ value = value.strip()
91
+ result = []
92
+ is_exclude = value.startswith("^")
93
+ if is_exclude:
94
+ value = value.removeprefix("^")
95
+ is_exact_match = value.startswith("=")
96
+ if is_exact_match:
97
+ value = value.removeprefix("=")
98
+
99
+ device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
100
+
101
+ result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
102
+ if is_exclude:
103
+ result = [dev for dev in available_devices if dev not in result]
104
+ if len(result) > max_hcas:
105
+ result = result[:max_hcas]
106
+
107
+ logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
108
+
109
+ return result
110
+
111
+
112
+ def _resolve_device_specs(
113
+ device_specs: list[str], is_exact_match: bool, available_devices: list[str]
114
+ ) -> list[str]:
115
+ devices = set()
116
+ for spec in device_specs:
117
+ parts = spec.split(":", 1)
118
+ device_name = parts[0].strip()
119
+ # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
120
+ # port = parts[1].strip() if len(parts) > 1 else None
121
+ base_devices = (
122
+ [device_name]
123
+ if device_name in available_devices
124
+ else []
125
+ if is_exact_match
126
+ else [dev for dev in available_devices if dev.startswith(device_name)]
127
+ )
128
+
129
+ if not base_devices:
130
+ logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
131
+ continue
132
+
133
+ for base_dev in base_devices:
134
+ devices.add(base_dev)
135
+
136
+ return sorted(devices)
137
+
138
+
139
+ class P2PStore:
140
+ def __init__(self, device_manager: DeviceManager):
141
+ from mooncake.engine import TransferEngine
142
+
143
+ self.rank = int(os.environ["RANK"]) # ENV RANK is required
144
+ gpu_count = device_manager.device_module.device_count()
145
+ local_rank = self.rank % gpu_count
146
+ device_type = device_manager.device_type
147
+ if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
148
+ self.device = ""
149
+ else:
150
+ self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
151
+ self.ip = get_ip()
152
+
153
+ # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
154
+ retry_count = 8
155
+ for i in range(retry_count):
156
+ self.engine = TransferEngine()
157
+ ret = self.engine.initialize(
158
+ self.ip,
159
+ "P2PHANDSHAKE",
160
+ "ascend_direct" if device_type == "npu" else "rdma",
161
+ self.device,
162
+ )
163
+ if ret == 0:
164
+ break
165
+ # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
166
+ sleep_ms = random.randint(500, 2000)
167
+ logger.warning(
168
+ f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
169
+ )
170
+ time.sleep(sleep_ms / 1000)
171
+ else:
172
+ raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
173
+ self.port = self.engine.get_rpc_port()
174
+ self.named_tensors: dict[str, torch.Tensor] = {}
175
+ logger.info(
176
+ f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
177
+ )
178
+
179
+ @property
180
+ def addr(self) -> str:
181
+ return f"{self.ip}:{self.port}"
182
+
183
+ def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
184
+ buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
185
+ capacities = [tensor.nbytes for tensor in named_tensors.values()]
186
+ self.named_tensors.update(named_tensors)
187
+ for i, name in enumerate(named_tensors.keys()):
188
+ logger.info(
189
+ f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
190
+ )
191
+ assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
192
+
193
+ def unregister_named_tensors(self, names: list[str]) -> int:
194
+ buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
195
+ assert self.engine.batch_unregister_memory(buffer_addresses) == 0
196
+ num_unregistered = 0
197
+ for i, name in enumerate(names):
198
+ del self.named_tensors[name]
199
+ logger.info(
200
+ f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
201
+ )
202
+ num_unregistered += 1
203
+ return num_unregistered
204
+
205
+ def batch_transfer_sync_read(
206
+ self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
207
+ ):
208
+ assert (
209
+ self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
210
+ )