checkpoint-engine 0.3.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/__init__.py +4 -0
- checkpoint_engine/_version.py +34 -0
- checkpoint_engine/device_utils.py +86 -0
- checkpoint_engine/ps.py +1576 -0
- checkpoint_engine/worker.py +168 -0
- checkpoint_engine-0.3.0rc0.dist-info/METADATA +236 -0
- checkpoint_engine-0.3.0rc0.dist-info/RECORD +10 -0
- checkpoint_engine-0.3.0rc0.dist-info/WHEEL +5 -0
- checkpoint_engine-0.3.0rc0.dist-info/licenses/LICENCE +21 -0
- checkpoint_engine-0.3.0rc0.dist-info/top_level.txt +1 -0
checkpoint_engine/ps.py
ADDED
|
@@ -0,0 +1,1576 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import concurrent.futures
|
|
3
|
+
import ctypes
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import pickle
|
|
7
|
+
import random
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from collections import defaultdict
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from datetime import timedelta
|
|
13
|
+
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
import numpy as np
|
|
17
|
+
import torch
|
|
18
|
+
import torch.distributed as dist
|
|
19
|
+
import zmq
|
|
20
|
+
from loguru import logger
|
|
21
|
+
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
22
|
+
from safetensors.torch import _getdtype, safe_open
|
|
23
|
+
from torch.multiprocessing.reductions import reduce_tensor
|
|
24
|
+
|
|
25
|
+
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from typing import TypeVar
|
|
30
|
+
|
|
31
|
+
from typing_extensions import TypedDict
|
|
32
|
+
|
|
33
|
+
class FileMeta(TypedDict):
|
|
34
|
+
key: str # parameter name
|
|
35
|
+
dtype: torch.dtype
|
|
36
|
+
shape: torch.Size
|
|
37
|
+
type: type
|
|
38
|
+
tp_concat_dim: int
|
|
39
|
+
|
|
40
|
+
T = TypeVar("T")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _dt_validate(value: Any) -> torch.dtype:
|
|
44
|
+
if isinstance(value, str):
|
|
45
|
+
if not value.startswith("torch."):
|
|
46
|
+
raise ValueError(f"dtype {value} should start with torch.")
|
|
47
|
+
try:
|
|
48
|
+
value = getattr(torch, value.split(".")[1])
|
|
49
|
+
except AttributeError as e:
|
|
50
|
+
raise ValueError(f"unknown dtype: {value}") from e
|
|
51
|
+
if not isinstance(value, torch.dtype):
|
|
52
|
+
raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
|
|
53
|
+
return value
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
_TorchDtype = Annotated[
|
|
57
|
+
torch.dtype,
|
|
58
|
+
PlainValidator(_dt_validate),
|
|
59
|
+
PlainSerializer(lambda x: str(x), return_type=str),
|
|
60
|
+
WithJsonSchema({"type": "string"}, mode="serialization"),
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _size_validate(value: Any) -> torch.Size:
|
|
65
|
+
if isinstance(value, list | tuple):
|
|
66
|
+
return torch.Size(value)
|
|
67
|
+
if not isinstance(value, torch.Size):
|
|
68
|
+
raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
|
|
69
|
+
return value
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
_TorchSize = Annotated[
|
|
73
|
+
torch.Size,
|
|
74
|
+
PlainValidator(_size_validate),
|
|
75
|
+
PlainSerializer(lambda x: tuple(x), return_type=tuple),
|
|
76
|
+
WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _tensor_validate(value: Any) -> torch.Tensor:
|
|
81
|
+
if isinstance(value, torch.Tensor):
|
|
82
|
+
return value
|
|
83
|
+
raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
_TorchTensor = Annotated[
|
|
87
|
+
torch.Tensor,
|
|
88
|
+
PlainValidator(_tensor_validate),
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ParameterMeta(BaseModel):
|
|
93
|
+
name: str
|
|
94
|
+
dtype: _TorchDtype
|
|
95
|
+
shape: _TorchSize
|
|
96
|
+
aligned_size: int
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class BucketRange(NamedTuple):
|
|
100
|
+
idx: int # bucket_idx of MemoryBucket in memory_pool
|
|
101
|
+
offset: int
|
|
102
|
+
size: int
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class H2DBucket(BaseModel):
|
|
106
|
+
size: int
|
|
107
|
+
ranges: list[BucketRange]
|
|
108
|
+
items: list[ParameterMeta]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class MemoryBufferMetas(BaseModel):
|
|
112
|
+
metas: list[ParameterMeta]
|
|
113
|
+
ptr: int
|
|
114
|
+
size: int
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class MemoryBuffer(BaseModel):
|
|
118
|
+
buffer: _TorchTensor
|
|
119
|
+
size: int
|
|
120
|
+
metas: list[ParameterMeta]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class MemoryBufferMetaList(BaseModel):
|
|
124
|
+
p2p_store_addr: str | None
|
|
125
|
+
memory_buffer_metas_list: list[MemoryBufferMetas]
|
|
126
|
+
rdma_device: str
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class DataToGather(MemoryBufferMetaList):
|
|
130
|
+
host_ip: str
|
|
131
|
+
device_uuid: str
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# 256 bytes alignment when flatten torch tensors to uint8 buffer
|
|
135
|
+
_ALIGN_SIZE = 256
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
139
|
+
return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
143
|
+
ret = []
|
|
144
|
+
for meta in metas:
|
|
145
|
+
size = meta.aligned_size
|
|
146
|
+
ret.append(
|
|
147
|
+
{
|
|
148
|
+
"name": meta.name,
|
|
149
|
+
"dtype": meta.dtype,
|
|
150
|
+
"shape": meta.shape,
|
|
151
|
+
"offset": offset,
|
|
152
|
+
}
|
|
153
|
+
)
|
|
154
|
+
offset += size
|
|
155
|
+
return ret
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
|
|
159
|
+
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
160
|
+
ret = {}
|
|
161
|
+
with safe_open(fn, framework="pt") as f:
|
|
162
|
+
for name in f.keys(): # noqa: SIM118
|
|
163
|
+
weight = f.get_tensor(name)
|
|
164
|
+
meta = {
|
|
165
|
+
"key": name,
|
|
166
|
+
"dtype": weight.dtype,
|
|
167
|
+
"shape": weight.shape,
|
|
168
|
+
"type": type(weight),
|
|
169
|
+
"tp_concat_dim": -1, # safetensors does not support tp_concat_dim
|
|
170
|
+
}
|
|
171
|
+
ret[name] = (meta, weight)
|
|
172
|
+
return ret
|
|
173
|
+
|
|
174
|
+
# deprecated, will be removed in the future
|
|
175
|
+
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
176
|
+
"""load *.np file and return memmap and related tensor meta"""
|
|
177
|
+
|
|
178
|
+
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
|
|
179
|
+
start = fin.tell()
|
|
180
|
+
major, minor = np.lib.format.read_magic(fin)
|
|
181
|
+
if major == 1 and minor == 0:
|
|
182
|
+
read_header_fn = np.lib.format.read_array_header_1_0
|
|
183
|
+
elif major == 2 and minor == 0:
|
|
184
|
+
read_header_fn = np.lib.format.read_array_header_2_0
|
|
185
|
+
else:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"unknown version {major}.{minor} when parsing npy header from {fn}"
|
|
188
|
+
)
|
|
189
|
+
shape, is_fortran, dtype = read_header_fn(fin)
|
|
190
|
+
return {
|
|
191
|
+
"shape": shape,
|
|
192
|
+
"is_fortran": is_fortran,
|
|
193
|
+
"dtype": dtype,
|
|
194
|
+
"header_length": fin.tell() - start,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
meta_fn = fn + ".meta"
|
|
198
|
+
with open(meta_fn, "rb") as fin:
|
|
199
|
+
meta_lst = pickle.load(fin)
|
|
200
|
+
|
|
201
|
+
tensors = []
|
|
202
|
+
offset = 0
|
|
203
|
+
with open(fn, "rb") as fin:
|
|
204
|
+
fin.seek(0, os.SEEK_END)
|
|
205
|
+
filesize = fin.tell()
|
|
206
|
+
fin.seek(0)
|
|
207
|
+
while fin.tell() < filesize:
|
|
208
|
+
tensor_meta = parse_npy_header(fin)
|
|
209
|
+
tensor = np.memmap(
|
|
210
|
+
fn,
|
|
211
|
+
dtype=tensor_meta["dtype"],
|
|
212
|
+
mode="c",
|
|
213
|
+
offset=offset + tensor_meta["header_length"],
|
|
214
|
+
shape=tensor_meta["shape"],
|
|
215
|
+
)
|
|
216
|
+
offset += tensor_meta["header_length"] + tensor.nbytes
|
|
217
|
+
fin.seek(offset)
|
|
218
|
+
tensors.append(tensor)
|
|
219
|
+
|
|
220
|
+
assert len(meta_lst) == len(tensors)
|
|
221
|
+
ret = {}
|
|
222
|
+
for meta, tensor in zip(meta_lst, tensors):
|
|
223
|
+
if meta["type"] == torch.Tensor:
|
|
224
|
+
tensor = torch.from_numpy(tensor)
|
|
225
|
+
tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
|
|
226
|
+
ret[meta["key"]] = (meta, tensor)
|
|
227
|
+
return ret
|
|
228
|
+
|
|
229
|
+
tp_rank = 0
|
|
230
|
+
if file_path.endswith(".npy"):
|
|
231
|
+
logger.warning("numpy model file is deprecated, will be removed in the future")
|
|
232
|
+
filename_split = os.path.basename(file_path).split(".")
|
|
233
|
+
# if using numpy and want to specify tp rank
|
|
234
|
+
# file should be in model.{layer}.{tp}[.{ep}].npy format
|
|
235
|
+
tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
|
|
236
|
+
ret = _fast_np_load(file_path)
|
|
237
|
+
elif file_path.endswith(".safetensors"):
|
|
238
|
+
ret = _safetensors_load(file_path)
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(f"unsupported file format: {file_path}")
|
|
241
|
+
return tp_rank, ret
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _concat_tp_weights(
|
|
245
|
+
tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
|
|
246
|
+
) -> torch.Tensor:
|
|
247
|
+
"""Concat tp weights with meta info.
|
|
248
|
+
If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
|
|
249
|
+
Else we will cat weights in concat_dim.
|
|
250
|
+
"""
|
|
251
|
+
if tp_concat_dim == -1:
|
|
252
|
+
return tp_weights[0]
|
|
253
|
+
assert tp_size == len(tp_weights)
|
|
254
|
+
if len(tp_weights) == 1:
|
|
255
|
+
return tp_weights[0]
|
|
256
|
+
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
|
|
260
|
+
try:
|
|
261
|
+
if device_manager.device_type == "npu":
|
|
262
|
+
return f"NPU-{npu_generate_uuid()}"
|
|
263
|
+
else:
|
|
264
|
+
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
|
|
265
|
+
except AssertionError as e:
|
|
266
|
+
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _ibv_get_device_list() -> list[str]:
|
|
270
|
+
lib = ctypes.CDLL("libibverbs.so.1")
|
|
271
|
+
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
272
|
+
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
|
|
273
|
+
|
|
274
|
+
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
|
|
275
|
+
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
|
|
276
|
+
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
|
|
277
|
+
|
|
278
|
+
num = ctypes.c_int()
|
|
279
|
+
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
|
|
280
|
+
if not dev_array or num.value <= 0:
|
|
281
|
+
return []
|
|
282
|
+
|
|
283
|
+
devices = []
|
|
284
|
+
for i in range(num.value):
|
|
285
|
+
dev_ptr = dev_array[i] # struct ibv_device *
|
|
286
|
+
name = lib.ibv_get_device_name(dev_ptr) # const char *
|
|
287
|
+
devices.append(name.decode())
|
|
288
|
+
lib.ibv_free_device_list(dev_array)
|
|
289
|
+
return devices
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _get_rdma_devices() -> list[str]:
|
|
293
|
+
"""
|
|
294
|
+
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
|
|
295
|
+
"""
|
|
296
|
+
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
|
|
297
|
+
if devices_str:
|
|
298
|
+
return devices_str.split(",")
|
|
299
|
+
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
|
|
300
|
+
hca = os.getenv("NCCL_IB_HCA", None)
|
|
301
|
+
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
|
|
305
|
+
"""
|
|
306
|
+
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
|
|
307
|
+
"""
|
|
308
|
+
if not devices:
|
|
309
|
+
raise RuntimeError("no rdma devices found")
|
|
310
|
+
try:
|
|
311
|
+
assert len(devices) <= gpu_count, (
|
|
312
|
+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
313
|
+
)
|
|
314
|
+
assert gpu_count % len(devices) == 0, (
|
|
315
|
+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
316
|
+
)
|
|
317
|
+
return devices[local_rank // (gpu_count // len(devices))]
|
|
318
|
+
except AssertionError:
|
|
319
|
+
logger.error(
|
|
320
|
+
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
|
|
321
|
+
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
|
|
322
|
+
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
|
|
323
|
+
)
|
|
324
|
+
raise
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
328
|
+
"""
|
|
329
|
+
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
|
|
330
|
+
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
|
|
331
|
+
|
|
332
|
+
The list is comma-separated; port numbers are NOT supported yet.
|
|
333
|
+
An optional prefix '^' indicates the list is an exclude list.
|
|
334
|
+
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
|
|
335
|
+
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
|
|
336
|
+
|
|
337
|
+
Examples:
|
|
338
|
+
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
|
|
339
|
+
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
|
|
340
|
+
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
|
|
341
|
+
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
|
|
342
|
+
"""
|
|
343
|
+
max_hcas = 32
|
|
344
|
+
if not value or value.strip() == "":
|
|
345
|
+
return available_devices[:max_hcas]
|
|
346
|
+
|
|
347
|
+
value = value.strip()
|
|
348
|
+
result = []
|
|
349
|
+
is_exclude = value.startswith("^")
|
|
350
|
+
if is_exclude:
|
|
351
|
+
value = value.removeprefix("^")
|
|
352
|
+
is_exact_match = value.startswith("=")
|
|
353
|
+
if is_exact_match:
|
|
354
|
+
value = value.removeprefix("=")
|
|
355
|
+
|
|
356
|
+
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
|
|
357
|
+
|
|
358
|
+
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
|
|
359
|
+
if is_exclude:
|
|
360
|
+
result = [dev for dev in available_devices if dev not in result]
|
|
361
|
+
if len(result) > max_hcas:
|
|
362
|
+
result = result[:max_hcas]
|
|
363
|
+
|
|
364
|
+
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
|
|
365
|
+
|
|
366
|
+
return result
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _resolve_device_specs(
|
|
370
|
+
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
|
|
371
|
+
) -> list[str]:
|
|
372
|
+
devices = set()
|
|
373
|
+
for spec in device_specs:
|
|
374
|
+
parts = spec.split(":", 1)
|
|
375
|
+
device_name = parts[0].strip()
|
|
376
|
+
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
|
|
377
|
+
# port = parts[1].strip() if len(parts) > 1 else None
|
|
378
|
+
base_devices = (
|
|
379
|
+
[device_name]
|
|
380
|
+
if device_name in available_devices
|
|
381
|
+
else []
|
|
382
|
+
if is_exact_match
|
|
383
|
+
else [dev for dev in available_devices if dev.startswith(device_name)]
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if not base_devices:
|
|
387
|
+
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
|
|
388
|
+
continue
|
|
389
|
+
|
|
390
|
+
for base_dev in base_devices:
|
|
391
|
+
devices.add(base_dev)
|
|
392
|
+
|
|
393
|
+
return sorted(devices)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
397
|
+
class TPMeta(BaseModel):
|
|
398
|
+
concat_dim: int
|
|
399
|
+
size: int
|
|
400
|
+
|
|
401
|
+
parameters: dict[str, torch.Tensor] = {}
|
|
402
|
+
parameter_metas: dict[str, ParameterMeta] = {}
|
|
403
|
+
tp_metas: dict[str, TPMeta] = {}
|
|
404
|
+
parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
|
|
405
|
+
for file in files:
|
|
406
|
+
tp_rank, ret = _load_checkpoint_file(file)
|
|
407
|
+
for parameter_name, (meta, weight) in ret.items():
|
|
408
|
+
if parameter_name not in parameters_with_tp:
|
|
409
|
+
parameters_with_tp[parameter_name] = {}
|
|
410
|
+
parameters_with_tp[parameter_name][tp_rank] = weight
|
|
411
|
+
if parameter_name not in tp_metas:
|
|
412
|
+
tp_metas[parameter_name] = TPMeta(
|
|
413
|
+
concat_dim=meta["tp_concat_dim"],
|
|
414
|
+
size=1,
|
|
415
|
+
)
|
|
416
|
+
if parameter_name not in parameter_metas:
|
|
417
|
+
assert isinstance(meta["dtype"], torch.dtype), (
|
|
418
|
+
f"meta {meta} dtype should be torch.dtype"
|
|
419
|
+
)
|
|
420
|
+
assert isinstance(meta["shape"], torch.Size), (
|
|
421
|
+
f"meta {meta} shape should be torch.Size"
|
|
422
|
+
)
|
|
423
|
+
parameter_metas[parameter_name] = ParameterMeta(
|
|
424
|
+
name=parameter_name,
|
|
425
|
+
shape=meta["shape"],
|
|
426
|
+
dtype=meta["dtype"],
|
|
427
|
+
aligned_size=_align_size(meta["dtype"], meta["shape"]),
|
|
428
|
+
)
|
|
429
|
+
tp_meta = tp_metas[parameter_name]
|
|
430
|
+
if tp_meta.concat_dim != -1:
|
|
431
|
+
tp_meta.size = max(tp_meta.size, tp_rank + 1)
|
|
432
|
+
for name, tp_meta in tp_metas.items():
|
|
433
|
+
if tp_meta.concat_dim != -1:
|
|
434
|
+
shape = list(parameter_metas[name].shape)
|
|
435
|
+
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
436
|
+
parameter_metas[name] = ParameterMeta(
|
|
437
|
+
name=name,
|
|
438
|
+
shape=torch.Size(shape),
|
|
439
|
+
dtype=parameter_metas[name].dtype,
|
|
440
|
+
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
|
|
441
|
+
)
|
|
442
|
+
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
443
|
+
# TODO: here concat is serial, which may be slow
|
|
444
|
+
# but since tp storage is not used in the future
|
|
445
|
+
# we ignore this performance issue for now
|
|
446
|
+
parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
|
|
447
|
+
for name, parameter in parameters.items():
|
|
448
|
+
assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
|
|
449
|
+
assert parameter_metas[name].shape == parameter.shape, (
|
|
450
|
+
f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
|
|
451
|
+
)
|
|
452
|
+
assert parameter_metas[name].dtype == parameter.dtype, (
|
|
453
|
+
f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
|
|
454
|
+
)
|
|
455
|
+
return parameters
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
459
|
+
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
460
|
+
"""
|
|
461
|
+
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
462
|
+
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
|
|
463
|
+
The actual tensor data is in the remaining bytes and is naturally aligned.
|
|
464
|
+
We pin the remaining bytes as the buffer, making pinning faster.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def _pin(t: torch.Tensor):
|
|
468
|
+
"""
|
|
469
|
+
Pin the memory of tensor in-place.
|
|
470
|
+
See: https://github.com/pytorch/pytorch/issues/32167
|
|
471
|
+
"""
|
|
472
|
+
cudart = torch.cuda.cudart()
|
|
473
|
+
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
474
|
+
assert r == 0, f"pin memory error, error code: {r}"
|
|
475
|
+
|
|
476
|
+
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
477
|
+
size = os.stat(file_path).st_size
|
|
478
|
+
flag_size = 8
|
|
479
|
+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
|
|
480
|
+
assert t.nbytes > flag_size, (
|
|
481
|
+
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
|
|
482
|
+
)
|
|
483
|
+
start_pos = (
|
|
484
|
+
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
|
|
485
|
+
+ flag_size
|
|
486
|
+
)
|
|
487
|
+
header_tensor = t[flag_size:start_pos]
|
|
488
|
+
header = json.loads(header_tensor.numpy().tobytes())
|
|
489
|
+
if "__metadata__" in header:
|
|
490
|
+
header.pop("__metadata__")
|
|
491
|
+
|
|
492
|
+
metas: list[ParameterMeta] = []
|
|
493
|
+
offset = 0
|
|
494
|
+
try:
|
|
495
|
+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
|
|
496
|
+
start, end = meta["data_offsets"]
|
|
497
|
+
# safetensors format ensures offsets are aligned
|
|
498
|
+
assert offset == start, f"offset {offset} should be equal to start {start}"
|
|
499
|
+
metas.append(
|
|
500
|
+
ParameterMeta(
|
|
501
|
+
name=name,
|
|
502
|
+
dtype=_getdtype(meta["dtype"]),
|
|
503
|
+
shape=torch.Size(meta["shape"]),
|
|
504
|
+
aligned_size=end - start,
|
|
505
|
+
)
|
|
506
|
+
)
|
|
507
|
+
offset = end
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
|
|
510
|
+
raise
|
|
511
|
+
|
|
512
|
+
buffer = t[start_pos:]
|
|
513
|
+
assert offset == buffer.nbytes, (
|
|
514
|
+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
|
|
515
|
+
)
|
|
516
|
+
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
517
|
+
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
518
|
+
os.remove(file_path)
|
|
519
|
+
_pin(buffer)
|
|
520
|
+
logger.info(
|
|
521
|
+
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
522
|
+
)
|
|
523
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
|
|
524
|
+
|
|
525
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
526
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
527
|
+
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
|
|
528
|
+
return memory_buffers
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _normal_pin_memory(
|
|
532
|
+
files: list[str],
|
|
533
|
+
named_tensors: dict[str, torch.Tensor],
|
|
534
|
+
rank: int | None = None,
|
|
535
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
536
|
+
) -> list[MemoryBuffer]:
|
|
537
|
+
parameters = _load_checkpoint(files)
|
|
538
|
+
if named_tensors:
|
|
539
|
+
parameters.update(named_tensors)
|
|
540
|
+
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
|
|
541
|
+
|
|
542
|
+
class MemoryBucket(BaseModel):
|
|
543
|
+
size: int
|
|
544
|
+
metas: list[ParameterMeta]
|
|
545
|
+
|
|
546
|
+
buckets: list[MemoryBucket] = []
|
|
547
|
+
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
548
|
+
for name, tensor in sorted(parameters.items()):
|
|
549
|
+
size = _align_size(tensor.dtype, tensor.shape)
|
|
550
|
+
if buckets[-1].size + size > bucket_size:
|
|
551
|
+
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
|
|
552
|
+
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
553
|
+
buckets[-1].metas.append(
|
|
554
|
+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
|
|
555
|
+
)
|
|
556
|
+
buckets[-1].size += size
|
|
557
|
+
|
|
558
|
+
memory_buffers = [
|
|
559
|
+
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
|
|
560
|
+
for bucket in buckets
|
|
561
|
+
]
|
|
562
|
+
|
|
563
|
+
def register_pin_memory(
|
|
564
|
+
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
|
|
565
|
+
) -> tuple[int, torch.Tensor]:
|
|
566
|
+
if shared_pin_memory:
|
|
567
|
+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
|
|
568
|
+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
|
|
569
|
+
assert idx < len(shared_pin_memory), (
|
|
570
|
+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
|
|
571
|
+
)
|
|
572
|
+
assert shared_pin_memory[idx].size == size, (
|
|
573
|
+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
|
|
574
|
+
)
|
|
575
|
+
return idx, shared_pin_memory[idx].buffer
|
|
576
|
+
else:
|
|
577
|
+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
578
|
+
return idx, buffer
|
|
579
|
+
|
|
580
|
+
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
|
|
581
|
+
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
582
|
+
|
|
583
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
584
|
+
futures = [
|
|
585
|
+
executor.submit(
|
|
586
|
+
register_pin_memory,
|
|
587
|
+
idx,
|
|
588
|
+
bucket.size,
|
|
589
|
+
shared_pin_memory,
|
|
590
|
+
)
|
|
591
|
+
for idx, bucket in enumerate(buckets)
|
|
592
|
+
]
|
|
593
|
+
new_futures = []
|
|
594
|
+
for future in concurrent.futures.as_completed(futures):
|
|
595
|
+
idx, buffer = future.result()
|
|
596
|
+
assert buffer.numel() == buckets[idx].size, (
|
|
597
|
+
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
|
|
598
|
+
)
|
|
599
|
+
memory_buffers[idx].buffer = buffer
|
|
600
|
+
logger.info(
|
|
601
|
+
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
|
|
602
|
+
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
|
|
603
|
+
)
|
|
604
|
+
offset = 0
|
|
605
|
+
for meta in buckets[idx].metas:
|
|
606
|
+
name = meta.name
|
|
607
|
+
tensor = parameters[name]
|
|
608
|
+
size = _align_size(tensor.dtype, tensor.shape)
|
|
609
|
+
assert size == _align_size(meta.dtype, meta.shape), (
|
|
610
|
+
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
|
|
611
|
+
)
|
|
612
|
+
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
|
|
613
|
+
offset += size
|
|
614
|
+
for future in concurrent.futures.as_completed(new_futures):
|
|
615
|
+
future.result()
|
|
616
|
+
return memory_buffers
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _register_checkpoint(
|
|
620
|
+
*,
|
|
621
|
+
files: list[str],
|
|
622
|
+
named_tensors: dict[str, torch.Tensor],
|
|
623
|
+
rank: int | None = None,
|
|
624
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
625
|
+
) -> list[MemoryBuffer]:
|
|
626
|
+
logger.info(
|
|
627
|
+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
628
|
+
)
|
|
629
|
+
if not files and not named_tensors:
|
|
630
|
+
return []
|
|
631
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
632
|
+
files_to_inplace_pin = [
|
|
633
|
+
file
|
|
634
|
+
for file in files
|
|
635
|
+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
636
|
+
]
|
|
637
|
+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
638
|
+
if files_to_normal_pin or named_tensors:
|
|
639
|
+
memory_buffers.extend(
|
|
640
|
+
_normal_pin_memory(
|
|
641
|
+
files=files_to_normal_pin,
|
|
642
|
+
named_tensors=named_tensors,
|
|
643
|
+
rank=rank,
|
|
644
|
+
shared_pin_memory=shared_pin_memory,
|
|
645
|
+
)
|
|
646
|
+
)
|
|
647
|
+
if files_to_inplace_pin:
|
|
648
|
+
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
|
|
649
|
+
return memory_buffers
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def request_inference_to_update(
|
|
653
|
+
url: str,
|
|
654
|
+
socket_paths: dict[str, str],
|
|
655
|
+
timeout: float = 300.0,
|
|
656
|
+
uds: str | None = None,
|
|
657
|
+
):
|
|
658
|
+
"""Send an inference update request to inference server via HTTP or Unix socket.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
|
|
662
|
+
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
|
|
663
|
+
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
|
|
664
|
+
uds (str, optional): Path to a Unix domain socket. If provided, the request
|
|
665
|
+
will be sent via the Unix socket instead of HTTP. Defaults to None.
|
|
666
|
+
|
|
667
|
+
Raises:
|
|
668
|
+
httpx.HTTPStatusError: If the response contains an HTTP error status.
|
|
669
|
+
httpx.RequestError: If there was an issue while making the request.
|
|
670
|
+
"""
|
|
671
|
+
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
|
|
672
|
+
url,
|
|
673
|
+
json={
|
|
674
|
+
"method": "update_weights_from_ipc",
|
|
675
|
+
"args": [socket_paths],
|
|
676
|
+
"timeout": timeout,
|
|
677
|
+
},
|
|
678
|
+
timeout=timeout,
|
|
679
|
+
)
|
|
680
|
+
resp.raise_for_status()
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def _gen_h2d_buckets(
|
|
684
|
+
global_metas: dict[int, MemoryBufferMetaList],
|
|
685
|
+
bucket_size: int,
|
|
686
|
+
local_topo: dict[str, set[int]],
|
|
687
|
+
remote_topo: dict[str, set[int]],
|
|
688
|
+
ranks: list[int] | None = None,
|
|
689
|
+
) -> list[tuple[int, int, H2DBucket]]:
|
|
690
|
+
buckets: list[tuple[int, H2DBucket]] = []
|
|
691
|
+
|
|
692
|
+
for owner_rank, items in global_metas.items():
|
|
693
|
+
buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
|
|
694
|
+
for idx, metas in enumerate(items.memory_buffer_metas_list):
|
|
695
|
+
start_offset, offset = 0, 0
|
|
696
|
+
for meta in metas.metas:
|
|
697
|
+
s = meta.aligned_size
|
|
698
|
+
if buckets[-1][1].size + s > bucket_size:
|
|
699
|
+
if offset - start_offset > 0:
|
|
700
|
+
buckets[-1][1].ranges.append(
|
|
701
|
+
BucketRange(idx, start_offset, offset - start_offset)
|
|
702
|
+
)
|
|
703
|
+
start_offset = offset
|
|
704
|
+
buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
|
|
705
|
+
offset += s
|
|
706
|
+
buckets[-1][1].size += s
|
|
707
|
+
buckets[-1][1].items.append(meta)
|
|
708
|
+
buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
|
|
709
|
+
assert buckets[-1][1].size > 0, (
|
|
710
|
+
f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
|
|
711
|
+
)
|
|
712
|
+
ranks_set = set(ranks) if ranks else set()
|
|
713
|
+
actual_local_topo = (
|
|
714
|
+
{k: v & ranks_set for k, v in local_topo.items() if v & ranks_set} if ranks else local_topo
|
|
715
|
+
)
|
|
716
|
+
# if ranks is empty, assign the owner_rank as receiver_rank, this is used for colocate architecture
|
|
717
|
+
if not ranks:
|
|
718
|
+
return [(owner_rank, owner_rank, bucket) for owner_rank, bucket in buckets]
|
|
719
|
+
else:
|
|
720
|
+
return _assign_receiver_ranks(buckets, actual_local_topo, remote_topo)
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def _assign_receiver_ranks(
|
|
724
|
+
buckets: list[tuple[int, "T"]],
|
|
725
|
+
local_topo: dict[str, set[int]],
|
|
726
|
+
remote_topo: dict[str, set[int]],
|
|
727
|
+
) -> list[tuple[int, int, "T"]]:
|
|
728
|
+
"""
|
|
729
|
+
(owner_rank, bucket) -> (receiver_rank, owner_rank, bucket)
|
|
730
|
+
|
|
731
|
+
Assign receiver ranks to buckets. If ranks is empty, assign the owner_rank as receiver_rank.
|
|
732
|
+
GPU-rdma_device topology will be considered to make full use of the bandwidth.
|
|
733
|
+
"""
|
|
734
|
+
if not buckets:
|
|
735
|
+
logger.warning("bucket list is empty, no need to assign receiver ranks")
|
|
736
|
+
return []
|
|
737
|
+
rank_to_rdma_device = {
|
|
738
|
+
rank: rdma_device for rdma_device, ranks in remote_topo.items() for rank in ranks
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
# group buckets by owner RDMA devices
|
|
742
|
+
buckets_by_rdma_device = defaultdict(list)
|
|
743
|
+
for owner_rank, bucket in buckets:
|
|
744
|
+
owner_rdma_device = rank_to_rdma_device[owner_rank]
|
|
745
|
+
buckets_by_rdma_device[owner_rdma_device].append((owner_rank, bucket))
|
|
746
|
+
|
|
747
|
+
buckets_matrix = list(buckets_by_rdma_device.values())
|
|
748
|
+
assert buckets_matrix, "buckets_matrix should not be empty"
|
|
749
|
+
|
|
750
|
+
# Select receiver ranks. We use the minimum rank in each local RDMA device group as receiver rank
|
|
751
|
+
num_receivers = min(len(local_topo), len(buckets_by_rdma_device))
|
|
752
|
+
receiver_list = [min(ranks) for ranks in list(local_topo.values())[:num_receivers]]
|
|
753
|
+
|
|
754
|
+
flattened_buckets = [
|
|
755
|
+
buckets_matrix[row][col]
|
|
756
|
+
for col in range(
|
|
757
|
+
max(len(matrix_row) for matrix_row in buckets_matrix) if buckets_matrix else 0
|
|
758
|
+
)
|
|
759
|
+
for row in range(len(buckets_matrix))
|
|
760
|
+
if col < len(buckets_matrix[row])
|
|
761
|
+
]
|
|
762
|
+
|
|
763
|
+
buckets_with_receiver = []
|
|
764
|
+
assigned_cnt = 0
|
|
765
|
+
while assigned_cnt < len(flattened_buckets):
|
|
766
|
+
occupied_devices = set()
|
|
767
|
+
for receiver_rank in receiver_list:
|
|
768
|
+
if assigned_cnt >= len(flattened_buckets):
|
|
769
|
+
break
|
|
770
|
+
owner_rank, bucket = flattened_buckets[assigned_cnt]
|
|
771
|
+
rdma_device = rank_to_rdma_device[owner_rank]
|
|
772
|
+
if rdma_device in occupied_devices:
|
|
773
|
+
break
|
|
774
|
+
buckets_with_receiver.append((receiver_rank, owner_rank, bucket))
|
|
775
|
+
occupied_devices.add(rdma_device)
|
|
776
|
+
assigned_cnt += 1
|
|
777
|
+
|
|
778
|
+
return buckets_with_receiver
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def _get_master_port(master_port: int | None = None) -> int:
|
|
782
|
+
if master_port is None:
|
|
783
|
+
# HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
|
|
784
|
+
# TODO: check whether master_port is available or use a more elegant way
|
|
785
|
+
master_port = int(os.getenv("MASTER_PORT")) + 1
|
|
786
|
+
return master_port
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
class P2PStore:
|
|
790
|
+
def __init__(self, device_manager: DeviceManager):
|
|
791
|
+
from mooncake.engine import TransferEngine
|
|
792
|
+
|
|
793
|
+
self.rank = int(os.getenv("RANK"))
|
|
794
|
+
gpu_count = device_manager.device_module.device_count()
|
|
795
|
+
local_rank = self.rank % gpu_count
|
|
796
|
+
device_type = device_manager.device_type
|
|
797
|
+
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
|
|
798
|
+
self.device = ""
|
|
799
|
+
else:
|
|
800
|
+
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
801
|
+
self.ip = get_ip()
|
|
802
|
+
|
|
803
|
+
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
804
|
+
retry_count = 8
|
|
805
|
+
for i in range(retry_count):
|
|
806
|
+
self.engine = TransferEngine()
|
|
807
|
+
ret = self.engine.initialize(
|
|
808
|
+
self.ip,
|
|
809
|
+
"P2PHANDSHAKE",
|
|
810
|
+
"ascend_direct" if device_type == "npu" else "rdma",
|
|
811
|
+
self.device,
|
|
812
|
+
)
|
|
813
|
+
if ret == 0:
|
|
814
|
+
break
|
|
815
|
+
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
816
|
+
sleep_ms = random.randint(500, 2000)
|
|
817
|
+
logger.warning(
|
|
818
|
+
f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
|
|
819
|
+
)
|
|
820
|
+
time.sleep(sleep_ms / 1000)
|
|
821
|
+
else:
|
|
822
|
+
raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
|
|
823
|
+
self.port = self.engine.get_rpc_port()
|
|
824
|
+
self.named_tensors: dict[str, torch.Tensor] = {}
|
|
825
|
+
logger.info(
|
|
826
|
+
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
@property
|
|
830
|
+
def addr(self) -> str:
|
|
831
|
+
return f"{self.ip}:{self.port}"
|
|
832
|
+
|
|
833
|
+
def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
|
|
834
|
+
buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
|
|
835
|
+
capacities = [tensor.nbytes for tensor in named_tensors.values()]
|
|
836
|
+
self.named_tensors.update(named_tensors)
|
|
837
|
+
for i, name in enumerate(named_tensors.keys()):
|
|
838
|
+
logger.info(
|
|
839
|
+
f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
|
|
840
|
+
)
|
|
841
|
+
assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
|
|
842
|
+
|
|
843
|
+
def unregister_named_tensors(self, names: list[str]) -> int:
|
|
844
|
+
buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
|
|
845
|
+
assert self.engine.batch_unregister_memory(buffer_addresses) == 0
|
|
846
|
+
num_unregistered = 0
|
|
847
|
+
for i, name in enumerate(names):
|
|
848
|
+
del self.named_tensors[name]
|
|
849
|
+
logger.info(
|
|
850
|
+
f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
|
|
851
|
+
)
|
|
852
|
+
num_unregistered += 1
|
|
853
|
+
return num_unregistered
|
|
854
|
+
|
|
855
|
+
def batch_transfer_sync_read(
|
|
856
|
+
self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
|
|
857
|
+
):
|
|
858
|
+
assert (
|
|
859
|
+
self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
class ParameterServer:
|
|
864
|
+
shared_memory_pool_name = "__shared_memory_pool__"
|
|
865
|
+
|
|
866
|
+
def __init__(
|
|
867
|
+
self,
|
|
868
|
+
*,
|
|
869
|
+
rank: int | None = None,
|
|
870
|
+
world_size: int | None = None,
|
|
871
|
+
auto_pg: bool = False,
|
|
872
|
+
gpu_count: int | None = None,
|
|
873
|
+
mem_fraction: float | None = None,
|
|
874
|
+
):
|
|
875
|
+
"""
|
|
876
|
+
Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
|
|
877
|
+
|
|
878
|
+
Args:
|
|
879
|
+
auto_pg: Whether to automatically initialize the process group.
|
|
880
|
+
Notice that if auto_pg is True, will destroy the process group after update.
|
|
881
|
+
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
882
|
+
"""
|
|
883
|
+
self._rank = rank or int(os.environ.get("RANK", None))
|
|
884
|
+
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
|
|
885
|
+
self.device_manager = DeviceManager()
|
|
886
|
+
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
|
|
887
|
+
self._local_rank = self._rank % self._gpu_count
|
|
888
|
+
self._auto_pg = auto_pg
|
|
889
|
+
self._all_hosts = []
|
|
890
|
+
self._global_device_uuids: list[str] = []
|
|
891
|
+
self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
892
|
+
self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
|
|
893
|
+
self._mem_fraction = mem_fraction or 0.9
|
|
894
|
+
|
|
895
|
+
assert self._rank is not None and self._rank >= 0, self._rank
|
|
896
|
+
assert self._world_size and self._world_size > 0, self._world_size
|
|
897
|
+
assert (
|
|
898
|
+
self._gpu_count is not None
|
|
899
|
+
and self._gpu_count > 0
|
|
900
|
+
and self._gpu_count <= self.device_manager.device_module.device_count()
|
|
901
|
+
), self._gpu_count
|
|
902
|
+
assert (
|
|
903
|
+
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
|
|
904
|
+
), self._mem_fraction
|
|
905
|
+
|
|
906
|
+
self._zmq_ctx = zmq.Context()
|
|
907
|
+
self._zmq_addr_counter = 0
|
|
908
|
+
|
|
909
|
+
# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
|
|
910
|
+
self._current_shared_memory_pool_user: str = ""
|
|
911
|
+
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
912
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
913
|
+
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
914
|
+
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
|
|
915
|
+
# NPU transfer engine initialization requires prior set_device.
|
|
916
|
+
device_index = self._local_rank
|
|
917
|
+
self.device_manager.device_module.set_device(device_index)
|
|
918
|
+
try:
|
|
919
|
+
self._p2p_store = P2PStore(self.device_manager)
|
|
920
|
+
except ImportError as e:
|
|
921
|
+
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
|
|
922
|
+
self._p2p_store = None
|
|
923
|
+
|
|
924
|
+
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
|
|
925
|
+
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
926
|
+
|
|
927
|
+
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
|
|
928
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
929
|
+
assert self._memory_pool[self.shared_memory_pool_name], (
|
|
930
|
+
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
|
|
931
|
+
)
|
|
932
|
+
return self._memory_pool[self.shared_memory_pool_name]
|
|
933
|
+
elif checkpoint_name in self._memory_pool:
|
|
934
|
+
return self._memory_pool[checkpoint_name]
|
|
935
|
+
else:
|
|
936
|
+
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
|
|
937
|
+
|
|
938
|
+
def _logger_rank0(self, msg: str):
|
|
939
|
+
if self._local_rank == 0:
|
|
940
|
+
logger.info(msg)
|
|
941
|
+
|
|
942
|
+
def get_metas(self) -> dict[int, MemoryBufferMetaList]:
|
|
943
|
+
return self._current_global_parameter_metas
|
|
944
|
+
|
|
945
|
+
def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
|
|
946
|
+
self._current_global_parameter_metas = metas
|
|
947
|
+
self._remote_rdma_devices = defaultdict(set)
|
|
948
|
+
for i, meta in self._current_global_parameter_metas.items():
|
|
949
|
+
assert meta.rdma_device is not None, "meta.rdma_device should not be None"
|
|
950
|
+
assert meta.p2p_store_addr is not None, "meta.p2p_store_addr should not be None"
|
|
951
|
+
self._remote_rdma_devices[
|
|
952
|
+
meta.rdma_device + "@" + meta.p2p_store_addr.split(":")[0]
|
|
953
|
+
].add(i)
|
|
954
|
+
|
|
955
|
+
def register_checkpoint(
|
|
956
|
+
self,
|
|
957
|
+
checkpoint_name: str,
|
|
958
|
+
*,
|
|
959
|
+
files: list[str] | None = None,
|
|
960
|
+
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
961
|
+
use_shared_memory_pool: bool = False,
|
|
962
|
+
) -> None:
|
|
963
|
+
"""
|
|
964
|
+
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
965
|
+
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
966
|
+
Please make sure to copy the files to disks if you need to keep them.
|
|
967
|
+
|
|
968
|
+
Args:
|
|
969
|
+
checkpoint_name: The name of the checkpoint.
|
|
970
|
+
files: The safetensors files to register.
|
|
971
|
+
named_tensors: The named tensors to register.
|
|
972
|
+
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
|
|
973
|
+
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
|
|
974
|
+
cannot accommodate checkpoints with different memory requirements.
|
|
975
|
+
To free the actual memory of the shared pool or to modify its shape,
|
|
976
|
+
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
977
|
+
"""
|
|
978
|
+
try:
|
|
979
|
+
if use_shared_memory_pool:
|
|
980
|
+
logger.info(
|
|
981
|
+
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
|
|
982
|
+
)
|
|
983
|
+
assert self._current_shared_memory_pool_user == "", (
|
|
984
|
+
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
|
|
985
|
+
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
|
|
986
|
+
f"This registration may cause unexpected conflicts."
|
|
987
|
+
)
|
|
988
|
+
# Since we set the uninitialized shared memory pool to empty list,
|
|
989
|
+
# we can check whether this is the first time to use shared memory pool
|
|
990
|
+
_is_first_time = not self._memory_pool[self.shared_memory_pool_name]
|
|
991
|
+
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
|
|
992
|
+
files=files or [],
|
|
993
|
+
named_tensors=named_tensors or {},
|
|
994
|
+
rank=self._rank,
|
|
995
|
+
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
|
|
996
|
+
)
|
|
997
|
+
self._current_shared_memory_pool_user = checkpoint_name
|
|
998
|
+
if self._p2p_store is not None and _is_first_time:
|
|
999
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
1000
|
+
else:
|
|
1001
|
+
assert checkpoint_name not in self._memory_pool, (
|
|
1002
|
+
f"checkpoint {checkpoint_name} already registered"
|
|
1003
|
+
)
|
|
1004
|
+
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
1005
|
+
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
|
|
1006
|
+
)
|
|
1007
|
+
if self._p2p_store is not None:
|
|
1008
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
1009
|
+
except Exception:
|
|
1010
|
+
logger.exception(
|
|
1011
|
+
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
|
|
1012
|
+
)
|
|
1013
|
+
if self._p2p_store is not None and not use_shared_memory_pool:
|
|
1014
|
+
self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
1015
|
+
self.unregister_checkpoint(checkpoint_name)
|
|
1016
|
+
raise
|
|
1017
|
+
|
|
1018
|
+
def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
|
|
1019
|
+
"""
|
|
1020
|
+
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
|
|
1021
|
+
from p2p store if p2p store is initialized.
|
|
1022
|
+
Args:
|
|
1023
|
+
checkpoint_name: The name of the checkpoint.
|
|
1024
|
+
force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
|
|
1025
|
+
If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
|
|
1026
|
+
"""
|
|
1027
|
+
if (
|
|
1028
|
+
checkpoint_name not in self._memory_pool
|
|
1029
|
+
and checkpoint_name != self._current_shared_memory_pool_user
|
|
1030
|
+
):
|
|
1031
|
+
logger.warning(
|
|
1032
|
+
f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
|
|
1033
|
+
)
|
|
1034
|
+
return
|
|
1035
|
+
|
|
1036
|
+
if checkpoint_name == self._current_shared_memory_pool_user and not force:
|
|
1037
|
+
self._current_shared_memory_pool_user = ""
|
|
1038
|
+
return
|
|
1039
|
+
|
|
1040
|
+
if self._p2p_store is not None:
|
|
1041
|
+
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
1042
|
+
logger.info(
|
|
1043
|
+
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
1047
|
+
self._current_shared_memory_pool_user = ""
|
|
1048
|
+
del self._memory_pool[self.shared_memory_pool_name]
|
|
1049
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
1050
|
+
else:
|
|
1051
|
+
del self._memory_pool[checkpoint_name]
|
|
1052
|
+
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
1053
|
+
# this works by using torch>=2.5.0
|
|
1054
|
+
torch._C._host_emptyCache()
|
|
1055
|
+
|
|
1056
|
+
def gather_metas(self, checkpoint_name: str):
|
|
1057
|
+
"""
|
|
1058
|
+
Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
|
|
1059
|
+
This function should be called before update and init a new value to `self._current_global_parameter_metas`,
|
|
1060
|
+
which can be exported by using `self.get_metas` function.
|
|
1061
|
+
"""
|
|
1062
|
+
if self._auto_pg and not dist.is_initialized():
|
|
1063
|
+
self.init_process_group()
|
|
1064
|
+
assert dist.is_initialized(), "process group is not initialized"
|
|
1065
|
+
metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
|
|
1066
|
+
try:
|
|
1067
|
+
memory_pool = self._get_memory_pool(checkpoint_name)
|
|
1068
|
+
except RuntimeError:
|
|
1069
|
+
memory_pool = []
|
|
1070
|
+
metas = DataToGather(
|
|
1071
|
+
memory_buffer_metas_list=[
|
|
1072
|
+
MemoryBufferMetas(
|
|
1073
|
+
metas=x.metas,
|
|
1074
|
+
ptr=x.buffer.data_ptr(),
|
|
1075
|
+
size=x.size,
|
|
1076
|
+
)
|
|
1077
|
+
for x in memory_pool
|
|
1078
|
+
],
|
|
1079
|
+
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
1080
|
+
host_ip=get_ip(),
|
|
1081
|
+
device_uuid=self._device_uuid,
|
|
1082
|
+
rdma_device=self._rdma_device or "",
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
dist.all_gather_object(metas_lst, metas)
|
|
1086
|
+
|
|
1087
|
+
self._current_global_parameter_metas = {}
|
|
1088
|
+
|
|
1089
|
+
num_parameters = 0
|
|
1090
|
+
all_hosts: list[str] = []
|
|
1091
|
+
global_device_uuids: list[str] = []
|
|
1092
|
+
for i, metas_buckets in enumerate(metas_lst):
|
|
1093
|
+
assert metas_buckets is not None, f"metas_buckets {i} should not be None"
|
|
1094
|
+
if i % self._gpu_count == 0 and not self._all_hosts:
|
|
1095
|
+
all_hosts.append(metas_buckets.host_ip)
|
|
1096
|
+
if not self._global_device_uuids:
|
|
1097
|
+
global_device_uuids.append(metas_buckets.device_uuid)
|
|
1098
|
+
if metas_buckets.memory_buffer_metas_list:
|
|
1099
|
+
self._current_global_parameter_metas[i] = MemoryBufferMetaList(
|
|
1100
|
+
memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
|
|
1101
|
+
p2p_store_addr=metas_buckets.p2p_store_addr,
|
|
1102
|
+
rdma_device=metas_buckets.rdma_device,
|
|
1103
|
+
)
|
|
1104
|
+
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
|
|
1105
|
+
self._local_rdma_devices[
|
|
1106
|
+
metas_buckets.rdma_device + "@" + metas_buckets.p2p_store_addr.split(":")[0]
|
|
1107
|
+
if metas_buckets.p2p_store_addr
|
|
1108
|
+
else metas_buckets.host_ip
|
|
1109
|
+
].add(i)
|
|
1110
|
+
if not self._all_hosts:
|
|
1111
|
+
self._all_hosts = all_hosts
|
|
1112
|
+
if not self._global_device_uuids:
|
|
1113
|
+
self._global_device_uuids = global_device_uuids
|
|
1114
|
+
# Sender node and Receiver node have the same GPU-rdma_device topology is considered as default.
|
|
1115
|
+
# Rewrite the sender's topology (_remote_rdma_devices) by calling load_metas.
|
|
1116
|
+
self._remote_rdma_devices = self._local_rdma_devices.copy()
|
|
1117
|
+
logger.info(
|
|
1118
|
+
f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
def init_process_group(
|
|
1122
|
+
self,
|
|
1123
|
+
*,
|
|
1124
|
+
master_addr: str | None = None,
|
|
1125
|
+
master_port: int | None = None,
|
|
1126
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
1127
|
+
):
|
|
1128
|
+
"""
|
|
1129
|
+
Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
|
|
1130
|
+
|
|
1131
|
+
Args:
|
|
1132
|
+
master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
|
|
1133
|
+
timeout: The timeout of the process group.
|
|
1134
|
+
"""
|
|
1135
|
+
master_addr = master_addr or os.getenv("MASTER_ADDR")
|
|
1136
|
+
assert master_addr, "master_addr is required"
|
|
1137
|
+
store = dist.TCPStore(
|
|
1138
|
+
master_addr,
|
|
1139
|
+
_get_master_port(master_port),
|
|
1140
|
+
self._world_size,
|
|
1141
|
+
timeout=timeout,
|
|
1142
|
+
is_master=self._rank == 0,
|
|
1143
|
+
)
|
|
1144
|
+
dist.init_process_group(
|
|
1145
|
+
backend=self.device_manager.backend,
|
|
1146
|
+
world_size=self._world_size,
|
|
1147
|
+
rank=self._rank,
|
|
1148
|
+
timeout=timeout,
|
|
1149
|
+
store=store,
|
|
1150
|
+
)
|
|
1151
|
+
logger.info(f"[rank{self._rank}] init process group successfully.")
|
|
1152
|
+
|
|
1153
|
+
def store_based_barrier(
|
|
1154
|
+
self, store: dist.TCPStore, timeout: timedelta = timedelta(minutes=5)
|
|
1155
|
+
) -> None:
|
|
1156
|
+
"""
|
|
1157
|
+
Perform a store-based barrier synchronization across all ranks.
|
|
1158
|
+
|
|
1159
|
+
This barrier uses a TCP store directly rather than a process group,
|
|
1160
|
+
allowing all ranks to synchronize regardless of which process group
|
|
1161
|
+
they belong to.
|
|
1162
|
+
|
|
1163
|
+
Args:
|
|
1164
|
+
store: The TCPStore instance to use for synchronization.
|
|
1165
|
+
"""
|
|
1166
|
+
dist.distributed_c10d._store_based_barrier(
|
|
1167
|
+
rank=self._rank,
|
|
1168
|
+
store=store,
|
|
1169
|
+
group_name="parameter_server_barrier",
|
|
1170
|
+
rendezvous_count=self._world_size,
|
|
1171
|
+
timeout=timeout,
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
def update(
|
|
1175
|
+
self,
|
|
1176
|
+
checkpoint_name: str,
|
|
1177
|
+
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1178
|
+
*,
|
|
1179
|
+
timeout: timedelta = timedelta(minutes=10),
|
|
1180
|
+
ranks: list[int] | None = None,
|
|
1181
|
+
master_addr: str | None = None,
|
|
1182
|
+
master_port: int | None = None,
|
|
1183
|
+
) -> None:
|
|
1184
|
+
"""
|
|
1185
|
+
Update the checkpoint to inference engine. This function should be called after gather_metas.
|
|
1186
|
+
|
|
1187
|
+
Args:
|
|
1188
|
+
checkpoint_name: The name of the checkpoint.
|
|
1189
|
+
req_func: The function to request the inference of inference engine.
|
|
1190
|
+
ranks: The ranks to update. If not set, will use fully broadcast to update to all ranks,
|
|
1191
|
+
which is the fastest way to update weights, especially in colocated architecture.
|
|
1192
|
+
If set, will use p2p to update to the ranks, this is flexible to update to a group of ranks,
|
|
1193
|
+
which is useful in disaggregated architecture.
|
|
1194
|
+
master_addr: The master address for process group initialization. If not set, will use env MASTER_ADDR.
|
|
1195
|
+
master_port: The master port for process group initialization. If not set, will use _get_master_port to get the port, which will use MASTER_PORT+1.
|
|
1196
|
+
timeout: The timeout of the barrier operation.
|
|
1197
|
+
"""
|
|
1198
|
+
assert req_func is not None, "req_func is required"
|
|
1199
|
+
ranks_group = None
|
|
1200
|
+
try:
|
|
1201
|
+
master_addr = os.getenv("MASTER_ADDR") or master_addr
|
|
1202
|
+
assert master_addr, "master_addr is required"
|
|
1203
|
+
if self._auto_pg:
|
|
1204
|
+
if not dist.is_initialized():
|
|
1205
|
+
self.init_process_group(
|
|
1206
|
+
timeout=timeout, master_addr=master_addr, master_port=master_port
|
|
1207
|
+
)
|
|
1208
|
+
manager_store = dist.distributed_c10d._get_default_store()
|
|
1209
|
+
else:
|
|
1210
|
+
# HACK: MASTER_PORT+2 for barrier store if master_port is not provided, _get_master_port() returns MASTER_PORT+1
|
|
1211
|
+
# If master_port is provided, use master_port+1 for barrier store
|
|
1212
|
+
manager_store = dist.TCPStore(
|
|
1213
|
+
master_addr,
|
|
1214
|
+
_get_master_port(master_port) + 1,
|
|
1215
|
+
self._world_size,
|
|
1216
|
+
timeout=timeout,
|
|
1217
|
+
is_master=self._rank == 0,
|
|
1218
|
+
)
|
|
1219
|
+
# if ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1220
|
+
ranks_group = dist.new_group(ranks if ranks else None)
|
|
1221
|
+
self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
|
|
1222
|
+
self.store_based_barrier(manager_store)
|
|
1223
|
+
except Exception as e:
|
|
1224
|
+
logger.exception(
|
|
1225
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
1226
|
+
)
|
|
1227
|
+
raise
|
|
1228
|
+
finally:
|
|
1229
|
+
if ranks_group:
|
|
1230
|
+
dist.destroy_process_group(ranks_group)
|
|
1231
|
+
if self._auto_pg and dist.is_initialized():
|
|
1232
|
+
dist.destroy_process_group()
|
|
1233
|
+
self.device_manager.device_module.empty_cache()
|
|
1234
|
+
logger.info(
|
|
1235
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
1236
|
+
f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
|
|
1237
|
+
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
|
|
1241
|
+
def zmq_handle(device_uuid: str) -> str:
|
|
1242
|
+
return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
|
|
1243
|
+
|
|
1244
|
+
socket_paths = [(uid, zmq_handle(uid)) for uid in self._global_device_uuids]
|
|
1245
|
+
socket = self._zmq_ctx.socket(zmq.REQ)
|
|
1246
|
+
socket.bind(zmq_handle(self._device_uuid))
|
|
1247
|
+
self._zmq_addr_counter += 1
|
|
1248
|
+
return socket, socket_paths
|
|
1249
|
+
|
|
1250
|
+
def _detect_bucket_size(
|
|
1251
|
+
self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
|
|
1252
|
+
) -> tuple[int, bool]:
|
|
1253
|
+
GiB = 1 << 30 # noqa: N806
|
|
1254
|
+
# auto detect bucket size
|
|
1255
|
+
tensor = torch.tensor(
|
|
1256
|
+
[
|
|
1257
|
+
# proportion of current device free memory bytes
|
|
1258
|
+
int(
|
|
1259
|
+
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
|
|
1260
|
+
),
|
|
1261
|
+
# we use negative value to reuse allreduce min operation
|
|
1262
|
+
# for getting the max value of zmq_addr_counter in all ranks
|
|
1263
|
+
-self._zmq_addr_counter,
|
|
1264
|
+
],
|
|
1265
|
+
dtype=torch.int64,
|
|
1266
|
+
device=self.device_manager.device_type,
|
|
1267
|
+
)
|
|
1268
|
+
dist.all_reduce(tensor, op=dist.ReduceOp.MIN, group=ranks_group)
|
|
1269
|
+
tensor = tensor.cpu()
|
|
1270
|
+
free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
|
|
1271
|
+
max_tensor_bytes = 0
|
|
1272
|
+
for items in self._current_global_parameter_metas.values():
|
|
1273
|
+
for metas_list in items.memory_buffer_metas_list:
|
|
1274
|
+
for meta in metas_list.metas:
|
|
1275
|
+
max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
|
|
1276
|
+
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
|
|
1277
|
+
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
|
|
1278
|
+
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
|
|
1279
|
+
# using h2d_buffer can make all ranks' h2d parallel execution
|
|
1280
|
+
# the cost is that we need to allocate extra h2d_buffer's GPU memory
|
|
1281
|
+
free_bytes = free_bytes_divided_3
|
|
1282
|
+
else:
|
|
1283
|
+
# if the memory is not enough, it will fallback to disable_h2d_buffer mode,
|
|
1284
|
+
# at this time, the bandwidth will be limited by the h2d of a single machine,
|
|
1285
|
+
# but we can save GPU memory
|
|
1286
|
+
self._logger_rank0(
|
|
1287
|
+
f"[rank{self._rank}] disable h2d buffer when max_tensor_bytes {max_tensor_bytes} is larger than free_bytes {free_bytes} // 3"
|
|
1288
|
+
)
|
|
1289
|
+
free_bytes = free_bytes // (2 * _ALIGN_SIZE) * _ALIGN_SIZE
|
|
1290
|
+
assert max_tensor_bytes <= free_bytes, (
|
|
1291
|
+
f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
|
|
1292
|
+
)
|
|
1293
|
+
disable_h2d_buffer = True
|
|
1294
|
+
max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
|
|
1295
|
+
bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
|
|
1296
|
+
logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
|
|
1297
|
+
return bucket_size, disable_h2d_buffer
|
|
1298
|
+
|
|
1299
|
+
def _copy_to_buffer(
|
|
1300
|
+
self,
|
|
1301
|
+
checkpoint_name: str,
|
|
1302
|
+
bucket: H2DBucket,
|
|
1303
|
+
buffer: torch.Tensor,
|
|
1304
|
+
owner_rank: int | None = None,
|
|
1305
|
+
):
|
|
1306
|
+
offset = 0
|
|
1307
|
+
if owner_rank is not None:
|
|
1308
|
+
buf_ptrs, remote_ptrs, lens = [], [], []
|
|
1309
|
+
ptr_base = buffer.data_ptr()
|
|
1310
|
+
target_addr, ptrs = self._get_addr_ptrs(owner_rank)
|
|
1311
|
+
for b in bucket.ranges:
|
|
1312
|
+
assert offset + b.size <= bucket.size, (
|
|
1313
|
+
f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
|
|
1314
|
+
)
|
|
1315
|
+
if owner_rank is not None:
|
|
1316
|
+
buf_ptrs.append(ptr_base + offset)
|
|
1317
|
+
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
|
|
1318
|
+
lens.append(b.size)
|
|
1319
|
+
else:
|
|
1320
|
+
pool = self._get_memory_pool(checkpoint_name)[b.idx]
|
|
1321
|
+
buffer[offset : offset + b.size].data.copy_(
|
|
1322
|
+
pool.buffer[b.offset : b.offset + b.size],
|
|
1323
|
+
non_blocking=True,
|
|
1324
|
+
)
|
|
1325
|
+
offset += b.size
|
|
1326
|
+
assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
|
|
1327
|
+
if owner_rank is not None:
|
|
1328
|
+
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
|
|
1329
|
+
self.device_manager.device_module.synchronize()
|
|
1330
|
+
|
|
1331
|
+
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
1332
|
+
addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr
|
|
1333
|
+
metas_list = self._current_global_parameter_metas[owner_rank].memory_buffer_metas_list
|
|
1334
|
+
return addr, [(metas.ptr, metas.size) for metas in metas_list]
|
|
1335
|
+
|
|
1336
|
+
def _register_parameters_to_p2p_store(self, checkpoint_name: str):
|
|
1337
|
+
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1338
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1339
|
+
if len(pool) == 0:
|
|
1340
|
+
return
|
|
1341
|
+
named_tensors, tensor_ptrs = {}, []
|
|
1342
|
+
register_name = (
|
|
1343
|
+
checkpoint_name
|
|
1344
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1345
|
+
else self.shared_memory_pool_name
|
|
1346
|
+
)
|
|
1347
|
+
for idx, memory_buffer in enumerate(pool):
|
|
1348
|
+
named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
|
|
1349
|
+
tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
|
|
1350
|
+
self._p2p_store.register_named_tensors(named_tensors)
|
|
1351
|
+
|
|
1352
|
+
def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
|
|
1353
|
+
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1354
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1355
|
+
if len(pool) == 0:
|
|
1356
|
+
return 0
|
|
1357
|
+
unregister_name = (
|
|
1358
|
+
checkpoint_name
|
|
1359
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1360
|
+
else self.shared_memory_pool_name
|
|
1361
|
+
)
|
|
1362
|
+
return self._p2p_store.unregister_named_tensors(
|
|
1363
|
+
[f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
|
|
1364
|
+
)
|
|
1365
|
+
|
|
1366
|
+
def _update_per_bucket(
|
|
1367
|
+
self,
|
|
1368
|
+
checkpoint_name: str,
|
|
1369
|
+
req_func: Callable[[list[tuple[str, str]]], None],
|
|
1370
|
+
ranks_group: dist.ProcessGroup,
|
|
1371
|
+
ranks: list[int] | None = None,
|
|
1372
|
+
):
|
|
1373
|
+
assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
|
|
1374
|
+
assert dist.is_initialized(), "process group is not initialized"
|
|
1375
|
+
|
|
1376
|
+
# if both ranks is None or [], it will use fully broadcast to update to all ranks
|
|
1377
|
+
if not ranks:
|
|
1378
|
+
logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
|
|
1379
|
+
# if ranks is set, it will use p2p to update to the ranks
|
|
1380
|
+
else:
|
|
1381
|
+
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1382
|
+
assert ranks, "ranks should be set"
|
|
1383
|
+
|
|
1384
|
+
need_update = self._rank in ranks
|
|
1385
|
+
logger.info(
|
|
1386
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} p2p, {need_update=} with {ranks=}, "
|
|
1387
|
+
f"gpu_count {self._gpu_count}, world_size {self._world_size}"
|
|
1388
|
+
)
|
|
1389
|
+
|
|
1390
|
+
if not need_update:
|
|
1391
|
+
return
|
|
1392
|
+
# first execute a barrier to avoid subsequent device oom
|
|
1393
|
+
dist.barrier(group=ranks_group)
|
|
1394
|
+
|
|
1395
|
+
bucket_size, disable_h2d_buffer = self._detect_bucket_size(ranks_group)
|
|
1396
|
+
buckets = _gen_h2d_buckets(
|
|
1397
|
+
self._current_global_parameter_metas,
|
|
1398
|
+
bucket_size,
|
|
1399
|
+
self._local_rdma_devices,
|
|
1400
|
+
self._remote_rdma_devices,
|
|
1401
|
+
ranks,
|
|
1402
|
+
)
|
|
1403
|
+
|
|
1404
|
+
h2d_buffer: torch.Tensor | None = (
|
|
1405
|
+
None
|
|
1406
|
+
if disable_h2d_buffer
|
|
1407
|
+
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
|
|
1408
|
+
)
|
|
1409
|
+
# p2p store need to register h2d_buffer to let other ranks read
|
|
1410
|
+
if ranks:
|
|
1411
|
+
h2d_buffer_name = "__h2d_buffer__"
|
|
1412
|
+
if h2d_buffer is not None and self._p2p_store is not None:
|
|
1413
|
+
self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
|
|
1414
|
+
receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
|
|
1415
|
+
for receiver_rank, owner_rank, bucket in buckets:
|
|
1416
|
+
if receiver_rank != self._rank:
|
|
1417
|
+
continue
|
|
1418
|
+
receiver_rank_buckets.append((owner_rank, bucket))
|
|
1419
|
+
|
|
1420
|
+
buffer = torch.empty(
|
|
1421
|
+
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
|
|
1422
|
+
)
|
|
1423
|
+
handle = reduce_tensor(buffer)
|
|
1424
|
+
|
|
1425
|
+
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
|
|
1426
|
+
max_len = 0
|
|
1427
|
+
for receiver_rank, _, bucket in buckets:
|
|
1428
|
+
buckets_by_receiver_rank[receiver_rank].append(bucket)
|
|
1429
|
+
if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
|
|
1430
|
+
max_len = len(buckets_by_receiver_rank[receiver_rank])
|
|
1431
|
+
|
|
1432
|
+
socket, socket_paths = self._bind_zmq_socket()
|
|
1433
|
+
req_thread = threading.Thread(
|
|
1434
|
+
target=req_func,
|
|
1435
|
+
args=(socket_paths,),
|
|
1436
|
+
)
|
|
1437
|
+
req_thread.start()
|
|
1438
|
+
socket.send_pyobj(handle)
|
|
1439
|
+
|
|
1440
|
+
gidx = 0
|
|
1441
|
+
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
|
|
1442
|
+
try:
|
|
1443
|
+
for i in range(max_len):
|
|
1444
|
+
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
1445
|
+
self._copy_to_buffer(
|
|
1446
|
+
checkpoint_name,
|
|
1447
|
+
receiver_rank_buckets[i][1],
|
|
1448
|
+
h2d_buffer,
|
|
1449
|
+
receiver_rank_buckets[i][0] if ranks else None,
|
|
1450
|
+
)
|
|
1451
|
+
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
|
|
1452
|
+
if i >= len(_buckets):
|
|
1453
|
+
continue
|
|
1454
|
+
bucket = _buckets[i]
|
|
1455
|
+
alloc, reserved = (
|
|
1456
|
+
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
|
|
1457
|
+
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
|
|
1458
|
+
)
|
|
1459
|
+
self._logger_rank0(
|
|
1460
|
+
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
1461
|
+
f"Current device allocated {alloc:.2f} MB, "
|
|
1462
|
+
f"reserved {reserved:.2f} MB."
|
|
1463
|
+
)
|
|
1464
|
+
start = gidx % 2 * bucket_size
|
|
1465
|
+
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
|
|
1466
|
+
if receiver_rank == self._rank:
|
|
1467
|
+
if disable_h2d_buffer:
|
|
1468
|
+
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1469
|
+
else:
|
|
1470
|
+
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1471
|
+
dist.broadcast(buffer_b, src=receiver_rank, group=ranks_group)
|
|
1472
|
+
resp = socket.recv()
|
|
1473
|
+
if resp != b"":
|
|
1474
|
+
msg = resp.decode("utf-8")
|
|
1475
|
+
logger.error(
|
|
1476
|
+
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
1477
|
+
)
|
|
1478
|
+
ret_code.fill_(1)
|
|
1479
|
+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM, group=ranks_group)
|
|
1480
|
+
self.device_manager.device_module.synchronize()
|
|
1481
|
+
if ret_code.item() != 0:
|
|
1482
|
+
# quit early if any rank failed
|
|
1483
|
+
socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
|
|
1484
|
+
raise RuntimeError("Failed to update weights due to remote errors")
|
|
1485
|
+
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
1486
|
+
gidx += 1
|
|
1487
|
+
|
|
1488
|
+
socket.recv()
|
|
1489
|
+
socket.send_pyobj(None)
|
|
1490
|
+
socket.recv()
|
|
1491
|
+
finally:
|
|
1492
|
+
req_thread.join()
|
|
1493
|
+
dist.barrier(group=ranks_group)
|
|
1494
|
+
socket.close()
|
|
1495
|
+
if ranks and h2d_buffer is not None:
|
|
1496
|
+
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
1497
|
+
|
|
1498
|
+
self.device_manager.device_module.empty_cache()
|
|
1499
|
+
|
|
1500
|
+
|
|
1501
|
+
def _init_api(ps: ParameterServer) -> Any:
|
|
1502
|
+
import fastapi
|
|
1503
|
+
from fastapi import Request
|
|
1504
|
+
from fastapi.responses import JSONResponse, Response
|
|
1505
|
+
|
|
1506
|
+
app = fastapi.FastAPI()
|
|
1507
|
+
|
|
1508
|
+
class RegisterRequest(BaseModel):
|
|
1509
|
+
files: list[str]
|
|
1510
|
+
|
|
1511
|
+
class UpdateRequest(BaseModel):
|
|
1512
|
+
ranks: list[int] = []
|
|
1513
|
+
update_url: str | None = None
|
|
1514
|
+
inference_group_ranks: list[int] = []
|
|
1515
|
+
timeout: float = 300.0
|
|
1516
|
+
uds: str | None = None
|
|
1517
|
+
|
|
1518
|
+
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
1519
|
+
try:
|
|
1520
|
+
func()
|
|
1521
|
+
except Exception as e: # noqa: BLE001
|
|
1522
|
+
logger.exception(f"wrap exception {func} failed")
|
|
1523
|
+
return JSONResponse(content=str(e), status_code=500)
|
|
1524
|
+
return Response(status_code=200)
|
|
1525
|
+
|
|
1526
|
+
@app.post("/v1/checkpoints/{checkpoint_name}/files")
|
|
1527
|
+
async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
|
|
1528
|
+
return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
|
|
1529
|
+
|
|
1530
|
+
@app.delete("/v1/checkpoints/{checkpoint_name}")
|
|
1531
|
+
async def unregister_checkpoint(checkpoint_name: str) -> Response:
|
|
1532
|
+
return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
|
|
1533
|
+
|
|
1534
|
+
@app.get("/v1/healthz")
|
|
1535
|
+
async def healthz() -> Response:
|
|
1536
|
+
return Response(status_code=200)
|
|
1537
|
+
|
|
1538
|
+
@app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
|
|
1539
|
+
async def gather_metas(checkpoint_name: str) -> Response:
|
|
1540
|
+
return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
|
|
1541
|
+
|
|
1542
|
+
@app.post("/v1/checkpoints/{checkpoint_name}/update")
|
|
1543
|
+
async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
|
|
1544
|
+
def update_func(socket_paths: list[tuple[str, str]]):
|
|
1545
|
+
if req.update_url is None:
|
|
1546
|
+
return
|
|
1547
|
+
if req.inference_group_ranks:
|
|
1548
|
+
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
|
|
1549
|
+
request_inference_to_update(
|
|
1550
|
+
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
|
|
1551
|
+
)
|
|
1552
|
+
|
|
1553
|
+
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
|
|
1554
|
+
|
|
1555
|
+
return app
|
|
1556
|
+
|
|
1557
|
+
|
|
1558
|
+
@logger.catch(reraise=True)
|
|
1559
|
+
def run_from_cli():
|
|
1560
|
+
import uvicorn
|
|
1561
|
+
|
|
1562
|
+
parser = argparse.ArgumentParser(description="Parameter Server")
|
|
1563
|
+
parser.add_argument("--uds", type=str)
|
|
1564
|
+
|
|
1565
|
+
args = parser.parse_args()
|
|
1566
|
+
logger.info(
|
|
1567
|
+
f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
|
|
1568
|
+
)
|
|
1569
|
+
|
|
1570
|
+
assert args.uds and len(args.uds) > 0, args.uds
|
|
1571
|
+
ps = ParameterServer(auto_pg=True)
|
|
1572
|
+
uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
|
|
1573
|
+
|
|
1574
|
+
|
|
1575
|
+
if __name__ == "__main__":
|
|
1576
|
+
run_from_cli()
|