checkpoint-engine 0.3.0rc0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpoint_engine/ps.py CHANGED
@@ -1,142 +1,33 @@
1
- import argparse
2
- import concurrent.futures
3
1
  import ctypes
4
- import json
5
2
  import os
6
- import pickle
7
- import random
8
3
  import threading
9
- import time
10
4
  from collections import defaultdict
11
5
  from collections.abc import Callable
12
6
  from datetime import timedelta
13
- from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
7
+ from typing import TYPE_CHECKING
14
8
 
15
- import httpx
16
- import numpy as np
17
9
  import torch
18
10
  import torch.distributed as dist
19
11
  import zmq
20
12
  from loguru import logger
21
- from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
22
- from safetensors.torch import _getdtype, safe_open
23
13
  from torch.multiprocessing.reductions import reduce_tensor
24
14
 
15
+ from checkpoint_engine.data_types import (
16
+ BucketRange,
17
+ DataToGather,
18
+ H2DBucket,
19
+ MemoryBuffer,
20
+ MemoryBufferMetaList,
21
+ MemoryBufferMetas,
22
+ ParameterMeta,
23
+ )
25
24
  from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
25
+ from checkpoint_engine.p2p_store import P2PStore
26
+ from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint
26
27
 
27
28
 
28
29
  if TYPE_CHECKING:
29
- from typing import TypeVar
30
-
31
- from typing_extensions import TypedDict
32
-
33
- class FileMeta(TypedDict):
34
- key: str # parameter name
35
- dtype: torch.dtype
36
- shape: torch.Size
37
- type: type
38
- tp_concat_dim: int
39
-
40
- T = TypeVar("T")
41
-
42
-
43
- def _dt_validate(value: Any) -> torch.dtype:
44
- if isinstance(value, str):
45
- if not value.startswith("torch."):
46
- raise ValueError(f"dtype {value} should start with torch.")
47
- try:
48
- value = getattr(torch, value.split(".")[1])
49
- except AttributeError as e:
50
- raise ValueError(f"unknown dtype: {value}") from e
51
- if not isinstance(value, torch.dtype):
52
- raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
53
- return value
54
-
55
-
56
- _TorchDtype = Annotated[
57
- torch.dtype,
58
- PlainValidator(_dt_validate),
59
- PlainSerializer(lambda x: str(x), return_type=str),
60
- WithJsonSchema({"type": "string"}, mode="serialization"),
61
- ]
62
-
63
-
64
- def _size_validate(value: Any) -> torch.Size:
65
- if isinstance(value, list | tuple):
66
- return torch.Size(value)
67
- if not isinstance(value, torch.Size):
68
- raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
69
- return value
70
-
71
-
72
- _TorchSize = Annotated[
73
- torch.Size,
74
- PlainValidator(_size_validate),
75
- PlainSerializer(lambda x: tuple(x), return_type=tuple),
76
- WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
77
- ]
78
-
79
-
80
- def _tensor_validate(value: Any) -> torch.Tensor:
81
- if isinstance(value, torch.Tensor):
82
- return value
83
- raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
84
-
85
-
86
- _TorchTensor = Annotated[
87
- torch.Tensor,
88
- PlainValidator(_tensor_validate),
89
- ]
90
-
91
-
92
- class ParameterMeta(BaseModel):
93
- name: str
94
- dtype: _TorchDtype
95
- shape: _TorchSize
96
- aligned_size: int
97
-
98
-
99
- class BucketRange(NamedTuple):
100
- idx: int # bucket_idx of MemoryBucket in memory_pool
101
- offset: int
102
- size: int
103
-
104
-
105
- class H2DBucket(BaseModel):
106
- size: int
107
- ranges: list[BucketRange]
108
- items: list[ParameterMeta]
109
-
110
-
111
- class MemoryBufferMetas(BaseModel):
112
- metas: list[ParameterMeta]
113
- ptr: int
114
- size: int
115
-
116
-
117
- class MemoryBuffer(BaseModel):
118
- buffer: _TorchTensor
119
- size: int
120
- metas: list[ParameterMeta]
121
-
122
-
123
- class MemoryBufferMetaList(BaseModel):
124
- p2p_store_addr: str | None
125
- memory_buffer_metas_list: list[MemoryBufferMetas]
126
- rdma_device: str
127
-
128
-
129
- class DataToGather(MemoryBufferMetaList):
130
- host_ip: str
131
- device_uuid: str
132
-
133
-
134
- # 256 bytes alignment when flatten torch tensors to uint8 buffer
135
- _ALIGN_SIZE = 256
136
-
137
-
138
- def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
139
- return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
30
+ from checkpoint_engine.data_types import T
140
31
 
