checkpoint-engine 0.2.1__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/PKG-INFO +9 -6
  2. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/README.md +8 -5
  3. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/_version.py +3 -3
  4. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/ps.py +259 -38
  5. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/worker.py +5 -2
  6. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/PKG-INFO +9 -6
  7. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/SOURCES.txt +1 -0
  8. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/examples/update.py +5 -2
  9. checkpoint_engine-0.2.3/tests/test_pin_memory.py +77 -0
  10. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/tests/test_update.py +96 -2
  11. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.github/workflows/cpu-tests.yml +0 -0
  12. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.github/workflows/pre-commit.yaml +0 -0
  13. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.github/workflows/python-publish.yml +0 -0
  14. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.gitignore +0 -0
  15. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/.pre-commit-config.yaml +0 -0
  16. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/LICENCE +0 -0
  17. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/__init__.py +0 -0
  18. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine/device_utils.py +0 -0
  19. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/dependency_links.txt +0 -0
  20. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/requires.txt +0 -0
  21. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/checkpoint_engine.egg-info/top_level.txt +0 -0
  22. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/docs/npu_start.md +0 -0
  23. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/figures/checkpoint-engine.png +0 -0
  24. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/figures/overlap-update-and-copy.png +0 -0
  25. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/figures/pipeline.png +0 -0
  26. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/patches/vllm_fp8.patch +0 -0
  27. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/pyproject.toml +0 -0
  28. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/setup.cfg +0 -0
  29. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/tests/test_assign_receiver_ranks.py +0 -0
  30. {checkpoint_engine-0.2.1 → checkpoint_engine-0.2.3}/tests/test_rdma_parser.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
99
99
  pip install 'checkpoint-engine[p2p]'
100
100
  ```
101
101
 
102
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
103
-
104
102
  ## Getting Started
105
103
 
106
- Prepare an H800 or H20 machine with 8 GPUs with latest vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights.
104
+ Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
107
105
 
108
106
  ```Bash
109
- cd /opt && git clone https://github.com/vllm-project/vllm && cd vllm
107
+ mkdir -p /opt/vLLM && cd /opt/vLLM
110
108
  uv venv --python 3.12 --seed
111
109
  source .venv/bin/activate
112
- VLLM_USE_PRECOMPILED=1 uv pip install --editable .
110
+ uv pip install vllm==0.10.2
113
111
  ```
114
112
 
115
113
  Install checkpoint-engine
@@ -180,6 +178,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
180
178
  pytest tests/ -m "not gpu"
181
179
  ```
182
180
 
181
+ ### Environment Variables
182
+ - `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
183
+ - `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
184
+ - `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
185
+
183
186
  ## SGLang Integration
184
187
 
185
188
  Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
@@ -75,17 +75,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
75
75
  pip install 'checkpoint-engine[p2p]'
76
76
  ```
77
77
 
78
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
79
-
80
78
  ## Getting Started
81
79
 
82
- Prepare an H800 or H20 machine with 8 GPUs with latest vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights.
80
+ Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
83
81
 
84
82
  ```Bash
85
- cd /opt && git clone https://github.com/vllm-project/vllm && cd vllm
83
+ mkdir -p /opt/vLLM && cd /opt/vLLM
86
84
  uv venv --python 3.12 --seed
87
85
  source .venv/bin/activate
88
- VLLM_USE_PRECOMPILED=1 uv pip install --editable .
86
+ uv pip install vllm==0.10.2
89
87
  ```
90
88
 
91
89
  Install checkpoint-engine
@@ -156,6 +154,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
156
154
  pytest tests/ -m "not gpu"
157
155
  ```
158
156
 
157
+ ### Environment Variables
158
+ - `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
159
+ - `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
160
+ - `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
161
+
159
162
  ## SGLang Integration
160
163
 
161
164
  Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.1'
32
- __version_tuple__ = version_tuple = (0, 2, 1)
31
+ __version__ = version = '0.2.3'
32
+ __version_tuple__ = version_tuple = (0, 2, 3)
33
33
 
34
- __commit_id__ = commit_id = 'g279a908a9'
34
+ __commit_id__ = commit_id = 'g0a6244951'
@@ -1,6 +1,7 @@
1
1
  import argparse
2
2
  import concurrent.futures
3
3
  import ctypes
4
+ import json
4
5
  import os
5
6
  import pickle
6
7
  import random
@@ -18,7 +19,7 @@ import torch.distributed as dist
18
19
  import zmq
19
20
  from loguru import logger
20
21
  from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
21
- from safetensors.torch import safe_open
22
+ from safetensors.torch import _getdtype, safe_open
22
23
  from torch.multiprocessing.reductions import reduce_tensor
