checkpoint-engine 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.1'
32
- __version_tuple__ = version_tuple = (0, 1, 1)
31
+ __version__ = version = '0.1.2'
32
+ __version_tuple__ = version_tuple = (0, 1, 2)
33
33
 
34
- __commit_id__ = commit_id = 'gf29b2e1c3'
34
+ __commit_id__ = commit_id = None
checkpoint_engine/ps.py CHANGED
@@ -2,31 +2,32 @@ from __future__ import annotations
2
2
 
3
3
  import argparse
4
4
  import concurrent.futures
5
+ import ctypes
5
6
  import os
6
7
  import pickle
7
8
  import random
8
9
  import socket
9
- import subprocess
10
10
  import threading
11
11
  import time
12
- import uuid
13
12
  from collections import defaultdict
14
13
  from datetime import timedelta
15
- from functools import cached_property, lru_cache
16
- from typing import Callable, NamedTuple
14
+ from functools import lru_cache
15
+ from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
17
16
 
17
+ import httpx
18
18
  import numpy as np
19
- import requests
20
19
  import torch
21
20
  import torch.distributed as dist
22
21
  import zmq
23
22
  from loguru import logger
24
- from pydantic import BaseModel, ConfigDict
23
+ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
25
24
  from safetensors.torch import safe_open
26
25
  from torch.multiprocessing.reductions import reduce_tensor
27
- from typing_extensions import TYPE_CHECKING
26
+
28
27
 
29
28
  if TYPE_CHECKING:
29
+ from collections.abc import Callable
30
+
30
31
  from typing_extensions import TypedDict
31
32
 
32
33
  class FileMeta(TypedDict):
@@ -37,16 +38,59 @@ if TYPE_CHECKING:
37
38
  tp_concat_dim: int
38
39
 
39
40
 
40
- class ParameterMeta(BaseModel):
41
- # now all classes are changed to pydantic BaseModel
42
- # it will directly report validation errors for unknown types
43
- # like torch.dtype, torch.Size, so we need this configuration
44
- # see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment
45
- model_config = ConfigDict(arbitrary_types_allowed=True)
41
+ def _dt_validate(value: Any) -> torch.dtype:
42
+ if isinstance(value, str):
43
+ if not value.startswith("torch."):
44
+ raise ValueError(f"dtype {value} should start with torch.")
45
+ try:
46
+ value = getattr(torch, value.split(".")[1])
47
+ except AttributeError as e:
48
+ raise ValueError(f"unknown dtype: {value}") from e
49
+ if not isinstance(value, torch.dtype):
50
+ raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
51
+ return value
52
+
53
+
54
+ _TorchDtype = Annotated[
55
+ torch.dtype,
56
+ PlainValidator(_dt_validate),
57
+ PlainSerializer(lambda x: str(x), return_type=str),
58
+ WithJsonSchema({"type": "string"}, mode="serialization"),
59
+ ]
60
+
61
+
62
+ def _size_validate(value: Any) -> torch.Size:
63
+ if isinstance(value, list | tuple):
64
+ return torch.Size(value)
65
+ if not isinstance(value, torch.Size):
66
+ raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
67
+ return value
68
+
69
+
70
+ _TorchSize = Annotated[
71
+ torch.Size,
72
+ PlainValidator(_size_validate),
73
+ PlainSerializer(lambda x: tuple(x), return_type=tuple),
74
+ WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
75
+ ]
76
+
46
77
 
78
+ def _tensor_validate(value: Any) -> torch.Tensor:
79
+ if isinstance(value, torch.Tensor):
80
+ return value
81
+ raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
82
+
83
+
84
+ _TorchTensor = Annotated[
85
+ torch.Tensor,
86
+ PlainValidator(_tensor_validate),
87
+ ]
88
+
89
+
90
+ class ParameterMeta(BaseModel):
47
91
  name: str
48
- dtype: torch.dtype
49
- shape: torch.Size
92
+ dtype: _TorchDtype
93
+ shape: _TorchSize
50
94
 
51
95
 
52
96
  class BucketRange(NamedTuple):
@@ -68,9 +112,7 @@ class MemoryBufferMetas(BaseModel):
68
112
 
69
113
 
70
114
  class MemoryBuffer(BaseModel):
71
- model_config = ConfigDict(arbitrary_types_allowed=True)
72
-
73
- buffer: torch.Tensor
115
+ buffer: _TorchTensor
74
116
  size: int
75
117
  metas: list[ParameterMeta]
76
118
 
@@ -82,7 +124,7 @@ class MemoryBufferMetaList(BaseModel):
82
124
 
83
125
  class DataToGather(MemoryBufferMetaList):
84
126
  host_ip: str
85
- zmq_socket_path: tuple[str, str]
127
+ device_uuid: str
86
128
 
87
129
 
88
130
  # 256 bytes alignment when flatten torch tensors to uint8 buffer
@@ -93,7 +135,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
93
135
  return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
94
136
 
95
137
 
96
- def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
138
+ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
97
139
  ret = []
98
140
  for meta in metas:
99
141
  size = _align_size(meta.dtype, meta.shape)
@@ -110,10 +152,10 @@ def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
110
152
 
111
153
 
112
154
  def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
113
- def _safetensors_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
155
+ def _safetensors_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
114
156
  ret = {}
115
157
  with safe_open(fn, framework="pt") as f:
116
- for name in f.keys():
158
+ for name in f.keys(): # noqa: SIM118
117
159
  weight = f.get_tensor(name)