141
32
 
142
33
  def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
@@ -155,107 +46,6 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
155
46
  return ret
156
47
 
157
48
 
158
- def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
159
- def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
160
- ret = {}
161
- with safe_open(fn, framework="pt") as f:
162
- for name in f.keys(): # noqa: SIM118
163
- weight = f.get_tensor(name)
164
- meta = {
165
- "key": name,
166
- "dtype": weight.dtype,
167
- "shape": weight.shape,
168
- "type": type(weight),
169
- "tp_concat_dim": -1, # safetensors does not support tp_concat_dim
170
- }
171
- ret[name] = (meta, weight)
172
- return ret
173
-
174
- # deprecated, will be removed in the future
175
- def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
176
- """load *.np file and return memmap and related tensor meta"""
177
-
178
- def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
179
- start = fin.tell()
180
- major, minor = np.lib.format.read_magic(fin)
181
- if major == 1 and minor == 0:
182
- read_header_fn = np.lib.format.read_array_header_1_0
183
- elif major == 2 and minor == 0:
184
- read_header_fn = np.lib.format.read_array_header_2_0
185
- else:
186
- raise ValueError(
187
- f"unknown version {major}.{minor} when parsing npy header from {fn}"
188
- )
189
- shape, is_fortran, dtype = read_header_fn(fin)
190
- return {
191
- "shape": shape,
192
- "is_fortran": is_fortran,
193
- "dtype": dtype,
194
- "header_length": fin.tell() - start,
195
- }
196
-
197
- meta_fn = fn + ".meta"
198
- with open(meta_fn, "rb") as fin:
199
- meta_lst = pickle.load(fin)
200
-
201
- tensors = []
202
- offset = 0
203
- with open(fn, "rb") as fin:
204
- fin.seek(0, os.SEEK_END)
205
- filesize = fin.tell()
206
- fin.seek(0)
207
- while fin.tell() < filesize:
208
- tensor_meta = parse_npy_header(fin)
209
- tensor = np.memmap(
210
- fn,
211
- dtype=tensor_meta["dtype"],
212
- mode="c",
213
- offset=offset + tensor_meta["header_length"],
214
- shape=tensor_meta["shape"],
215
- )
216
- offset += tensor_meta["header_length"] + tensor.nbytes
217
- fin.seek(offset)
218
- tensors.append(tensor)
219
-
220
- assert len(meta_lst) == len(tensors)
221
- ret = {}
222
- for meta, tensor in zip(meta_lst, tensors):
223
- if meta["type"] == torch.Tensor:
224
- tensor = torch.from_numpy(tensor)
225
- tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
226
- ret[meta["key"]] = (meta, tensor)
227
- return ret
228
-
229
- tp_rank = 0
230
- if file_path.endswith(".npy"):
231
- logger.warning("numpy model file is deprecated, will be removed in the future")
232
- filename_split = os.path.basename(file_path).split(".")
233
- # if using numpy and want to specify tp rank
234
- # file should be in model.{layer}.{tp}[.{ep}].npy format
235
- tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
236
- ret = _fast_np_load(file_path)
237
- elif file_path.endswith(".safetensors"):
238
- ret = _safetensors_load(file_path)
239
- else:
240
- raise ValueError(f"unsupported file format: {file_path}")
241
- return tp_rank, ret
242
-
243
-
244
- def _concat_tp_weights(
245
- tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
246
- ) -> torch.Tensor:
247
- """Concat tp weights with meta info.
248
- If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
249
- Else we will cat weights in concat_dim.
250
- """
251
- if tp_concat_dim == -1:
252
- return tp_weights[0]
253
- assert tp_size == len(tp_weights)
254
- if len(tp_weights) == 1:
255
- return tp_weights[0]
256
- return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
257
-
258
-
259
49
  def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
260
50
  try:
261
51
  if device_manager.device_type == "npu":
@@ -266,420 +56,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None
266
56
  raise ValueError(f"fail to get physical gpu id {device_index}") from e
267
57
 
268
58
 