23
24
 
24
25
  from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
92
93
  name: str
93
94
  dtype: _TorchDtype
94
95
  shape: _TorchSize
96
+ aligned_size: int
95
97
 
96
98
 
97
99
  class BucketRange(NamedTuple):
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
140
142
  def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
141
143
  ret = []
142
144
  for meta in metas:
143
- size = _align_size(meta.dtype, meta.shape)
145
+ size = meta.aligned_size
144
146
  ret.append(
145
147
  {
146
148
  "name": meta.name,
@@ -422,6 +424,7 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
422
424
  name=parameter_name,
423
425
  shape=meta["shape"],
424
426
  dtype=meta["dtype"],
427
+ aligned_size=_align_size(meta["dtype"], meta["shape"]),
425
428
  )
426
429
  tp_meta = tp_metas[parameter_name]
427
430
  if tp_meta.concat_dim != -1:
@@ -431,7 +434,10 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
431
434
  shape = list(parameter_metas[name].shape)
432
435
  shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
433
436
  parameter_metas[name] = ParameterMeta(
434
- name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
437
+ name=name,
438
+ shape=torch.Size(shape),
439
+ dtype=parameter_metas[name].dtype,
440
+ aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
435
441
  )
436
442
  weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
437
443
  # TODO: here concat is serial, which may be slow
@@ -449,17 +455,85 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
449
455
  return parameters
450
456
 
451
457
 
452
- def _register_checkpoint(
453
- *,
458
+ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459
+ def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460
+ """
461
+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462
+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463
+ The actual tensor data is in the remaining bytes and is naturally aligned.
464
+ We pin the remaining bytes as the buffer, making pinning faster.
465
+ """
466
+
467
+ def _pin(t: torch.Tensor):
468
+ """
469
+ Pin the memory of tensor in-place.
470
+ See: https://github.com/pytorch/pytorch/issues/32167
471
+ """
472
+ cudart = torch.cuda.cudart()
473
+ r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
474
+ assert r == 0, f"pin memory error, error code: {r}"
475
+
476
+ # TODO: should only support /dev/shm? but we found files in disk also work?
477
+ size = os.stat(file_path).st_size
478
+ flag_size = 8
479
+ t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480
+ assert t.nbytes > flag_size, (
481
+ f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482
+ )
483
+ start_pos = (
484
+ int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485
+ + flag_size
486
+ )
487
+ header_tensor = t[flag_size:start_pos]
488
+ header = json.loads(header_tensor.numpy().tobytes())
489
+ if "__metadata__" in header:
490
+ header.pop("__metadata__")
491
+
492
+ metas: list[ParameterMeta] = []
493
+ offset = 0
494
+ try:
495
+ for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496
+ start, end = meta["data_offsets"]
497
+ # safetensors format ensures offsets are aligned
498
+ assert offset == start, f"offset {offset} should be equal to start {start}"
499
+ metas.append(
500
+ ParameterMeta(
501
+ name=name,
502
+ dtype=_getdtype(meta["dtype"]),
503
+ shape=torch.Size(meta["shape"]),
504
+ aligned_size=end - start,
505
+ )
506
+ )
507
+ offset = end
508
+ except Exception as e:
509
+ logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510
+ raise
511
+
512
+ buffer = t[start_pos:]
513
+ assert offset == buffer.nbytes, (
514
+ f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515
+ )
516
+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
517
+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518
+ os.remove(file_path)
519
+ _pin(buffer)
520
+ logger.info(
521
+ f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522
+ )
523
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
524
+
525
+ memory_buffers: list[MemoryBuffer] = []
526
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527
+ memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528
+ return memory_buffers
529
+
530
+
531
+ def _normal_pin_memory(
454
532
  files: list[str],
455
533
  named_tensors: dict[str, torch.Tensor],
456
534
  rank: int | None = None,
535
+ shared_pin_memory: list[MemoryBuffer] | None = None,
457
536
  ) -> list[MemoryBuffer]:
458
- logger.info(
459
- f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
460
- )
461
- if not files and not named_tensors:
462
- return []
463
537
  parameters = _load_checkpoint(files)
464
538
  if named_tensors:
465
539
  parameters.update(named_tensors)
@@ -469,13 +543,16 @@ def _register_checkpoint(
469
543
  size: int
470
544
  metas: list[ParameterMeta]
471
545
 
472
- buckets: list[MemoryBucket] = [MemoryBucket(size=0, metas=[])]
546
+ buckets: list[MemoryBucket] = []
547
+ buckets.append(MemoryBucket(size=0, metas=[]))
473
548
  for name, tensor in sorted(parameters.items()):
474
549
  size = _align_size(tensor.dtype, tensor.shape)
475
550
  if buckets[-1].size + size > bucket_size:
476
551
  assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
477
552
  buckets.append(MemoryBucket(size=0, metas=[]))
478
- buckets[-1].metas.append(ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype))
553
+ buckets[-1].metas.append(
554
+ ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
555
+ )
479
556
  buckets[-1].size += size
480
557
 
481
558
  memory_buffers = [
@@ -483,16 +560,34 @@ def _register_checkpoint(
483
560
  for bucket in buckets
484
561
  ]
485
562
 
486
- def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
487
- buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
488
- return idx, buffer
563
+ def register_pin_memory(
564
+ idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
565
+ ) -> tuple[int, torch.Tensor]:
566
+ if shared_pin_memory:
567
+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
568
+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
569
+ assert idx < len(shared_pin_memory), (
570
+ f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
571
+ )
572
+ assert shared_pin_memory[idx].size == size, (
573
+ f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
574
+ )
575
+ return idx, shared_pin_memory[idx].buffer
576
+ else:
577
+ buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578
+ return idx, buffer
489
579
 
490
580
  def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
491
581
  buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
492
582
 
493
583
  with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
494
584
  futures = [
495
- executor.submit(register_pin_memory, idx, bucket.size)
585
+ executor.submit(
586
+ register_pin_memory,
587
+ idx,
588
+ bucket.size,
589
+ shared_pin_memory,
590
+ )
496
591
  for idx, bucket in enumerate(buckets)
497
592
  ]
498
593
  new_futures = []
@@ -518,6 +613,45 @@ def _register_checkpoint(
518
613
  offset += size
519
614
  for future in concurrent.futures.as_completed(new_futures):
520
615
  future.result()
616
+ return memory_buffers
617
+
618
+
619
+ def _register_checkpoint(
620
+ *,
621
+ files: list[str],
622
+ named_tensors: dict[str, torch.Tensor],
623
+ rank: int | None = None,
624
+ shared_pin_memory: list[MemoryBuffer] | None = None,
625
+ inplace_pin: bool = False,
626
+ ) -> list[MemoryBuffer]:
627
+ logger.info(
628
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
629
+ )
630
+ if not files and not named_tensors:
631
+ return []
632
+ memory_buffers: list[MemoryBuffer] = []
633
+ if inplace_pin:
634
+ logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
635
+ files_to_inplace_pin = [
636
+ file
637
+ for file in files
638
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
639
+ ]
640
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
641
+ else:
642
+ files_to_normal_pin = files
643
+ files_to_inplace_pin = []
644
+ if files_to_normal_pin or named_tensors:
645
+ memory_buffers.extend(
646
+ _normal_pin_memory(
647
+ files=files_to_normal_pin,
648
+ named_tensors=named_tensors,
649
+ rank=rank,
650
+ shared_pin_memory=shared_pin_memory,
651
+ )
652
+ )
653
+ if files_to_inplace_pin:
654
+ memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
521
655
  return memory_buffers
522
656
 
523
657
 
@@ -566,7 +700,7 @@ def _gen_h2d_buckets(
566
700
  for idx, metas in enumerate(items.memory_buffer_metas_list):
567
701
  start_offset, offset = 0, 0
568
702
  for meta in metas.metas:
569
- s = _align_size(meta.dtype, meta.shape)
703
+ s = meta.aligned_size
570
704
  if buckets[-1][1].size + s > bucket_size:
571
705
  if offset - start_offset > 0:
572
706
  buckets[-1][1].ranges.append(
@@ -747,6 +881,8 @@ class P2PStore:
747
881
 
748
882
 
749
883
  class ParameterServer:
884
+ shared_memory_pool_name = "__shared_memory_pool__"
885
+
750
886
  def __init__(
751
887
  self,
752
888
  *,
@@ -790,7 +926,10 @@ class ParameterServer:
790
926
  self._zmq_ctx = zmq.Context()
791
927
  self._zmq_addr_counter = 0
792
928
 
929
+ # stores the name of the checkpoint currently using the shared memory pool, or empty string if none
930
+ self._current_shared_memory_pool_user: str = ""
793
931
  self._memory_pool: dict[str, list[MemoryBuffer]] = {}
932
+ self._memory_pool[self.shared_memory_pool_name] = []
794
933
  # dict key is owner_rank, value is a bucket metas list in owner_rank
795
934
  self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
796
935
  # NPU transfer engine initialization requires prior set_device.
@@ -805,6 +944,17 @@ class ParameterServer:
805
944
  self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
806
945
  self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
807
946
 
947
+ def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
948
+ if checkpoint_name == self._current_shared_memory_pool_user:
949
+ assert self._memory_pool[self.shared_memory_pool_name], (
950
+ f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
951
+ )
952
+ return self._memory_pool[self.shared_memory_pool_name]
953
+ elif checkpoint_name in self._memory_pool:
954
+ return self._memory_pool[checkpoint_name]
955
+ else:
956
+ raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
957
+
808
958
  def _logger_rank0(self, msg: str):
809
959
  if self._local_rank == 0:
810
960
  logger.info(msg)
@@ -828,46 +978,103 @@ class ParameterServer:
828
978
  *,
829
979
  files: list[str] | None = None,
830
980
  named_tensors: dict[str, torch.Tensor] | None = None,
981
+ use_shared_memory_pool: bool = False,
982
+ use_inplace_pin_memory: bool = False,
831
983
  ) -> None:
832
984
  """
833
985
  Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
986
+ Warning: if `use_inplace_pin_memory` is True, .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
987
+ Please make sure to copy the files to disks if you need to keep them.
834
988
 
835
989
  Args:
836
990
  checkpoint_name: The name of the checkpoint.
837
991
  files: The safetensors files to register.
838
992
  named_tensors: The named tensors to register.
993
+ use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
994
+ Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
995
+ cannot accommodate checkpoints with different memory requirements.
996
+ To free the actual memory of the shared pool or to modify its shape,
997
+ please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
998
+ use_inplace_pin_memory: If True, allows inplace pin memory for /dev/shm/ safetensors files. This option is ignored when ``use_shared_memory_pool`` is True.
999
+ Currently, this feature is experimental and may crash.
839
1000
  """
840
1001
  try:
841
- assert checkpoint_name not in self._memory_pool, (
842
- f"checkpoint {checkpoint_name} already registered"
843
- )
844
- self._memory_pool[checkpoint_name] = _register_checkpoint(
845
- files=files or [], named_tensors=named_tensors or {}, rank=self._rank
846
- )
847
- if self._p2p_store is not None:
848
- self._register_parameters_to_p2p_store(checkpoint_name)
1002
+ if use_shared_memory_pool:
1003
+ logger.info(
1004
+ f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
1005
+ )
1006
+ assert self._current_shared_memory_pool_user == "", (
1007
+ f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
1008
+ f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
1009
+ f"This registration may cause unexpected conflicts."
1010
+ )
1011
+ # Since we set the uninitialized shared memory pool to empty list,
1012
+ # we can check whether this is the first time to use shared memory pool
1013
+ _is_first_time = not self._memory_pool[self.shared_memory_pool_name]
1014
+ self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
1015
+ files=files or [],
1016
+ named_tensors=named_tensors or {},
1017
+ rank=self._rank,
1018
+ shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
1019
+ )
1020
+ self._current_shared_memory_pool_user = checkpoint_name
1021
+ if self._p2p_store is not None and _is_first_time:
1022
+ self._register_parameters_to_p2p_store(checkpoint_name)
1023
+ else:
1024
+ assert checkpoint_name not in self._memory_pool, (
1025
+ f"checkpoint {checkpoint_name} already registered"
1026
+ )
1027
+ self._memory_pool[checkpoint_name] = _register_checkpoint(
1028
+ files=files or [],
1029
+ named_tensors=named_tensors or {},
1030
+ rank=self._rank,
1031
+ inplace_pin=use_inplace_pin_memory,
1032
+ )
1033
+ if self._p2p_store is not None:
1034
+ self._register_parameters_to_p2p_store(checkpoint_name)
849
1035
  except Exception:
850
1036
  logger.exception(
851
1037
  f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
852
1038
  )
853
- if self._p2p_store is not None:
1039
+ if self._p2p_store is not None and not use_shared_memory_pool:
854
1040
  self._unregister_parameters_from_p2p_store(checkpoint_name)
855
1041
  self.unregister_checkpoint(checkpoint_name)
856
1042
  raise
857
1043
 
858
- def unregister_checkpoint(self, checkpoint_name: str):
1044
+ def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
859
1045
  """
860
1046
  Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
861
1047
  from p2p store if p2p store is initialized.
1048
+ Args:
1049
+ checkpoint_name: The name of the checkpoint.
1050
+ force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
1051
+ If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
862
1052
  """
863
- if checkpoint_name not in self._memory_pool:
1053
+ if (
1054
+ checkpoint_name not in self._memory_pool
1055
+ and checkpoint_name != self._current_shared_memory_pool_user
1056
+ ):
1057
+ logger.warning(
1058
+ f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
1059
+ )
1060
+ return
1061
+
1062
+ if checkpoint_name == self._current_shared_memory_pool_user and not force:
1063
+ self._current_shared_memory_pool_user = ""
864
1064
  return
1065
+
865
1066
  if self._p2p_store is not None:
866
1067
  num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
867
1068
  logger.info(
868
1069
  f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
869
1070
  )
870
- del self._memory_pool[checkpoint_name]
1071
+
1072
+ if checkpoint_name == self._current_shared_memory_pool_user:
1073
+ self._current_shared_memory_pool_user = ""
1074
+ del self._memory_pool[self.shared_memory_pool_name]
1075
+ self._memory_pool[self.shared_memory_pool_name] = []
1076
+ else:
1077
+ del self._memory_pool[checkpoint_name]
871
1078
  # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
872
1079
  # this works by using torch>=2.5.0
873
1080
  torch._C._host_emptyCache()
@@ -882,6 +1089,10 @@ class ParameterServer:
882
1089
  self.init_process_group()
883
1090
  assert dist.is_initialized(), "process group is not initialized"
884
1091
  metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
1092
+ try:
1093
+ memory_pool = self._get_memory_pool(checkpoint_name)
1094
+ except RuntimeError:
1095
+ memory_pool = []
885
1096
  metas = DataToGather(
886
1097
  memory_buffer_metas_list=[
887
1098
  MemoryBufferMetas(
@@ -889,7 +1100,7 @@ class ParameterServer:
889
1100
  ptr=x.buffer.data_ptr(),
890
1101
  size=x.size,
891
1102
  )
892
- for x in self._memory_pool.get(checkpoint_name, [])
1103
+ for x in memory_pool
893
1104
  ],
894
1105
  p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
895
1106
  host_ip=get_ip(),
@@ -1050,7 +1261,7 @@ class ParameterServer:
1050
1261
  for items in self._current_global_parameter_metas.values():
1051
1262
  for metas_list in items.memory_buffer_metas_list:
1052
1263
  for meta in metas_list.metas:
1053
- max_tensor_bytes = max(max_tensor_bytes, _align_size(meta.dtype, meta.shape))
1264
+ max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
1054
1265
  free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
1055
1266
  if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
1056
1267
  self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
@@ -1095,7 +1306,7 @@ class ParameterServer:
1095
1306
  remote_ptrs.append(ptrs[b.idx][0] + b.offset)
1096
1307
  lens.append(b.size)
1097
1308
  else:
1098
- pool = self._memory_pool[checkpoint_name][b.idx]
1309
+ pool = self._get_memory_pool(checkpoint_name)[b.idx]
1099
1310
  buffer[offset : offset + b.size].data.copy_(
1100
1311
  pool.buffer[b.offset : b.offset + b.size],
1101
1312
  non_blocking=True,
@@ -1158,22 +1369,32 @@ class ParameterServer:
1158
1369
 
1159
1370
  def _register_parameters_to_p2p_store(self, checkpoint_name: str):
1160
1371
  assert self._p2p_store is not None, "p2p store is not initialized"
1161
- pool = self._memory_pool[checkpoint_name]
1372
+ pool = self._get_memory_pool(checkpoint_name)
1162
1373
  if len(pool) == 0:
1163
1374
  return
1164
1375
  named_tensors, tensor_ptrs = {}, []
1376
+ register_name = (
1377
+ checkpoint_name
1378
+ if checkpoint_name != self._current_shared_memory_pool_user
1379
+ else self.shared_memory_pool_name
1380
+ )
1165
1381
  for idx, memory_buffer in enumerate(pool):
1166
- named_tensors[f"memory_pool_{checkpoint_name}_{idx}"] = memory_buffer.buffer
1382
+ named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
1167
1383
  tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
1168
1384
  self._p2p_store.register_named_tensors(named_tensors)
1169
1385
 
1170
1386
  def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
1171
1387
  assert self._p2p_store is not None, "p2p store is not initialized"
1172
- pool = self._memory_pool[checkpoint_name]
1388
+ pool = self._get_memory_pool(checkpoint_name)
1173
1389
  if len(pool) == 0:
1174
1390
  return 0
1391
+ unregister_name = (
1392
+ checkpoint_name
1393
+ if checkpoint_name != self._current_shared_memory_pool_user
1394
+ else self.shared_memory_pool_name
1395
+ )
1175
1396
  return self._p2p_store.unregister_named_tensors(
1176
- [f"memory_pool_{checkpoint_name}_{idx}" for idx, _ in enumerate(pool)]
1397
+ [f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
1177
1398
  )
1178
1399
 
1179
1400
  def _update_per_bucket(
@@ -1284,9 +1505,9 @@ class ParameterServer:
1284
1505
  dist.broadcast(buffer_b, src=brank)
1285
1506
  resp = socket.recv()
1286
1507
  if resp != b"":
1287
- exception_obj = pickle.loads(resp)
1508
+ msg = resp.decode("utf-8")
1288
1509
  logger.error(
1289
- f"[rank{self._rank}] receive error response '{type(exception_obj).__name__}: {exception_obj}' from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}"
1510
+ f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
1290
1511
  )
1291
1512
  ret_code.fill_(1)
1292
1513
  dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
@@ -1,4 +1,5 @@
1
1
  import gc
2
+ import traceback
2
3
  from collections.abc import Callable
3
4
  from typing import TypedDict
4
5
 
@@ -63,7 +64,8 @@ def update_weights_from_ipc(
63
64
  assert buffer.dtype == torch.uint8
64
65
  socket.send(b"")
65
66
  except Exception as e:
66
- socket.send_pyobj(e)
67
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
68
+ socket.send_string(msg)
67
69
  socket.recv() # wait for ack
68
70
  raise
69
71
  try:
@@ -83,7 +85,8 @@ def update_weights_from_ipc(
83
85
  except Exception as e: # noqa: BLE001
84
86
  # Send exception back to Parameter Server.
85
87
  # Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
86
- socket.send_pyobj(e)
88
+ msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
89
+ socket.send_string(msg)
87
90
  elif isinstance(
88
91
  payload, Exception
89
92
  ): # error occurred, got force quit signal from Parameter Server
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: checkpoint-engine
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
5
5
  Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
6
6
  Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
99
99
  pip install 'checkpoint-engine[p2p]'
100
100
  ```
101
101
 
102
- If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
103
-
104
102
  ## Getting Started
105
103
 
106
- Prepare an H800 or H20 machine with 8 GPUs with latest vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights.
104
+ Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
107
105
 
108
106
  ```Bash
109
- cd /opt && git clone https://github.com/vllm-project/vllm && cd vllm
107
+ mkdir -p /opt/vLLM && cd /opt/vLLM
110
108
  uv venv --python 3.12 --seed
111
109
  source .venv/bin/activate
112
- VLLM_USE_PRECOMPILED=1 uv pip install --editable .
110
+ uv pip install vllm==0.10.2
113
111
  ```
114
112
 
115
113
  Install checkpoint-engine
@@ -180,6 +178,11 @@ Other unit tests can also be done with pytest. Only test_update.py requires GPUs
180
178
  pytest tests/ -m "not gpu"
181
179
  ```
182
180
 
181
+ ### Environment Variables
182
+ - `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
183
+ - `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
184
+ - `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
185
+
183
186
  ## SGLang Integration
184
187
 
185
188
  Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
@@ -23,5 +23,6 @@ figures/overlap-update-and-copy.png
23
23
  figures/pipeline.png
24
24
  patches/vllm_fp8.patch
25
25
  tests/test_assign_receiver_ranks.py
26
+ tests/test_pin_memory.py
26
27
  tests/test_rdma_parser.py
27
28
  tests/test_update.py
@@ -100,8 +100,9 @@ def update_weights(
100
100
  update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
101
101
  uds: str | None = None,
102
102
  ):
103
- ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
104
103
  ps.init_process_group()
104
+ dist.barrier()
105
+ ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
105
106
  check_vllm_ready(endpoint, inference_parallel_size, uds)
106
107
  dist.barrier()
107
108
  with timer("Gather metas"):
@@ -173,7 +174,9 @@ if __name__ == "__main__":
173
174
  args.uds,
174
175
  )
175
176
  else:
176
- if os.path.exists(os.path.join(args.checkpoint_path, "model.safetensors.index.json")):
177
+ if os.path.exists(
178
+ os.path.join(args.checkpoint_path, "model.safetensors.index.json")
179
+ ) and not args.checkpoint_path.startswith("/dev/shm/"): # noqa: S108
177
180
  named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
178
181
  checkpoint_files = []
179
182
  else:
@@ -0,0 +1,77 @@
1
+ import os
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from checkpoint_engine.ps import ParameterServer
7
+
8
+
9
+ def generate_dummy_checkpoint() -> dict[str, torch.Tensor]:
10
+ """
11
+ Generate dummy checkpoint data
12
+ """
13
+ named_tensors = {
14
+ "layer1.weight": torch.randn(1024, 1024),
15
+ "layer1.bias": torch.randn(1024),
16
+ "layer2.weight": torch.randn(2048, 1024),
17
+ "layer2.bias": torch.randn(2048),
18
+ }
19
+ return named_tensors
20
+
21
+
22
+ @pytest.mark.gpu
23
+ def test_register_pin_memory():
24
+ os.environ["RANK"] = "0"
25
+ os.environ["WORLD_SIZE"] = "1"
26
+ ps = ParameterServer()
27
+ checkpoint1 = generate_dummy_checkpoint()
28
+ checkpoint_shared1 = generate_dummy_checkpoint()
29
+ checkpoint2 = generate_dummy_checkpoint()
30
+ checkpoint_shared2 = generate_dummy_checkpoint()
31
+ checkpoint_shared3 = generate_dummy_checkpoint()
32
+ checkpoint_shared3["layer3.weight"] = torch.randn(4096, 2048)
33
+ checkpoint_shared3["layer3.bias"] = torch.randn(4096)
34
+ ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1)
35
+ ps.unregister_checkpoint("test_checkpoint1")
36
+ assert "test_checkpoint1" not in ps._memory_pool
37
+ ps.register_checkpoint(
38
+ "test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True
39
+ )
40
+ ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2)
41
+ assert "test_checkpoint_shared1" not in ps._memory_pool
42
+ assert "__shared_memory_pool__" in ps._memory_pool
43
+ assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
44
+ assert "test_checkpoint2" in ps._memory_pool
45
+ try:
46
+ ps.register_checkpoint(
47
+ "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
48
+ ) # this will fail
49
+ except AssertionError:
50
+ print("Caught expected AssertionError when registering second shared memory pool user")
51
+ assert "test_checkpoint_shared2" not in ps._memory_pool
52
+ assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
53
+ ps.unregister_checkpoint("test_checkpoint_shared1")
54
+ assert ps._current_shared_memory_pool_user == ""
55
+ assert "__shared_memory_pool__" in ps._memory_pool
56
+ ps.register_checkpoint(
57
+ "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
58
+ )
59
+ assert "test_checkpoint_shared2" not in ps._memory_pool
60
+ assert "__shared_memory_pool__" in ps._memory_pool
61
+ assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2"
62
+ ps.unregister_checkpoint("test_checkpoint1") # this will trigger an warning
63
+ assert "test_checkpoint1" not in ps._memory_pool
64
+ ps.unregister_checkpoint("test_checkpoint2")
65
+ assert "test_checkpoint2" not in ps._memory_pool
66
+ ps.unregister_checkpoint("test_checkpoint_shared2", force=True)
67
+ assert ps._current_shared_memory_pool_user == ""
68
+ assert "__shared_memory_pool__" in ps._memory_pool
69
+ ps.register_checkpoint(
70
+ "test_checkpoint_shared3", named_tensors=checkpoint_shared3, use_shared_memory_pool=True
71
+ )
72
+ assert "test_checkpoint_shared3" not in ps._memory_pool
73
+ assert "__shared_memory_pool__" in ps._memory_pool
74
+ assert ps._current_shared_memory_pool_user == "test_checkpoint_shared3"
75
+ ps.unregister_checkpoint("test_checkpoint_shared3")
76
+ assert ps._current_shared_memory_pool_user == ""
77
+ assert "__shared_memory_pool__" in ps._memory_pool
@@ -82,7 +82,7 @@ def checker_proc_with_error(
82
82
  try:
83
83
  trigger_error(socket_paths)
84
84
  except RuntimeError as e:
85
- assert str(e) == "Failed to update weights due to remote errors"
85
+ assert str(e) == "Some workers failed to update weights"
86
86
 
87
87
 
88
88
  def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue):
@@ -96,7 +96,7 @@ def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Ten
96
96
  for name, weight in weights:
97
97
  if name not in named_tensors:
98
98
  continue
99
- assert (weight == named_tensors[name]).all()
99
+ assert (weight == named_tensors[name]).all(), f"Tensor {name} does not match!"
100
100
  names_to_check[name] = True
101
101
 
102
102
  def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str, str]]):
@@ -163,6 +163,67 @@ def run(
163
163
  assert proc.exitcode == 0
164
164
 
165
165
 
166
+ def run_with_files(
167
+ checker_func: callable,
168
+ ):
169
+ rank = int(os.getenv("RANK"))
170
+ ctx = get_context("spawn")
171
+ queue = ctx.Queue()
172
+ _device_uuid = _get_physical_gpu_id(device_manager, rank)
173
+ ps = ParameterServer(auto_pg=True)
174
+ _device_uuid = _get_physical_gpu_id(ps.device_manager, rank)
175
+ named_tensors = dict(gen_test_tensors(rank))
176
+
177
+ # Save 1/3 tensors to /dev/shm/ as .safetensors files
178
+ # Save 1/3 tensors to ./tmp (disk) as .safetensors files
179
+ # Keep 1/3 tensors in memory
180
+ import safetensors.torch
181
+
182
+ files = []
183
+ dev_shm_dir = "/dev/shm/checkpoint_engine_tests" # noqa: S108
184
+ disk_dir = "/tmp/checkpoint_engine_tests" # noqa: S108
185
+ os.makedirs(dev_shm_dir, exist_ok=True)
186
+ os.makedirs(disk_dir, exist_ok=True)
187
+ tensors_items = list(named_tensors.items())
188
+ tensors_in_dev_shm = named_tensors
189
+ tensors_in_dev_shm = dict(tensors_items[: len(tensors_items) // 2])
190
+ tensors_in_disk = dict(tensors_items[len(tensors_items) // 3 : 2 * len(tensors_items) // 3])
191
+ tensors_in_memory = dict(tensors_items[1 * len(tensors_items) // 2 :])
192
+ disk_files = [
193
+ os.path.join(disk_dir, f"rank{_rank}_checkpoint.safetensors")
194
+ for _rank in range(get_world_size())
195
+ ]
196
+ safetensors.torch.save_file(tensors_in_disk, disk_files[rank])
197
+ time.sleep(1)
198
+ files.append(disk_files[rank])
199
+ dev_shm_files = [
200
+ os.path.join(dev_shm_dir, f"rank{rank}_checkpoint.safetensors")
201
+ for _ in range(get_world_size())
202
+ ]
203
+ safetensors.torch.save_file(tensors_in_dev_shm, dev_shm_files[rank])
204
+ time.sleep(1)
205
+ files.append(dev_shm_files[rank])
206
+
207
+ checkpoint_name = "test_with_files"
208
+ proc = ctx.Process(target=checker_func, args=(rank, _device_uuid, named_tensors, queue))
209
+ proc.start()
210
+ ps.register_checkpoint(checkpoint_name, named_tensors=tensors_in_memory, files=files)
211
+ ps.gather_metas(checkpoint_name)
212
+ ps.update(checkpoint_name, queue.put, ranks=[])
213
+ # sleep 3s to wait process group is destroyed
214
+ time.sleep(3)
215
+ ps.unregister_checkpoint(checkpoint_name)
216
+ queue.put(None)
217
+ proc.join()
218
+ if rank == 0:
219
+ import shutil
220
+
221
+ # this test should be run under use_inplace_pin_memory=False. Otherwise, the files in /dev/shm/ will be deleted.
222
+ shutil.rmtree(dev_shm_dir)
223
+ shutil.rmtree(disk_dir)
224
+ assert proc.exitcode == 0
225
+
226
+
166
227
  @pytest.mark.gpu
167
228
  @pytest.mark.parametrize(
168
229
  "test_name,rank_list",
@@ -211,6 +272,37 @@ def test_update(test_name: str, rank_list: list[list[int]] | None):
211
272
  assert result.returncode == 0
212
273
 
213
274
 
275
+ @pytest.mark.gpu
276
+ def test_update_with_files(test_name: str = "test_with_files"):
277
+ world_size = device_manager.device_module.device_count()
278
+ assert world_size >= 2, "This test requires at least 2 GPUs."
279
+ master_addr = "localhost"
280
+ master_port = 25400
281
+ cmd = [
282
+ "torchrun",
283
+ "--nproc_per_node",
284
+ str(world_size),
285
+ "--master_addr",
286
+ master_addr,
287
+ "--master_port",
288
+ str(master_port),
289
+ __file__,
290
+ test_name,
291
+ "[]",
292
+ ]
293
+
294
+ result = subprocess.run( # noqa: S603
295
+ cmd,
296
+ capture_output=False,
297
+ text=True,
298
+ cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
299
+ shell=False,
300
+ check=False,
301
+ )
302
+
303
+ assert result.returncode == 0
304
+
305
+
214
306
  if __name__ == "__main__":
215
307
  run_with_pytest = "PYTEST_CURRENT_TEST" in os.environ
216
308
  if not run_with_pytest:
@@ -230,5 +322,7 @@ if __name__ == "__main__":
230
322
  expected_exception=RuntimeError,
231
323
  exception_msg="Failed to update weights due to remote errors",
232
324
  )
325
+ elif test_type == "test_with_files":
326
+ run_with_files(checker_proc)
233
327
  else:
234
328
  raise ValueError(f"Unknown TEST_TYPE: {test_type}")