checkpoint-engine 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/__init__.py +36 -0
- checkpoint_engine/__main__.py +28 -0
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/api.py +95 -0
- checkpoint_engine/data_types.py +111 -0
- checkpoint_engine/p2p_store.py +210 -0
- checkpoint_engine/pin_memory.py +390 -0
- checkpoint_engine/ps.py +85 -798
- checkpoint_engine/worker.py +18 -9
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.1.dist-info}/METADATA +1 -1
- checkpoint_engine-0.3.1.dist-info/RECORD +15 -0
- checkpoint_engine-0.3.0rc0.dist-info/RECORD +0 -10
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.1.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.1.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.0rc0.dist-info → checkpoint_engine-0.3.1.dist-info}/top_level.txt +0 -0
checkpoint_engine/ps.py
CHANGED
|
@@ -1,142 +1,33 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
import concurrent.futures
|
|
3
1
|
import ctypes
|
|
4
|
-
import json
|
|
5
2
|
import os
|
|
6
|
-
import pickle
|
|
7
|
-
import random
|
|
8
3
|
import threading
|
|
9
|
-
import time
|
|
10
4
|
from collections import defaultdict
|
|
11
5
|
from collections.abc import Callable
|
|
12
6
|
from datetime import timedelta
|
|
13
|
-
from typing import TYPE_CHECKING
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
14
8
|
|
|
15
|
-
import httpx
|
|
16
|
-
import numpy as np
|
|
17
9
|
import torch
|
|
18
10
|
import torch.distributed as dist
|
|
19
11
|
import zmq
|
|
20
12
|
from loguru import logger
|
|
21
|
-
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
22
|
-
from safetensors.torch import _getdtype, safe_open
|
|
23
13
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
24
14
|
|
|
15
|
+
from checkpoint_engine.data_types import (
|
|
16
|
+
BucketRange,
|
|
17
|
+
DataToGather,
|
|
18
|
+
H2DBucket,
|
|
19
|
+
MemoryBuffer,
|
|
20
|
+
MemoryBufferMetaList,
|
|
21
|
+
MemoryBufferMetas,
|
|
22
|
+
ParameterMeta,
|
|
23
|
+
)
|
|
25
24
|
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
25
|
+
from checkpoint_engine.p2p_store import P2PStore
|
|
26
|
+
from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
if TYPE_CHECKING:
|
|
29
|
-
from
|
|
30
|
-
|
|
31
|
-
from typing_extensions import TypedDict
|
|
32
|
-
|
|
33
|
-
class FileMeta(TypedDict):
|
|
34
|
-
key: str # parameter name
|
|
35
|
-
dtype: torch.dtype
|
|
36
|
-
shape: torch.Size
|
|
37
|
-
type: type
|
|
38
|
-
tp_concat_dim: int
|
|
39
|
-
|
|
40
|
-
T = TypeVar("T")
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def _dt_validate(value: Any) -> torch.dtype:
|
|
44
|
-
if isinstance(value, str):
|
|
45
|
-
if not value.startswith("torch."):
|
|
46
|
-
raise ValueError(f"dtype {value} should start with torch.")
|
|
47
|
-
try:
|
|
48
|
-
value = getattr(torch, value.split(".")[1])
|
|
49
|
-
except AttributeError as e:
|
|
50
|
-
raise ValueError(f"unknown dtype: {value}") from e
|
|
51
|
-
if not isinstance(value, torch.dtype):
|
|
52
|
-
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
|
|
53
|
-
return value
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
_TorchDtype = Annotated[
|
|
57
|
-
torch.dtype,
|
|
58
|
-
PlainValidator(_dt_validate),
|
|
59
|
-
PlainSerializer(lambda x: str(x), return_type=str),
|
|
60
|
-
WithJsonSchema({"type": "string"}, mode="serialization"),
|
|
61
|
-
]
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def _size_validate(value: Any) -> torch.Size:
|
|
65
|
-
if isinstance(value, list | tuple):
|
|
66
|
-
return torch.Size(value)
|
|
67
|
-
if not isinstance(value, torch.Size):
|
|
68
|
-
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
|
|
69
|
-
return value
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
_TorchSize = Annotated[
|
|
73
|
-
torch.Size,
|
|
74
|
-
PlainValidator(_size_validate),
|
|
75
|
-
PlainSerializer(lambda x: tuple(x), return_type=tuple),
|
|
76
|
-
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
|
|
77
|
-
]
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def _tensor_validate(value: Any) -> torch.Tensor:
|
|
81
|
-
if isinstance(value, torch.Tensor):
|
|
82
|
-
return value
|
|
83
|
-
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
_TorchTensor = Annotated[
|
|
87
|
-
torch.Tensor,
|
|
88
|
-
PlainValidator(_tensor_validate),
|
|
89
|
-
]
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class ParameterMeta(BaseModel):
|
|
93
|
-
name: str
|
|
94
|
-
dtype: _TorchDtype
|
|
95
|
-
shape: _TorchSize
|
|
96
|
-
aligned_size: int
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
class BucketRange(NamedTuple):
|
|
100
|
-
idx: int # bucket_idx of MemoryBucket in memory_pool
|
|
101
|
-
offset: int
|
|
102
|
-
size: int
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
class H2DBucket(BaseModel):
|
|
106
|
-
size: int
|
|
107
|
-
ranges: list[BucketRange]
|
|
108
|
-
items: list[ParameterMeta]
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class MemoryBufferMetas(BaseModel):
|
|
112
|
-
metas: list[ParameterMeta]
|
|
113
|
-
ptr: int
|
|
114
|
-
size: int
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class MemoryBuffer(BaseModel):
|
|
118
|
-
buffer: _TorchTensor
|
|
119
|
-
size: int
|
|
120
|
-
metas: list[ParameterMeta]
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
class MemoryBufferMetaList(BaseModel):
|
|
124
|
-
p2p_store_addr: str | None
|
|
125
|
-
memory_buffer_metas_list: list[MemoryBufferMetas]
|
|
126
|
-
rdma_device: str
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
class DataToGather(MemoryBufferMetaList):
|
|
130
|
-
host_ip: str
|
|
131
|
-
device_uuid: str
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
# 256 bytes alignment when flatten torch tensors to uint8 buffer
|
|
135
|
-
_ALIGN_SIZE = 256
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
139
|
-
return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
|
|
30
|
+
from checkpoint_engine.data_types import T
|
|
140
31
|
|
|
141
32
|
|
|
142
33
|
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
@@ -155,107 +46,6 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
|
155
46
|
return ret
|
|
156
47
|
|
|
157
48
|
|
|
158
|
-
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
|
|
159
|
-
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
160
|
-
ret = {}
|
|
161
|
-
with safe_open(fn, framework="pt") as f:
|
|
162
|
-
for name in f.keys(): # noqa: SIM118
|
|
163
|
-
weight = f.get_tensor(name)
|
|
164
|
-
meta = {
|
|
165
|
-
"key": name,
|
|
166
|
-
"dtype": weight.dtype,
|
|
167
|
-
"shape": weight.shape,
|
|
168
|
-
"type": type(weight),
|
|
169
|
-
"tp_concat_dim": -1, # safetensors does not support tp_concat_dim
|
|
170
|
-
}
|
|
171
|
-
ret[name] = (meta, weight)
|
|
172
|
-
return ret
|
|
173
|
-
|
|
174
|
-
# deprecated, will be removed in the future
|
|
175
|
-
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
176
|
-
"""load *.np file and return memmap and related tensor meta"""
|
|
177
|
-
|
|
178
|
-
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
|
|
179
|
-
start = fin.tell()
|
|
180
|
-
major, minor = np.lib.format.read_magic(fin)
|
|
181
|
-
if major == 1 and minor == 0:
|
|
182
|
-
read_header_fn = np.lib.format.read_array_header_1_0
|
|
183
|
-
elif major == 2 and minor == 0:
|
|
184
|
-
read_header_fn = np.lib.format.read_array_header_2_0
|
|
185
|
-
else:
|
|
186
|
-
raise ValueError(
|
|
187
|
-
f"unknown version {major}.{minor} when parsing npy header from {fn}"
|
|
188
|
-
)
|
|
189
|
-
shape, is_fortran, dtype = read_header_fn(fin)
|
|
190
|
-
return {
|
|
191
|
-
"shape": shape,
|
|
192
|
-
"is_fortran": is_fortran,
|
|
193
|
-
"dtype": dtype,
|
|
194
|
-
"header_length": fin.tell() - start,
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
meta_fn = fn + ".meta"
|
|
198
|
-
with open(meta_fn, "rb") as fin:
|
|
199
|
-
meta_lst = pickle.load(fin)
|
|
200
|
-
|
|
201
|
-
tensors = []
|
|
202
|
-
offset = 0
|
|
203
|
-
with open(fn, "rb") as fin:
|
|
204
|
-
fin.seek(0, os.SEEK_END)
|
|
205
|
-
filesize = fin.tell()
|
|
206
|
-
fin.seek(0)
|
|
207
|
-
while fin.tell() < filesize:
|
|
208
|
-
tensor_meta = parse_npy_header(fin)
|
|
209
|
-
tensor = np.memmap(
|
|
210
|
-
fn,
|
|
211
|
-
dtype=tensor_meta["dtype"],
|
|
212
|
-
mode="c",
|
|
213
|
-
offset=offset + tensor_meta["header_length"],
|
|
214
|
-
shape=tensor_meta["shape"],
|
|
215
|
-
)
|
|
216
|
-
offset += tensor_meta["header_length"] + tensor.nbytes
|
|
217
|
-
fin.seek(offset)
|
|
218
|
-
tensors.append(tensor)
|
|
219
|
-
|
|
220
|
-
assert len(meta_lst) == len(tensors)
|
|
221
|
-
ret = {}
|
|
222
|
-
for meta, tensor in zip(meta_lst, tensors):
|
|
223
|
-
if meta["type"] == torch.Tensor:
|
|
224
|
-
tensor = torch.from_numpy(tensor)
|
|
225
|
-
tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
|
|
226
|
-
ret[meta["key"]] = (meta, tensor)
|
|
227
|
-
return ret
|
|
228
|
-
|
|
229
|
-
tp_rank = 0
|
|
230
|
-
if file_path.endswith(".npy"):
|
|
231
|
-
logger.warning("numpy model file is deprecated, will be removed in the future")
|
|
232
|
-
filename_split = os.path.basename(file_path).split(".")
|
|
233
|
-
# if using numpy and want to specify tp rank
|
|
234
|
-
# file should be in model.{layer}.{tp}[.{ep}].npy format
|
|
235
|
-
tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
|
|
236
|
-
ret = _fast_np_load(file_path)
|
|
237
|
-
elif file_path.endswith(".safetensors"):
|
|
238
|
-
ret = _safetensors_load(file_path)
|
|
239
|
-
else:
|
|
240
|
-
raise ValueError(f"unsupported file format: {file_path}")
|
|
241
|
-
return tp_rank, ret
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
def _concat_tp_weights(
|
|
245
|
-
tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
|
|
246
|
-
) -> torch.Tensor:
|
|
247
|
-
"""Concat tp weights with meta info.
|
|
248
|
-
If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
|
|
249
|
-
Else we will cat weights in concat_dim.
|
|
250
|
-
"""
|
|
251
|
-
if tp_concat_dim == -1:
|
|
252
|
-
return tp_weights[0]
|
|
253
|
-
assert tp_size == len(tp_weights)
|
|
254
|
-
if len(tp_weights) == 1:
|
|
255
|
-
return tp_weights[0]
|
|
256
|
-
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
257
|
-
|
|
258
|
-
|
|
259
49
|
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
|
|
260
50
|
try:
|
|
261
51
|
if device_manager.device_type == "npu":
|
|
@@ -266,420 +56,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None
|
|
|
266
56
|
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
267
57
|
|
|
268
58
|
|
|
269
|
-
def _ibv_get_device_list() -> list[str]:
|
|
270
|
-
lib = ctypes.CDLL("libibverbs.so.1")
|
|
271
|
-
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
272
|
-
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
|
|
273
|
-
|
|
274
|
-
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
|
275
|
-
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
|
|
276
|
-
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
|
|
277
|
-
|
|
278
|
-
num = ctypes.c_int()
|
|
279
|
-
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
|
|
280
|
-
if not dev_array or num.value <= 0:
|
|
281
|
-
return []
|
|
282
|
-
|
|
283
|
-
devices = []
|
|
284
|
-
for i in range(num.value):
|
|
285
|
-
dev_ptr = dev_array[i] # struct ibv_device *
|
|
286
|
-
name = lib.ibv_get_device_name(dev_ptr) # const char *
|
|
287
|
-
devices.append(name.decode())
|
|
288
|
-
lib.ibv_free_device_list(dev_array)
|
|
289
|
-
return devices
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
def _get_rdma_devices() -> list[str]:
|
|
293
|
-
"""
|
|
294
|
-
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
|
|
295
|
-
"""
|
|
296
|
-
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
|
|
297
|
-
if devices_str:
|
|
298
|
-
return devices_str.split(",")
|
|
299
|
-
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
300
|
-
hca = os.getenv("NCCL_IB_HCA", None)
|
|
301
|
-
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
305
|
-
"""
|
|
306
|
-
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
|
|
307
|
-
"""
|
|
308
|
-
if not devices:
|
|
309
|
-
raise RuntimeError("no rdma devices found")
|
|
310
|
-
try:
|
|
311
|
-
assert len(devices) <= gpu_count, (
|
|
312
|
-
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
313
|
-
)
|
|
314
|
-
assert gpu_count % len(devices) == 0, (
|
|
315
|
-
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
316
|
-
)
|
|
317
|
-
return devices[local_rank // (gpu_count // len(devices))]
|
|
318
|
-
except AssertionError:
|
|
319
|
-
logger.error(
|
|
320
|
-
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
|
|
321
|
-
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
|
|
322
|
-
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
|
|
323
|
-
)
|
|
324
|
-
raise
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
328
|
-
"""
|
|
329
|
-
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
|
|
330
|
-
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
|
|
331
|
-
|
|
332
|
-
The list is comma-separated; port numbers are NOT supported yet.
|
|
333
|
-
An optional prefix '^' indicates the list is an exclude list.
|
|
334
|
-
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
|
|
335
|
-
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
|
|
336
|
-
|
|
337
|
-
Examples:
|
|
338
|
-
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
|
|
339
|
-
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
|
|
340
|
-
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
|
|
341
|
-
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
|
|
342
|
-
"""
|
|
343
|
-
max_hcas = 32
|
|
344
|
-
if not value or value.strip() == "":
|
|
345
|
-
return available_devices[:max_hcas]
|
|
346
|
-
|
|
347
|
-
value = value.strip()
|
|
348
|
-
result = []
|
|
349
|
-
is_exclude = value.startswith("^")
|
|
350
|
-
if is_exclude:
|
|
351
|
-
value = value.removeprefix("^")
|
|
352
|
-
is_exact_match = value.startswith("=")
|
|
353
|
-
if is_exact_match:
|
|
354
|
-
value = value.removeprefix("=")
|
|
355
|
-
|
|
356
|
-
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
|
|
357
|
-
|
|
358
|
-
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
|
|
359
|
-
if is_exclude:
|
|
360
|
-
result = [dev for dev in available_devices if dev not in result]
|
|
361
|
-
if len(result) > max_hcas:
|
|
362
|
-
result = result[:max_hcas]
|
|
363
|
-
|
|
364
|
-
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
|
|
365
|
-
|
|
366
|
-
return result
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
def _resolve_device_specs(
|
|
370
|
-
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
|
|
371
|
-
) -> list[str]:
|
|
372
|
-
devices = set()
|
|
373
|
-
for spec in device_specs:
|
|
374
|
-
parts = spec.split(":", 1)
|
|
375
|
-
device_name = parts[0].strip()
|
|
376
|
-
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
|
|
377
|
-
# port = parts[1].strip() if len(parts) > 1 else None
|
|
378
|
-
base_devices = (
|
|
379
|
-
[device_name]
|
|
380
|
-
if device_name in available_devices
|
|
381
|
-
else []
|
|
382
|
-
if is_exact_match
|
|
383
|
-
else [dev for dev in available_devices if dev.startswith(device_name)]
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
if not base_devices:
|
|
387
|
-
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
|
|
388
|
-
continue
|
|
389
|
-
|
|
390
|
-
for base_dev in base_devices:
|
|
391
|
-
devices.add(base_dev)
|
|
392
|
-
|
|
393
|
-
return sorted(devices)
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
397
|
-
class TPMeta(BaseModel):
|
|
398
|
-
concat_dim: int
|
|
399
|
-
size: int
|
|
400
|
-
|
|
401
|
-
parameters: dict[str, torch.Tensor] = {}
|
|
402
|
-
parameter_metas: dict[str, ParameterMeta] = {}
|
|
403
|
-
tp_metas: dict[str, TPMeta] = {}
|
|
404
|
-
parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
|
|
405
|
-
for file in files:
|
|
406
|
-
tp_rank, ret = _load_checkpoint_file(file)
|
|
407
|
-
for parameter_name, (meta, weight) in ret.items():
|
|
408
|
-
if parameter_name not in parameters_with_tp:
|
|
409
|
-
parameters_with_tp[parameter_name] = {}
|
|
410
|
-
parameters_with_tp[parameter_name][tp_rank] = weight
|
|
411
|
-
if parameter_name not in tp_metas:
|
|
412
|
-
tp_metas[parameter_name] = TPMeta(
|
|
413
|
-
concat_dim=meta["tp_concat_dim"],
|
|
414
|
-
size=1,
|
|
415
|
-
)
|
|
416
|
-
if parameter_name not in parameter_metas:
|
|
417
|
-
assert isinstance(meta["dtype"], torch.dtype), (
|
|
418
|
-
f"meta {meta} dtype should be torch.dtype"
|
|
419
|
-
)
|
|
420
|
-
assert isinstance(meta["shape"], torch.Size), (
|
|
421
|
-
f"meta {meta} shape should be torch.Size"
|
|
422
|
-
)
|
|
423
|
-
parameter_metas[parameter_name] = ParameterMeta(
|
|
424
|
-
name=parameter_name,
|
|
425
|
-
shape=meta["shape"],
|
|
426
|
-
dtype=meta["dtype"],
|
|
427
|
-
aligned_size=_align_size(meta["dtype"], meta["shape"]),
|
|
428
|
-
)
|
|
429
|
-
tp_meta = tp_metas[parameter_name]
|
|
430
|
-
if tp_meta.concat_dim != -1:
|
|
431
|
-
tp_meta.size = max(tp_meta.size, tp_rank + 1)
|
|
432
|
-
for name, tp_meta in tp_metas.items():
|
|
433
|
-
if tp_meta.concat_dim != -1:
|
|
434
|
-
shape = list(parameter_metas[name].shape)
|
|
435
|
-
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
436
|
-
parameter_metas[name] = ParameterMeta(
|
|
437
|
-
name=name,
|
|
438
|
-
shape=torch.Size(shape),
|
|
439
|
-
dtype=parameter_metas[name].dtype,
|
|
440
|
-
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
|
|
441
|
-
)
|
|
442
|
-
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
443
|
-
# TODO: here concat is serial, which may be slow
|
|
444
|
-
# but since tp storage is not used in the future
|
|
445
|
-
# we ignore this performance issue for now
|
|
446
|
-
parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
|
|
447
|
-
for name, parameter in parameters.items():
|
|
448
|
-
assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
|
|
449
|
-
assert parameter_metas[name].shape == parameter.shape, (
|
|
450
|
-
f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
|
|
451
|
-
)
|
|
452
|
-
assert parameter_metas[name].dtype == parameter.dtype, (
|
|
453
|
-
f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
|
|
454
|
-
)
|
|
455
|
-
return parameters
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
459
|
-
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
460
|
-
"""
|
|
461
|
-
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
462
|
-
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
|
|
463
|
-
The actual tensor data is in the remaining bytes and is naturally aligned.
|
|
464
|
-
We pin the remaining bytes as the buffer, making pinning faster.
|
|
465
|
-
"""
|
|
466
|
-
|
|
467
|
-
def _pin(t: torch.Tensor):
|
|
468
|
-
"""
|
|
469
|
-
Pin the memory of tensor in-place.
|
|
470
|
-
See: https://github.com/pytorch/pytorch/issues/32167
|
|
471
|
-
"""
|
|
472
|
-
cudart = torch.cuda.cudart()
|
|
473
|
-
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
474
|
-
assert r == 0, f"pin memory error, error code: {r}"
|
|
475
|
-
|
|
476
|
-
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
477
|
-
size = os.stat(file_path).st_size
|
|
478
|
-
flag_size = 8
|
|
479
|
-
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
|
|
480
|
-
assert t.nbytes > flag_size, (
|
|
481
|
-
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
|
|
482
|
-
)
|
|
483
|
-
start_pos = (
|
|
484
|
-
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
|
|
485
|
-
+ flag_size
|
|
486
|
-
)
|
|
487
|
-
header_tensor = t[flag_size:start_pos]
|
|
488
|
-
header = json.loads(header_tensor.numpy().tobytes())
|
|
489
|
-
if "__metadata__" in header:
|
|
490
|
-
header.pop("__metadata__")
|
|
491
|
-
|
|
492
|
-
metas: list[ParameterMeta] = []
|
|
493
|
-
offset = 0
|
|
494
|
-
try:
|
|
495
|
-
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
|
|
496
|
-
start, end = meta["data_offsets"]
|
|
497
|
-
# safetensors format ensures offsets are aligned
|
|
498
|
-
assert offset == start, f"offset {offset} should be equal to start {start}"
|
|
499
|
-
metas.append(
|
|
500
|
-
ParameterMeta(
|
|
501
|
-
name=name,
|
|
502
|
-
dtype=_getdtype(meta["dtype"]),
|
|
503
|
-
shape=torch.Size(meta["shape"]),
|
|
504
|
-
aligned_size=end - start,
|
|
505
|
-
)
|
|
506
|
-
)
|
|
507
|
-
offset = end
|
|
508
|
-
except Exception as e:
|
|
509
|
-
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
|
|
510
|
-
raise
|
|
511
|
-
|
|
512
|
-
buffer = t[start_pos:]
|
|
513
|
-
assert offset == buffer.nbytes, (
|
|
514
|
-
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
|
|
515
|
-
)
|
|
516
|
-
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
517
|
-
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
518
|
-
os.remove(file_path)
|
|
519
|
-
_pin(buffer)
|
|
520
|
-
logger.info(
|
|
521
|
-
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
522
|
-
)
|
|
523
|
-
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
|
|
524
|
-
|
|
525
|
-
memory_buffers: list[MemoryBuffer] = []
|
|
526
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
527
|
-
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
|
|
528
|
-
return memory_buffers
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
def _normal_pin_memory(
|
|
532
|
-
files: list[str],
|
|
533
|
-
named_tensors: dict[str, torch.Tensor],
|
|
534
|
-
rank: int | None = None,
|
|
535
|
-
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
536
|
-
) -> list[MemoryBuffer]:
|
|
537
|
-
parameters = _load_checkpoint(files)
|
|
538
|
-
if named_tensors:
|
|
539
|
-
parameters.update(named_tensors)
|
|
540
|
-
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
|
|
541
|
-
|
|
542
|
-
class MemoryBucket(BaseModel):
|
|
543
|
-
size: int
|
|
544
|
-
metas: list[ParameterMeta]
|
|
545
|
-
|
|
546
|
-
buckets: list[MemoryBucket] = []
|
|
547
|
-
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
548
|
-
for name, tensor in sorted(parameters.items()):
|
|
549
|
-
size = _align_size(tensor.dtype, tensor.shape)
|
|
550
|
-
if buckets[-1].size + size > bucket_size:
|
|
551
|
-
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
|
|
552
|
-
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
553
|
-
buckets[-1].metas.append(
|
|
554
|
-
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
|
|
555
|
-
)
|
|
556
|
-
buckets[-1].size += size
|
|
557
|
-
|
|
558
|
-
memory_buffers = [
|
|
559
|
-
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
|
|
560
|
-
for bucket in buckets
|
|
561
|
-
]
|
|
562
|
-
|
|
563
|
-
def register_pin_memory(
|
|
564
|
-
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
|
|
565
|
-
) -> tuple[int, torch.Tensor]:
|
|
566
|
-
if shared_pin_memory:
|
|
567
|
-
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
|
|
568
|
-
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
|
|
569
|
-
assert idx < len(shared_pin_memory), (
|
|
570
|
-
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
|
|
571
|
-
)
|
|
572
|
-
assert shared_pin_memory[idx].size == size, (
|
|
573
|
-
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
|
|
574
|
-
)
|
|
575
|
-
return idx, shared_pin_memory[idx].buffer
|
|
576
|
-
else:
|
|
577
|
-
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
578
|
-
return idx, buffer
|
|
579
|
-
|
|
580
|
-
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
|
|
581
|
-
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
582
|
-
|
|
583
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
584
|
-
futures = [
|
|
585
|
-
executor.submit(
|
|
586
|
-
register_pin_memory,
|
|
587
|
-
idx,
|
|
588
|
-
bucket.size,
|
|
589
|
-
shared_pin_memory,
|
|
590
|
-
)
|
|
591
|
-
for idx, bucket in enumerate(buckets)
|
|
592
|
-
]
|
|
593
|
-
new_futures = []
|
|
594
|
-
for future in concurrent.futures.as_completed(futures):
|
|
595
|
-
idx, buffer = future.result()
|
|
596
|
-
assert buffer.numel() == buckets[idx].size, (
|
|
597
|
-
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
|
|
598
|
-
)
|
|
599
|
-
memory_buffers[idx].buffer = buffer
|
|
600
|
-
logger.info(
|
|
601
|
-
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
|
|
602
|
-
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
|
|
603
|
-
)
|
|
604
|
-
offset = 0
|
|
605
|
-
for meta in buckets[idx].metas:
|
|
606
|
-
name = meta.name
|
|
607
|
-
tensor = parameters[name]
|
|
608
|
-
size = _align_size(tensor.dtype, tensor.shape)
|
|
609
|
-
assert size == _align_size(meta.dtype, meta.shape), (
|
|
610
|
-
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
|
|
611
|
-
)
|
|
612
|
-
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
|
|
613
|
-
offset += size
|
|
614
|
-
for future in concurrent.futures.as_completed(new_futures):
|
|
615
|
-
future.result()
|
|
616
|
-
return memory_buffers
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
def _register_checkpoint(
|
|
620
|
-
*,
|
|
621
|
-
files: list[str],
|
|
622
|
-
named_tensors: dict[str, torch.Tensor],
|
|
623
|
-
rank: int | None = None,
|
|
624
|
-
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
625
|
-
) -> list[MemoryBuffer]:
|
|
626
|
-
logger.info(
|
|
627
|
-
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
628
|
-
)
|
|
629
|
-
if not files and not named_tensors:
|
|
630
|
-
return []
|
|
631
|
-
memory_buffers: list[MemoryBuffer] = []
|
|
632
|
-
files_to_inplace_pin = [
|
|
633
|
-
file
|
|
634
|
-
for file in files
|
|
635
|
-
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
636
|
-
]
|
|
637
|
-
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
638
|
-
if files_to_normal_pin or named_tensors:
|
|
639
|
-
memory_buffers.extend(
|
|
640
|
-
_normal_pin_memory(
|
|
641
|
-
files=files_to_normal_pin,
|
|
642
|
-
named_tensors=named_tensors,
|
|
643
|
-
rank=rank,
|
|
644
|
-
shared_pin_memory=shared_pin_memory,
|
|
645
|
-
)
|
|
646
|
-
)
|
|
647
|
-
if files_to_inplace_pin:
|
|
648
|
-
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
|
|
649
|
-
return memory_buffers
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
def request_inference_to_update(
|
|
653
|
-
url: str,
|
|
654
|
-
socket_paths: dict[str, str],
|
|
655
|
-
timeout: float = 300.0,
|
|
656
|
-
uds: str | None = None,
|
|
657
|
-
):
|
|
658
|
-
"""Send an inference update request to inference server via HTTP or Unix socket.
|
|
659
|
-
|
|
660
|
-
Args:
|
|
661
|
-
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
|
|
662
|
-
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
|
|
663
|
-
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
|
|
664
|
-
uds (str, optional): Path to a Unix domain socket. If provided, the request
|
|
665
|
-
will be sent via the Unix socket instead of HTTP. Defaults to None.
|
|
666
|
-
|
|
667
|
-
Raises:
|
|
668
|
-
httpx.HTTPStatusError: If the response contains an HTTP error status.
|
|
669
|
-
httpx.RequestError: If there was an issue while making the request.
|
|
670
|
-
"""
|
|
671
|
-
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
|
|
672
|
-
url,
|
|
673
|
-
json={
|
|
674
|
-
"method": "update_weights_from_ipc",
|
|
675
|
-
"args": [socket_paths],
|
|
676
|
-
"timeout": timeout,
|
|
677
|
-
},
|
|
678
|
-
timeout=timeout,
|
|
679
|
-
)
|
|
680
|
-
resp.raise_for_status()
|
|
681
|
-
|
|
682
|
-
|
|
683
59
|
def _gen_h2d_buckets(
|
|
684
60
|
global_metas: dict[int, MemoryBufferMetaList],
|
|
685
61
|
bucket_size: int,
|
|
@@ -782,84 +158,12 @@ def _get_master_port(master_port: int | None = None) -> int:
|
|
|
782
158
|
if master_port is None:
|
|
783
159
|
# HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
|
|
784
160
|
# TODO: check whether master_port is available or use a more elegant way
|
|
785
|
-
|
|
161
|
+
master_port_str = os.getenv("MASTER_PORT")
|
|
162
|
+
assert master_port_str, "MASTER_PORT is required if no master_port is provided."
|
|
163
|
+
master_port = int(master_port_str) + 1
|
|
786
164
|
return master_port
|
|
787
165
|
|
|
788
166
|
|
|
789
|
-
class P2PStore:
|
|
790
|
-
def __init__(self, device_manager: DeviceManager):
|
|
791
|
-
from mooncake.engine import TransferEngine
|
|
792
|
-
|
|
793
|
-
self.rank = int(os.getenv("RANK"))
|
|
794
|
-
gpu_count = device_manager.device_module.device_count()
|
|
795
|
-
local_rank = self.rank % gpu_count
|
|
796
|
-
device_type = device_manager.device_type
|
|
797
|
-
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
|
|
798
|
-
self.device = ""
|
|
799
|
-
else:
|
|
800
|
-
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
801
|
-
self.ip = get_ip()
|
|
802
|
-
|
|
803
|
-
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
804
|
-
retry_count = 8
|
|
805
|
-
for i in range(retry_count):
|
|
806
|
-
self.engine = TransferEngine()
|
|
807
|
-
ret = self.engine.initialize(
|
|
808
|
-
self.ip,
|
|
809
|
-
"P2PHANDSHAKE",
|
|
810
|
-
"ascend_direct" if device_type == "npu" else "rdma",
|
|
811
|
-
self.device,
|
|
812
|
-
)
|
|
813
|
-
if ret == 0:
|
|
814
|
-
break
|
|
815
|
-
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
816
|
-
sleep_ms = random.randint(500, 2000)
|
|
817
|
-
logger.warning(
|
|
818
|
-
f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
|
|
819
|
-
)
|
|
820
|
-
time.sleep(sleep_ms / 1000)
|
|
821
|
-
else:
|
|
822
|
-
raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
|
|
823
|
-
self.port = self.engine.get_rpc_port()
|
|
824
|
-
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
825
|
-
logger.info(
|
|
826
|
-
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
|
|
827
|
-
)
|
|
828
|
-
|
|
829
|
-
@property
|
|
830
|
-
def addr(self) -> str:
|
|
831
|
-
return f"{self.ip}:{self.port}"
|
|
832
|
-
|
|
833
|
-
def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
|
|
834
|
-
buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
|
|
835
|
-
capacities = [tensor.nbytes for tensor in named_tensors.values()]
|
|
836
|
-
self.named_tensors.update(named_tensors)
|
|
837
|
-
for i, name in enumerate(named_tensors.keys()):
|
|
838
|
-
logger.info(
|
|
839
|
-
f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
|
|
840
|
-
)
|
|
841
|
-
assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
|
|
842
|
-
|
|
843
|
-
def unregister_named_tensors(self, names: list[str]) -> int:
|
|
844
|
-
buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
|
|
845
|
-
assert self.engine.batch_unregister_memory(buffer_addresses) == 0
|
|
846
|
-
num_unregistered = 0
|
|
847
|
-
for i, name in enumerate(names):
|
|
848
|
-
del self.named_tensors[name]
|
|
849
|
-
logger.info(
|
|
850
|
-
f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
|
|
851
|
-
)
|
|
852
|
-
num_unregistered += 1
|
|
853
|
-
return num_unregistered
|
|
854
|
-
|
|
855
|
-
def batch_transfer_sync_read(
|
|
856
|
-
self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
|
|
857
|
-
):
|
|
858
|
-
assert (
|
|
859
|
-
self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
|
|
860
|
-
)
|
|
861
|
-
|
|
862
|
-
|
|
863
167
|
class ParameterServer:
|
|
864
168
|
shared_memory_pool_name = "__shared_memory_pool__"
|
|
865
169
|
|
|
@@ -868,7 +172,7 @@ class ParameterServer:
|
|
|
868
172
|
*,
|
|
869
173
|
rank: int | None = None,
|
|
870
174
|
world_size: int | None = None,
|
|
871
|
-
auto_pg: bool =
|
|
175
|
+
auto_pg: bool = True,
|
|
872
176
|
gpu_count: int | None = None,
|
|
873
177
|
mem_fraction: float | None = None,
|
|
874
178
|
):
|
|
@@ -877,11 +181,11 @@ class ParameterServer:
|
|
|
877
181
|
|
|
878
182
|
Args:
|
|
879
183
|
auto_pg: Whether to automatically initialize the process group.
|
|
880
|
-
Notice that if auto_pg is True, will destroy the process group after update.
|
|
184
|
+
Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
|
|
881
185
|
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
882
186
|
"""
|
|
883
|
-
self._rank = rank or int(os.environ
|
|
884
|
-
self._world_size = world_size or int(os.environ
|
|
187
|
+
self._rank = rank or int(os.environ["RANK"])
|
|
188
|
+
self._world_size = world_size or int(os.environ["WORLD_SIZE"])
|
|
885
189
|
self.device_manager = DeviceManager()
|
|
886
190
|
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
|
|
887
191
|
self._local_rank = self._rank % self._gpu_count
|
|
@@ -890,7 +194,7 @@ class ParameterServer:
|
|
|
890
194
|
self._global_device_uuids: list[str] = []
|
|
891
195
|
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
892
196
|
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
893
|
-
self._mem_fraction = mem_fraction or 0.9
|
|
197
|
+
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
|
|
894
198
|
|
|
895
199
|
assert self._rank is not None and self._rank >= 0, self._rank
|
|
896
200
|
assert self._world_size and self._world_size > 0, self._world_size
|
|
@@ -959,11 +263,12 @@ class ParameterServer:
|
|
|
959
263
|
files: list[str] | None = None,
|
|
960
264
|
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
961
265
|
use_shared_memory_pool: bool = False,
|
|
266
|
+
use_inplace_pin_memory: bool = True,
|
|
962
267
|
) -> None:
|
|
963
268
|
"""
|
|
964
269
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
965
|
-
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
966
|
-
Please make sure to copy the files to disks if you need to keep them.
|
|
270
|
+
Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
271
|
+
Please make sure to copy the files to disks if you need to keep them. NPU does not support inplace pin memory.
|
|
967
272
|
|
|
968
273
|
Args:
|
|
969
274
|
checkpoint_name: The name of the checkpoint.
|
|
@@ -974,7 +279,14 @@ class ParameterServer:
|
|
|
974
279
|
cannot accommodate checkpoints with different memory requirements.
|
|
975
280
|
To free the actual memory of the shared pool or to modify its shape,
|
|
976
281
|
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
282
|
+
use_inplace_pin_memory: If True (default), allows inplace pin memory for /dev/shm/ safetensors files.
|
|
283
|
+
This option is ignored when ``use_shared_memory_pool`` is True.
|
|
977
284
|
"""
|
|
285
|
+
if self.device_manager.device_type != "cuda" and use_inplace_pin_memory:
|
|
286
|
+
logger.warning(
|
|
287
|
+
f"[rank{self._rank}] Only cuda devices support in-place pin memory, set use_inplace_pin_memory to False"
|
|
288
|
+
)
|
|
289
|
+
use_inplace_pin_memory = False
|
|
978
290
|
try:
|
|
979
291
|
if use_shared_memory_pool:
|
|
980
292
|
logger.info(
|
|
@@ -993,6 +305,7 @@ class ParameterServer:
|
|
|
993
305
|
named_tensors=named_tensors or {},
|
|
994
306
|
rank=self._rank,
|
|
995
307
|
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
|
|
308
|
+
inplace_pin=False, # inplace pin memory is not compatible with shared memory pool
|
|
996
309
|
)
|
|
997
310
|
self._current_shared_memory_pool_user = checkpoint_name
|
|
998
311
|
if self._p2p_store is not None and _is_first_time:
|
|
@@ -1002,7 +315,10 @@ class ParameterServer:
|
|
|
1002
315
|
f"checkpoint {checkpoint_name} already registered"
|
|
1003
316
|
)
|
|
1004
317
|
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
1005
|
-
files=files or [],
|
|
318
|
+
files=files or [],
|
|
319
|
+
named_tensors=named_tensors or {},
|
|
320
|
+
rank=self._rank,
|
|
321
|
+
inplace_pin=use_inplace_pin_memory,
|
|
1006
322
|
)
|
|
1007
323
|
if self._p2p_store is not None:
|
|
1008
324
|
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
@@ -1048,6 +364,46 @@ class ParameterServer:
|
|
|
1048
364
|
del self._memory_pool[self.shared_memory_pool_name]
|
|
1049
365
|
self._memory_pool[self.shared_memory_pool_name] = []
|
|
1050
366
|
else:
|
|
367
|
+
|
|
368
|
+
def _unpin(t: torch.Tensor):
|
|
369
|
+
"""
|
|
370
|
+
Un-pin the pinned memory.
|
|
371
|
+
"""
|
|
372
|
+
p_flags = ctypes.c_uint()
|
|
373
|
+
try:
|
|
374
|
+
libc = ctypes.CDLL(None) # get all symbols from the current process
|
|
375
|
+
cuda_host_get_flags = libc.cudaHostGetFlags
|
|
376
|
+
cuda_host_get_flags.argtypes = [ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p]
|
|
377
|
+
cuda_host_get_flags.restype = ctypes.c_int
|
|
378
|
+
except AttributeError:
|
|
379
|
+
logger.error("cudaHostGetFlags not found in libc, cannot unpin memory manually")
|
|
380
|
+
raise
|
|
381
|
+
r = cuda_host_get_flags(ctypes.byref(p_flags), ctypes.c_void_p(t.data_ptr()))
|
|
382
|
+
assert r == 0, f"get pin flags error, error code: {r}"
|
|
383
|
+
# p_flags value meaning from cuda/include/driver_types.h
|
|
384
|
+
# cudaHostRegisterDefault 0x00 /**< Default host memory registration flag */
|
|
385
|
+
# cudaHostRegisterPortable 0x01 /**< Pinned memory accessible by all CUDA contexts */
|
|
386
|
+
# cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */
|
|
387
|
+
# cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */
|
|
388
|
+
# cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */
|
|
389
|
+
assert p_flags.value == 0x02, (
|
|
390
|
+
f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
|
|
391
|
+
)
|
|
392
|
+
cudart = torch.cuda.cudart()
|
|
393
|
+
r = cudart.cudaHostUnregister(t.data_ptr())
|
|
394
|
+
assert r == 0, f"unpin memory error, error code: {r}"
|
|
395
|
+
|
|
396
|
+
# if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
|
|
397
|
+
try:
|
|
398
|
+
for memory_buffer in self._memory_pool.get(checkpoint_name, []):
|
|
399
|
+
if memory_buffer.manually_pinned:
|
|
400
|
+
_unpin(memory_buffer.buffer)
|
|
401
|
+
except Exception as e:
|
|
402
|
+
logger.error(
|
|
403
|
+
f"[rank{self._rank}] fail to unpin memory for checkpoint {checkpoint_name}: {e}"
|
|
404
|
+
)
|
|
405
|
+
raise
|
|
406
|
+
# we won't delete the memory pool if unpinning fails.
|
|
1051
407
|
del self._memory_pool[checkpoint_name]
|
|
1052
408
|
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
1053
409
|
# this works by using torch>=2.5.0
|
|
@@ -1183,6 +539,8 @@ class ParameterServer:
|
|
|
1183
539
|
) -> None:
|
|
1184
540
|
"""
|
|
1185
541
|
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
542
|
+
Warning: if _auto_pg is False when initializing ParameterServer, please make sure ALL ranks in the WORLD_SIZE call `update` function,
|
|
543
|
+
otherwise, it will hang.
|
|
1186
544
|
|
|
1187
545
|
Args:
|
|
1188
546
|
checkpoint_name: The name of the checkpoint.
|
|
@@ -1217,7 +575,7 @@ class ParameterServer:
|
|
|
1217
575
|
is_master=self._rank == 0,
|
|
1218
576
|
)
|
|
1219
577
|
# if ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1220
|
-
ranks_group = dist.new_group(ranks if ranks else None
|
|
578
|
+
ranks_group = dist.new_group(ranks) if ranks else None
|
|
1221
579
|
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
|
|
1222
580
|
self.store_based_barrier(manager_store)
|
|
1223
581
|
except Exception as e:
|
|
@@ -1248,7 +606,7 @@ class ParameterServer:
|
|
|
1248
606
|
return socket, socket_paths
|
|
1249
607
|
|
|
1250
608
|
def _detect_bucket_size(
|
|
1251
|
-
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
|
|
609
|
+
self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
|
|
1252
610
|
) -> tuple[int, bool]:
|
|
1253
611
|
GiB = 1 << 30 # noqa: N806
|
|
1254
612
|
# auto detect bucket size
|
|
@@ -1291,7 +649,7 @@ class ParameterServer:
|
|
|
1291
649
|
f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
|
|
1292
650
|
)
|
|
1293
651
|
disable_h2d_buffer = True
|
|
1294
|
-
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
|
|
652
|
+
max_bytes = int(float(os.getenv("PS_MAX_BUCKET_SIZE_GB", "8")) * GiB)
|
|
1295
653
|
bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
|
|
1296
654
|
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
|
|
1297
655
|
return bucket_size, disable_h2d_buffer
|
|
@@ -1367,7 +725,7 @@ class ParameterServer:
|
|
|
1367
725
|
self,
|
|
1368
726
|
checkpoint_name: str,
|
|
1369
727
|
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1370
|
-
ranks_group: dist.ProcessGroup,
|
|
728
|
+
ranks_group: dist.ProcessGroup | None,
|
|
1371
729
|
ranks: list[int] | None = None,
|
|
1372
730
|
):
|
|
1373
731
|
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
@@ -1498,79 +856,8 @@ class ParameterServer:
|
|
|
1498
856
|
self.device_manager.device_module.empty_cache()
|
|
1499
857
|
|
|
1500
858
|
|
|
1501
|
-
|
|
1502
|
-
import fastapi
|
|
1503
|
-
from fastapi import Request
|
|
1504
|
-
from fastapi.responses import JSONResponse, Response
|
|
1505
|
-
|
|
1506
|
-
app = fastapi.FastAPI()
|
|
1507
|
-
|
|
1508
|
-
class RegisterRequest(BaseModel):
|
|
1509
|
-
files: list[str]
|
|
1510
|
-
|
|
1511
|
-
class UpdateRequest(BaseModel):
|
|
1512
|
-
ranks: list[int] = []
|
|
1513
|
-
update_url: str | None = None
|
|
1514
|
-
inference_group_ranks: list[int] = []
|
|
1515
|
-
timeout: float = 300.0
|
|
1516
|
-
uds: str | None = None
|
|
1517
|
-
|
|
1518
|
-
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
1519
|
-
try:
|
|
1520
|
-
func()
|
|
1521
|
-
except Exception as e: # noqa: BLE001
|
|
1522
|
-
logger.exception(f"wrap exception {func} failed")
|
|
1523
|
-
return JSONResponse(content=str(e), status_code=500)
|
|
1524
|
-
return Response(status_code=200)
|
|
1525
|
-
|
|
1526
|
-
@app.post("/v1/checkpoints/{checkpoint_name}/files")
|
|
1527
|
-
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
|
|
1528
|
-
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
|
|
1529
|
-
|
|
1530
|
-
@app.delete("/v1/checkpoints/{checkpoint_name}")
|
|
1531
|
-
async def unregister_checkpoint(checkpoint_name: str) -> Response:
|
|
1532
|
-
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
|
|
1533
|
-
|
|
1534
|
-
@app.get("/v1/healthz")
|
|
1535
|
-
async def healthz() -> Response:
|
|
1536
|
-
return Response(status_code=200)
|
|
1537
|
-
|
|
1538
|
-
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
|
|
1539
|
-
async def gather_metas(checkpoint_name: str) -> Response:
|
|
1540
|
-
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
|
|
1541
|
-
|
|
1542
|
-
@app.post("/v1/checkpoints/{checkpoint_name}/update")
|
|
1543
|
-
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
|
|
1544
|
-
def update_func(socket_paths: list[tuple[str, str]]):
|
|
1545
|
-
if req.update_url is None:
|
|
1546
|
-
return
|
|
1547
|
-
if req.inference_group_ranks:
|
|
1548
|
-
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
|
|
1549
|
-
request_inference_to_update(
|
|
1550
|
-
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
|
|
1551
|
-
)
|
|
1552
|
-
|
|
1553
|
-
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
|
|
1554
|
-
|
|
1555
|
-
return app
|
|
1556
|
-
|
|
1557
|
-
|
|
1558
|
-
@logger.catch(reraise=True)
|
|
1559
|
-
def run_from_cli():
|
|
1560
|
-
import uvicorn
|
|
1561
|
-
|
|
1562
|
-
parser = argparse.ArgumentParser(description="Parameter Server")
|
|
1563
|
-
parser.add_argument("--uds", type=str)
|
|
1564
|
-
|
|
1565
|
-
args = parser.parse_args()
|
|
1566
|
-
logger.info(
|
|
1567
|
-
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
|
|
1568
|
-
)
|
|
1569
|
-
|
|
1570
|
-
assert args.uds and len(args.uds) > 0, args.uds
|
|
1571
|
-
ps = ParameterServer(auto_pg=True)
|
|
1572
|
-
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
|
|
1573
|
-
|
|
1574
|
-
|
|
859
|
+
# we need this CLI entry point for compatibility with former versions
|
|
1575
860
|
if __name__ == "__main__":
|
|
861
|
+
from .__main__ import run_from_cli
|
|
862
|
+
|
|
1576
863
|
run_from_cli()
|