269
- def _ibv_get_device_list() -> list[str]:
270
- lib = ctypes.CDLL("libibverbs.so.1")
271
- lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
272
- lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
273
-
274
- lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
275
- lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
276
- lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
277
-
278
- num = ctypes.c_int()
279
- dev_array = lib.ibv_get_device_list(ctypes.byref(num))
280
- if not dev_array or num.value <= 0:
281
- return []
282
-
283
- devices = []
284
- for i in range(num.value):
285
- dev_ptr = dev_array[i] # struct ibv_device *
286
- name = lib.ibv_get_device_name(dev_ptr) # const char *
287
- devices.append(name.decode())
288
- lib.ibv_free_device_list(dev_array)
289
- return devices
290
-
291
-
292
- def _get_rdma_devices() -> list[str]:
293
- """
294
- use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
295
- """
296
- devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
297
- if devices_str:
298
- return devices_str.split(",")
299
- # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
300
- hca = os.getenv("NCCL_IB_HCA", None)
301
- return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
302
-
303
-
304
- def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
305
- """
306
- implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
307
- """
308
- if not devices:
309
- raise RuntimeError("no rdma devices found")
310
- try:
311
- assert len(devices) <= gpu_count, (
312
- f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
313
- )
314
- assert gpu_count % len(devices) == 0, (
315
- f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
316
- )
317
- return devices[local_rank // (gpu_count // len(devices))]
318
- except AssertionError:
319
- logger.error(
320
- "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
321
- "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
322
- "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
323
- )
324
- raise
325
-
326
-
327
- def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
328
- """
329
- The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
330
- The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
331
-
332
- The list is comma-separated; port numbers are NOT supported yet.
333
- An optional prefix '^' indicates the list is an exclude list.
334
- A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
335
- Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
336
-
337
- Examples:
338
- - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
339
- - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
340
- - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
341
- - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
342
- """
343
- max_hcas = 32
344
- if not value or value.strip() == "":
345
- return available_devices[:max_hcas]
346
-
347
- value = value.strip()
348
- result = []
349
- is_exclude = value.startswith("^")
350
- if is_exclude:
351
- value = value.removeprefix("^")
352
- is_exact_match = value.startswith("=")
353
- if is_exact_match:
354
- value = value.removeprefix("=")
355
-
356
- device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
357
-
358
- result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
359
- if is_exclude:
360
- result = [dev for dev in available_devices if dev not in result]
361
- if len(result) > max_hcas:
362
- result = result[:max_hcas]
363
-
364
- logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
365
-
366
- return result
367
-
368
-
369
- def _resolve_device_specs(
370
- device_specs: list[str], is_exact_match: bool, available_devices: list[str]
371
- ) -> list[str]:
372
- devices = set()
373
- for spec in device_specs:
374
- parts = spec.split(":", 1)
375
- device_name = parts[0].strip()
376
- # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
377
- # port = parts[1].strip() if len(parts) > 1 else None
378
- base_devices = (
379
- [device_name]
380
- if device_name in available_devices
381
- else []
382
- if is_exact_match
383
- else [dev for dev in available_devices if dev.startswith(device_name)]
384
- )
385
-
386
- if not base_devices:
387
- logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
388
- continue
389
-
390
- for base_dev in base_devices:
391
- devices.add(base_dev)
392
-
393
- return sorted(devices)
394
-
395
-
396
- def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
397
- class TPMeta(BaseModel):
398
- concat_dim: int
399
- size: int
400
-
401
- parameters: dict[str, torch.Tensor] = {}
402
- parameter_metas: dict[str, ParameterMeta] = {}
403
- tp_metas: dict[str, TPMeta] = {}
404
- parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
405
- for file in files:
406
- tp_rank, ret = _load_checkpoint_file(file)
407
- for parameter_name, (meta, weight) in ret.items():
408
- if parameter_name not in parameters_with_tp:
409
- parameters_with_tp[parameter_name] = {}
410
- parameters_with_tp[parameter_name][tp_rank] = weight
411
- if parameter_name not in tp_metas:
412
- tp_metas[parameter_name] = TPMeta(
413
- concat_dim=meta["tp_concat_dim"],
414
- size=1,
415
- )
416
- if parameter_name not in parameter_metas:
417
- assert isinstance(meta["dtype"], torch.dtype), (
418
- f"meta {meta} dtype should be torch.dtype"
419
- )
420
- assert isinstance(meta["shape"], torch.Size), (
421
- f"meta {meta} shape should be torch.Size"
422
- )
423
- parameter_metas[parameter_name] = ParameterMeta(
424
- name=parameter_name,
425
- shape=meta["shape"],
426
- dtype=meta["dtype"],
427
- aligned_size=_align_size(meta["dtype"], meta["shape"]),
428
- )
429
- tp_meta = tp_metas[parameter_name]
430
- if tp_meta.concat_dim != -1:
431
- tp_meta.size = max(tp_meta.size, tp_rank + 1)
432
- for name, tp_meta in tp_metas.items():
433
- if tp_meta.concat_dim != -1:
434
- shape = list(parameter_metas[name].shape)
435
- shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
436
- parameter_metas[name] = ParameterMeta(
437
- name=name,
438
- shape=torch.Size(shape),
439
- dtype=parameter_metas[name].dtype,
440
- aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
441
- )
442
- weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
443
- # TODO: here concat is serial, which may be slow
444
- # but since tp storage is not used in the future
445
- # we ignore this performance issue for now
446
- parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
447
- for name, parameter in parameters.items():
448
- assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
449
- assert parameter_metas[name].shape == parameter.shape, (
450
- f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
451
- )
452
- assert parameter_metas[name].dtype == parameter.dtype, (
453
- f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
454
- )
455
- return parameters
456
-
457
-
458
- def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459
- def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460
- """
461
- safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462
- We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463
- The actual tensor data is in the remaining bytes and is naturally aligned.
464
- We pin the remaining bytes as the buffer, making pinning faster.
465
- """
466
-
467
- def _pin(t: torch.Tensor):
468
- """
469
- Pin the memory of tensor in-place.
470
- See: https://github.com/pytorch/pytorch/issues/32167
471
- """
472
- cudart = torch.cuda.cudart()
473
- r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
474
- assert r == 0, f"pin memory error, error code: {r}"
475
-
476
- # TODO: should only support /dev/shm? but we found files in disk also work?
477
- size = os.stat(file_path).st_size
478
- flag_size = 8
479
- t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480
- assert t.nbytes > flag_size, (
481
- f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482
- )
483
- start_pos = (
484
- int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485
- + flag_size
486
- )
487
- header_tensor = t[flag_size:start_pos]
488
- header = json.loads(header_tensor.numpy().tobytes())
489
- if "__metadata__" in header:
490
- header.pop("__metadata__")
491
-
492
- metas: list[ParameterMeta] = []
493
- offset = 0
494
- try:
495
- for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496
- start, end = meta["data_offsets"]
497
- # safetensors format ensures offsets are aligned
498
- assert offset == start, f"offset {offset} should be equal to start {start}"
499
- metas.append(
500
- ParameterMeta(
501
- name=name,
502
- dtype=_getdtype(meta["dtype"]),
503
- shape=torch.Size(meta["shape"]),
504
- aligned_size=end - start,
505
- )
506
- )
507
- offset = end
508
- except Exception as e:
509
- logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510
- raise
511
-
512
- buffer = t[start_pos:]
513
- assert offset == buffer.nbytes, (
514
- f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515
- )
516
- # Remove the file after successfully loading. This will avoid doubling the memory usage.
517
- # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518
- os.remove(file_path)
519
- _pin(buffer)
520
- logger.info(
521
- f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
- )
523
- return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
-
525
- memory_buffers: list[MemoryBuffer] = []
526
- with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527
- memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528
- return memory_buffers
529
-
530
-
531
- def _normal_pin_memory(
532
- files: list[str],
533
- named_tensors: dict[str, torch.Tensor],
534
- rank: int | None = None,
535
- shared_pin_memory: list[MemoryBuffer] | None = None,
536
- ) -> list[MemoryBuffer]:
537
- parameters = _load_checkpoint(files)
538
- if named_tensors:
539
- parameters.update(named_tensors)
540
- bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
541
-
542
- class MemoryBucket(BaseModel):
543
- size: int
544
- metas: list[ParameterMeta]
545
-
546
- buckets: list[MemoryBucket] = []
547
- buckets.append(MemoryBucket(size=0, metas=[]))
548
- for name, tensor in sorted(parameters.items()):
549
- size = _align_size(tensor.dtype, tensor.shape)
550
- if buckets[-1].size + size > bucket_size:
551
- assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
552
- buckets.append(MemoryBucket(size=0, metas=[]))
553
- buckets[-1].metas.append(
554
- ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
555
- )
556
- buckets[-1].size += size
557
-
558
- memory_buffers = [
559
- MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
560
- for bucket in buckets
561
- ]
562
-
563
- def register_pin_memory(
564
- idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565
- ) -> tuple[int, torch.Tensor]:
566
- if shared_pin_memory:
567
- # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568
- # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569
- assert idx < len(shared_pin_memory), (
570
- f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571
- )
572
- assert shared_pin_memory[idx].size == size, (
573
- f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574
- )
575
- return idx, shared_pin_memory[idx].buffer
576
- else:
577
- buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578
- return idx, buffer
579
-
580
- def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
581
- buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
582
-
583
- with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
584
- futures = [
585
- executor.submit(
586
- register_pin_memory,
587
- idx,
588
- bucket.size,
589
- shared_pin_memory,
590
- )
591
- for idx, bucket in enumerate(buckets)
592
- ]
593
- new_futures = []
594
- for future in concurrent.futures.as_completed(futures):
595
- idx, buffer = future.result()
596
- assert buffer.numel() == buckets[idx].size, (
597
- f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
598
- )
599
- memory_buffers[idx].buffer = buffer
600
- logger.info(
601
- f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
602
- f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
603
- )
604
- offset = 0
605
- for meta in buckets[idx].metas:
606
- name = meta.name
607
- tensor = parameters[name]
608
- size = _align_size(tensor.dtype, tensor.shape)
609
- assert size == _align_size(meta.dtype, meta.shape), (
610
- f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
611
- )
612
- new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
613
- offset += size
614
- for future in concurrent.futures.as_completed(new_futures):
615
- future.result()
616
- return memory_buffers
617
-
618
-
619
- def _register_checkpoint(
620
- *,
621
- files: list[str],
622
- named_tensors: dict[str, torch.Tensor],
623
- rank: int | None = None,
624
- shared_pin_memory: list[MemoryBuffer] | None = None,
625
- ) -> list[MemoryBuffer]:
626
- logger.info(
627
- f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
628
- )
629
- if not files and not named_tensors:
630
- return []
631
- memory_buffers: list[MemoryBuffer] = []
632
- files_to_inplace_pin = [
633
- file
634
- for file in files
635
- if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
636
- ]
637
- files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
638
- if files_to_normal_pin or named_tensors:
639
- memory_buffers.extend(
640
- _normal_pin_memory(
641
- files=files_to_normal_pin,
642
- named_tensors=named_tensors,
643
- rank=rank,
644
- shared_pin_memory=shared_pin_memory,
645
- )
646
- )
647
- if files_to_inplace_pin:
648
- memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
649
- return memory_buffers
650
-
651
-
652
- def request_inference_to_update(
653
- url: str,
654
- socket_paths: dict[str, str],
655
- timeout: float = 300.0,
656
- uds: str | None = None,
657
- ):
658
- """Send an inference update request to inference server via HTTP or Unix socket.
659
-
660
- Args:
661
- url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
662
- socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
663
- timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
664
- uds (str, optional): Path to a Unix domain socket. If provided, the request
665
- will be sent via the Unix socket instead of HTTP. Defaults to None.
666
-
667
- Raises:
668
- httpx.HTTPStatusError: If the response contains an HTTP error status.
669
- httpx.RequestError: If there was an issue while making the request.
670
- """
671
- resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
672
- url,
673
- json={
674
- "method": "update_weights_from_ipc",
675
- "args": [socket_paths],
676
- "timeout": timeout,
677
- },
678
- timeout=timeout,
679
- )
680
- resp.raise_for_status()
681
-
682
-
683
59
  def _gen_h2d_buckets(
684
60
  global_metas: dict[int, MemoryBufferMetaList],
685
61
  bucket_size: int,
@@ -782,84 +158,12 @@ def _get_master_port(master_port: int | None = None) -> int:
782
158
  if master_port is None:
783
159
  # HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
784
160
  # TODO: check whether master_port is available or use a more elegant way
785
- master_port = int(os.getenv("MASTER_PORT")) + 1
161
+ master_port_str = os.getenv("MASTER_PORT")
162
+ assert master_port_str, "MASTER_PORT is required if no master_port is provided."
163
+ master_port = int(master_port_str) + 1
786
164
  return master_port
787
165
 
788
166
 
789
- class P2PStore:
790
- def __init__(self, device_manager: DeviceManager):
791
- from mooncake.engine import TransferEngine
792
-
793
- self.rank = int(os.getenv("RANK"))
794
- gpu_count = device_manager.device_module.device_count()
795
- local_rank = self.rank % gpu_count
796
- device_type = device_manager.device_type
797
- if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
798
- self.device = ""
799
- else:
800
- self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
801
- self.ip = get_ip()
802
-
803
- # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
804
- retry_count = 8
805
- for i in range(retry_count):
806
- self.engine = TransferEngine()
807
- ret = self.engine.initialize(
808
- self.ip,
809
- "P2PHANDSHAKE",
810
- "ascend_direct" if device_type == "npu" else "rdma",
811
- self.device,
812
- )
813
- if ret == 0:
814
- break
815
- # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
816
- sleep_ms = random.randint(500, 2000)
817
- logger.warning(
818
- f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
819
- )
820
- time.sleep(sleep_ms / 1000)
821
- else:
822
- raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
823
- self.port = self.engine.get_rpc_port()
824
- self.named_tensors: dict[str, torch.Tensor] = {}
825
- logger.info(
826
- f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
827
- )
828
-
829
- @property
830
- def addr(self) -> str:
831
- return f"{self.ip}:{self.port}"
832
-
833
- def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
834
- buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
835
- capacities = [tensor.nbytes for tensor in named_tensors.values()]
836
- self.named_tensors.update(named_tensors)
837
- for i, name in enumerate(named_tensors.keys()):
838
- logger.info(
839
- f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
840
- )
841
- assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
842
-
843
- def unregister_named_tensors(self, names: list[str]) -> int:
844
- buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
845
- assert self.engine.batch_unregister_memory(buffer_addresses) == 0
846
- num_unregistered = 0
847
- for i, name in enumerate(names):
848
- del self.named_tensors[name]
849
- logger.info(
850
- f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
851
- )
852
- num_unregistered += 1
853
- return num_unregistered
854
-
855
- def batch_transfer_sync_read(
856
- self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
857
- ):
858
- assert (
859
- self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
860
- )
861
-
862
-
863
167
  class ParameterServer:
864
168
  shared_memory_pool_name = "__shared_memory_pool__"
865
169
 
@@ -868,7 +172,7 @@ class ParameterServer:
868
172
  *,
869
173
  rank: int | None = None,
870
174
  world_size: int | None = None,
871
- auto_pg: bool = False,
175
+ auto_pg: bool = True,
872
176
  gpu_count: int | None = None,
873
177
  mem_fraction: float | None = None,
874
178
  ):