118
160
  meta = {
119
161
  "key": name,
@@ -126,10 +168,10 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
126
168
  return ret
127
169
 
128
170
  # deprecated, will be removed in the future
129
- def _fast_np_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
171
+ def _fast_np_load(fn: str) -> dict[str, tuple[FileMeta, torch.Tensor]]:
130
172
  """load *.np file and return memmap and related tensor meta"""
131
173
 
132
- def parse_npy_header(fin):
174
+ def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
133
175
  start = fin.tell()
134
176
  major, minor = np.lib.format.read_magic(fin)
135
177
  if major == 1 and minor == 0:
@@ -137,7 +179,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
137
179
  elif major == 2 and minor == 0:
138
180
  read_header_fn = np.lib.format.read_array_header_2_0
139
181
  else:
140
- raise ValueError(f"unknown version {major}.{minor} when parsing npy header from {fn}")
182
+ raise ValueError(
183
+ f"unknown version {major}.{minor} when parsing npy header from {fn}"
184
+ )
141
185
  shape, is_fortran, dtype = read_header_fn(fin)
142
186
  return {
143
187
  "shape": shape,
@@ -193,7 +237,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
193
237
  return tp_rank, ret
194
238
 
195
239
 
196
- def _concat_tp_weights(tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int) -> torch.Tensor:
240
+ def _concat_tp_weights(
241
+ tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
242
+ ) -> torch.Tensor:
197
243
  """Concat tp weights with meta info.
198
244
  If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
199
245
  Else we will cat weights in concat_dim.
@@ -206,39 +252,54 @@ def _concat_tp_weights(tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_si
206
252
  return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
207
253
 
208
254
 
209
- def _get_physical_gpu_id(rank: int) -> str:
210
- result = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True)
211
- if result.returncode != 0:
212
- raise ValueError(result.stdout)
213
- lines = result.stdout.strip().split("\n")
214
- for line in lines:
215
- if f"GPU {rank}" in line:
216
- uuid = line.split("UUID: ")[1].strip(")")
217
- return uuid
218
- raise ValueError(f"not found gpu{rank} uuid")
255
+ def _get_physical_gpu_id(device_index: int | None = None) -> str:
256
+ try:
257
+ return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
258
+ except AssertionError as e:
259
+ raise ValueError(f"fail to get physical gpu id {device_index}") from e
219
260
 
220
261
 
221
262
  @lru_cache(maxsize=1)
222
- def _get_ip():
263
+ def _get_ip() -> str:
223
264
  try:
224
265
  # try to get ip from network interface
225
266
  with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
226
267
  s.connect(("8.8.8.8", 80))
227
268
  return s.getsockname()[0]
228
- except:
269
+ except Exception as e: # noqa: BLE001
229
270
  # fallback to get ip from hostname
230
- logger.warning("fail to get ip from network interface, fallback to get ip from hostname")
271
+ logger.warning(
272
+ f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
273
+ )
231
274
  return socket.gethostbyname(socket.gethostname())
232
275
 
233
276
 
277
+ def _ibv_get_device_list() -> list[str]:
278
+ lib = ctypes.CDLL("libibverbs.so.1")
279
+ lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
280
+ lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
281
+
282
+ lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
283
+ lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
284
+ lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
285
+
286
+ num = ctypes.c_int()
287
+ dev_array = lib.ibv_get_device_list(ctypes.byref(num))
288
+ if not dev_array or num.value <= 0:
289
+ return []
290
+
291
+ devices = []
292
+ for i in range(num.value):
293
+ dev_ptr = dev_array[i] # struct ibv_device *
294
+ name = lib.ibv_get_device_name(dev_ptr) # const char *
295
+ devices.append(name.decode())
296
+ lib.ibv_free_device_list(dev_array)
297
+ return devices
298
+
299
+
234
300
  def _get_rdma_devices() -> list[str]:
235
301
  """
236
- use script like below to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
237
- ```bash
238
- pushd /sys/class/infiniband/ > /dev/null;
239
- for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done;
240
- popd > /dev/null;
241
- ```
302
+ use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
242
303
  """
243
304
  devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
244
305
  if devices_str:
@@ -246,41 +307,27 @@ def _get_rdma_devices() -> list[str]:
246
307
  # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
247
308
  hca = os.getenv("NCCL_IB_HCA", None)
248
309
  if hca:
249
- l = hca.split(",")
250
- if len(l) > 1:
310
+ hca_list = hca.split(",")
311
+ if len(hca_list) > 1:
251
312
  # if NCCL_IB_HCA has multiple values, just return
252
- return l
313
+ return hca_list
253
314
  else:
254
- hca = l[0]
255
- basepath = "/sys/class/infiniband/"
256
- port_path = "ports/1/gid_attrs/types"
257
- devices = []
258
- for device in sorted(os.listdir(basepath)):
259
- if hca is not None and hca not in device:
260
- continue
261
- path = os.path.join(basepath, device, port_path)
262
- if not os.path.exists(path) or not os.path.isdir(path):
263
- continue
264
- for port in os.listdir(path):
265
- try:
266
- content = open(os.path.join(path, port)).read()
267
- if "v" in content:
268
- print(f"found rdma device {device} in port {port}: {content.strip()}")
269
- devices.append(device)
270
- break
271
- except Exception:
272
- pass
273
- return devices
315
+ hca = hca_list[0]
316
+ return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
274
317
 
275
318
 
276
- def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]):
319
+ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
277
320
  """
278
321
  implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
279
322
  """
280
323
  if not devices:
281
324
  raise RuntimeError("no rdma devices found")
282
- assert len(devices) <= gpu_count, f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
283
- assert gpu_count % len(devices) == 0, f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
325
+ assert len(devices) <= gpu_count, (
326
+ f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
327
+ )
328
+ assert gpu_count % len(devices) == 0, (
329
+ f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
330
+ )
284
331
  return devices[local_rank // (gpu_count // len(devices))]
285
332
 
286
333
 
@@ -305,8 +352,12 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
305
352
  size=1,
306
353
  )
307
354
  if parameter_name not in parameter_metas:
308
- assert isinstance(meta["dtype"], torch.dtype), f"meta {meta} dtype should be torch.dtype"
309
- assert isinstance(meta["shape"], torch.Size), f"meta {meta} shape should be torch.Size"
355
+ assert isinstance(meta["dtype"], torch.dtype), (
356
+ f"meta {meta} dtype should be torch.dtype"
357
+ )
358
+ assert isinstance(meta["shape"], torch.Size), (
359
+ f"meta {meta} shape should be torch.Size"
360
+ )
310
361
  parameter_metas[parameter_name] = ParameterMeta(
311
362
  name=parameter_name,
312
363
  shape=meta["shape"],
@@ -319,7 +370,9 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
319
370
  if tp_meta.concat_dim != -1:
320
371
  shape = list(parameter_metas[name].shape)
321
372
  shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
322
- parameter_metas[name] = ParameterMeta(name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype)
373
+ parameter_metas[name] = ParameterMeta(
374
+ name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
375
+ )
323
376
  weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
324
377
  # TODO: here concat is serial, which may be slow
325
378
  # but since tp storage is not used in the future
@@ -338,17 +391,19 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
338
391
 
339
392
  def _register_checkpoint(
340
393
  *,
341
- files: list[str] = [],
342
- named_tensors: dict[str, torch.Tensor] = {},
394
+ files: list[str],
395
+ named_tensors: dict[str, torch.Tensor],
343
396
  rank: int | None = None,
344
397
  ) -> list[MemoryBuffer]:
345
- logger.info(f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors")
398
+ logger.info(
399
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
400
+ )
346
401
  if not files and not named_tensors:
347
402
  return []
348
403
  parameters = _load_checkpoint(files)
349
404
  if named_tensors:
350
405
  parameters.update(named_tensors)
351
- bucket_size = max(4 << 30, max(map(lambda x: _align_size(x.dtype, x.shape), parameters.values())))
406
+ bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
352
407
 
353
408
  class MemoryBucket(BaseModel):
354
409
  size: int
@@ -363,7 +418,10 @@ def _register_checkpoint(
363
418
  buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
364
419
  buckets[-1].size += size
365
420
 
366
- memory_buffers = [MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) for bucket in buckets]
421
+ memory_buffers = [
422
+ MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
423
+ for bucket in buckets
424
+ ]
367
425
 
368
426
  def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
369
427
  buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
@@ -373,7 +431,10 @@ def _register_checkpoint(
373
431
  buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
374
432
 
375
433
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
376
- futures = [executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(buckets)]
434
+ futures = [
435
+ executor.submit(register_pin_memory, idx, bucket.size)
436
+ for idx, bucket in enumerate(buckets)
437
+ ]
377
438
  new_futures = []
378
439
  for future in concurrent.futures.as_completed(futures):
379
440
  idx, buffer = future.result()
@@ -400,8 +461,26 @@ def _register_checkpoint(
400
461
  return memory_buffers
401
462
 
402
463
 
403
- def request_inference_to_update(url: str, socket_paths: dict[str, str], timeout: float = 300.0):
404
- resp = requests.post(
464
+ def request_inference_to_update(
465
+ url: str,
466
+ socket_paths: dict[str, str],
467
+ timeout: float = 300.0,
468
+ uds: str | None = None,
469
+ ):
470
+ """Send an inference update request to inference server via HTTP or Unix socket.
471
+
472
+ Args:
473
+ url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
474
+ socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
475
+ timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
476
+ uds (str, optional): Path to a Unix domain socket. If provided, the request
477
+ will be sent via the Unix socket instead of HTTP. Defaults to None.
478
+
479
+ Raises:
480
+ httpx.HTTPStatusError: If the response contains an HTTP error status.
481
+ httpx.RequestError: If there was an issue while making the request.
482
+ """
483
+ resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
405
484
  url,
406
485
  json={
407
486
  "method": "update_weights_from_ipc",
@@ -413,7 +492,9 @@ def request_inference_to_update(url: str, socket_paths: dict[str, str], timeout:
413
492
  resp.raise_for_status()
414
493
 
415
494
 
416
- def _gen_h2d_buckets(global_metas: dict[int, MemoryBufferMetaList], bucket_size: int) -> list[tuple[int, H2DBucket]]:
495
+ def _gen_h2d_buckets(
496
+ global_metas: dict[int, MemoryBufferMetaList], bucket_size: int
497
+ ) -> list[tuple[int, H2DBucket]]:
417
498
  buckets: list[tuple[int, H2DBucket]] = []
418
499
 
419
500
  for owner_rank, items in global_metas.items():
@@ -424,14 +505,18 @@ def _gen_h2d_buckets(global_metas: dict[int, MemoryBufferMetaList], bucket_size:
424
505
  s = _align_size(meta.dtype, meta.shape)
425
506
  if buckets[-1][1].size + s > bucket_size:
426
507
  if offset - start_offset > 0:
427
- buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
508
+ buckets[-1][1].ranges.append(
509
+ BucketRange(idx, start_offset, offset - start_offset)
510
+ )
428
511
  start_offset = offset
429
512
  buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
430
513
  offset += s
431
514
  buckets[-1][1].size += s
432
515
  buckets[-1][1].items.append(meta)
433
516
  buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
434
- assert buckets[-1][1].size > 0, f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
517
+ assert buckets[-1][1].size > 0, (
518
+ f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
519
+ )
435
520
  return buckets
436
521
 
437
522
 
@@ -470,7 +555,9 @@ class P2PStore:
470
555
  raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
471
556
  self.port = self.engine.get_rpc_port()
472
557
  self.named_tensors: dict[str, torch.Tensor] = {}
473
- logger.info(f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}")
558
+ logger.info(
559
+ f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
560
+ )
474
561
 
475
562
  @property
476
563
  def addr(self) -> str:
@@ -492,16 +579,24 @@ class P2PStore:
492
579
  num_unregistered = 0
493
580
  for i, name in enumerate(names):
494
581
  del self.named_tensors[name]
495
- logger.info(f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}")
582
+ logger.info(
583
+ f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
584
+ )
496
585
  num_unregistered += 1
497
586
  return num_unregistered
498
587
 
499
- def batch_transfer_sync_read(self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]):
500
- assert self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
588
+ def batch_transfer_sync_read(
589
+ self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
590
+ ):
591
+ assert (
592
+ self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
593
+ )
501
594
 
502
595
 
503
596
  class ParameterServer:
504
- def __init__(self, *, auto_pg: bool = False):
597
+ def __init__(
598
+ self, *, rank: int | None = None, world_size: int | None = None, auto_pg: bool = False
599
+ ):
505
600
  """
506
601
  Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
507
602
 
@@ -509,20 +604,19 @@ class ParameterServer:
509
604
  auto_pg: Whether to automatically initialize the process group.
510
605
  Notice that if auto_pg is True, will destroy the process group after update.
511
606
  """
512
- self._rank = int(os.environ.get("RANK", None))
513
- self._world_size = int(os.environ.get("WORLD_SIZE", None))
514
- self._master_addr = os.getenv("MASTER_ADDR")
607
+ self._rank = rank or int(os.environ.get("RANK", None))
608
+ self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
515
609
  self._gpu_count = torch.cuda.device_count()
516
610
  self._local_rank = self._rank % self._gpu_count
517
611
  self._auto_pg = auto_pg
518
612
  self._all_hosts = []
519
- self._global_socket_paths: list[tuple[str, str]] = []
613
+ self._global_device_uuids: list[str] = []
520
614
 
521
615
  assert self._rank is not None and self._rank >= 0, self._rank
522
616
  assert self._world_size and self._world_size > 0, self._world_size
523
617
 
524
- self._device_uuid = _get_physical_gpu_id(self._local_rank)
525
618
  self._zmq_ctx = zmq.Context()
619
+ self._zmq_addr_counter = 0
526
620
 
527
621
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
528
622
  # dict key is owner_rank, value is a bucket metas list in owner_rank
@@ -533,19 +627,27 @@ class ParameterServer:
533
627
  logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
534
628
  self._p2p_store = None
535
629
 
536
- torch.cuda.set_device(self._local_rank)
630
+ device_index = self._local_rank
631
+ torch.cuda.set_device(device_index)
632
+ self._device_uuid = _get_physical_gpu_id(device_index)
537
633
 
538
- def _logger_rank0(self, msg):
634
+ def _logger_rank0(self, msg: str):
539
635
  if self._local_rank == 0:
540
636
  logger.info(msg)
541
637
 
542
- def get_metas(self):
638
+ def get_metas(self) -> dict[int, MemoryBufferMetaList]:
543
639
  return self._current_global_parameter_metas
544
640
 
545
641
  def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
546
642
  self._current_global_parameter_metas = metas
547
643
 
548
- def register_checkpoint(self, checkpoint_name: str, *, files: list[str] = [], named_tensors: dict[str, torch.Tensor] = {}):
644
+ def register_checkpoint(
645
+ self,
646
+ checkpoint_name: str,
647
+ *,
648
+ files: list[str] | None = None,
649
+ named_tensors: dict[str, torch.Tensor] | None = None,
650
+ ) -> None:
549
651
  """
550
652
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
551
653
 
@@ -555,12 +657,18 @@ class ParameterServer:
555
657
  named_tensors: The named tensors to register.
556
658
  """
557
659
  try:
558
- assert checkpoint_name not in self._memory_pool, f"checkpoint {checkpoint_name} already registered"
559
- self._memory_pool[checkpoint_name] = _register_checkpoint(files=files, named_tensors=named_tensors, rank=self._rank)
660
+ assert checkpoint_name not in self._memory_pool, (
661
+ f"checkpoint {checkpoint_name} already registered"
662
+ )
663
+ self._memory_pool[checkpoint_name] = _register_checkpoint(
664
+ files=files or [], named_tensors=named_tensors or {}, rank=self._rank
665
+ )
560
666
  if self._p2p_store is not None:
561
667
  self._register_parameters_to_p2p_store(checkpoint_name)
562
668
  except Exception:
563
- logger.exception(f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}")
669
+ logger.exception(
670
+ f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
671
+ )
564
672
  if self._p2p_store is not None:
565
673
  self._unregister_parameters_from_p2p_store(checkpoint_name)
566
674
  self.unregister_checkpoint(checkpoint_name)
@@ -583,10 +691,6 @@ class ParameterServer:
583
691
  # this works by using torch>=2.5.0
584
692
  torch._C._host_emptyCache()
585
693
 
586
- @cached_property
587
- def _zmq_socket_path(self) -> str:
588
- return f"ipc://@checkpoint-engine-{uuid.uuid4()}.sock"
589
-
590
694
  def gather_metas(self, checkpoint_name: str):
591
695
  """
592
696
  Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
@@ -598,19 +702,17 @@ class ParameterServer:
598
702
  assert dist.is_initialized(), "process group is not initialized"
599
703
  metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
600
704
  metas = DataToGather(
601
- memory_buffer_metas_list=list(
602
- map(
603
- lambda x: MemoryBufferMetas(
604
- metas=x.metas,
605
- ptr=x.buffer.data_ptr(),
606
- size=x.size,
607
- ),
608
- self._memory_pool.get(checkpoint_name, []),
609
- ),
610
- ),
705
+ memory_buffer_metas_list=[
706
+ MemoryBufferMetas(
707
+ metas=x.metas,
708
+ ptr=x.buffer.data_ptr(),
709
+ size=x.size,
710
+ )
711
+ for x in self._memory_pool.get(checkpoint_name, [])
712
+ ],
611
713
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
612
714
  host_ip=_get_ip(),
613
- zmq_socket_path=(self._device_uuid, self._zmq_socket_path),
715
+ device_uuid=self._device_uuid,
614
716
  )
615
717
 
616
718
  dist.all_gather_object(metas_lst, metas)
@@ -618,23 +720,31 @@ class ParameterServer:
618
720
  self._current_global_parameter_metas = {}
619
721
  num_parameters = 0
620
722
  all_hosts: list[str] = []
621
- global_socket_paths: list[tuple[str, str]] = []
723
+ global_device_uuids: list[str] = []
622
724
  for i, metas_buckets in enumerate(metas_lst):
623
725
  assert metas_buckets is not None, f"metas_buckets {i} should not be None"
624
726
  if i % self._gpu_count == 0 and not self._all_hosts:
625
727
  all_hosts.append(metas_buckets.host_ip)
626
- if not self._global_socket_paths:
627
- global_socket_paths.append(metas_buckets.zmq_socket_path)
728
+ if not self._global_device_uuids:
729
+ global_device_uuids.append(metas_buckets.device_uuid)
628
730
  if metas_buckets.memory_buffer_metas_list:
629
731
  self._current_global_parameter_metas[i] = metas_buckets
630
- num_parameters += sum(map(lambda x: len(x.metas), metas_buckets.memory_buffer_metas_list))
732
+ num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
631
733
  if not self._all_hosts:
632
734
  self._all_hosts = all_hosts
633
- if not self._global_socket_paths:
634
- self._global_socket_paths = global_socket_paths
635
- logger.info(f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}")
735
+ if not self._global_device_uuids:
736
+ self._global_device_uuids = global_device_uuids
737
+ logger.info(
738
+ f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
739
+ )
636
740
 
637
- def init_process_group(self, *, master_port: int | None = None, timeout: timedelta = timedelta(minutes=10)):
741
+ def init_process_group(
742
+ self,
743
+ *,
744
+ master_addr: str | None = None,
745
+ master_port: int | None = None,
746
+ timeout: timedelta = timedelta(minutes=10),
747
+ ):
638
748
  """
639
749
  Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
640
750
 
@@ -642,10 +752,22 @@ class ParameterServer:
642
752
  master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
643
753
  timeout: The timeout of the process group.
644
754
  """
755
+ master_addr = master_addr or os.getenv("MASTER_ADDR")
756
+ assert master_addr, "master_addr is required"
645
757
  store = dist.TCPStore(
646
- self._master_addr, _get_master_port(master_port), self._world_size, timeout=timeout, is_master=self._rank == 0
758
+ master_addr,
759
+ _get_master_port(master_port),
760
+ self._world_size,
761
+ timeout=timeout,
762
+ is_master=self._rank == 0,
763
+ )
764
+ dist.init_process_group(
765
+ backend="nccl",
766
+ world_size=self._world_size,
767
+ rank=self._rank,
768
+ timeout=timeout,
769
+ store=store,
647
770
  )
648
- dist.init_process_group(backend="nccl", world_size=self._world_size, rank=self._rank, timeout=timeout, store=store)
649
771
  logger.info(f"[rank{self._rank}] init process group successfully.")
650
772
 
651
773
  def update(
@@ -653,8 +775,8 @@ class ParameterServer:
653
775
  checkpoint_name: str,
654
776
  req_func: Callable[[list[tuple[str, str]]], None],
655
777
  *,
656
- ranks: list[int] = [],
657
- ):
778
+ ranks: list[int] | None = None,
779
+ ) -> None:
658
780
  """
659
781
  Update the checkpoint to inference engine. This function should be called after gather_metas.
660
782
 
@@ -667,6 +789,7 @@ class ParameterServer:
667
789
  which is useful in disaggregated architecture.
668
790
  """
669
791
  try:
792
+ # if both ranks is None or [], it will use fully broadcast to update to all ranks
670
793
  if not ranks:
671
794
  if self._auto_pg and not dist.is_initialized():
672
795
  self.init_process_group()
@@ -692,21 +815,39 @@ class ParameterServer:
692
815
  f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
693
816
  )
