checkpoint-engine 0.3.0rc1__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checkpoint_engine/ps.py CHANGED
@@ -1,143 +1,33 @@
1
- import argparse
2
- import concurrent.futures
3
1
  import ctypes
4
- import json
5
2
  import os
6
- import pickle
7
- import random
8
3
  import threading
9
- import time
10
4
  from collections import defaultdict
11
5
  from collections.abc import Callable
12
6
  from datetime import timedelta
13
- from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
7
+ from typing import TYPE_CHECKING
14
8
 
15
- import httpx
16
- import numpy as np
17
9
  import torch
18
10
  import torch.distributed as dist
19
11
  import zmq
20
12
  from loguru import logger
21
- from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
22
- from safetensors.torch import _getdtype, safe_open
23
13
  from torch.multiprocessing.reductions import reduce_tensor
24
14
 
15
+ from checkpoint_engine.data_types import (
16
+ BucketRange,
17
+ DataToGather,
18
+ H2DBucket,
19
+ MemoryBuffer,
20
+ MemoryBufferMetaList,
21
+ MemoryBufferMetas,
22
+ ParameterMeta,
23
+ )
25
24
  from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
25
+ from checkpoint_engine.p2p_store import P2PStore
26
+ from checkpoint_engine.pin_memory import _ALIGN_SIZE, _register_checkpoint
26
27
 
27
28
 
28
29
  if TYPE_CHECKING:
29
- from typing import TypeVar
30
-
31
- from typing_extensions import TypedDict
32
-
33
- class FileMeta(TypedDict):
34
- key: str # parameter name
35
- dtype: torch.dtype
36
- shape: torch.Size
37
- type: type
38
- tp_concat_dim: int
39
-
40
- T = TypeVar("T")
41
-
42
-
43
- def _dt_validate(value: Any) -> torch.dtype:
44
- if isinstance(value, str):
45
- if not value.startswith("torch."):
46
- raise ValueError(f"dtype {value} should start with torch.")
47
- try:
48
- value = getattr(torch, value.split(".")[1])
49
- except AttributeError as e:
50
- raise ValueError(f"unknown dtype: {value}") from e
51
- if not isinstance(value, torch.dtype):
52
- raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
53
- return value
54
-
55
-
56
- _TorchDtype = Annotated[
57
- torch.dtype,
58
- PlainValidator(_dt_validate),
59
- PlainSerializer(lambda x: str(x), return_type=str),
60
- WithJsonSchema({"type": "string"}, mode="serialization"),
61
- ]
62
-
63
-
64
- def _size_validate(value: Any) -> torch.Size:
65
- if isinstance(value, list | tuple):
66
- return torch.Size(value)
67
- if not isinstance(value, torch.Size):
68
- raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
69
- return value
70
-
71
-
72
- _TorchSize = Annotated[
73
- torch.Size,
74
- PlainValidator(_size_validate),
75
- PlainSerializer(lambda x: tuple(x), return_type=tuple),
76
- WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
77
- ]
78
-
79
-
80
- def _tensor_validate(value: Any) -> torch.Tensor:
81
- if isinstance(value, torch.Tensor):
82
- return value
83
- raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
84
-
85
-
86
- _TorchTensor = Annotated[
87
- torch.Tensor,
88
- PlainValidator(_tensor_validate),
89
- ]
90
-
91
-
92
- class ParameterMeta(BaseModel):
93
- name: str
94
- dtype: _TorchDtype
95
- shape: _TorchSize
96
- aligned_size: int
97
-
98
-
99
- class BucketRange(NamedTuple):
100
- idx: int # bucket_idx of MemoryBucket in memory_pool
101
- offset: int
102
- size: int
103
-
104
-
105
- class H2DBucket(BaseModel):
106
- size: int
107
- ranges: list[BucketRange]
108
- items: list[ParameterMeta]
109
-
110
-
111
- class MemoryBufferMetas(BaseModel):
112
- metas: list[ParameterMeta]
113
- ptr: int
114
- size: int
115
-
116
-
117
- class MemoryBuffer(BaseModel):
118
- buffer: _TorchTensor
119
- size: int
120
- metas: list[ParameterMeta]
121
- manually_pinned: bool = False
122
-
123
-
124
- class MemoryBufferMetaList(BaseModel):
125
- p2p_store_addr: str | None
126
- memory_buffer_metas_list: list[MemoryBufferMetas]
127
- rdma_device: str
128
-
129
-
130
- class DataToGather(MemoryBufferMetaList):
131
- host_ip: str
132
- device_uuid: str
133
-
134
-
135
- # 256 bytes alignment when flatten torch tensors to uint8 buffer
136
- _ALIGN_SIZE = 256
137
-
138
-
139
- def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
140
- return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
30
+ from checkpoint_engine.data_types import T
141
31
 