@@ -877,11 +181,11 @@ class ParameterServer:
877
181
 
878
182
  Args:
879
183
  auto_pg: Whether to automatically initialize the process group.
880
- Notice that if auto_pg is True, will destroy the process group after update.
184
+ Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
881
185
  mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
882
186
  """
883
- self._rank = rank or int(os.environ.get("RANK", None))
884
- self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
187
+ self._rank = rank or int(os.environ["RANK"])
188
+ self._world_size = world_size or int(os.environ["WORLD_SIZE"])
885
189
  self.device_manager = DeviceManager()
886
190
  self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
887
191
  self._local_rank = self._rank % self._gpu_count
@@ -890,7 +194,7 @@ class ParameterServer:
890
194
  self._global_device_uuids: list[str] = []
891
195
  self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
892
196
  self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
893
- self._mem_fraction = mem_fraction or 0.9
197
+ self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
894
198
 
895
199
  assert self._rank is not None and self._rank >= 0, self._rank
896
200
  assert self._world_size and self._world_size > 0, self._world_size
@@ -959,11 +263,12 @@ class ParameterServer:
959
263
  files: list[str] | None = None,
960
264
  named_tensors: dict[str, torch.Tensor] | None = None,
961
265
  use_shared_memory_pool: bool = False,
266
+ use_inplace_pin_memory: bool = True,
962
267
  ) -> None:
963
268
  """