694
817
  except Exception as e:
695
- logger.exception(f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}")
696
- raise e
818
+ logger.exception(
819
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
820
+ )
821
+ raise
822
+
823
+ def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
824
+ def zmq_handle(device_uuid: str) -> str:
825
+ return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
697
826
 
698
- def _get_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
699
- GiB_bytes = 1 << 30
827
+ socket_paths = [(uid, zmq_handle(uid)) for uid in self._global_device_uuids]
828
+ socket = self._zmq_ctx.socket(zmq.REQ)
829
+ socket.bind(zmq_handle(self._device_uuid))
830
+ self._zmq_addr_counter += 1
831
+ return socket, socket_paths
832
+
833
+ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
834
+ GiB = 1 << 30 # noqa: N806
700
835
  # auto detect bucket size
701
- free_bytes_tensor = torch.tensor(
702
- int(float(torch.cuda.mem_get_info()[0]) * 0.9),
836
+ tensor = torch.tensor(
837
+ [
838
+ # 90% of current cuda free memory bytes
839
+ int(float(torch.cuda.mem_get_info()[0]) * 0.9),
840
+ # we use negative value to reuse allreduce min operation
841
+ # for getting the max value of zmq_addr_counter in all ranks
842
+ -self._zmq_addr_counter,
843
+ ],
703
844
  dtype=torch.int64,
704
845
  device="cuda",
705
846
  )
706
- dist.all_reduce(free_bytes_tensor, op=dist.ReduceOp.MIN)
707
- free_bytes = free_bytes_tensor.item()
847
+ dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
848
+ tensor = tensor.cpu()
849
+ free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
708
850
  max_tensor_bytes = 0
709
- max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB_bytes
710
851
  for items in self._current_global_parameter_metas.values():
711
852
  for metas_list in items.memory_buffer_metas_list:
712
853
  for meta in metas_list.metas:
@@ -729,18 +870,27 @@ class ParameterServer:
729
870
  f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
730
871
  )