142
32
 
143
33
  def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
@@ -156,107 +46,6 @@ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
156
46
  return ret
157
47
 
158
48
 
159
- def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
160
- def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
161
- ret = {}
162
- with safe_open(fn, framework="pt") as f:
163
- for name in f.keys(): # noqa: SIM118
164
- weight = f.get_tensor(name)
165
- meta = {
166
- "key": name,
167
- "dtype": weight.dtype,
168
- "shape": weight.shape,
169
- "type": type(weight),
170
- "tp_concat_dim": -1, # safetensors does not support tp_concat_dim
171
- }
172
- ret[name] = (meta, weight)
173
- return ret
174
-
175
- # deprecated, will be removed in the future
176
- def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
177
- """load *.np file and return memmap and related tensor meta"""
178
-
179
- def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
180
- start = fin.tell()
181
- major, minor = np.lib.format.read_magic(fin)
182
- if major == 1 and minor == 0:
183
- read_header_fn = np.lib.format.read_array_header_1_0
184
- elif major == 2 and minor == 0:
185
- read_header_fn = np.lib.format.read_array_header_2_0
186
- else:
187
- raise ValueError(
188
- f"unknown version {major}.{minor} when parsing npy header from {fn}"
189
- )
190
- shape, is_fortran, dtype = read_header_fn(fin)
191
- return {
192
- "shape": shape,
193
- "is_fortran": is_fortran,
194
- "dtype": dtype,
195
- "header_length": fin.tell() - start,
196
- }
197
-
198
- meta_fn = fn + ".meta"
199
- with open(meta_fn, "rb") as fin:
200
- meta_lst = pickle.load(fin)
201
-
202
- tensors = []
203
- offset = 0
204
- with open(fn, "rb") as fin:
205
- fin.seek(0, os.SEEK_END)
206
- filesize = fin.tell()
207
- fin.seek(0)
208
- while fin.tell() < filesize:
209
- tensor_meta = parse_npy_header(fin)
210
- tensor = np.memmap(
211
- fn,
212
- dtype=tensor_meta["dtype"],
213
- mode="c",
214
- offset=offset + tensor_meta["header_length"],
215
- shape=tensor_meta["shape"],
216
- )
217
- offset += tensor_meta["header_length"] + tensor.nbytes
218
- fin.seek(offset)
219
- tensors.append(tensor)
220
-
221
- assert len(meta_lst) == len(tensors)
222
- ret = {}
223
- for meta, tensor in zip(meta_lst, tensors):
224
- if meta["type"] == torch.Tensor:
225
- tensor = torch.from_numpy(tensor)
226
- tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
227
- ret[meta["key"]] = (meta, tensor)
228
- return ret
229
-
230
- tp_rank = 0
231
- if file_path.endswith(".npy"):
232
- logger.warning("numpy model file is deprecated, will be removed in the future")
233
- filename_split = os.path.basename(file_path).split(".")
234
- # if using numpy and want to specify tp rank
235
- # file should be in model.{layer}.{tp}[.{ep}].npy format
236
- tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
237
- ret = _fast_np_load(file_path)
238
- elif file_path.endswith(".safetensors"):
239
- ret = _safetensors_load(file_path)
240
- else:
241
- raise ValueError(f"unsupported file format: {file_path}")
242
- return tp_rank, ret
243
-
244
-
245
- def _concat_tp_weights(
246
- tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
247
- ) -> torch.Tensor:
248
- """Concat tp weights with meta info.
249
- If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
250
- Else we will cat weights in concat_dim.
251
- """
252
- if tp_concat_dim == -1:
253
- return tp_weights[0]
254
- assert tp_size == len(tp_weights)
255
- if len(tp_weights) == 1:
256
- return tp_weights[0]
257
- return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
258
-
259
-
260
49
  def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