964
269
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
965
- Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
966
- Please make sure to copy the files to disks if you need to keep them.
270
+ Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
271
+ Please make sure to copy the files to disks if you need to keep them. NPU does not support inplace pin memory.
967
272
 
968
273
  Args:
969
274
  checkpoint_name: The name of the checkpoint.
@@ -974,7 +279,14 @@ class ParameterServer:
974
279
  cannot accommodate checkpoints with different memory requirements.
975
280
  To free the actual memory of the shared pool or to modify its shape,
976
281
  please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
282
+ use_inplace_pin_memory: If True (default), allows inplace pin memory for /dev/shm/ safetensors files.
283
+ This option is ignored when ``use_shared_memory_pool`` is True.
977
284
  """
285
+ if self.device_manager.device_type != "cuda" and use_inplace_pin_memory:
286
+ logger.warning(
287
+ f"[rank{self._rank}] Only cuda devices support in-place pin memory, set use_inplace_pin_memory to False"
288
+ )
289
+ use_inplace_pin_memory = False
978
290
  try:
979
291
  if use_shared_memory_pool:
980
292
  logger.info(
@@ -993,6 +305,7 @@ class ParameterServer:
993
305
  named_tensors=named_tensors or {},
994
306
  rank=self._rank,
995
307
  shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
308
+ inplace_pin=False, # inplace pin memory is not compatible with shared memory pool
996
309
  )
997
310
  self._current_shared_memory_pool_user = checkpoint_name
998
311
  if self._p2p_store is not None and _is_first_time:
@@ -1002,7 +315,10 @@ class ParameterServer:
1002
315
  f"checkpoint {checkpoint_name} already registered"
1003
316
  )
1004
317
  self._memory_pool[checkpoint_name] = _register_checkpoint(
1005
- files=files or [], named_tensors=named_tensors or {}, rank=self._rank
318
+ files=files or [],
319
+ named_tensors=named_tensors or {},
320
+ rank=self._rank,
321
+ inplace_pin=use_inplace_pin_memory,
1006
322
  )
1007
323
  if self._p2p_store is not None:
1008
324
  self._register_parameters_to_p2p_store(checkpoint_name)
@@ -1048,6 +364,46 @@ class ParameterServer:
1048
364
  del self._memory_pool[self.shared_memory_pool_name]
1049
365
  self._memory_pool[self.shared_memory_pool_name] = []
1050
366
  else:
367
+
368
+ def _unpin(t: torch.Tensor):
369
+ """
370
+ Un-pin the pinned memory.
371
+ """
372
+ p_flags = ctypes.c_uint()
373
+ try:
374
+ libc = ctypes.CDLL(None) # get all symbols from the current process
375
+ cuda_host_get_flags = libc.cudaHostGetFlags
376
+ cuda_host_get_flags.argtypes = [ctypes.POINTER(ctypes.c_uint), ctypes.c_void_p]
377
+ cuda_host_get_flags.restype = ctypes.c_int
378
+ except AttributeError:
379
+ logger.error("cudaHostGetFlags not found in libc, cannot unpin memory manually")
380
+ raise
381
+ r = cuda_host_get_flags(ctypes.byref(p_flags), ctypes.c_void_p(t.data_ptr()))
382
+ assert r == 0, f"get pin flags error, error code: {r}"
383
+ # p_flags value meaning from cuda/include/driver_types.h
384
+ # cudaHostRegisterDefault 0x00 /**< Default host memory registration flag */
385
+ # cudaHostRegisterPortable 0x01 /**< Pinned memory accessible by all CUDA contexts */
386
+ # cudaHostRegisterMapped 0x02 /**< Map registered memory into device space */
387
+ # cudaHostRegisterIoMemory 0x04 /**< Memory-mapped I/O space */
388
+ # cudaHostRegisterReadOnly 0x08 /**< Memory-mapped read-only */
389
+ assert p_flags.value == 0x02, (
390
+ f"pin memory flag error, expected: 0x02 (cudaHostRegisterMapped), got flag: {p_flags.value}"
391
+ )
392
+ cudart = torch.cuda.cudart()
393
+ r = cudart.cudaHostUnregister(t.data_ptr())
394
+ assert r == 0, f"unpin memory error, error code: {r}"
395
+
396
+ # if the checkpoint is pinned by cudaHostRegister manually, we need to unpin it manually
397
+ try:
398
+ for memory_buffer in self._memory_pool.get(checkpoint_name, []):
399
+ if memory_buffer.manually_pinned:
400
+ _unpin(memory_buffer.buffer)
401
+ except Exception as e:
402
+ logger.error(
403
+ f"[rank{self._rank}] fail to unpin memory for checkpoint {checkpoint_name}: {e}"
404
+ )
405
+ raise
406
+ # we won't delete the memory pool if unpinning fails.
1051
407
  del self._memory_pool[checkpoint_name]
1052
408
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
1053
409
  # this works by using torch>=2.5.0
@@ -1183,6 +539,8 @@ class ParameterServer:
1183
539
  ) -> None:
1184
540
  """
