checkpoint-engine 0.3.0rc1__py3-none-any.whl → 0.3.1rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/__init__.py +36 -0
- checkpoint_engine/__main__.py +28 -0
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/api.py +95 -0
- checkpoint_engine/data_types.py +111 -0
- checkpoint_engine/p2p_store.py +210 -0
- checkpoint_engine/pin_memory.py +390 -0
- checkpoint_engine/ps.py +23 -797
- checkpoint_engine/worker.py +18 -9
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/METADATA +1 -1
- checkpoint_engine-0.3.1rc0.dist-info/RECORD +15 -0
- checkpoint_engine-0.3.0rc1.dist-info/RECORD +0 -10
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/top_level.txt +0 -0
checkpoint_engine/ps.py
CHANGED
|
@@ -1,143 +1,33 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
import concurrent.futures
|
|
3
1
|
import ctypes
|
|
4
|
-
import json
|
|
5
2
|
import os
|
|
6
|
-
import pickle
|
|
7
|
-
import random
|
|
8
3
|
import threading
|
|
9
|
-
import time
|
|
10
4
|
from collections import defaultdict
|
|
11
5
|
from collections.abc import Callable
|
|
12
6
|
from datetime import timedelta
|
|
13
|
-
from typing import TYPE_CHECKING
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
14
8
|
|
|
15
|
-
import httpx
|
|
16
|
-
import numpy as np
|
|
17
9
|
import torch
|
|
18
10
|
import torch.distributed as dist
|
|
19
11
|
import zmq
|
|
20
12
|
from loguru import logger
|
|
21
|
-
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
22
|
-
from safetensors.torch import _getdtype, safe_open
|
|
23
13
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
24
14
|
|
|
15
|
+
from checkpoint_engine.data_types import (
|
|
16
|
+
BucketRange,
|
|
17
|
+
DataToGather,
|
|
18
|
+
H2DBucket,
|
|
19
|
+
MemoryBuffer,
|
|
20
|
+
MemoryBufferMetaList,
|
|
21
|
+
MemoryBufferMetas,
|
|
22
|
+
ParameterMeta,
|
|
23
|
+
)
|
|
25
24
|
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
25
|
+
from checkpoint_engine.p2p_store import P2PStore
|
|
26
|
+
from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
if TYPE_CHECKING:
|
|
29
|
-
from
|
|
30
|
-
|
|
31
|
-
from typing_extensions import TypedDict
|
|
32
|
-
|
|
33
|
-
class FileMeta(TypedDict):
|
|
34
|
-
key: str # parameter name
|
|
35
|
-
dtype: torch.dtype
|
|
36
|
-
shape: torch.Size
|
|
37
|
-
type: type
|
|
38
|
-
tp_concat_dim: int
|
|
39
|
-
|
|
40
|
-
T = TypeVar("T")
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def _dt_validate(value: Any) -> torch.dtype:
|
|
44
|
-
if isinstance(value, str):
|
|
45
|
-
if not value.startswith("torch."):
|
|
46
|
-
raise ValueError(f"dtype {value} should start with torch.")
|
|
47
|
-
try:
|
|
48
|
-
value = getattr(torch, value.split(".")[1])
|
|
49
|
-
except AttributeError as e:
|
|
50
|
-
raise ValueError(f"unknown dtype: {value}") from e
|
|
51
|
-
if not isinstance(value, torch.dtype):
|
|
52
|
-
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
|
|
53
|
-
return value
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
_TorchDtype = Annotated[
|
|
57
|
-
torch.dtype,
|
|
58
|
-
PlainValidator(_dt_validate),
|
|
59
|
-
PlainSerializer(lambda x: str(x), return_type=str),
|
|
60
|
-
WithJsonSchema({"type": "string"}, mode="serialization"),
|
|
61
|
-
]
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def _size_validate(value: Any) -> torch.Size:
|
|
65
|
-
if isinstance(value, list | tuple):
|
|
66
|
-
return torch.Size(value)
|
|
67
|
-
if not isinstance(value, torch.Size):
|
|
68
|
-
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
|
|
69
|
-
return value
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
_TorchSize = Annotated[
|
|
73
|
-
torch.Size,
|
|
74
|
-
PlainValidator(_size_validate),
|
|
75
|
-
PlainSerializer(lambda x: tuple(x), return_type=tuple),
|
|
76
|
-
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
|
|
77
|
-
]
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def _tensor_validate(value: Any) -> torch.Tensor:
|
|
81
|
-
if isinstance(value, torch.Tensor):
|
|
82
|
-
return value
|
|
83
|
-
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
_TorchTensor = Annotated[
|
|
87
|
-
torch.Tensor,
|
|
88
|
-
PlainValidator(_tensor_validate),
|
|
89
|
-
]
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class ParameterMeta(BaseModel):
|
|
93
|
-
name: str
|
|
94
|
-
dtype: _TorchDtype
|
|
95
|
-
shape: _TorchSize
|
|
96
|
-
aligned_size: int
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
class BucketRange(NamedTuple):
|
|
100
|
-
idx: int # bucket_idx of MemoryBucket in memory_pool
|
|
101
|
-
offset: int
|
|
102
|
-
size: int
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
class H2DBucket(BaseModel):
|
|
106
|
-
size: int
|
|
107
|
-
ranges: list[BucketRange]
|
|
108
|
-
items: list[ParameterMeta]
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
class MemoryBufferMetas(BaseModel):
|
|
112
|
-
metas: list[ParameterMeta]
|
|
113
|
-
ptr: int
|
|
114
|
-
size: int
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class MemoryBuffer(BaseModel):
|
|
118
|
-
buffer: _TorchTensor
|
|
119
|
-
size: int
|
|
120
|
-
metas: list[ParameterMeta]
|
|
121
|
-
manually_pinned: bool = False
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
class MemoryBufferMetaList(BaseModel):
|
|
125
|
-
p2p_store_addr: str | None
|
|
126
|
-
memory_buffer_metas_list: list[MemoryBufferMetas]
|
|
127
|
-
rdma_device: str
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
class DataToGather(MemoryBufferMetaList):
|
|
131
|
-
host_ip: str
|
|
132
|
-
device_uuid: str
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
# 256 bytes alignment when flatten torch tensors to uint8 buffer
|
|
136
|
-
_ALIGN_SIZE = 256
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
140
|
-
return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
|
|
30
|
+
from checkpoint_engine.data_types import T
|
|
141
31
|
|
|
142
32
|
|
|
143
33
|
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
@@ -156,107 +46,6 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
|
156
46
|
return ret
|
|
157
47
|
|
|
158
48
|
|
|
159
|
-
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
|
|
160
|
-
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
161
|
-
ret = {}
|
|
162
|
-
with safe_open(fn, framework="pt") as f:
|
|
163
|
-
for name in f.keys(): # noqa: SIM118
|
|
164
|
-
weight = f.get_tensor(name)
|
|
165
|
-
meta = {
|
|
166
|
-
"key": name,
|
|
167
|
-
"dtype": weight.dtype,
|
|
168
|
-
"shape": weight.shape,
|
|
169
|
-
"type": type(weight),
|
|
170
|
-
"tp_concat_dim": -1, # safetensors does not support tp_concat_dim
|
|
171
|
-
}
|
|
172
|
-
ret[name] = (meta, weight)
|
|
173
|
-
return ret
|
|
174
|
-
|
|
175
|
-
# deprecated, will be removed in the future
|
|
176
|
-
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
177
|
-
"""load *.np file and return memmap and related tensor meta"""
|
|
178
|
-
|
|
179
|
-
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
|
|
180
|
-
start = fin.tell()
|
|
181
|
-
major, minor = np.lib.format.read_magic(fin)
|
|
182
|
-
if major == 1 and minor == 0:
|
|
183
|
-
read_header_fn = np.lib.format.read_array_header_1_0
|
|
184
|
-
elif major == 2 and minor == 0:
|
|
185
|
-
read_header_fn = np.lib.format.read_array_header_2_0
|
|
186
|
-
else:
|
|
187
|
-
raise ValueError(
|
|
188
|
-
f"unknown version {major}.{minor} when parsing npy header from {fn}"
|
|
189
|
-
)
|
|
190
|
-
shape, is_fortran, dtype = read_header_fn(fin)
|
|
191
|
-
return {
|
|
192
|
-
"shape": shape,
|
|
193
|
-
"is_fortran": is_fortran,
|
|
194
|
-
"dtype": dtype,
|
|
195
|
-
"header_length": fin.tell() - start,
|
|
196
|
-
}
|
|
197
|
-
|
|
198
|
-
meta_fn = fn + ".meta"
|
|
199
|
-
with open(meta_fn, "rb") as fin:
|
|
200
|
-
meta_lst = pickle.load(fin)
|
|
201
|
-
|
|
202
|
-
tensors = []
|
|
203
|
-
offset = 0
|
|
204
|
-
with open(fn, "rb") as fin:
|
|
205
|
-
fin.seek(0, os.SEEK_END)
|
|
206
|
-
filesize = fin.tell()
|
|
207
|
-
fin.seek(0)
|
|
208
|
-
while fin.tell() < filesize:
|
|
209
|
-
tensor_meta = parse_npy_header(fin)
|
|
210
|
-
tensor = np.memmap(
|
|
211
|
-
fn,
|
|
212
|
-
dtype=tensor_meta["dtype"],
|
|
213
|
-
mode="c",
|
|
214
|
-
offset=offset + tensor_meta["header_length"],
|
|
215
|
-
shape=tensor_meta["shape"],
|
|
216
|
-
)
|
|
217
|
-
offset += tensor_meta["header_length"] + tensor.nbytes
|
|
218
|
-
fin.seek(offset)
|
|
219
|
-
tensors.append(tensor)
|
|
220
|
-
|
|
221
|
-
assert len(meta_lst) == len(tensors)
|
|
222
|
-
ret = {}
|
|
223
|
-
for meta, tensor in zip(meta_lst, tensors):
|
|
224
|
-
if meta["type"] == torch.Tensor:
|
|
225
|
-
tensor = torch.from_numpy(tensor)
|
|
226
|
-
tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
|
|
227
|
-
ret[meta["key"]] = (meta, tensor)
|
|
228
|
-
return ret
|
|
229
|
-
|
|
230
|
-
tp_rank = 0
|
|
231
|
-
if file_path.endswith(".npy"):
|
|
232
|
-
logger.warning("numpy model file is deprecated, will be removed in the future")
|
|
233
|
-
filename_split = os.path.basename(file_path).split(".")
|
|
234
|
-
# if using numpy and want to specify tp rank
|
|
235
|
-
# file should be in model.{layer}.{tp}[.{ep}].npy format
|
|
236
|
-
tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
|
|
237
|
-
ret = _fast_np_load(file_path)
|
|
238
|
-
elif file_path.endswith(".safetensors"):
|
|
239
|
-
ret = _safetensors_load(file_path)
|
|
240
|
-
else:
|
|
241
|
-
raise ValueError(f"unsupported file format: {file_path}")
|
|
242
|
-
return tp_rank, ret
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
def _concat_tp_weights(
|
|
246
|
-
tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
|
|
247
|
-
) -> torch.Tensor:
|
|
248
|
-
"""Concat tp weights with meta info.
|
|
249
|
-
If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
|
|
250
|
-
Else we will cat weights in concat_dim.
|
|
251
|
-
"""
|
|
252
|
-
if tp_concat_dim == -1:
|
|
253
|
-
return tp_weights[0]
|
|
254
|
-
assert tp_size == len(tp_weights)
|
|
255
|
-
if len(tp_weights) == 1:
|
|
256
|
-
return tp_weights[0]
|
|
257
|
-
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
258
|
-
|
|
259
|
-
|
|
260
49
|
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
|
|
261
50
|
try:
|
|
262
51
|
if device_manager.device_type == "npu":
|
|
@@ -267,426 +56,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None
|
|
|
267
56
|
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
268
57
|
|
|
269
58
|
|
|
270
|
-
def _ibv_get_device_list() -> list[str]:
|
|
271
|
-
lib = ctypes.CDLL("libibverbs.so.1")
|
|
272
|
-
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
273
|
-
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
|
|
274
|
-
|
|
275
|
-
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
|
276
|
-
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
|
|
277
|
-
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
|
|
278
|
-
|
|
279
|
-
num = ctypes.c_int()
|
|
280
|
-
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
|
|
281
|
-
if not dev_array or num.value <= 0:
|
|
282
|
-
return []
|
|
283
|
-
|
|
284
|
-
devices = []
|
|
285
|
-
for i in range(num.value):
|
|
286
|
-
dev_ptr = dev_array[i] # struct ibv_device *
|
|
287
|
-
name = lib.ibv_get_device_name(dev_ptr) # const char *
|
|
288
|
-
devices.append(name.decode())
|
|
289
|
-
lib.ibv_free_device_list(dev_array)
|
|
290
|
-
return devices
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
def _get_rdma_devices() -> list[str]:
|
|
294
|
-
"""
|
|
295
|
-
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
|
|
296
|
-
"""
|
|
297
|
-
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
|
|
298
|
-
if devices_str:
|
|
299
|
-
return devices_str.split(",")
|
|
300
|
-
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
301
|
-
hca = os.getenv("NCCL_IB_HCA", None)
|
|
302
|
-
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
306
|
-
"""
|
|
307
|
-
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
|
|
308
|
-
"""
|
|
309
|
-
if not devices:
|
|
310
|
-
raise RuntimeError("no rdma devices found")
|
|
311
|
-
try:
|
|
312
|
-
assert len(devices) <= gpu_count, (
|
|
313
|
-
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
314
|
-
)
|
|
315
|
-
assert gpu_count % len(devices) == 0, (
|
|
316
|
-
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
317
|
-
)
|
|
318
|
-
return devices[local_rank // (gpu_count // len(devices))]
|
|
319
|
-
except AssertionError:
|
|
320
|
-
logger.error(
|
|
321
|
-
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
|
|
322
|
-
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
|
|
323
|
-
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
|
|
324
|
-
)
|
|
325
|
-
raise
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
329
|
-
"""
|
|
330
|
-
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
|
|
331
|
-
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
|
|
332
|
-
|
|
333
|
-
The list is comma-separated; port numbers are NOT supported yet.
|
|
334
|
-
An optional prefix '^' indicates the list is an exclude list.
|
|
335
|
-
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
|
|
336
|
-
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
|
|
337
|
-
|
|
338
|
-
Examples:
|
|
339
|
-
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
|
|
340
|
-
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
|
|
341
|
-
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
|
|
342
|
-
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
|
|
343
|
-
"""
|
|
344
|
-
max_hcas = 32
|
|
345
|
-
if not value or value.strip() == "":
|
|
346
|
-
return available_devices[:max_hcas]
|
|
347
|
-
|
|
348
|
-
value = value.strip()
|
|
349
|
-
result = []
|
|
350
|
-
is_exclude = value.startswith("^")
|
|
351
|
-
if is_exclude:
|
|
352
|
-
value = value.removeprefix("^")
|
|
353
|
-
is_exact_match = value.startswith("=")
|
|
354
|
-
if is_exact_match:
|
|
355
|
-
value = value.removeprefix("=")
|
|
356
|
-
|
|
357
|
-
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
|
|
358
|
-
|
|
359
|
-
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
|
|
360
|
-
if is_exclude:
|
|
361
|
-
result = [dev for dev in available_devices if dev not in result]
|
|
362
|
-
if len(result) > max_hcas:
|
|
363
|
-
result = result[:max_hcas]
|
|
364
|
-
|
|
365
|
-
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
|
|
366
|
-
|
|
367
|
-
return result
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
def _resolve_device_specs(
|
|
371
|
-
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
|
|
372
|
-
) -> list[str]:
|
|
373
|
-
devices = set()
|
|
374
|
-
for spec in device_specs:
|
|
375
|
-
parts = spec.split(":", 1)
|
|
376
|
-
device_name = parts[0].strip()
|
|
377
|
-
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
|
|
378
|
-
# port = parts[1].strip() if len(parts) > 1 else None
|
|
379
|
-
base_devices = (
|
|
380
|
-
[device_name]
|
|
381
|
-
if device_name in available_devices
|
|
382
|
-
else []
|
|
383
|
-
if is_exact_match
|
|
384
|
-
else [dev for dev in available_devices if dev.startswith(device_name)]
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
if not base_devices:
|
|
388
|
-
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
|
|
389
|
-
continue
|
|
390
|
-
|
|
391
|
-
for base_dev in base_devices:
|
|
392
|
-
devices.add(base_dev)
|
|
393
|
-
|
|
394
|
-
return sorted(devices)
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
398
|
-
class TPMeta(BaseModel):
|
|
399
|
-
concat_dim: int
|
|
400
|
-
size: int
|
|
401
|
-
|
|
402
|
-
parameters: dict[str, torch.Tensor] = {}
|
|
403
|
-
parameter_metas: dict[str, ParameterMeta] = {}
|
|
404
|
-
tp_metas: dict[str, TPMeta] = {}
|
|
405
|
-
parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
|
|
406
|
-
for file in files:
|
|
407
|
-
tp_rank, ret = _load_checkpoint_file(file)
|
|
408
|
-
for parameter_name, (meta, weight) in ret.items():
|
|
409
|
-
if parameter_name not in parameters_with_tp:
|
|
410
|
-
parameters_with_tp[parameter_name] = {}
|
|
411
|
-
parameters_with_tp[parameter_name][tp_rank] = weight
|
|
412
|
-
if parameter_name not in tp_metas:
|
|
413
|
-
tp_metas[parameter_name] = TPMeta(
|
|
414
|
-
concat_dim=meta["tp_concat_dim"],
|
|
415
|
-
size=1,
|
|
416
|
-
)
|
|
417
|
-
if parameter_name not in parameter_metas:
|
|
418
|
-
assert isinstance(meta["dtype"], torch.dtype), (
|
|
419
|
-
f"meta {meta} dtype should be torch.dtype"
|
|
420
|
-
)
|
|
421
|
-
assert isinstance(meta["shape"], torch.Size), (
|
|
422
|
-
f"meta {meta} shape should be torch.Size"
|
|
423
|
-
)
|
|
424
|
-
parameter_metas[parameter_name] = ParameterMeta(
|
|
425
|
-
name=parameter_name,
|
|
426
|
-
shape=meta["shape"],
|
|
427
|
-
dtype=meta["dtype"],
|
|
428
|
-
aligned_size=_align_size(meta["dtype"], meta["shape"]),
|
|
429
|
-
)
|
|
430
|
-
tp_meta = tp_metas[parameter_name]
|
|
431
|
-
if tp_meta.concat_dim != -1:
|
|
432
|
-
tp_meta.size = max(tp_meta.size, tp_rank + 1)
|
|
433
|
-
for name, tp_meta in tp_metas.items():
|
|
434
|
-
if tp_meta.concat_dim != -1:
|
|
435
|
-
shape = list(parameter_metas[name].shape)
|
|
436
|
-
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
437
|
-
parameter_metas[name] = ParameterMeta(
|
|
438
|
-
name=name,
|
|
439
|
-
shape=torch.Size(shape),
|
|
440
|
-
dtype=parameter_metas[name].dtype,
|
|
441
|
-
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
|
|
442
|
-
)
|
|
443
|
-
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
444
|
-
# TODO: here concat is serial, which may be slow
|
|
445
|
-
# but since tp storage is not used in the future
|
|
446
|
-
# we ignore this performance issue for now
|
|
447
|
-
parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
|
|
448
|
-
for name, parameter in parameters.items():
|
|
449
|
-
assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
|
|
450
|
-
assert parameter_metas[name].shape == parameter.shape, (
|
|
451
|
-
f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
|
|
452
|
-
)
|
|
453
|
-
assert parameter_metas[name].dtype == parameter.dtype, (
|
|
454
|
-
f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
|
|
455
|
-
)
|
|
456
|
-
return parameters
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
460
|
-
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
461
|
-
"""
|
|
462
|
-
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
463
|
-
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
|
|
464
|
-
The actual tensor data is in the remaining bytes and is naturally aligned.
|
|
465
|
-
We pin the remaining bytes as the buffer, making pinning faster.
|
|
466
|
-
"""
|
|
467
|
-
|
|
468
|
-
def _pin(t: torch.Tensor):
|
|
469
|
-
"""
|
|
470
|
-
Pin the memory of tensor in-place.
|
|
471
|
-
See: https://github.com/pytorch/pytorch/issues/32167
|
|
472
|
-
"""
|
|
473
|
-
cudart = torch.cuda.cudart()
|
|
474
|
-
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
475
|
-
assert r == 0, f"pin memory error, error code: {r}"
|
|
476
|
-
|
|
477
|
-
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
478
|
-
size = os.stat(file_path).st_size
|
|
479
|
-
flag_size = 8
|
|
480
|
-
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
|
|
481
|
-
assert t.nbytes > flag_size, (
|
|
482
|
-
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
|
|
483
|
-
)
|
|
484
|
-
start_pos = (
|
|
485
|
-
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
|
|
486
|
-
+ flag_size
|
|
487
|
-
)
|
|
488
|
-
header_tensor = t[flag_size:start_pos]
|
|
489
|
-
header = json.loads(header_tensor.numpy().tobytes())
|
|
490
|
-
if "__metadata__" in header:
|
|
491
|
-
header.pop("__metadata__")
|
|
492
|
-
|
|
493
|
-
metas: list[ParameterMeta] = []
|
|
494
|
-
offset = 0
|
|
495
|
-
try:
|
|
496
|
-
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
|
|
497
|
-
start, end = meta["data_offsets"]
|
|
498
|
-
# safetensors format ensures offsets are aligned
|
|
499
|
-
assert offset == start, f"offset {offset} should be equal to start {start}"
|
|
500
|
-
metas.append(
|
|
501
|
-
ParameterMeta(
|
|
502
|
-
name=name,
|
|
503
|
-
dtype=_getdtype(meta["dtype"]),
|
|
504
|
-
shape=torch.Size(meta["shape"]),
|
|
505
|
-
aligned_size=end - start,
|
|
506
|
-
)
|
|
507
|
-
)
|
|
508
|
-
offset = end
|
|
509
|
-
except Exception as e:
|
|
510
|
-
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
|
|
511
|
-
raise
|
|
512
|
-
|
|
513
|
-
buffer = t[start_pos:]
|
|
514
|
-
assert offset == buffer.nbytes, (
|
|
515
|
-
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
|
|
516
|
-
)
|
|
517
|
-
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
518
|
-
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
519
|
-
os.remove(file_path)
|
|
520
|
-
_pin(buffer)
|
|
521
|
-
logger.info(
|
|
522
|
-
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
523
|
-
)
|
|
524
|
-
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
|
|
525
|
-
|
|
526
|
-
memory_buffers: list[MemoryBuffer] = []
|
|
527
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
528
|
-
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
|
|
529
|
-
return memory_buffers
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
def _normal_pin_memory(
|
|
533
|
-
files: list[str],
|
|
534
|
-
named_tensors: dict[str, torch.Tensor],
|
|
535
|
-
rank: int | None = None,
|
|
536
|
-
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
537
|
-
) -> list[MemoryBuffer]:
|
|
538
|
-
parameters = _load_checkpoint(files)
|
|
539
|
-
if named_tensors:
|
|
540
|
-
parameters.update(named_tensors)
|
|
541
|
-
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
|
|
542
|
-
|
|
543
|
-
class MemoryBucket(BaseModel):
|
|
544
|
-
size: int
|
|
545
|
-
metas: list[ParameterMeta]
|
|
546
|
-
|
|
547
|
-
buckets: list[MemoryBucket] = []
|
|
548
|
-
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
549
|
-
for name, tensor in sorted(parameters.items()):
|
|
550
|
-
size = _align_size(tensor.dtype, tensor.shape)
|
|
551
|
-
if buckets[-1].size + size > bucket_size:
|
|
552
|
-
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
|
|
553
|
-
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
554
|
-
buckets[-1].metas.append(
|
|
555
|
-
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
|
|
556
|
-
)
|
|
557
|
-
buckets[-1].size += size
|
|
558
|
-
|
|
559
|
-
memory_buffers = [
|
|
560
|
-
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
|
|
561
|
-
for bucket in buckets
|
|
562
|
-
]
|
|
563
|
-
|
|
564
|
-
def register_pin_memory(
|
|
565
|
-
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
|
|
566
|
-
) -> tuple[int, torch.Tensor]:
|
|
567
|
-
if shared_pin_memory:
|
|
568
|
-
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
|
|
569
|
-
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
|
|
570
|
-
assert idx < len(shared_pin_memory), (
|
|
571
|
-
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
|
|
572
|
-
)
|
|
573
|
-
assert shared_pin_memory[idx].size == size, (
|
|
574
|
-
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
|
|
575
|
-
)
|
|
576
|
-
return idx, shared_pin_memory[idx].buffer
|
|
577
|
-
else:
|
|
578
|
-
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
579
|
-
return idx, buffer
|
|
580
|
-
|
|
581
|
-
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
|
|
582
|
-
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
583
|
-
|
|
584
|
-
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
585
|
-
futures = [
|
|
586
|
-
executor.submit(
|
|
587
|
-
register_pin_memory,
|
|
588
|
-
idx,
|
|
589
|
-
bucket.size,
|
|
590
|
-
shared_pin_memory,
|
|
591
|
-
)
|
|
592
|
-
for idx, bucket in enumerate(buckets)
|
|
593
|
-
]
|
|
594
|
-
new_futures = []
|
|
595
|
-
for future in concurrent.futures.as_completed(futures):
|
|
596
|
-
idx, buffer = future.result()
|
|
597
|
-
assert buffer.numel() == buckets[idx].size, (
|
|
598
|
-
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
|
|
599
|
-
)
|
|
600
|
-
memory_buffers[idx].buffer = buffer
|
|
601
|
-
logger.info(
|
|
602
|
-
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
|
|
603
|
-
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
|
|
604
|
-
)
|
|
605
|
-
offset = 0
|
|
606
|
-
for meta in buckets[idx].metas:
|
|
607
|
-
name = meta.name
|
|
608
|
-
tensor = parameters[name]
|
|
609
|
-
size = _align_size(tensor.dtype, tensor.shape)
|
|
610
|
-
assert size == _align_size(meta.dtype, meta.shape), (
|
|
611
|
-
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
|
|
612
|
-
)
|
|
613
|
-
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
|
|
614
|
-
offset += size
|
|
615
|
-
for future in concurrent.futures.as_completed(new_futures):
|
|
616
|
-
future.result()
|
|
617
|
-
return memory_buffers
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
def _register_checkpoint(
|
|
621
|
-
*,
|
|
622
|
-
files: list[str],
|
|
623
|
-
named_tensors: dict[str, torch.Tensor],
|
|
624
|
-
rank: int | None = None,
|
|
625
|
-
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
626
|
-
inplace_pin: bool = False,
|
|
627
|
-
) -> list[MemoryBuffer]:
|
|
628
|
-
logger.info(
|
|
629
|
-
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
630
|
-
)
|
|
631
|
-
if not files and not named_tensors:
|
|
632
|
-
return []
|
|
633
|
-
memory_buffers: list[MemoryBuffer] = []
|
|
634
|
-
if inplace_pin:
|
|
635
|
-
logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
|
|
636
|
-
files_to_inplace_pin = [
|
|
637
|
-
file
|
|
638
|
-
for file in files
|
|
639
|
-
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
640
|
-
]
|
|
641
|
-
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
642
|
-
else:
|
|
643
|
-
files_to_normal_pin = files
|
|
644
|
-
files_to_inplace_pin = []
|
|
645
|
-
if files_to_normal_pin or named_tensors:
|
|
646
|
-
memory_buffers.extend(
|
|
647
|
-
_normal_pin_memory(
|
|
648
|
-
files=files_to_normal_pin,
|
|
649
|
-
named_tensors=named_tensors,
|
|
650
|
-
rank=rank,
|
|
651
|
-
shared_pin_memory=shared_pin_memory,
|
|
652
|
-
)
|
|
653
|
-
)
|
|
654
|
-
if files_to_inplace_pin:
|
|
655
|
-
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
|
|
656
|
-
return memory_buffers
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
def request_inference_to_update(
|
|
660
|
-
url: str,
|
|
661
|
-
socket_paths: dict[str, str],
|
|
662
|
-
timeout: float = 300.0,
|
|
663
|
-
uds: str | None = None,
|
|
664
|
-
):
|
|
665
|
-
"""Send an inference update request to inference server via HTTP or Unix socket.
|
|
666
|
-
|
|
667
|
-
Args:
|
|
668
|
-
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
|
|
669
|
-
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
|
|
670
|
-
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
|
|
671
|
-
uds (str, optional): Path to a Unix domain socket. If provided, the request
|
|
672
|
-
will be sent via the Unix socket instead of HTTP. Defaults to None.
|
|
673
|
-
|
|
674
|
-
Raises:
|
|
675
|
-
httpx.HTTPStatusError: If the response contains an HTTP error status.
|
|
676
|
-
httpx.RequestError: If there was an issue while making the request.
|
|
677
|
-
"""
|
|
678
|
-
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
|
|
679
|
-
url,
|
|
680
|
-
json={
|
|
681
|
-
"method": "update_weights_from_ipc",
|
|
682
|
-
"args": [socket_paths],
|
|
683
|
-
"timeout": timeout,
|
|
684
|
-
},
|
|
685
|
-
timeout=timeout,
|
|
686
|
-
)
|
|
687
|
-
resp.raise_for_status()
|
|
688
|
-
|
|
689
|
-
|
|
690
59
|
def _gen_h2d_buckets(
|
|
691
60
|
global_metas: dict[int, MemoryBufferMetaList],
|
|
692
61
|
bucket_size: int,
|
|
@@ -789,84 +158,12 @@ def _get_master_port(master_port: int | None = None) -> int:
|
|
|
789
158
|
if master_port is None:
|
|
790
159
|
# HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
|
|
791
160
|
# TODO: check whether master_port is available or use a more elegant way
|
|
792
|
-
|
|
161
|
+
master_port_str = os.getenv("MASTER_PORT")
|
|
162
|
+
assert master_port_str, "MASTER_PORT is required if no master_port is provided."
|
|
163
|
+
master_port = int(master_port_str) + 1
|
|
793
164
|
return master_port
|
|
794
165
|
|
|
795
166
|
|
|
796
|
-
class P2PStore:
|
|
797
|
-
def __init__(self, device_manager: DeviceManager):
|
|
798
|
-
from mooncake.engine import TransferEngine
|
|
799
|
-
|
|
800
|
-
self.rank = int(os.getenv("RANK"))
|
|
801
|
-
gpu_count = device_manager.device_module.device_count()
|
|
802
|
-
local_rank = self.rank % gpu_count
|
|
803
|
-
device_type = device_manager.device_type
|
|
804
|
-
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
|
|
805
|
-
self.device = ""
|
|
806
|
-
else:
|
|
807
|
-
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
808
|
-
self.ip = get_ip()
|
|
809
|
-
|
|
810
|
-
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
811
|
-
retry_count = 8
|
|
812
|
-
for i in range(retry_count):
|
|
813
|
-
self.engine = TransferEngine()
|
|
814
|
-
ret = self.engine.initialize(
|
|
815
|
-
self.ip,
|
|
816
|
-
"P2PHANDSHAKE",
|
|
817
|
-
"ascend_direct" if device_type == "npu" else "rdma",
|
|
818
|
-
self.device,
|
|
819
|
-
)
|
|
820
|
-
if ret == 0:
|
|
821
|
-
break
|
|
822
|
-
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
823
|
-
sleep_ms = random.randint(500, 2000)
|
|
824
|
-
logger.warning(
|
|
825
|
-
f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
|
|
826
|
-
)
|
|
827
|
-
time.sleep(sleep_ms / 1000)
|
|
828
|
-
else:
|
|
829
|
-
raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
|
|
830
|
-
self.port = self.engine.get_rpc_port()
|
|
831
|
-
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
832
|
-
logger.info(
|
|
833
|
-
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
|
|
834
|
-
)
|
|
835
|
-
|
|
836
|
-
@property
|
|
837
|
-
def addr(self) -> str:
|
|
838
|
-
return f"{self.ip}:{self.port}"
|
|
839
|
-
|
|
840
|
-
def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
|
|
841
|
-
buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
|
|
842
|
-
capacities = [tensor.nbytes for tensor in named_tensors.values()]
|
|
843
|
-
self.named_tensors.update(named_tensors)
|
|
844
|
-
for i, name in enumerate(named_tensors.keys()):
|
|
845
|
-
logger.info(
|
|
846
|
-
f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
|
|
847
|
-
)
|
|
848
|
-
assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
|
|
849
|
-
|
|
850
|
-
def unregister_named_tensors(self, names: list[str]) -> int:
|
|
851
|
-
buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
|
|
852
|
-
assert self.engine.batch_unregister_memory(buffer_addresses) == 0
|
|
853
|
-
num_unregistered = 0
|
|
854
|
-
for i, name in enumerate(names):
|
|
855
|
-
del self.named_tensors[name]
|
|
856
|
-
logger.info(
|
|
857
|
-
f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
|
|
858
|
-
)
|
|
859
|
-
num_unregistered += 1
|
|
860
|
-
return num_unregistered
|
|
861
|
-
|
|
862
|
-
def batch_transfer_sync_read(
|
|
863
|
-
self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
|
|
864
|
-
):
|
|
865
|
-
assert (
|
|
866
|
-
self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
|
|
867
|
-
)
|
|
868
|
-
|
|
869
|
-
|
|
870
167
|
class ParameterServer:
|
|
871
168
|
shared_memory_pool_name = "__shared_memory_pool__"
|
|
872
169
|
|
|
@@ -887,8 +184,8 @@ class ParameterServer:
|
|
|
887
184
|
Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
|
|
888
185
|
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
889
186
|
"""
|
|
890
|
-
self._rank = rank or int(os.environ
|
|
891
|
-
self._world_size = world_size or int(os.environ
|
|
187
|
+
self._rank = rank or int(os.environ["RANK"])
|
|
188
|
+
self._world_size = world_size or int(os.environ["WORLD_SIZE"])
|
|
892
189
|
self.device_manager = DeviceManager()
|
|
893
190
|
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
|
|
894
191
|
self._local_rank = self._rank % self._gpu_count
|
|
@@ -897,7 +194,7 @@ class ParameterServer:
|
|
|
897
194
|
self._global_device_uuids: list[str] = []
|
|
898
195
|
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
899
196
|
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
900
|
-
self._mem_fraction = mem_fraction or 0.9
|
|
197
|
+
self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
|
|
901
198
|
|
|
902
199
|
assert self._rank is not None and self._rank >= 0, self._rank
|
|
903
200
|
assert self._world_size and self._world_size > 0, self._world_size
|
|
@@ -1352,7 +649,7 @@ class ParameterServer:
|
|
|
1352
649
|
f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
|
|
1353
650
|
)
|
|
1354
651
|
disable_h2d_buffer = True
|
|
1355
|
-
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
|
|
652
|
+
max_bytes = int(float(os.getenv("PS_MAX_BUCKET_SIZE_GB", "8")) * GiB)
|
|
1356
653
|
bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
|
|
1357
654
|
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
|
|
1358
655
|
return bucket_size, disable_h2d_buffer
|
|
@@ -1559,79 +856,8 @@ class ParameterServer:
|
|
|
1559
856
|
self.device_manager.device_module.empty_cache()
|
|
1560
857
|
|
|
1561
858
|
|
|
1562
|
-
|
|
1563
|
-
import fastapi
|
|
1564
|
-
from fastapi import Request
|
|
1565
|
-
from fastapi.responses import JSONResponse, Response
|
|
1566
|
-
|
|
1567
|
-
app = fastapi.FastAPI()
|
|
1568
|
-
|
|
1569
|
-
class RegisterRequest(BaseModel):
|
|
1570
|
-
files: list[str]
|
|
1571
|
-
|
|
1572
|
-
class UpdateRequest(BaseModel):
|
|
1573
|
-
ranks: list[int] = []
|
|
1574
|
-
update_url: str | None = None
|
|
1575
|
-
inference_group_ranks: list[int] = []
|
|
1576
|
-
timeout: float = 300.0
|
|
1577
|
-
uds: str | None = None
|
|
1578
|
-
|
|
1579
|
-
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
1580
|
-
try:
|
|
1581
|
-
func()
|
|
1582
|
-
except Exception as e: # noqa: BLE001
|
|
1583
|
-
logger.exception(f"wrap exception {func} failed")
|
|
1584
|
-
return JSONResponse(content=str(e), status_code=500)
|
|
1585
|
-
return Response(status_code=200)
|
|
1586
|
-
|
|
1587
|
-
@app.post("/v1/checkpoints/{checkpoint_name}/files")
|
|
1588
|
-
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
|
|
1589
|
-
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
|
|
1590
|
-
|
|
1591
|
-
@app.delete("/v1/checkpoints/{checkpoint_name}")
|
|
1592
|
-
async def unregister_checkpoint(checkpoint_name: str) -> Response:
|
|
1593
|
-
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
|
|
1594
|
-
|
|
1595
|
-
@app.get("/v1/healthz")
|
|
1596
|
-
async def healthz() -> Response:
|
|
1597
|
-
return Response(status_code=200)
|
|
1598
|
-
|
|
1599
|
-
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
|
|
1600
|
-
async def gather_metas(checkpoint_name: str) -> Response:
|
|
1601
|
-
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
|
|
1602
|
-
|
|
1603
|
-
@app.post("/v1/checkpoints/{checkpoint_name}/update")
|
|
1604
|
-
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
|
|
1605
|
-
def update_func(socket_paths: list[tuple[str, str]]):
|
|
1606
|
-
if req.update_url is None:
|
|
1607
|
-
return
|
|
1608
|
-
if req.inference_group_ranks:
|
|
1609
|
-
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
|
|
1610
|
-
request_inference_to_update(
|
|
1611
|
-
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
|
|
1612
|
-
)
|
|
1613
|
-
|
|
1614
|
-
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
|
|
1615
|
-
|
|
1616
|
-
return app
|
|
1617
|
-
|
|
1618
|
-
|
|
1619
|
-
@logger.catch(reraise=True)
|
|
1620
|
-
def run_from_cli():
|
|
1621
|
-
import uvicorn
|
|
1622
|
-
|
|
1623
|
-
parser = argparse.ArgumentParser(description="Parameter Server")
|
|
1624
|
-
parser.add_argument("--uds", type=str)
|
|
1625
|
-
|
|
1626
|
-
args = parser.parse_args()
|
|
1627
|
-
logger.info(
|
|
1628
|
-
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
|
|
1629
|
-
)
|
|
1630
|
-
|
|
1631
|
-
assert args.uds and len(args.uds) > 0, args.uds
|
|
1632
|
-
ps = ParameterServer(auto_pg=True)
|
|
1633
|
-
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
|
|
1634
|
-
|
|
1635
|
-
|
|
859
|
+
# we need this CLI entry point for compatibility with former versions
|
|
1636
860
|
if __name__ == "__main__":
|
|
861
|
+
from .__main__ import run_from_cli
|
|
862
|
+
|
|
1637
863
|
run_from_cli()
|