261
50
  try:
262
51
  if device_manager.device_type == "npu":
@@ -267,426 +56,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None
267
56
  raise ValueError(f"fail to get physical gpu id {device_index}") from e
268
57
 
269
58
 
270
- def _ibv_get_device_list() -> list[str]:
271
- lib = ctypes.CDLL("libibverbs.so.1")
272
- lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
273
- lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
274
-
275
- lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
276
- lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
277
- lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
278
-
279
- num = ctypes.c_int()
280
- dev_array = lib.ibv_get_device_list(ctypes.byref(num))
281
- if not dev_array or num.value <= 0:
282
- return []
283
-
284
- devices = []
285
- for i in range(num.value):
286
- dev_ptr = dev_array[i] # struct ibv_device *
287
- name = lib.ibv_get_device_name(dev_ptr) # const char *
288
- devices.append(name.decode())
289
- lib.ibv_free_device_list(dev_array)
290
- return devices
291
-
292
-
293
- def _get_rdma_devices() -> list[str]:
294
- """
295
- use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
296
- """
297
- devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
298
- if devices_str:
299
- return devices_str.split(",")
300
- # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
301
- hca = os.getenv("NCCL_IB_HCA", None)
302
- return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
303
-
304
-
305
- def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
306
- """
307
- implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
308
- """
309
- if not devices:
310
- raise RuntimeError("no rdma devices found")
311
- try:
312
- assert len(devices) <= gpu_count, (
313
- f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
314
- )
315
- assert gpu_count % len(devices) == 0, (
316
- f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
317
- )
318
- return devices[local_rank // (gpu_count // len(devices))]
319
- except AssertionError:
320
- logger.error(
321
- "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
322
- "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
323
- "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
324
- )
325
- raise
326
-
327
-
328
- def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
329
- """
330
- The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
331
- The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
332
-
333
- The list is comma-separated; port numbers are NOT supported yet.
334
- An optional prefix '^' indicates the list is an exclude list.
335
- A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
336
- Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
337
-
338
- Examples:
339
- - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
340
- - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
341
- - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
342
- - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
343
- """
344
- max_hcas = 32
345
- if not value or value.strip() == "":
346
- return available_devices[:max_hcas]
347
-
348
- value = value.strip()
349
- result = []
350
- is_exclude = value.startswith("^")
351
- if is_exclude:
352
- value = value.removeprefix("^")
353
- is_exact_match = value.startswith("=")
354
- if is_exact_match:
355
- value = value.removeprefix("=")
356
-
357
- device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
358
-
359
- result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
360
- if is_exclude:
361
- result = [dev for dev in available_devices if dev not in result]
362
- if len(result) > max_hcas:
363
- result = result[:max_hcas]
364
-
365
- logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
366
-
367
- return result
368
-
369
-
370
- def _resolve_device_specs(
371
- device_specs: list[str], is_exact_match: bool, available_devices: list[str]
372
- ) -> list[str]:
373
- devices = set()
374
- for spec in device_specs:
375
- parts = spec.split(":", 1)
376
- device_name = parts[0].strip()
377
- # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
378
- # port = parts[1].strip() if len(parts) > 1 else None
379
- base_devices = (
380
- [device_name]
381
- if device_name in available_devices
382
- else []
383
- if is_exact_match
384
- else [dev for dev in available_devices if dev.startswith(device_name)]
385
- )
386
-
387
- if not base_devices:
388
- logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
389
- continue
390
-
391
- for base_dev in base_devices:
392
- devices.add(base_dev)
393
-
394
- return sorted(devices)
395
-
396
-
397
- def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
398
- class TPMeta(BaseModel):
399
- concat_dim: int
400
- size: int
401
-
402
- parameters: dict[str, torch.Tensor] = {}
403
- parameter_metas: dict[str, ParameterMeta] = {}
404
- tp_metas: dict[str, TPMeta] = {}
405
- parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
406
- for file in files:
407
- tp_rank, ret = _load_checkpoint_file(file)
408
- for parameter_name, (meta, weight) in ret.items():
409
- if parameter_name not in parameters_with_tp:
410
- parameters_with_tp[parameter_name] = {}
411
- parameters_with_tp[parameter_name][tp_rank] = weight
412
- if parameter_name not in tp_metas:
413
- tp_metas[parameter_name] = TPMeta(
414
- concat_dim=meta["tp_concat_dim"],
415
- size=1,
416
- )
417
- if parameter_name not in parameter_metas:
418
- assert isinstance(meta["dtype"], torch.dtype), (
419
- f"meta {meta} dtype should be torch.dtype"
420
- )
421
- assert isinstance(meta["shape"], torch.Size), (
422
- f"meta {meta} shape should be torch.Size"
423
- )
424
- parameter_metas[parameter_name] = ParameterMeta(
425
- name=parameter_name,
426
- shape=meta["shape"],
427
- dtype=meta["dtype"],
428
- aligned_size=_align_size(meta["dtype"], meta["shape"]),
429
- )
430
- tp_meta = tp_metas[parameter_name]
431
- if tp_meta.concat_dim != -1:
432
- tp_meta.size = max(tp_meta.size, tp_rank + 1)
433
- for name, tp_meta in tp_metas.items():
434
- if tp_meta.concat_dim != -1:
435
- shape = list(parameter_metas[name].shape)
436
- shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
437
- parameter_metas[name] = ParameterMeta(
438
- name=name,
439
- shape=torch.Size(shape),
440
- dtype=parameter_metas[name].dtype,
441
- aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
442
- )
443
- weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
444
- # TODO: here concat is serial, which may be slow
445
- # but since tp storage is not used in the future
446
- # we ignore this performance issue for now
447
- parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
448
- for name, parameter in parameters.items():
449
- assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
450
- assert parameter_metas[name].shape == parameter.shape, (
451
- f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
452
- )
453
- assert parameter_metas[name].dtype == parameter.dtype, (
454
- f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
455
- )
456
- return parameters
457
-
458
-
459
- def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
460
- def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
461
- """
462
- safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
463
- We load the safetensors file as bytes, then parse the header manually to get parameter metas.
464
- The actual tensor data is in the remaining bytes and is naturally aligned.
465
- We pin the remaining bytes as the buffer, making pinning faster.
466
- """
467
-
468
- def _pin(t: torch.Tensor):
469
- """
470
- Pin the memory of tensor in-place.
471
- See: https://github.com/pytorch/pytorch/issues/32167
472
- """
473
- cudart = torch.cuda.cudart()
474
- r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
475
- assert r == 0, f"pin memory error, error code: {r}"
476
-
477
- # TODO: should only support /dev/shm? but we found files in disk also work?
478
- size = os.stat(file_path).st_size
479
- flag_size = 8
480
- t = torch.from_file(file_path, True, size, dtype=torch.uint8)
481
- assert t.nbytes > flag_size, (
482
- f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
483
- )
484
- start_pos = (
485
- int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
486
- + flag_size
487
- )
488
- header_tensor = t[flag_size:start_pos]
489
- header = json.loads(header_tensor.numpy().tobytes())
490
- if "__metadata__" in header:
491
- header.pop("__metadata__")
492
-
493
- metas: list[ParameterMeta] = []
494
- offset = 0
495
- try:
496
- for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
497
- start, end = meta["data_offsets"]
498
- # safetensors format ensures offsets are aligned
499
- assert offset == start, f"offset {offset} should be equal to start {start}"
500
- metas.append(
501
- ParameterMeta(
502
- name=name,
503
- dtype=_getdtype(meta["dtype"]),
504
- shape=torch.Size(meta["shape"]),
505
- aligned_size=end - start,
506
- )
507
- )
508
- offset = end
509
- except Exception as e:
510
- logger.error(f"fail to parse safetensors header from {file_path}: {e}")
511
- raise
512
-
513
- buffer = t[start_pos:]
514
- assert offset == buffer.nbytes, (
515
- f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
516
- )
517
- # Remove the file after successfully loading. This will avoid doubling the memory usage.
518
- # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
519
- os.remove(file_path)
520
- _pin(buffer)
521
- logger.info(
522
- f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
523
- )
524
- return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
525
-
526
- memory_buffers: list[MemoryBuffer] = []
527
- with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
528
- memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
529
- return memory_buffers
530
-
531
-
532
- def _normal_pin_memory(
533
- files: list[str],
534
- named_tensors: dict[str, torch.Tensor],
535
- rank: int | None = None,
536
- shared_pin_memory: list[MemoryBuffer] | None = None,
537
- ) -> list[MemoryBuffer]:
538
- parameters = _load_checkpoint(files)
539
- if named_tensors:
540
- parameters.update(named_tensors)
541
- bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
542
-
543
- class MemoryBucket(BaseModel):
544
- size: int
545
- metas: list[ParameterMeta]
546
-
547
- buckets: list[MemoryBucket] = []
548
- buckets.append(MemoryBucket(size=0, metas=[]))
549
- for name, tensor in sorted(parameters.items()):
550
- size = _align_size(tensor.dtype, tensor.shape)
551
- if buckets[-1].size + size > bucket_size:
552
- assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
553
- buckets.append(MemoryBucket(size=0, metas=[]))
554
- buckets[-1].metas.append(
555
- ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
556
- )
557
- buckets[-1].size += size
558
-
559
- memory_buffers = [
560
- MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
561
- for bucket in buckets
562
- ]
563
-
564
- def register_pin_memory(
565
- idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
566
- ) -> tuple[int, torch.Tensor]:
567
- if shared_pin_memory:
568
- # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
569
- # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
570
- assert idx < len(shared_pin_memory), (
571
- f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
572
- )
573
- assert shared_pin_memory[idx].size == size, (
574
- f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
575
- )
576
- return idx, shared_pin_memory[idx].buffer
577
- else:
578
- buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
579
- return idx, buffer
580
-
581
- def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
582
- buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
583
-
584
- with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
585
- futures = [
586
- executor.submit(
587
- register_pin_memory,
588
- idx,
589
- bucket.size,
590
- shared_pin_memory,
591
- )
592
- for idx, bucket in enumerate(buckets)
593
- ]
594
- new_futures = []
595
- for future in concurrent.futures.as_completed(futures):
596
- idx, buffer = future.result()
597
- assert buffer.numel() == buckets[idx].size, (
598
- f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
599
- )
600
- memory_buffers[idx].buffer = buffer
601
- logger.info(
602
- f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
603
- f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
604
- )
605
- offset = 0
606
- for meta in buckets[idx].metas:
607
- name = meta.name
608
- tensor = parameters[name]
609
- size = _align_size(tensor.dtype, tensor.shape)
610
- assert size == _align_size(meta.dtype, meta.shape), (
611
- f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
612
- )
613
- new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
614
- offset += size
615
- for future in concurrent.futures.as_completed(new_futures):
616
- future.result()
617
- return memory_buffers
618
-
619
-
620
- def _register_checkpoint(
621
- *,
622
- files: list[str],
623
- named_tensors: dict[str, torch.Tensor],
624
- rank: int | None = None,
625
- shared_pin_memory: list[MemoryBuffer] | None = None,
626
- inplace_pin: bool = False,
627
- ) -> list[MemoryBuffer]:
628
- logger.info(
629
- f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
630
- )
631
- if not files and not named_tensors:
632
- return []
633
- memory_buffers: list[MemoryBuffer] = []
634
- if inplace_pin:
635
- logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
636
- files_to_inplace_pin = [
637
- file
638
- for file in files
639
- if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
640
- ]
641
- files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
642
- else:
643
- files_to_normal_pin = files
644
- files_to_inplace_pin = []
645
- if files_to_normal_pin or named_tensors:
646
- memory_buffers.extend(
647
- _normal_pin_memory(
648
- files=files_to_normal_pin,
649
- named_tensors=named_tensors,
650
- rank=rank,
651
- shared_pin_memory=shared_pin_memory,
652
- )
653
- )
654
- if files_to_inplace_pin:
655
- memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
656
- return memory_buffers
657
-
658
-
659
- def request_inference_to_update(
660
- url: str,
661
- socket_paths: dict[str, str],
662
- timeout: float = 300.0,
663
- uds: str | None = None,
664
- ):
665
- """Send an inference update request to inference server via HTTP or Unix socket.
666
-
667
- Args:
668
- url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
669
- socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
670
- timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
671
- uds (str, optional): Path to a Unix domain socket. If provided, the request
672
- will be sent via the Unix socket instead of HTTP. Defaults to None.
673
-
674
- Raises:
675
- httpx.HTTPStatusError: If the response contains an HTTP error status.
676
- httpx.RequestError: If there was an issue while making the request.
677
- """
678
- resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
679
- url,
680
- json={
681
- "method": "update_weights_from_ipc",
682
- "args": [socket_paths],
683
- "timeout": timeout,
684
- },
685
- timeout=timeout,
686
- )
687
- resp.raise_for_status()
688
-
689
-
690
59
  def _gen_h2d_buckets(
691
60
  global_metas: dict[int, MemoryBufferMetaList],
692
61
  bucket_size: int,
@@ -789,84 +158,12 @@ def _get_master_port(master_port: int | None = None) -> int:
789
158
  if master_port is None:
790
159
  # HACK: use MASTER_PORT + 1 as master_port, avoid conflict with torchrun's rendezvous port
791
160
  # TODO: check whether master_port is available or use a more elegant way
792
- master_port = int(os.getenv("MASTER_PORT")) + 1
161
+ master_port_str = os.getenv("MASTER_PORT")
162
+ assert master_port_str, "MASTER_PORT is required if no master_port is provided."
163
+ master_port = int(master_port_str) + 1
793
164
  return master_port
794
165
 
795
166
 
796
- class P2PStore:
797
- def __init__(self, device_manager: DeviceManager):
798
- from mooncake.engine import TransferEngine
799
-
800
- self.rank = int(os.getenv("RANK"))
801
- gpu_count = device_manager.device_module.device_count()
802
- local_rank = self.rank % gpu_count
803
- device_type = device_manager.device_type
804
- if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
805
- self.device = ""
806
- else:
807
- self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
808
- self.ip = get_ip()
809
-
810
- # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
811
- retry_count = 8
812
- for i in range(retry_count):
813
- self.engine = TransferEngine()
814
- ret = self.engine.initialize(
815
- self.ip,
816
- "P2PHANDSHAKE",
817
- "ascend_direct" if device_type == "npu" else "rdma",
818
- self.device,
819
- )
820
- if ret == 0:
821
- break
822
- # sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
823
- sleep_ms = random.randint(500, 2000)
824
- logger.warning(
825
- f"[rank{self.rank}] fail to initialize transfer engine, ret {ret}, retry {i + 1}/{retry_count} in {sleep_ms}ms"
826
- )
827
- time.sleep(sleep_ms / 1000)
828
- else:
829
- raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
830
- self.port = self.engine.get_rpc_port()
831
- self.named_tensors: dict[str, torch.Tensor] = {}
832
- logger.info(
833
- f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
834
- )
835
-
836
- @property
837
- def addr(self) -> str:
838
- return f"{self.ip}:{self.port}"
839
-
840
- def register_named_tensors(self, named_tensors: dict[str, torch.Tensor]):
841
- buffer_addresses = [tensor.data_ptr() for tensor in named_tensors.values()]
842
- capacities = [tensor.nbytes for tensor in named_tensors.values()]
843
- self.named_tensors.update(named_tensors)
844
- for i, name in enumerate(named_tensors.keys()):
845
- logger.info(
846
- f"[rank{self.rank}] p2p store register tensor {name} with addr {hex(buffer_addresses[i])} and capacity {capacities[i]}"
847
- )
848
- assert self.engine.batch_register_memory(buffer_addresses, capacities) == 0
849
-
850
- def unregister_named_tensors(self, names: list[str]) -> int:
851
- buffer_addresses = [self.named_tensors[name].data_ptr() for name in names]
852
- assert self.engine.batch_unregister_memory(buffer_addresses) == 0
853
- num_unregistered = 0
854
- for i, name in enumerate(names):
855
- del self.named_tensors[name]
856
- logger.info(
857
- f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
858
- )
859
- num_unregistered += 1
860
- return num_unregistered
861
-
862
- def batch_transfer_sync_read(
863
- self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
864
- ):
865
- assert (
866
- self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
867
- )
868
-
869
-
870
167
  class ParameterServer:
871
168
  shared_memory_pool_name = "__shared_memory_pool__"
872
169
 
@@ -887,8 +184,8 @@ class ParameterServer:
887
184
  Notice that if auto_pg is True, will destroy the process group after update. It is recommended to set auto_pg to True!
888
185
  mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
889
186
  """
890
- self._rank = rank or int(os.environ.get("RANK", None))
891
- self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
187
+ self._rank = rank or int(os.environ["RANK"])
188
+ self._world_size = world_size or int(os.environ["WORLD_SIZE"])
892
189
  self.device_manager = DeviceManager()
893
190
  self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
894
191
  self._local_rank = self._rank % self._gpu_count
@@ -897,7 +194,7 @@ class ParameterServer:
897
194
  self._global_device_uuids: list[str] = []
898
195
  self._local_rdma_devices: dict[str, set[int]] = defaultdict(set)
899
196
  self._remote_rdma_devices: dict[str, set[int]] = defaultdict(set)
900
- self._mem_fraction = mem_fraction or 0.9
197
+ self._mem_fraction = mem_fraction or float(os.getenv("PS_MEM_FRACTION", "0.9"))
901
198
 
902
199
  assert self._rank is not None and self._rank >= 0, self._rank
903
200
  assert self._world_size and self._world_size > 0, self._world_size
@@ -1352,7 +649,7 @@ class ParameterServer:
1352
649
  f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
1353
650
  )
1354
651
  disable_h2d_buffer = True
1355
- max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
652
+ max_bytes = int(float(os.getenv("PS_MAX_BUCKET_SIZE_GB", "8")) * GiB)
1356
653
  bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
1357
654
  logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
1358
655
  return bucket_size, disable_h2d_buffer
@@ -1559,79 +856,8 @@ class ParameterServer:
1559
856
  self.device_manager.device_module.empty_cache()
1560
857
 
1561
858
 
1562
- def _init_api(ps: ParameterServer) -> Any:
1563
- import fastapi
1564
- from fastapi import Request
1565
- from fastapi.responses import JSONResponse, Response
1566
-
1567
- app = fastapi.FastAPI()
1568
-
1569
- class RegisterRequest(BaseModel):
1570
- files: list[str]
1571
-
1572
- class UpdateRequest(BaseModel):
1573
- ranks: list[int] = []
1574
- update_url: str | None = None
1575
- inference_group_ranks: list[int] = []
1576
- timeout: float = 300.0
1577
- uds: str | None = None
1578
-
1579
- def wrap_exception(func: Callable[[], None]) -> Response:
1580
- try:
1581
- func()
1582
- except Exception as e: # noqa: BLE001
1583
- logger.exception(f"wrap exception {func} failed")
1584
- return JSONResponse(content=str(e), status_code=500)
1585
- return Response(status_code=200)
1586
-
1587
- @app.post("/v1/checkpoints/{checkpoint_name}/files")
1588
- async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
1589
- return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
1590
-
1591
- @app.delete("/v1/checkpoints/{checkpoint_name}")
1592
- async def unregister_checkpoint(checkpoint_name: str) -> Response:
1593
- return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
1594
-
1595
- @app.get("/v1/healthz")
1596
- async def healthz() -> Response:
1597
- return Response(status_code=200)
1598
-
1599
- @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
1600
- async def gather_metas(checkpoint_name: str) -> Response:
1601
- return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
1602
-
1603
- @app.post("/v1/checkpoints/{checkpoint_name}/update")
1604
- async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
1605
- def update_func(socket_paths: list[tuple[str, str]]):
1606
- if req.update_url is None:
1607
- return
1608
- if req.inference_group_ranks:
1609
- socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
1610
- request_inference_to_update(
1611
- req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
1612
- )
1613
-
1614
- return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
1615
-
1616
- return app
1617
-
1618
-
1619
- @logger.catch(reraise=True)
1620
- def run_from_cli():
1621
- import uvicorn
1622
-
1623
- parser = argparse.ArgumentParser(description="Parameter Server")
1624
- parser.add_argument("--uds", type=str)
1625
-
1626
- args = parser.parse_args()
1627
- logger.info(
1628
- f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
1629
- )
1630
-
1631
- assert args.uds and len(args.uds) > 0, args.uds
1632
- ps = ParameterServer(auto_pg=True)
1633
- uvicorn.run(_init_api(ps), uds=args.uds, timeout_keep_alive=60)
1634
-
1635
-
859
+ # we need this CLI entry point for compatibility with former versions
1636
860
  if __name__ == "__main__":
861
+ from .__main__ import run_from_cli
862
+
1637
863
  run_from_cli()