1185
541
  Update the checkpoint to inference engine. This function should be called after gather_metas.
542
+ Warning: if _auto_pg is False when initializing ParameterServer, please make sure ALL ranks in the WORLD_SIZE call `update` function,
543
+ otherwise, it will hang.
1186
544
 
1187
545
  Args:
1188
546
  checkpoint_name: The name of the checkpoint.
@@ -1217,7 +575,7 @@ class ParameterServer:
1217
575
  is_master=self._rank == 0,
1218
576
  )
1219
577
  # if ranks is None or [], it will use fully broadcast to update to all ranks
1220
- ranks_group = dist.new_group(ranks if ranks else None)
578
+ ranks_group = dist.new_group(ranks) if ranks else None
1221
579
  self._update_per_bucket(checkpoint_name, req_func, ranks_group, ranks)
1222
580
  self.store_based_barrier(manager_store)
1223
581
  except Exception as e:
@@ -1248,7 +606,7 @@ class ParameterServer:
1248
606
  return socket, socket_paths
1249
607
 
1250
608
  def _detect_bucket_size(
1251
- self, ranks_group: dist.ProcessGroup, *, disable_h2d_buffer: bool = False
609
+ self, ranks_group: dist.ProcessGroup | None, *, disable_h2d_buffer: bool = False
1252
610
  ) -> tuple[int, bool]:
