checkpoint-engine 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.1.1'
32
- __version_tuple__ = version_tuple = (0, 1, 1)
31
+ __version__ = version = '0.1.3'
32
+ __version_tuple__ = version_tuple = (0, 1, 3)
33
33
 
34
- __commit_id__ = commit_id = 'gf29b2e1c3'
34
+ __commit_id__ = commit_id = None
checkpoint_engine/ps.py CHANGED
@@ -1,30 +1,28 @@
1
- from __future__ import annotations
2
-
3
1
  import argparse
4
2
  import concurrent.futures
3
+ import ctypes
5
4
  import os
6
5
  import pickle
7
6
  import random
8
7
  import socket
9
- import subprocess
10
8
  import threading
11
9
  import time
12
- import uuid
13
10
  from collections import defaultdict
11
+ from collections.abc import Callable
14
12
  from datetime import timedelta
15
- from functools import cached_property, lru_cache
16
- from typing import Callable, NamedTuple
13
+ from functools import lru_cache
14
+ from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
17
15
 
16
+ import httpx
18
17
  import numpy as np
19
- import requests
20
18
  import torch
21
19
  import torch.distributed as dist
22
20
  import zmq
23
21
  from loguru import logger
24
- from pydantic import BaseModel, ConfigDict
22
+ from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
25
23
  from safetensors.torch import safe_open
26
24
  from torch.multiprocessing.reductions import reduce_tensor
27
- from typing_extensions import TYPE_CHECKING
25
+
28
26
 
29
27
  if TYPE_CHECKING:
30
28
  from typing_extensions import TypedDict
@@ -37,16 +35,59 @@ if TYPE_CHECKING:
37
35
  tp_concat_dim: int
38
36
 
39
37
 
40
- class ParameterMeta(BaseModel):
41
- # now all classes are changed to pydantic BaseModel
42
- # it will directly report validation errors for unknown types
43
- # like torch.dtype, torch.Size, so we need this configuration
44
- # see https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.validate_assignment
45
- model_config = ConfigDict(arbitrary_types_allowed=True)
38
+ def _dt_validate(value: Any) -> torch.dtype:
39
+ if isinstance(value, str):
40
+ if not value.startswith("torch."):
41
+ raise ValueError(f"dtype {value} should start with torch.")
42
+ try:
43
+ value = getattr(torch, value.split(".")[1])
44
+ except AttributeError as e:
45
+ raise ValueError(f"unknown dtype: {value}") from e
46
+ if not isinstance(value, torch.dtype):
47
+ raise TypeError(f"dtype {value} should be torch.dtype, got {type(value)}")
48
+ return value
49
+
50
+
51
+ _TorchDtype = Annotated[
52
+ torch.dtype,
53
+ PlainValidator(_dt_validate),
54
+ PlainSerializer(lambda x: str(x), return_type=str),
55
+ WithJsonSchema({"type": "string"}, mode="serialization"),
56
+ ]
57
+
58
+
59
+ def _size_validate(value: Any) -> torch.Size:
60
+ if isinstance(value, list | tuple):
61
+ return torch.Size(value)
62
+ if not isinstance(value, torch.Size):
63
+ raise TypeError(f"size {value} should be torch.Size, got {type(value)}")
64
+ return value
65
+
66
+
67
+ _TorchSize = Annotated[
68
+ torch.Size,
69
+ PlainValidator(_size_validate),
70
+ PlainSerializer(lambda x: tuple(x), return_type=tuple),
71
+ WithJsonSchema({"type": "array", "items": {"type": "integer"}}, mode="serialization"),
72
+ ]
73
+
46
74
 
75
+ def _tensor_validate(value: Any) -> torch.Tensor:
76
+ if isinstance(value, torch.Tensor):
77
+ return value
78
+ raise TypeError(f"tensor {value} should be torch.Tensor, got {type(value)}")
79
+
80
+
81
+ _TorchTensor = Annotated[
82
+ torch.Tensor,
83
+ PlainValidator(_tensor_validate),
84
+ ]
85
+
86
+
87
+ class ParameterMeta(BaseModel):
47
88
  name: str
48
- dtype: torch.dtype
49
- shape: torch.Size
89
+ dtype: _TorchDtype
90
+ shape: _TorchSize
50
91
 
51
92
 
52
93
  class BucketRange(NamedTuple):
@@ -68,9 +109,7 @@ class MemoryBufferMetas(BaseModel):
68
109
 
69
110
 
70
111
  class MemoryBuffer(BaseModel):
71
- model_config = ConfigDict(arbitrary_types_allowed=True)
72
-
73
- buffer: torch.Tensor
112
+ buffer: _TorchTensor
74
113
  size: int
75
114
  metas: list[ParameterMeta]
76
115
 
@@ -82,7 +121,7 @@ class MemoryBufferMetaList(BaseModel):
82
121
 
83
122
  class DataToGather(MemoryBufferMetaList):
84
123
  host_ip: str
85
- zmq_socket_path: tuple[str, str]
124
+ device_uuid: str
86
125
 
87
126
 
88
127
  # 256 bytes alignment when flatten torch tensors to uint8 buffer
@@ -93,7 +132,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
93
132
  return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
94
133
 
95
134
 
96
- def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
135
+ def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
97
136
  ret = []
98
137
  for meta in metas:
99
138
  size = _align_size(meta.dtype, meta.shape)
@@ -109,11 +148,11 @@ def _to_named_tensor(metas: list[ParameterMeta], offset=0) -> list[dict]:
109
148
  return ret
110
149
 
111
150
 
112
- def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta, torch.Tensor]]]:
113
- def _safetensors_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
151
+ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
152
+ def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
114
153
  ret = {}
115
154
  with safe_open(fn, framework="pt") as f:
116
- for name in f.keys():
155
+ for name in f.keys(): # noqa: SIM118
117
156
  weight = f.get_tensor(name)