731
872
  disable_h2d_buffer = True
873
+ max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
732
874
  bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
733
- logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB_bytes:.2f} GiB")
875
+ logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
734
876
  return bucket_size, disable_h2d_buffer
735
877
 
736
- def _copy_to_buffer(self, checkpoint_name: str, bucket: H2DBucket, buffer: torch.Tensor, owner_rank: int | None = None):
878
+ def _copy_to_buffer(
879
+ self,
880
+ checkpoint_name: str,
881
+ bucket: H2DBucket,
882
+ buffer: torch.Tensor,
883
+ owner_rank: int | None = None,
884
+ ):
737
885
  offset = 0
738
886
  if owner_rank is not None:
739
887
  buf_ptrs, remote_ptrs, lens = [], [], []
740
888
  ptr_base = buffer.data_ptr()
741
889
  target_addr, ptrs = self._get_addr_ptrs(owner_rank)
742
890
  for b in bucket.ranges:
743
- assert offset + b.size <= bucket.size, f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
891
+ assert offset + b.size <= bucket.size, (
892
+ f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
893
+ )
744
894
  if owner_rank is not None:
745
895
  buf_ptrs.append(ptr_base + offset)
746
896
  remote_ptrs.append(ptrs[b.idx][0] + b.offset)
@@ -758,7 +908,11 @@ class ParameterServer:
758
908
  torch.cuda.synchronize()