1253
611
  GiB = 1 << 30 # noqa: N806
1254
612
  # auto detect bucket size
@@ -1291,7 +649,7 @@ class ParameterServer:
1291
649
  f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
1292
650
  )
1293
651
  disable_h2d_buffer = True
1294
- max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
652
+ max_bytes = int(float(os.getenv("PS_MAX_BUCKET_SIZE_GB", "8")) * GiB)
1295
653
  bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
1296
654
  logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
1297
655
  return bucket_size, disable_h2d_buffer
@@ -1367,7 +725,7 @@ class ParameterServer:
1367
725
  self,
1368
726
  checkpoint_name: str,
1369
727
  req_func: Callable[[list[tuple[str, str]]], None],
1370
- ranks_group: dist.ProcessGroup,
728
+ ranks_group: dist.ProcessGroup | None,
1371
729
  ranks: list[int] | None = None,
1372
730
  ):
1373
731
  assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
@@ -1498,79 +856,8 @@ class ParameterServer:
1498
856
  self.device_manager.device_module.empty_cache()
1499
857
 
1500
858
 
1501
- def _init_api(ps: ParameterServer) -> Any:
1502
- import fastapi
1503
- from fastapi import Request
1504
- from fastapi.responses import JSONResponse, Response
1505
-
1506
- app = fastapi.FastAPI()
1507
-
1508
- class RegisterRequest(BaseModel):
1509
- files: list[str]
1510
-
1511
- class UpdateRequest(BaseModel):
1512
- ranks: list[int] = []
1513
- update_url: str | None = None
1514
- inference_group_ranks: list[int] = []
1515
- timeout: float = 300.0
1516
- uds: str | None = None
1517
-
1518
- def wrap_exception(func: Callable[[], None]) -> Response:
1519
- try:
1520
- func()
1521
- except Exception as e: # noqa: BLE001
1522
- logger.exception(f"wrap exception {func} failed")
1523
- return JSONResponse(content=str(e), status_code=500)
1524
- return Response(status_code=200)
1525
-
1526
- @app.post("/v1/checkpoints/{checkpoint_name}/files")
1527
- async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
1528
- return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
1529
-
1530
- @app.delete("/v1/checkpoints/{checkpoint_name}")
1531
- async def unregister_checkpoint(checkpoint_name: str) -> Response:
1532
- return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
1533
-
1534
- @app.get("/v1/healthz")
1535
- async def healthz() -> Response:
1536
- return Response(status_code=200)
1537
-
1538
- @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
1539
- async def gather_metas(checkpoint_name: str) -> Response:
1540
- return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
1541
-
1542
- @app.post("/v1/checkpoints/{checkpoint_name}/update")
1543
- async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
1544
- def update_func(socket_paths: list[tuple[str, str]]):
1545
- if req.update_url is None:
1546
- return
1547
- if req.inference_group_ranks:
1548
- socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
1549
- request_inference_to_update(
1550
- req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
1551
- )
1552
-
1553
- return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
1554
-
1555
- return app
1556
-
1557
-
1558
- @logger.catch(reraise=True)
1559
- def run_from_cli():
1560
- import uvicorn
1561
-
1562
- parser = argparse.ArgumentParser(description="Parameter Server")
1563
- parser.add_argument("--uds", type=str)
1564
-
1565
- args = parser.parse_args()
1566
- logger.info(
1567
- f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
1568
- )
1569
-
1570
- assert args.uds and len(args.uds) > 0, args.uds
1571
- ps = ParameterServer(auto_pg=True)
1572
- uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
1573
-
1574
-
859
+ # we need this CLI entry point for compatibility with former versions
1575
860
  if __name__ == "__main__":
861
+ from .__main__ import run_from_cli
862
+
1576
863
  run_from_cli()