118
157
  meta = {
119
158
  "key": name,
@@ -126,10 +165,10 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
126
165
  return ret
127
166
 
128
167
  # deprecated, will be removed in the future
129
- def _fast_np_load(fn) -> dict[str, tuple[FileMeta, torch.Tensor]]:
168
+ def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
130
169
  """load *.np file and return memmap and related tensor meta"""
131
170
 
132
- def parse_npy_header(fin):
171
+ def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
133
172
  start = fin.tell()
134
173
  major, minor = np.lib.format.read_magic(fin)
135
174
  if major == 1 and minor == 0:
@@ -137,7 +176,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
137
176
  elif major == 2 and minor == 0:
138
177
  read_header_fn = np.lib.format.read_array_header_2_0
139
178
  else:
140
- raise ValueError(f"unknown version {major}.{minor} when parsing npy header from {fn}")
179
+ raise ValueError(
180
+ f"unknown version {major}.{minor} when parsing npy header from {fn}"
181
+ )
141
182
  shape, is_fortran, dtype = read_header_fn(fin)
142
183
  return {
143
184
  "shape": shape,
@@ -193,7 +234,9 @@ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple[FileMeta
193
234
  return tp_rank, ret
194
235
 
195
236
 
196
- def _concat_tp_weights(tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int) -> torch.Tensor:
237
+ def _concat_tp_weights(
238
+ tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
239
+ ) -> torch.Tensor:
197
240
  """Concat tp weights with meta info.
198
241
  If meta.concat_dim is -1, meas this is shared tp weights, just use the first weights.
199
242
  Else we will cat weights in concat_dim.
@@ -206,39 +249,54 @@ def _concat_tp_weights(tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_si
206
249
  return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
207
250
 
208
251
 
209
- def _get_physical_gpu_id(rank: int) -> str:
210
- result = subprocess.run(["nvidia-smi", "-L"], capture_output=True, text=True)
211
- if result.returncode != 0:
212
- raise ValueError(result.stdout)
213
- lines = result.stdout.strip().split("\n")
214
- for line in lines:
215
- if f"GPU {rank}" in line:
216
- uuid = line.split("UUID: ")[1].strip(")")
217
- return uuid
218
- raise ValueError(f"not found gpu{rank} uuid")
252
+ def _get_physical_gpu_id(device_index: int | None = None) -> str:
253
+ try:
254
+ return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
255
+ except AssertionError as e:
256
+ raise ValueError(f"fail to get physical gpu id {device_index}") from e
219
257
 
220
258
 
221
259
  @lru_cache(maxsize=1)
222
- def _get_ip():
260
+ def _get_ip() -> str:
223
261
  try:
224
262
  # try to get ip from network interface
225
263
  with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
226
264
  s.connect(("8.8.8.8", 80))
227
265
  return s.getsockname()[0]
228
- except:
266
+ except Exception as e: # noqa: BLE001
229
267
  # fallback to get ip from hostname
230
- logger.warning("fail to get ip from network interface, fallback to get ip from hostname")
268
+ logger.warning(
269
+ f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
270
+ )
231
271
  return socket.gethostbyname(socket.gethostname())
232
272
 
233
273
 
274
+ def _ibv_get_device_list() -> list[str]:
275
+ lib = ctypes.CDLL("libibverbs.so.1")
276
+ lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
277
+ lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
278
+
279
+ lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
280
+ lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
281
+ lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
282
+
283
+ num = ctypes.c_int()
284
+ dev_array = lib.ibv_get_device_list(ctypes.byref(num))
285
+ if not dev_array or num.value <= 0:
286
+ return []
287
+
288
+ devices = []
289
+ for i in range(num.value):
290
+ dev_ptr = dev_array[i] # struct ibv_device *
291
+ name = lib.ibv_get_device_name(dev_ptr) # const char *
292
+ devices.append(name.decode())
293
+ lib.ibv_free_device_list(dev_array)
294
+ return devices
295
+
296
+
234
297
  def _get_rdma_devices() -> list[str]:
235
298
  """
236
- use script like below to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
237
- ```bash
238
- pushd /sys/class/infiniband/ > /dev/null;
239
- for i in mlx5_*; do cat "$i"/ports/1/gid_attrs/types/* 2>/dev/null | grep v >/dev/null && echo "$i" ; done;
240
- popd > /dev/null;
241
- ```
299
+ use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
242
300
  """
243
301
  devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
244
302
  if devices_str:
@@ -246,41 +304,27 @@ def _get_rdma_devices() -> list[str]:
246
304
  # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
247
305
  hca = os.getenv("NCCL_IB_HCA", None)
248
306
  if hca:
249
- l = hca.split(",")
250
- if len(l) > 1:
307
+ hca_list = hca.split(",")
308
+ if len(hca_list) > 1:
251
309
  # if NCCL_IB_HCA has multiple values, just return
252
- return l
310
+ return hca_list
253
311
  else:
254
- hca = l[0]
255
- basepath = "/sys/class/infiniband/"
256
- port_path = "ports/1/gid_attrs/types"
257
- devices = []
258
- for device in sorted(os.listdir(basepath)):
259
- if hca is not None and hca not in device:
260
- continue
261
- path = os.path.join(basepath, device, port_path)
262
- if not os.path.exists(path) or not os.path.isdir(path):
263
- continue
264
- for port in os.listdir(path):
265
- try:
266
- content = open(os.path.join(path, port)).read()
267
- if "v" in content:
268
- print(f"found rdma device {device} in port {port}: {content.strip()}")
269
- devices.append(device)
270
- break
271
- except Exception:
272
- pass
273
- return devices
312
+ hca = hca_list[0]
313
+ return [device for device in sorted(_ibv_get_device_list()) if hca is None or hca in device]
274
314
 
275
315
 
276
- def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]):
316
+ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
277
317
  """
278
318
  implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
279
319
  """
280
320
  if not devices:
281
321
  raise RuntimeError("no rdma devices found")
282
- assert len(devices) <= gpu_count, f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
283
- assert gpu_count % len(devices) == 0, f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
322
+ assert len(devices) <= gpu_count, (
323
+ f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
324
+ )
325
+ assert gpu_count % len(devices) == 0, (
326
+ f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
327
+ )
284
328
  return devices[local_rank // (gpu_count // len(devices))]
285
329
 
286
330
 
@@ -305,8 +349,12 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
305
349
  size=1,
306
350
  )
307
351
  if parameter_name not in parameter_metas:
308
- assert isinstance(meta["dtype"], torch.dtype), f"meta {meta} dtype should be torch.dtype"
309
- assert isinstance(meta["shape"], torch.Size), f"meta {meta} shape should be torch.Size"
352
+ assert isinstance(meta["dtype"], torch.dtype), (
353
+ f"meta {meta} dtype should be torch.dtype"
354
+ )
355
+ assert isinstance(meta["shape"], torch.Size), (
356
+ f"meta {meta} shape should be torch.Size"
357
+ )
310
358
  parameter_metas[parameter_name] = ParameterMeta(
311
359
  name=parameter_name,
312
360
  shape=meta["shape"],
@@ -319,7 +367,9 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
319
367
  if tp_meta.concat_dim != -1:
320
368
  shape = list(parameter_metas[name].shape)
321
369
  shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
322
- parameter_metas[name] = ParameterMeta(name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype)
370
+ parameter_metas[name] = ParameterMeta(
371
+ name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
372
+ )
323
373
  weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
324
374
  # TODO: here concat is serial, which may be slow
325
375
  # but since tp storage is not used in the future
@@ -338,17 +388,19 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
338
388
 
339
389
  def _register_checkpoint(
340
390
  *,
341
- files: list[str] = [],
342
- named_tensors: dict[str, torch.Tensor] = {},
391
+ files: list[str],
392
+ named_tensors: dict[str, torch.Tensor],
343
393
  rank: int | None = None,
344
394
  ) -> list[MemoryBuffer]:
345
- logger.info(f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors")
395
+ logger.info(
396
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
397
+ )
346
398
  if not files and not named_tensors:
347
399
  return []
348
400
  parameters = _load_checkpoint(files)
349
401
  if named_tensors:
350
402
  parameters.update(named_tensors)
351
- bucket_size = max(4 << 30, max(map(lambda x: _align_size(x.dtype, x.shape), parameters.values())))
403
+ bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
352
404
 
353
405
  class MemoryBucket(BaseModel):
354
406
  size: int
@@ -363,7 +415,10 @@ def _register_checkpoint(
363
415
  buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
364
416
  buckets[-1].size += size
365
417
 
366
- memory_buffers = [MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas) for bucket in buckets]
418
+ memory_buffers = [
419
+ MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
420
+ for bucket in buckets
421
+ ]
367
422
 
368
423
  def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
369
424
  buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
@@ -373,7 +428,10 @@ def _register_checkpoint(
373
428
  buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
374
429
 
375
430
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
376
- futures = [executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(buckets)]
431
+ futures = [
432
+ executor.submit(register_pin_memory, idx, bucket.size)
433
+ for idx, bucket in enumerate(buckets)
434
+ ]
377
435
  new_futures = []
378
436
  for future in concurrent.futures.as_completed(futures):
379
437
  idx, buffer = future.result()
@@ -400,8 +458,26 @@ def _register_checkpoint(
400
458
  return memory_buffers
401
459
 
402
460
 
403
- def request_inference_to_update(url: str, socket_paths: dict[str, str], timeout: float = 300.0):
404
- resp = requests.post(
461
+ def request_inference_to_update(
462
+ url: str,
463
+ socket_paths: dict[str, str],
464
+ timeout: float = 300.0,
465
+ uds: str | None = None,
466
+ ):
467
+ """Send an inference update request to inference server via HTTP or Unix socket.
468
+
469
+ Args:
470
+ url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
471
+ socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
472
+ timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
473
+ uds (str, optional): Path to a Unix domain socket. If provided, the request
474
+ will be sent via the Unix socket instead of HTTP. Defaults to None.
475
+
476
+ Raises:
477
+ httpx.HTTPStatusError: If the response contains an HTTP error status.
478
+ httpx.RequestError: If there was an issue while making the request.
479
+ """
480
+ resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
405
481
  url,
406
482
  json={
407
483
  "method": "update_weights_from_ipc",
@@ -413,7 +489,9 @@ def request_inference_to_update(url: str, socket_paths: dict[str, str], timeout:
413
489
  resp.raise_for_status()
414
490
 
415
491
 
416
- def _gen_h2d_buckets(global_metas: dict[int, MemoryBufferMetaList], bucket_size: int) -> list[tuple[int, H2DBucket]]:
492
+ def _gen_h2d_buckets(
493
+ global_metas: dict[int, MemoryBufferMetaList], bucket_size: int
494
+ ) -> list[tuple[int, H2DBucket]]:
417
495
  buckets: list[tuple[int, H2DBucket]] = []
418
496
 
419
497
  for owner_rank, items in global_metas.items():
@@ -424,14 +502,18 @@ def _gen_h2d_buckets(global_metas: dict[int, MemoryBufferMetaList], bucket_size:
424
502
  s = _align_size(meta.dtype, meta.shape)
425
503
  if buckets[-1][1].size + s > bucket_size:
426
504
  if offset - start_offset > 0:
427
- buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
505
+ buckets[-1][1].ranges.append(
506
+ BucketRange(idx, start_offset, offset - start_offset)
507
+ )
428
508
  start_offset = offset
429
509
  buckets.append((owner_rank, H2DBucket(size=0, ranges=[], items=[])))
430
510
  offset += s
431
511
  buckets[-1][1].size += s
432
512
  buckets[-1][1].items.append(meta)
433
513
  buckets[-1][1].ranges.append(BucketRange(idx, start_offset, offset - start_offset))
434
- assert buckets[-1][1].size > 0, f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
514
+ assert buckets[-1][1].size > 0, (
515
+ f"buckets[-1][1].size {buckets[-1][1].size} should be greater than 0"
516
+ )
435
517
  return buckets
436
518
 
437
519
 
@@ -470,7 +552,9 @@ class P2PStore:
470
552
  raise RuntimeError(f"[rank{self.rank}] fail to initialize transfer engine")
471
553
  self.port = self.engine.get_rpc_port()
472
554
  self.named_tensors: dict[str, torch.Tensor] = {}
473
- logger.info(f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}")
555
+ logger.info(
556
+ f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {device}"
557
+ )
474
558
 
475
559
  @property
476
560
  def addr(self) -> str:
@@ -492,37 +576,60 @@ class P2PStore:
492
576
  num_unregistered = 0
493
577
  for i, name in enumerate(names):
494
578
  del self.named_tensors[name]
495
- logger.info(f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}")
579
+ logger.info(
580
+ f"[rank{self.rank}] p2p store unregister tensor {name} with addr {hex(buffer_addresses[i])}"
581
+ )
496
582
  num_unregistered += 1
497
583
  return num_unregistered
498
584
 
499
- def batch_transfer_sync_read(self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]):
500
- assert self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
585
+ def batch_transfer_sync_read(
586
+ self, target_hostname: str, buf_ptrs: list[int], remote_ptrs: list[int], lens: list[int]
587
+ ):
588
+ assert (
589
+ self.engine.batch_transfer_sync_read(target_hostname, buf_ptrs, remote_ptrs, lens) == 0
590
+ )
501
591
 
502
592
 
503
593
  class ParameterServer:
504
- def __init__(self, *, auto_pg: bool = False):
594
+ def __init__(
595
+ self,
596
+ *,
597
+ rank: int | None = None,
598
+ world_size: int | None = None,
599
+ auto_pg: bool = False,
600
+ gpu_count: int | None = None,
601
+ mem_fraction: float | None = None,
602
+ ):
505
603
  """
506
604
  Initialize the parameter server. env RANK, WORLD_SIZE and MASTER_ADDR must be set.
507
605
 
508
606
  Args:
509
607
  auto_pg: Whether to automatically initialize the process group.
510
608
  Notice that if auto_pg is True, will destroy the process group after update.
609
+ mem_fraction: The proportion (as a fraction) of the current free CUDA memory for allocation.
511
610
  """
512
- self._rank = int(os.environ.get("RANK", None))
513
- self._world_size = int(os.environ.get("WORLD_SIZE", None))
514
- self._master_addr = os.getenv("MASTER_ADDR")
515
- self._gpu_count = torch.cuda.device_count()
611
+ self._rank = rank or int(os.environ.get("RANK", None))
612
+ self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
613
+ self._gpu_count = gpu_count or torch.cuda.device_count()
516
614
  self._local_rank = self._rank % self._gpu_count
517
615
  self._auto_pg = auto_pg
518
616
  self._all_hosts = []
519
- self._global_socket_paths: list[tuple[str, str]] = []
617
+ self._global_device_uuids: list[str] = []
618
+ self._mem_fraction = mem_fraction or 0.9
520
619
 
521
620
  assert self._rank is not None and self._rank >= 0, self._rank
522
621
  assert self._world_size and self._world_size > 0, self._world_size
622
+ assert (
623
+ self._gpu_count is not None
624
+ and self._gpu_count > 0
625
+ and self._gpu_count <= torch.cuda.device_count()
626
+ ), self._gpu_count
627
+ assert (
628
+ self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
629
+ ), self._mem_fraction
523
630
 
524
- self._device_uuid = _get_physical_gpu_id(self._local_rank)
525
631
  self._zmq_ctx = zmq.Context()
632
+ self._zmq_addr_counter = 0
526
633
 
527
634
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
528
635
  # dict key is owner_rank, value is a bucket metas list in owner_rank
@@ -533,19 +640,27 @@ class ParameterServer:
533
640
  logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
534
641
  self._p2p_store = None
535
642
 
536
- torch.cuda.set_device(self._local_rank)
643
+ device_index = self._local_rank
644
+ torch.cuda.set_device(device_index)
645
+ self._device_uuid = _get_physical_gpu_id(device_index)
537
646
 
538
- def _logger_rank0(self, msg):
647
+ def _logger_rank0(self, msg: str):
539
648
  if self._local_rank == 0:
540
649
  logger.info(msg)
541
650
 
542
- def get_metas(self):
651
+ def get_metas(self) -> dict[int, MemoryBufferMetaList]:
543
652
  return self._current_global_parameter_metas
544
653
 
545
654
  def load_metas(self, metas: dict[int, MemoryBufferMetaList]):
546
655
  self._current_global_parameter_metas = metas
547
656
 
548
- def register_checkpoint(self, checkpoint_name: str, *, files: list[str] = [], named_tensors: dict[str, torch.Tensor] = {}):
657
+ def register_checkpoint(
658
+ self,
659
+ checkpoint_name: str,
660
+ *,
661
+ files: list[str] | None = None,
662
+ named_tensors: dict[str, torch.Tensor] | None = None,
663
+ ) -> None:
549
664
  """
550
665
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
551
666
 
@@ -555,12 +670,18 @@ class ParameterServer:
555
670
  named_tensors: The named tensors to register.
556
671
  """
557
672
  try:
558
- assert checkpoint_name not in self._memory_pool, f"checkpoint {checkpoint_name} already registered"
559
- self._memory_pool[checkpoint_name] = _register_checkpoint(files=files, named_tensors=named_tensors, rank=self._rank)
673
+ assert checkpoint_name not in self._memory_pool, (
674
+ f"checkpoint {checkpoint_name} already registered"
675
+ )
676
+ self._memory_pool[checkpoint_name] = _register_checkpoint(
677
+ files=files or [], named_tensors=named_tensors or {}, rank=self._rank
678
+ )
560
679
  if self._p2p_store is not None:
561
680
  self._register_parameters_to_p2p_store(checkpoint_name)
562
681
  except Exception:
563
- logger.exception(f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}")
682
+ logger.exception(
683
+ f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
684
+ )
564
685
  if self._p2p_store is not None:
565
686
  self._unregister_parameters_from_p2p_store(checkpoint_name)
566
687
  self.unregister_checkpoint(checkpoint_name)
@@ -583,10 +704,6 @@ class ParameterServer:
583
704
  # this works by using torch>=2.5.0
584
705
  torch._C._host_emptyCache()
585
706
 
586
- @cached_property
587
- def _zmq_socket_path(self) -> str:
588
- return f"ipc://@checkpoint-engine-{uuid.uuid4()}.sock"
589
-
590
707
  def gather_metas(self, checkpoint_name: str):
591
708
  """
592
709
  Gather the parameter metas from all ranks. This will gather memory_buffer, and other metadatas.
@@ -598,19 +715,17 @@ class ParameterServer:
598
715
  assert dist.is_initialized(), "process group is not initialized"
599
716
  metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
600
717
  metas = DataToGather(
601
- memory_buffer_metas_list=list(
602
- map(
603
- lambda x: MemoryBufferMetas(
604
- metas=x.metas,
605
- ptr=x.buffer.data_ptr(),
606
- size=x.size,
607
- ),
608
- self._memory_pool.get(checkpoint_name, []),
609
- ),
610
- ),
718
+ memory_buffer_metas_list=[
719
+ MemoryBufferMetas(
720
+ metas=x.metas,
721
+ ptr=x.buffer.data_ptr(),
722
+ size=x.size,
723
+ )
724
+ for x in self._memory_pool.get(checkpoint_name, [])
725
+ ],
611
726
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
612
727
  host_ip=_get_ip(),
613
- zmq_socket_path=(self._device_uuid, self._zmq_socket_path),
728
+ device_uuid=self._device_uuid,
614
729
  )
615
730
 
616
731
  dist.all_gather_object(metas_lst, metas)
@@ -618,23 +733,31 @@ class ParameterServer:
618
733
  self._current_global_parameter_metas = {}
619
734
  num_parameters = 0
620
735
  all_hosts: list[str] = []
621
- global_socket_paths: list[tuple[str, str]] = []
736
+ global_device_uuids: list[str] = []
622
737
  for i, metas_buckets in enumerate(metas_lst):
623
738
  assert metas_buckets is not None, f"metas_buckets {i} should not be None"
624
739
  if i % self._gpu_count == 0 and not self._all_hosts:
625
740
  all_hosts.append(metas_buckets.host_ip)
626
- if not self._global_socket_paths:
627
- global_socket_paths.append(metas_buckets.zmq_socket_path)
741
+ if not self._global_device_uuids:
742
+ global_device_uuids.append(metas_buckets.device_uuid)
628
743
  if metas_buckets.memory_buffer_metas_list:
629
744
  self._current_global_parameter_metas[i] = metas_buckets
630
- num_parameters += sum(map(lambda x: len(x.metas), metas_buckets.memory_buffer_metas_list))
745
+ num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
631
746
  if not self._all_hosts:
632
747
  self._all_hosts = all_hosts
633
- if not self._global_socket_paths:
634
- self._global_socket_paths = global_socket_paths
635
- logger.info(f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}")
748
+ if not self._global_device_uuids:
749
+ self._global_device_uuids = global_device_uuids
750
+ logger.info(
751
+ f"[rank{self._rank}] gather parameter metas finished, num_parameters: {num_parameters}"
752
+ )
636
753
 
637
- def init_process_group(self, *, master_port: int | None = None, timeout: timedelta = timedelta(minutes=10)):
754
+ def init_process_group(
755
+ self,
756
+ *,
757
+ master_addr: str | None = None,
758
+ master_port: int | None = None,
759
+ timeout: timedelta = timedelta(minutes=10),
760
+ ):
638
761
  """
639
762
  Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
640
763
 
@@ -642,10 +765,22 @@ class ParameterServer:
642
765
  master_port: The specified port of the master node. If not set, will use _get_master_port to get the port.
643
766
  timeout: The timeout of the process group.
644
767
  """
768
+ master_addr = master_addr or os.getenv("MASTER_ADDR")
769
+ assert master_addr, "master_addr is required"
645
770
  store = dist.TCPStore(
646
- self._master_addr, _get_master_port(master_port), self._world_size, timeout=timeout, is_master=self._rank == 0
771
+ master_addr,
772
+ _get_master_port(master_port),
773
+ self._world_size,
774
+ timeout=timeout,
775
+ is_master=self._rank == 0,
776
+ )
777
+ dist.init_process_group(
778
+ backend="nccl",
779
+ world_size=self._world_size,
780
+ rank=self._rank,
781
+ timeout=timeout,
782
+ store=store,
647
783
  )
648
- dist.init_process_group(backend="nccl", world_size=self._world_size, rank=self._rank, timeout=timeout, store=store)
649
784
  logger.info(f"[rank{self._rank}] init process group successfully.")
650
785
 
651
786
  def update(
@@ -653,8 +788,8 @@ class ParameterServer:
653
788
  checkpoint_name: str,
654
789
  req_func: Callable[[list[tuple[str, str]]], None],
655
790
  *,
656
- ranks: list[int] = [],
657
- ):
791
+ ranks: list[int] | None = None,
792
+ ) -> None:
658
793
  """
659
794
  Update the checkpoint to inference engine. This function should be called after gather_metas.
660
795
 
@@ -667,18 +802,21 @@ class ParameterServer:
667
802
  which is useful in disaggregated architecture.
668
803
  """
669
804
  try:
805
+ # if both ranks is None or [], it will use fully broadcast to update to all ranks
670
806
  if not ranks:
671
807
  if self._auto_pg and not dist.is_initialized():
672
808
  self.init_process_group()
673
809
  self._update_per_bucket(checkpoint_name, req_func)
674
810
  else:
675
- if self._rank not in ranks:
811
+ if not self._auto_pg and self._rank not in ranks:
676
812
  return
677
813
  if self._auto_pg:
678
814
  if dist.is_initialized():
679
815
  dist.destroy_process_group()
680
816
  # HACK: wait 2s to ensure destroy is finished
681
817
  time.sleep(2)
818
+ if self._rank not in ranks:
819
+ return
682
820
  self.init_process_group_for_ranks(ranks)
683
821
  self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
684
822
  if self._auto_pg:
@@ -692,21 +830,39 @@ class ParameterServer:
692
830
  f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
693
831
  )
694
832
  except Exception as e:
695
- logger.exception(f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}")
696
- raise e
833
+ logger.exception(
834
+ f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
835
+ )
836
+ raise
837
+
838
+ def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
839
+ def zmq_handle(device_uuid: str) -> str:
840
+ return f"ipc://@checkpoint-engine-{device_uuid}-{self._zmq_addr_counter}.sock"
697
841
 
698
- def _get_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
699
- GiB_bytes = 1 << 30
842
+ socket_paths = [(uid, zmq_handle(uid)) for uid in self._global_device_uuids]
843
+ socket = self._zmq_ctx.socket(zmq.REQ)
844
+ socket.bind(zmq_handle(self._device_uuid))
845
+ self._zmq_addr_counter += 1
846
+ return socket, socket_paths
847
+
848
+ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, bool]:
849
+ GiB = 1 << 30 # noqa: N806
700
850
  # auto detect bucket size