759
909
 
760
910
  def init_process_group_for_ranks(
761
- self, ranks: list[int], *, master_port: int | None = None, timeout: timedelta = timedelta(minutes=10)
911
+ self,
912
+ ranks: list[int],
913
+ *,
914
+ master_port: int | None = None,
915
+ timeout: timedelta = timedelta(minutes=10),
762
916
  ):
763
917
  """
764
918
  Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
@@ -787,8 +941,12 @@ class ParameterServer:
787
941
  # and will not participate in this update. Since they have registered memory addresses
788
942
  # to p2p_store at the beginning, update ranks can directly get the memory addresses
789
943
  # from other nodes and put the weights into the buffer.
790
- store = dist.TCPStore(master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout)
791
- dist.init_process_group(backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store)
944
+ store = dist.TCPStore(
945
+ master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
946
+ )
947
+ dist.init_process_group(
948
+ backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
949
+ )
792
950
 
793
951
  def _update_per_bucket_p2p(
794
952
  self,
@@ -800,7 +958,9 @@ class ParameterServer:
800
958
  assert ranks, "ranks should be set"
801
959
  if len(self._current_global_parameter_metas) == 0:
802
960
  raise ValueError("parameter metas is empty")
803
- assert dist.is_initialized(), "process group is not initialized when update model per bucket p2p"
961
+ assert dist.is_initialized(), (
962
+ "process group is not initialized when update model per bucket p2p"
963
+ )
804
964
 
805
965
  need_update = self._rank in ranks
806
966
  logger.info(
@@ -814,26 +974,24 @@ class ParameterServer:
814
974
  # first execute a barrier to avoid subsequent cuda oom
815
975
  dist.barrier()
816
976
 
817
- bucket_size, _ = self._get_bucket_size(disable_h2d_buffer=True)
977
+ bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
818
978
  buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
819
- IPC_BUFFER_NAME = "__ipc_buffer___"
820
- self._p2p_store.register_named_tensors({IPC_BUFFER_NAME: buffer})
979
+ ipc_buffer_name = "__ipc_buffer___"
980
+ self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
821
981
  logger.info(
822
982
  f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
823
983
  )
824
984
  handle = reduce_tensor(buffer)
825
985
 
826
- gidx = 0
827
986
  buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
987
+ socket, socket_paths = self._bind_zmq_socket()
828
988
  req_thread = threading.Thread(
829
989
  target=req_func,
830
- args=(self._global_socket_paths,),
990
+ args=(socket_paths,),
831
991
  )
832
992
  req_thread.start()
833
- socket = self._zmq_ctx.socket(zmq.REQ)
834
- socket.bind(self._zmq_socket_path)
835
993
  socket.send_pyobj(handle)
836
- for owner_rank, bucket in buckets:
994
+ for gidx, (owner_rank, bucket) in enumerate(buckets):
837
995
  self._logger_rank0(
838
996
  f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
839
997
  )
@@ -845,7 +1003,6 @@ class ParameterServer:
845
1003
  socket.recv()
846
1004
  dist.barrier()
847
1005
  socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
848
- gidx += 1
849
1006
 
850
1007
  socket.recv()
851
1008
  socket.send_pyobj(None)
@@ -853,7 +1010,7 @@ class ParameterServer:
853
1010
  req_thread.join()
854
1011
  dist.barrier()
855
1012
  socket.close()
856
- self._p2p_store.unregister_named_tensors([IPC_BUFFER_NAME])
1013
+ self._p2p_store.unregister_named_tensors([ipc_buffer_name])
857
1014
  torch.cuda.empty_cache()
858
1015
 
859
1016
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
@@ -877,7 +1034,9 @@ class ParameterServer:
877
1034
  pool = self._memory_pool[checkpoint_name]
878
1035
  if len(pool) == 0:
879
1036
  return 0
880
- return self._p2p_store.unregister_named_tensors([f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)])
1037
+ return self._p2p_store.unregister_named_tensors(
1038
+ [f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
1039
+ )
881
1040
 
882
1041
  def _update_per_bucket(
883
1042
  self,
@@ -891,11 +1050,13 @@ class ParameterServer:
891
1050
 
892
1051
  logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
893
1052
 
894
- bucket_size, disable_h2d_buffer = self._get_bucket_size()
1053
+ bucket_size, disable_h2d_buffer = self._detect_bucket_size()
895
1054
  buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
896
1055
 
897
1056
  h2d_buffer: torch.Tensor | None = (
898
- None if disable_h2d_buffer else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1057
+ None
1058
+ if disable_h2d_buffer
1059
+ else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
899
1060
  )
900
1061
 
901
1062
  owner_rank_buckets: list[H2DBucket] = []
@@ -914,13 +1075,12 @@ class ParameterServer:
914
1075
  if len(buckets_by_owner_rank[owner_rank]) > max_len:
915
1076
  max_len = len(buckets_by_owner_rank[owner_rank])
916
1077
 
1078
+ socket, socket_paths = self._bind_zmq_socket()
917
1079
  req_thread = threading.Thread(
918
1080
  target=req_func,
919
- args=(self._global_socket_paths,),
1081
+ args=(socket_paths,),
920
1082
  )
921
1083
  req_thread.start()
922
- socket = self._zmq_ctx.socket(zmq.REQ)
923
- socket.bind(self._zmq_socket_path)
924
1084
  socket.send_pyobj(handle)
925
1085
 
926
1086
  gidx = 0
@@ -932,7 +1092,10 @@ class ParameterServer:
932
1092
  if i >= len(_buckets):
933
1093
  continue
934
1094
  bucket = _buckets[i]
935
- alloc, reserved = torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.memory_reserved() / 1024 / 1024
1095
+ alloc, reserved = (
1096
+ torch.cuda.memory_allocated() / 1024 / 1024,
1097
+ torch.cuda.memory_reserved() / 1024 / 1024,
1098
+ )
936
1099
  self._logger_rank0(
937
1100
  f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
938
1101
  f"Current CUDA allocated {alloc:.2f} MB, "
@@ -960,7 +1123,7 @@ class ParameterServer:
960
1123
  torch.cuda.empty_cache()
961
1124
 
962
1125
 
963
- def _init_api(ps: ParameterServer):
1126
+ def _init_api(ps: ParameterServer) -> Any:
964
1127
  import fastapi
965
1128
  from fastapi import Request
966
1129
  from fastapi.responses import JSONResponse, Response
@@ -976,32 +1139,32 @@ def _init_api(ps: ParameterServer):
976
1139
  inference_group_ranks: list[int] = []
977
1140
  timeout: float = 300.0
978
1141
 
979
- def wrap_exception(func):
1142
+ def wrap_exception(func: Callable[[], None]) -> Response:
980
1143
  try:
981
1144
  func()
982
- except Exception as e:
1145
+ except Exception as e: # noqa: BLE001
983
1146
  logger.exception(f"wrap exception {func} failed")
984
1147
  return JSONResponse(content=str(e), status_code=500)
985
1148
  return Response(status_code=200)
986
1149
 
987
1150
  @app.post("/v1/checkpoints/{checkpoint_name}/files")
988
- async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request):
1151
+ async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
989
1152
  return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
990
1153
 
991
1154
  @app.delete("/v1/checkpoints/{checkpoint_name}")
992
- async def unregister_checkpoint(checkpoint_name: str):
1155
+ async def unregister_checkpoint(checkpoint_name: str) -> Response:
993
1156
  return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
994
1157
 
995
1158
  @app.get("/v1/healthz")
996
- async def healthz():
1159
+ async def healthz() -> Response:
997
1160
  return Response(status_code=200)
998
1161
 
999
1162
  @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
1000
- async def gather_metas(checkpoint_name: str):
1163
+ async def gather_metas(checkpoint_name: str) -> Response:
1001
1164
  return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
1002
1165
 
1003
1166
  @app.post("/v1/checkpoints/{checkpoint_name}/update")
1004
- async def update(checkpoint_name: str, req: UpdateRequest):
1167
+ async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
1005
1168
  def update_func(socket_paths: list[tuple[str, str]]):
1006
1169
  if req.update_url is None:
1007
1170
  return
@@ -1018,11 +1181,13 @@ def _init_api(ps: ParameterServer):
1018
1181
  def run_from_cli():
1019
1182
  import uvicorn
1020
1183
 
1021
- parser = argparse.ArgumentParser(description="Paramter Server")
1184
+ parser = argparse.ArgumentParser(description="Parameter Server")
1022
1185
  parser.add_argument("--uds", type=str)
1023
1186
 
1024
1187
  args = parser.parse_args()
1025
- logger.info(f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}")
1188
+ logger.info(
1189
+ f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
1190
+ )
1026
1191
 
1027
1192
  assert args.uds and len(args.uds) > 0, args.uds
1028
1193
  ps = ParameterServer(auto_pg=True)
@@ -1,11 +1,12 @@
1
1
  import gc
2
- from typing import Callable, Optional, TypedDict
2
+ from collections.abc import Callable
3
+ from typing import TypedDict
3
4
 
4
5
  import torch
5
6
  import zmq
6
7
 
7
8
 
8
- def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: Optional[int] = None) -> torch.Tensor:
9
+ def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
9
10
  func, args = handle
10
11
  list_args = list(args)
11
12
  if device_id is not None:
@@ -24,12 +25,14 @@ class FlattenedTensorMetadata(TypedDict):
24
25
  offset: int
25
26
 
26
27
 
27
- def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> list[tuple[str, torch.Tensor]]:
28
+ def _extract_weights(
29
+ payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
30
+ ) -> list[tuple[str, torch.Tensor]]:
28
31
  assert buffer is not None
29
32
  weights: list[tuple[str, torch.Tensor]] = []
30
33
  for item in payload:
31
34
  shape = item["shape"]
32
- if isinstance(shape, (list, tuple)):
35
+ if isinstance(shape, list | tuple):
33
36
  shape = torch.Size(shape)
34
37
  assert isinstance(shape, torch.Size)
35
38
  dtype, offset = item["dtype"], item["offset"]
@@ -45,11 +48,11 @@ def update_weights_from_ipc(
45
48
  device_id: int,
46
49
  *,
47
50
  run: Callable[[list[tuple[str, torch.Tensor]]], None],
48
- post_hook: Callable[[], None] = None,
51
+ post_hook: Callable[[], None] | None = None,
49
52
  ):
50
53
  socket = zmq_ctx.socket(zmq.REP)
51
54
  socket.connect(zmq_handle)
52
- buffer: Optional[torch.Tensor] = None
55
+ buffer: torch.Tensor | None = None
53
56
  while True:
54
57
  payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
55
58
  if payload is None:
@@ -100,5 +103,7 @@ class VllmColocateWorkerExtension:
100
103
  zmq_handles[device_uuid],
101
104
  device_id=self.device.index,
102
105
  run=self.model_runner.model.load_weights,
103
- post_hook=lambda: process_weights_after_loading(self.model_runner.model, self.model_config, self.device),
106
+ post_hook=lambda: process_weights_after_loading(
107
+ self.model_runner.model, self.model_config, self.device
108
+ ),
104
109
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.1
3
+ Version: 0.1.2
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -11,14 +11,15 @@ Description-Content-Type: text/markdown
11
11
  License-File: LICENCE
12
12
  Requires-Dist: torch>=2.5.0
13
13
  Requires-Dist: fastapi
14
- Requires-Dist: pydantic
14
+ Requires-Dist: pydantic>=2.0.0
15
15
  Requires-Dist: safetensors
16
16
  Requires-Dist: pyzmq
17
17
  Requires-Dist: uvicorn
18
18
  Requires-Dist: loguru
19
19
  Requires-Dist: numpy
20
+ Requires-Dist: httpx
20
21
  Provides-Extra: p2p
21
- Requires-Dist: mooncake-transfer-engine; extra == "p2p"
22
+ Requires-Dist: mooncake-transfer-engine>=0.3.5; extra == "p2p"
22
23
  Dynamic: license-file
23
24
 
24
25
  # Checkpoint Engine
@@ -41,7 +42,7 @@ The core weight update logic is in `ParameterServer` class, a service colocated
41
42
  - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
42
43
 
43
44
  ### Optimized Weight Broadcast
44
- In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
45
+ In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
45
46
  We arrange the data transfer into 3 stages:
46
47
  1. H2D: moving weights to GPU memory. These weights may come from disk or the training engine.
47
48
  2. broadcast: broadcast among checkpoint engine workers; the data results in a CUDA IPC buffer shared with inference engine.
@@ -73,9 +74,9 @@ Pipelining naturally requires more GPU memory. When memory is not enough, checkp
73
74
  All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
74
75
 
75
76
  * FP8 test needs additional vLLM patches, see [FP8 quantization](#fp8-quantization).
76
- * Device Info: we tested various combination of devices and paralleism setups. For exmaple, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
77
+ * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
77
78
  * Since update duration is related to IPC bucket size, we provide the bucket size in the table.
78
- * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
79
+ * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
79
80
 
80
81
  ## Installation
81
82
 
@@ -88,7 +89,7 @@ pip install checkpoint-engine
88
89
  Use the flexible P2P implementation, notice this will install `mooncake-transfer-engine` to support RDMA transfer between different ranks.
89
90
 
90
91
  ```Bash
91
- pip install checkpoint-engine[p2p]
92
+ pip install 'checkpoint-engine[p2p]'
92
93
  ```
93
94
 
94
95
  If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
@@ -107,7 +108,7 @@ VLLM_USE_PRECOMPILED=1 uv pip install --editable .
107
108
  Install checkpoint-engine
108
109
 
109
110
  ```Bash
110
- uv pip install checkpoint-engine[p2p]
111
+ uv pip install 'checkpoint-engine[p2p]'
111
112
  ```
112
113
 
113
114
  We use `Qwen/Qwen3-235B-A22B-Instruct-2507` (BF16) as the test model
@@ -133,7 +134,7 @@ torchrun --nproc-per-node 8 examples/update.py --update-method all --checkpoint-
133
134
 
134
135
  ### Reuse weights from existing instances
135
136
 
136
- New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
137
+ New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
137
138
 
138
139
  First, start the existing instances with `--save-metas-file global_metas.pkl` to save global metas to a file and use `--sleep-time 300` to make sure they stay alive.
139
140
 
@@ -150,7 +151,7 @@ torchrun --nproc-per-node 8 examples/update.py --load-metas-file global_metas.pk
150
151
 
151
152
  ### FP8 quantization
152
153
 
153
- FP8 quantization currently do not natively work in vLLM when updating weights.
154
+ FP8 quantization currently do not natively work in vLLM when updating weights.
154
155
  We provide a simple patch in [`patches/vllm_fp8.patch`](./patches/vllm_fp8.patch) to handle the correct weight update.
155
156
  Notice this patch is only tested in DeepSeek-V3.1 and Kimi-K2. Other models may meet some compatible issues.
156
157
 
@@ -0,0 +1,9 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=Ok5oAXdWgR9aghaFXTafTeDW6sYO3uVe6d2Nket57R4,704
3
+ checkpoint_engine/ps.py,sha256=ckM2vdLg3aeOKmM_vTcbIPKcT-r-E4s73yPaCESKdwg,48439
4
+ checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
5
+ checkpoint_engine-0.1.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
6
+ checkpoint_engine-0.1.2.dist-info/METADATA,sha256=9FUb4s1KMSzrDOOV-C18q3gqcVWD_qZ4t_DaLuai_M4,9322
7
+ checkpoint_engine-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ checkpoint_engine-0.1.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
+ checkpoint_engine-0.1.2.dist-info/RECORD,,
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
18
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
19
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
20
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE
21
+ SOFTWARE
@@ -1,9 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=yfx_VE-4lpqM4jnWOSq-8rihMWIwMaX9CQ7tNEpA4T0,712
3
- checkpoint_engine/ps.py,sha256=9u2rLOj-oQrXsnpYhdXFjv7ak2-f4BRUXB6KYlG3ah0,44422
4
- checkpoint_engine/worker.py,sha256=OrSeknjtECnO88I-YMdfkZj70TIRhjvTEeZkNyZk21M,3695
5
- checkpoint_engine-0.1.1.dist-info/licenses/LICENCE,sha256=0jqA0jrA_i9VUqd7FTVoI1KnN1ZRENwG_tlMRCjC63k,1066
6
- checkpoint_engine-0.1.1.dist-info/METADATA,sha256=WDc5tg3RQiCthzbEzI4D3t3I0AQhGUPcDpJn7TMhYbI,9286
7
- checkpoint_engine-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- checkpoint_engine-0.1.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
- checkpoint_engine-0.1.1.dist-info/RECORD,,