701
- free_bytes_tensor = torch.tensor(
702
- int(float(torch.cuda.mem_get_info()[0]) * 0.9),
851
+ tensor = torch.tensor(
852
+ [
853
+ # proportion of current cuda free memory bytes
854
+ int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
855
+ # we use negative value to reuse allreduce min operation
856
+ # for getting the max value of zmq_addr_counter in all ranks
857
+ -self._zmq_addr_counter,
858
+ ],
703
859
  dtype=torch.int64,
704
860
  device="cuda",
705
861
  )
706
- dist.all_reduce(free_bytes_tensor, op=dist.ReduceOp.MIN)
707
- free_bytes = free_bytes_tensor.item()
862
+ dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
863
+ tensor = tensor.cpu()
864
+ free_bytes, self._zmq_addr_counter = tensor[0].item(), -tensor[1].item()
708
865
  max_tensor_bytes = 0
709
- max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB_bytes
710
866
  for items in self._current_global_parameter_metas.values():
711
867
  for metas_list in items.memory_buffer_metas_list:
712
868
  for meta in metas_list.metas:
@@ -729,18 +885,27 @@ class ParameterServer:
729
885
  f"max_tensor_bytes {max_tensor_bytes} should be less than free_bytes {free_bytes}"
730
886
  )
731
887
  disable_h2d_buffer = True
888
+ max_bytes = int(os.getenv("PS_MAX_BUCKET_SIZE_GB", 8)) * GiB
732
889
  bucket_size = min(max(max_bytes, max_tensor_bytes), free_bytes)
733
- logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB_bytes:.2f} GiB")
890
+ logger.info(f"[rank{self._rank}] auto detect bucket size {bucket_size / GiB:.2f} GiB")
734
891
  return bucket_size, disable_h2d_buffer
735
892
 
736
- def _copy_to_buffer(self, checkpoint_name: str, bucket: H2DBucket, buffer: torch.Tensor, owner_rank: int | None = None):
893
+ def _copy_to_buffer(
894
+ self,
895
+ checkpoint_name: str,
896
+ bucket: H2DBucket,
897
+ buffer: torch.Tensor,
898
+ owner_rank: int | None = None,
899
+ ):
737
900
  offset = 0
738
901
  if owner_rank is not None:
739
902
  buf_ptrs, remote_ptrs, lens = [], [], []
740
903
  ptr_base = buffer.data_ptr()
741
904
  target_addr, ptrs = self._get_addr_ptrs(owner_rank)
742
905
  for b in bucket.ranges:
743
- assert offset + b.size <= bucket.size, f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
906
+ assert offset + b.size <= bucket.size, (
907
+ f"offset {offset} + size {b.size} > bucket_size {bucket.size}"
908
+ )
744
909
  if owner_rank is not None:
745
910
  buf_ptrs.append(ptr_base + offset)
746
911
  remote_ptrs.append(ptrs[b.idx][0] + b.offset)
@@ -758,7 +923,11 @@ class ParameterServer:
758
923
  torch.cuda.synchronize()
759
924
 
760
925
  def init_process_group_for_ranks(
761
- self, ranks: list[int], *, master_port: int | None = None, timeout: timedelta = timedelta(minutes=10)
926
+ self,
927
+ ranks: list[int],
928
+ *,
929
+ master_port: int | None = None,
930
+ timeout: timedelta = timedelta(minutes=10),
762
931
  ):
763
932
  """
764
933
  Initialize the process group for the ranks. This global group can be easily destroyed by calling dist.destroy_process_group.
@@ -787,8 +956,12 @@ class ParameterServer:
787
956
  # and will not participate in this update. Since they have registered memory addresses
788
957
  # to p2p_store at the beginning, update ranks can directly get the memory addresses
789
958
  # from other nodes and put the weights into the buffer.
790
- store = dist.TCPStore(master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout)
791
- dist.init_process_group(backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store)
959
+ store = dist.TCPStore(
960
+ master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
961
+ )
962
+ dist.init_process_group(
963
+ backend="nccl", world_size=len(ranks), rank=rank, timeout=timeout, store=store
964
+ )
792
965
 
793
966
  def _update_per_bucket_p2p(
794
967
  self,
@@ -800,7 +973,9 @@ class ParameterServer:
800
973
  assert ranks, "ranks should be set"
801
974
  if len(self._current_global_parameter_metas) == 0:
802
975
  raise ValueError("parameter metas is empty")
803
- assert dist.is_initialized(), "process group is not initialized when update model per bucket p2p"
976
+ assert dist.is_initialized(), (
977
+ "process group is not initialized when update model per bucket p2p"
978
+ )
804
979
 
805
980
  need_update = self._rank in ranks
806
981
  logger.info(
@@ -814,26 +989,24 @@ class ParameterServer:
814
989
  # first execute a barrier to avoid subsequent cuda oom
815
990
  dist.barrier()
816
991
 
817
- bucket_size, _ = self._get_bucket_size(disable_h2d_buffer=True)
992
+ bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True)
818
993
  buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
819
- IPC_BUFFER_NAME = "__ipc_buffer___"
820
- self._p2p_store.register_named_tensors({IPC_BUFFER_NAME: buffer})
994
+ ipc_buffer_name = "__ipc_buffer___"
995
+ self._p2p_store.register_named_tensors({ipc_buffer_name: buffer})
821
996
  logger.info(
822
997
  f"[rank{self._rank}] register buffer, shape={buffer.shape}, dtype={buffer.dtype}, data_ptr={buffer.data_ptr()}, nbytes={buffer.nbytes}"
823
998
  )
824
999
  handle = reduce_tensor(buffer)
825
1000
 
826
- gidx = 0
827
1001
  buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
1002
+ socket, socket_paths = self._bind_zmq_socket()
828
1003
  req_thread = threading.Thread(
829
1004
  target=req_func,
830
- args=(self._global_socket_paths,),
1005
+ args=(socket_paths,),
831
1006
  )
832
1007
  req_thread.start()
833
- socket = self._zmq_ctx.socket(zmq.REQ)
834
- socket.bind(self._zmq_socket_path)
835
1008
  socket.send_pyobj(handle)
836
- for owner_rank, bucket in buckets:
1009
+ for gidx, (owner_rank, bucket) in enumerate(buckets):
837
1010
  self._logger_rank0(
838
1011
  f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
839
1012
  )
@@ -845,7 +1018,6 @@ class ParameterServer:
845
1018
  socket.recv()
846
1019
  dist.barrier()
847
1020
  socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
848
- gidx += 1
849
1021
 
850
1022
  socket.recv()
851
1023
  socket.send_pyobj(None)
@@ -853,7 +1025,7 @@ class ParameterServer:
853
1025
  req_thread.join()
854
1026
  dist.barrier()
855
1027
  socket.close()
856
- self._p2p_store.unregister_named_tensors([IPC_BUFFER_NAME])
1028
+ self._p2p_store.unregister_named_tensors([ipc_buffer_name])
857
1029
  torch.cuda.empty_cache()
858
1030
 
859
1031
  def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
@@ -877,7 +1049,9 @@ class ParameterServer:
877
1049
  pool = self._memory_pool[checkpoint_name]
878
1050
  if len(pool) == 0:
879
1051
  return 0
880
- return self._p2p_store.unregister_named_tensors([f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)])
1052
+ return self._p2p_store.unregister_named_tensors(
1053
+ [f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
1054
+ )
881
1055
 
882
1056
  def _update_per_bucket(
883
1057
  self,
@@ -891,11 +1065,13 @@ class ParameterServer:
891
1065
 
892
1066
  logger.info(f"[rank{self._rank}] update checkpoint {checkpoint_name}")
893
1067
 
894
- bucket_size, disable_h2d_buffer = self._get_bucket_size()
1068
+ bucket_size, disable_h2d_buffer = self._detect_bucket_size()
895
1069
  buckets = _gen_h2d_buckets(self._current_global_parameter_metas, bucket_size)
896
1070
 
897
1071
  h2d_buffer: torch.Tensor | None = (
898
- None if disable_h2d_buffer else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
1072
+ None
1073
+ if disable_h2d_buffer
1074
+ else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
899
1075
  )
900
1076
 
901
1077
  owner_rank_buckets: list[H2DBucket] = []
@@ -914,13 +1090,12 @@ class ParameterServer:
914
1090
  if len(buckets_by_owner_rank[owner_rank]) > max_len:
915
1091
  max_len = len(buckets_by_owner_rank[owner_rank])
916
1092
 
1093
+ socket, socket_paths = self._bind_zmq_socket()
917
1094
  req_thread = threading.Thread(
918
1095
  target=req_func,
919
- args=(self._global_socket_paths,),
1096
+ args=(socket_paths,),
920
1097
  )
921
1098
  req_thread.start()
922
- socket = self._zmq_ctx.socket(zmq.REQ)
923
- socket.bind(self._zmq_socket_path)
924
1099
  socket.send_pyobj(handle)
925
1100
 
926
1101
  gidx = 0
@@ -932,7 +1107,10 @@ class ParameterServer:
932
1107
  if i >= len(_buckets):
933
1108
  continue
934
1109
  bucket = _buckets[i]
935
- alloc, reserved = torch.cuda.memory_allocated() / 1024 / 1024, torch.cuda.memory_reserved() / 1024 / 1024
1110
+ alloc, reserved = (
1111
+ torch.cuda.memory_allocated() / 1024 / 1024,
1112
+ torch.cuda.memory_reserved() / 1024 / 1024,
1113
+ )
936
1114
  self._logger_rank0(
937
1115
  f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
938
1116
  f"Current CUDA allocated {alloc:.2f} MB, "
@@ -960,7 +1138,7 @@ class ParameterServer:
960
1138
  torch.cuda.empty_cache()
961
1139
 
962
1140
 
963
- def _init_api(ps: ParameterServer):
1141
+ def _init_api(ps: ParameterServer) -> Any:
964
1142
  import fastapi
965
1143
  from fastapi import Request
966
1144
  from fastapi.responses import JSONResponse, Response
@@ -976,32 +1154,32 @@ def _init_api(ps: ParameterServer):
976
1154
  inference_group_ranks: list[int] = []
977
1155
  timeout: float = 300.0
978
1156
 
979
- def wrap_exception(func):
1157
+ def wrap_exception(func: Callable[[], None]) -> Response:
980
1158
  try:
981
1159
  func()
982
- except Exception as e:
1160
+ except Exception as e: # noqa: BLE001
983
1161
  logger.exception(f"wrap exception {func} failed")
984
1162
  return JSONResponse(content=str(e), status_code=500)
985
1163
  return Response(status_code=200)
986
1164
 
987
1165
  @app.post("/v1/checkpoints/{checkpoint_name}/files")
988
- async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request):
1166
+ async def register_files(checkpoint_name: str, req: RegisterRequest, raw: Request) -> Response:
989
1167
  return wrap_exception(lambda: ps.register_checkpoint(checkpoint_name, files=req.files))
990
1168
 
991
1169
  @app.delete("/v1/checkpoints/{checkpoint_name}")
992
- async def unregister_checkpoint(checkpoint_name: str):
1170
+ async def unregister_checkpoint(checkpoint_name: str) -> Response:
993
1171
  return wrap_exception(lambda: ps.unregister_checkpoint(checkpoint_name))
994
1172
 
995
1173
  @app.get("/v1/healthz")
996
- async def healthz():
1174
+ async def healthz() -> Response:
997
1175
  return Response(status_code=200)
998
1176
 
999
1177
  @app.post("/v1/checkpoints/{checkpoint_name}/gather-metas")
1000
- async def gather_metas(checkpoint_name: str):
1178
+ async def gather_metas(checkpoint_name: str) -> Response:
1001
1179
  return wrap_exception(lambda: ps.gather_metas(checkpoint_name))
1002
1180
 
1003
1181
  @app.post("/v1/checkpoints/{checkpoint_name}/update")
1004
- async def update(checkpoint_name: str, req: UpdateRequest):
1182
+ async def update(checkpoint_name: str, req: UpdateRequest) -> Response:
1005
1183
  def update_func(socket_paths: list[tuple[str, str]]):
1006
1184
  if req.update_url is None:
1007
1185
  return
@@ -1018,11 +1196,13 @@ def _init_api(ps: ParameterServer):
1018
1196
  def run_from_cli():
1019
1197
  import uvicorn
1020
1198
 
1021
- parser = argparse.ArgumentParser(description="Paramter Server")
1199
+ parser = argparse.ArgumentParser(description="Parameter Server")
1022
1200
  parser.add_argument("--uds", type=str)
1023
1201
 
1024
1202
  args = parser.parse_args()
1025
- logger.info(f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}")
1203
+ logger.info(
1204
+ f"Parameter Server {args=}, master addr: {os.getenv('MASTER_ADDR')}, master port {os.getenv('MASTER_PORT')}"
1205
+ )
1026
1206
 
1027
1207
  assert args.uds and len(args.uds) > 0, args.uds
1028
1208
  ps = ParameterServer(auto_pg=True)
@@ -1,11 +1,12 @@
1
1
  import gc
2
- from typing import Callable, Optional, TypedDict
2
+ from collections.abc import Callable
3
+ from typing import TypedDict
3
4
 
4
5
  import torch
5
6
  import zmq
6
7
 
7
8
 
8
- def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: Optional[int] = None) -> torch.Tensor:
9
+ def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
9
10
  func, args = handle
10
11
  list_args = list(args)
11
12
  if device_id is not None:
@@ -24,12 +25,14 @@ class FlattenedTensorMetadata(TypedDict):
24
25
  offset: int
25
26
 
26
27
 
27
- def _extract_weights(payload: list[FlattenedTensorMetadata], buffer: torch.Tensor) -> list[tuple[str, torch.Tensor]]:
28
+ def _extract_weights(
29
+ payload: list[FlattenedTensorMetadata], buffer: torch.Tensor
30
+ ) -> list[tuple[str, torch.Tensor]]:
28
31
  assert buffer is not None
29
32
  weights: list[tuple[str, torch.Tensor]] = []
30
33
  for item in payload:
31
34
  shape = item["shape"]
32
- if isinstance(shape, (list, tuple)):
35
+ if isinstance(shape, list | tuple):
33
36
  shape = torch.Size(shape)
34
37
  assert isinstance(shape, torch.Size)
35
38
  dtype, offset = item["dtype"], item["offset"]
@@ -45,11 +48,11 @@ def update_weights_from_ipc(
45
48
  device_id: int,
46
49
  *,
47
50
  run: Callable[[list[tuple[str, torch.Tensor]]], None],
48
- post_hook: Callable[[], None] = None,
51
+ post_hook: Callable[[], None] | None = None,
49
52
  ):
50
53
  socket = zmq_ctx.socket(zmq.REP)
51
54
  socket.connect(zmq_handle)
52
- buffer: Optional[torch.Tensor] = None
55
+ buffer: torch.Tensor | None = None
53
56
  while True:
54
57
  payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
55
58
  if payload is None:
@@ -100,5 +103,7 @@ class VllmColocateWorkerExtension:
100
103
  zmq_handles[device_uuid],
101
104
  device_id=self.device.index,
102
105
  run=self.model_runner.model.load_weights,
103
- post_hook=lambda: process_weights_after_loading(self.model_runner.model, self.model_config, self.device),
106
+ post_hook=lambda: process_weights_after_loading(
107
+ self.model_runner.model, self.model_config, self.device
108
+ ),
104
109
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -11,14 +11,15 @@ Description-Content-Type: text/markdown
11
11
  License-File: LICENCE
12
12
  Requires-Dist: torch>=2.5.0
13
13
  Requires-Dist: fastapi
14
- Requires-Dist: pydantic
14
+ Requires-Dist: pydantic>=2.0.0
15
15
  Requires-Dist: safetensors
16
16
  Requires-Dist: pyzmq
17
17
  Requires-Dist: uvicorn
18
18
  Requires-Dist: loguru
19
19
  Requires-Dist: numpy
20
+ Requires-Dist: httpx
20
21
  Provides-Extra: p2p
21
- Requires-Dist: mooncake-transfer-engine; extra == "p2p"
22
+ Requires-Dist: mooncake-transfer-engine>=0.3.5; extra == "p2p"
22
23
  Dynamic: license-file
23
24
 
24
25
  # Checkpoint Engine
@@ -41,7 +42,7 @@ The core weight update logic is in `ParameterServer` class, a service colocated
41
42
  - **P2P**: Used when new inference instances are dynamically added (due to restarts or dynamic availability) while the existing instances are already serving requests. Under this scenario, to avoid affecting the workloads on existing instances, we use the [`mooncake-transfer-engine`](https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#use-python-package) to P2P send weights from CPUs in existing instances to GPUs in new instances. See `_update_per_bucket_p2p`.
42
43
 
43
44
  ### Optimized Weight Broadcast
44
- In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
45
+ In the *Broadcast* implementation, the checkpoint-engine holds references to sharded weights in CPU memory, and need to efficiently broadcast them to a cluster of inference instances, often under a different sharding pattern.
45
46
  We arrange the data transfer into 3 stages:
46
47
  1. H2D: moving weights to GPU memory. These weights may come from disk or the training engine.
47
48
  2. broadcast: broadcast among checkpoint engine workers; the data results in a CUDA IPC buffer shared with inference engine.
@@ -73,9 +74,9 @@ Pipelining naturally requires more GPU memory. When memory is not enough, checkp
73
74
  All results above are tested by [`examples/update.py`](./examples/update.py) and use [vLLM v0.10.2rc1](https://github.com/vllm-project/vllm/tree/v0.10.2rc1) as inference engine. Some notes:
74
75
 
75
76
  * FP8 test needs additional vLLM patches, see [FP8 quantization](#fp8-quantization).
76
- * Device Info: we tested various combination of devices and paralleism setups. For exmaple, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
77
+ * Device Info: we tested various combination of devices and parallelism setups. For example, a 256-GPU TP16 setup means that we deploy 16 vLLM instances, each with 16-way tensor parallelism.
77
78
  * Since update duration is related to IPC bucket size, we provide the bucket size in the table.
78
- * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
79
+ * The P2P time were tested for updating no more than two nodes (16 GPUs) (`ParameterServer.update(ranks=range(0, 16))`) out of the entire cluster.
79
80
 
80
81
  ## Installation
81
82
 
@@ -88,7 +89,7 @@ pip install checkpoint-engine
88
89
  Use the flexible P2P implementation, notice this will install `mooncake-transfer-engine` to support RDMA transfer between different ranks.
89
90
 
90
91
  ```Bash
91
- pip install checkpoint-engine[p2p]
92
+ pip install 'checkpoint-engine[p2p]'
92
93
  ```
93
94
 
94
95
  If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. If not set, it will read all RDMA devices and try to divide them into each rank.
@@ -107,7 +108,7 @@ VLLM_USE_PRECOMPILED=1 uv pip install --editable .
107
108
  Install checkpoint-engine
108
109
 
109
110
  ```Bash
110
- uv pip install checkpoint-engine[p2p]
111
+ uv pip install 'checkpoint-engine[p2p]'
111
112
  ```
112
113
 
113
114
  We use `Qwen/Qwen3-235B-A22B-Instruct-2507` (BF16) as the test model
@@ -133,7 +134,7 @@ torchrun --nproc-per-node 8 examples/update.py --update-method all --checkpoint-
133
134
 
134
135
  ### Reuse weights from existing instances
135
136
 
136
- New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
137
+ New checkpoint-engine instances can join existing instances and reuse their weights. This is simple to achieve.
137
138
 
138
139
  First, start the existing instances with `--save-metas-file global_metas.pkl` to save global metas to a file and use `--sleep-time 300` to make sure they stay alive.
139
140
 
@@ -150,7 +151,7 @@ torchrun --nproc-per-node 8 examples/update.py --load-metas-file global_metas.pk
150
151
 
151
152
  ### FP8 quantization
152
153
 
153
- FP8 quantization currently do not natively work in vLLM when updating weights.
154
+ FP8 quantization currently do not natively work in vLLM when updating weights.
154
155
  We provide a simple patch in [`patches/vllm_fp8.patch`](./patches/vllm_fp8.patch) to handle the correct weight update.
155
156
  Notice this patch is only tested in DeepSeek-V3.1 and Kimi-K2. Other models may meet some compatible issues.
156
157
 
@@ -0,0 +1,9 @@
1
+ checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
+ checkpoint_engine/_version.py,sha256=q5nF98G8SoVeJqaknL0xdyxtv0egsqb0fK06_84Izu8,704
3
+ checkpoint_engine/ps.py,sha256=9dXRXi0QDPoRYrgGKAYvdmDFBXejgusjR0ltbii5_B0,49134
4
+ checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
5
+ checkpoint_engine-0.1.3.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
6
+ checkpoint_engine-0.1.3.dist-info/METADATA,sha256=y96dMjEOKWaO_PA0h5BX8G3Ku7Tt1jCU09uIf8iYgic,9322
7
+ checkpoint_engine-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ checkpoint_engine-0.1.3.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
+ checkpoint_engine-0.1.3.dist-info/RECORD,,
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
18
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
19
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
20
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE
21
+ SOFTWARE
@@ -1,9 +0,0 @@
1
- checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
2
- checkpoint_engine/_version.py,sha256=yfx_VE-4lpqM4jnWOSq-8rihMWIwMaX9CQ7tNEpA4T0,712
3
- checkpoint_engine/ps.py,sha256=9u2rLOj-oQrXsnpYhdXFjv7ak2-f4BRUXB6KYlG3ah0,44422
4
- checkpoint_engine/worker.py,sha256=OrSeknjtECnO88I-YMdfkZj70TIRhjvTEeZkNyZk21M,3695
5
- checkpoint_engine-0.1.1.dist-info/licenses/LICENCE,sha256=0jqA0jrA_i9VUqd7FTVoI1KnN1ZRENwG_tlMRCjC63k,1066
6
- checkpoint_engine-0.1.1.dist-info/METADATA,sha256=WDc5tg3RQiCthzbEzI4D3t3I0AQhGUPcDpJn7TMhYbI,9286
7
- checkpoint_engine-0.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
- checkpoint_engine-0.1.1.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
9
- checkpoint_engine-0.1.1.dist-info/RECORD,,