checkpoint-engine 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/device_utils.py +86 -0
- checkpoint_engine/ps.py +383 -139
- checkpoint_engine/worker.py +89 -30
- {checkpoint_engine-0.2.0.dist-info → checkpoint_engine-0.2.2.dist-info}/METADATA +62 -9
- checkpoint_engine-0.2.2.dist-info/RECORD +10 -0
- checkpoint_engine-0.2.0.dist-info/RECORD +0 -9
- {checkpoint_engine-0.2.0.dist-info → checkpoint_engine-0.2.2.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.2.0.dist-info → checkpoint_engine-0.2.2.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.2.0.dist-info → checkpoint_engine-0.2.2.dist-info}/top_level.txt +0 -0
checkpoint_engine/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.2'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 2)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import socket
|
|
4
|
+
import subprocess
|
|
5
|
+
from functools import lru_cache
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from loguru import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=1)
|
|
12
|
+
def get_ip() -> str:
|
|
13
|
+
try:
|
|
14
|
+
# try to get ip from network interface
|
|
15
|
+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
16
|
+
s.connect(("8.8.8.8", 80))
|
|
17
|
+
return s.getsockname()[0]
|
|
18
|
+
except Exception as e: # noqa: BLE001
|
|
19
|
+
# fallback to get ip from hostname
|
|
20
|
+
logger.warning(
|
|
21
|
+
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
22
|
+
)
|
|
23
|
+
return socket.gethostbyname(socket.gethostname())
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def npu_generate_uuid() -> str:
|
|
27
|
+
str_pid = str(os.getpid())
|
|
28
|
+
npu_num = 8
|
|
29
|
+
try:
|
|
30
|
+
for npu_id in range(npu_num):
|
|
31
|
+
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
|
|
32
|
+
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
|
|
33
|
+
str_result = str(result.stdout)
|
|
34
|
+
if str_pid in str_result:
|
|
35
|
+
# In A3 server, one NPU has two chips.
|
|
36
|
+
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
|
|
37
|
+
chip_count = int(match_chip_count.group(1))
|
|
38
|
+
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
|
|
39
|
+
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
|
|
40
|
+
chip_id = int(match_chip_id.group(1))
|
|
41
|
+
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
|
|
42
|
+
raise ValueError("The current process is not running on the npu device")
|
|
43
|
+
except subprocess.CalledProcessError as e:
|
|
44
|
+
raise ValueError("The current process is not running on the npu device") from e
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DeviceManager:
|
|
48
|
+
def __init__(self):
|
|
49
|
+
self.device_type = self._detect_device_type()
|
|
50
|
+
self._setup_device_module()
|
|
51
|
+
|
|
52
|
+
def _is_torch_npu_available(self) -> bool:
|
|
53
|
+
try:
|
|
54
|
+
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
|
|
55
|
+
return torch.npu.is_available()
|
|
56
|
+
else:
|
|
57
|
+
return False
|
|
58
|
+
except ImportError:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
def _detect_device_type(self) -> str:
|
|
62
|
+
if self._is_torch_npu_available():
|
|
63
|
+
return "npu"
|
|
64
|
+
elif torch.cuda.is_available():
|
|
65
|
+
return "cuda"
|
|
66
|
+
else:
|
|
67
|
+
raise TypeError("The current device type is not supported")
|
|
68
|
+
|
|
69
|
+
def _setup_device_module(self):
|
|
70
|
+
if self.device_type == "npu":
|
|
71
|
+
import torch_npu
|
|
72
|
+
|
|
73
|
+
self.device_module = torch_npu.npu
|
|
74
|
+
elif self.device_type == "cuda":
|
|
75
|
+
self.device_module = torch.cuda
|
|
76
|
+
else:
|
|
77
|
+
raise TypeError("The current device type is not supported")
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def backend(self) -> str:
|
|
81
|
+
if self.device_type == "npu":
|
|
82
|
+
return "hccl"
|
|
83
|
+
elif self.device_type == "cuda":
|
|
84
|
+
return "nccl"
|
|
85
|
+
else:
|
|
86
|
+
raise TypeError("The current device type is not supported")
|
checkpoint_engine/ps.py
CHANGED
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import concurrent.futures
|
|
3
3
|
import ctypes
|
|
4
|
+
import json
|
|
4
5
|
import os
|
|
5
6
|
import pickle
|
|
6
7
|
import random
|
|
7
|
-
import socket
|
|
8
8
|
import threading
|
|
9
9
|
import time
|
|
10
10
|
from collections import defaultdict
|
|
11
11
|
from collections.abc import Callable
|
|
12
12
|
from datetime import timedelta
|
|
13
|
-
from functools import lru_cache
|
|
14
13
|
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
|
|
15
14
|
|
|
16
15
|
import httpx
|
|
@@ -20,9 +19,11 @@ import torch.distributed as dist
|
|
|
20
19
|
import zmq
|
|
21
20
|
from loguru import logger
|
|
22
21
|
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
|
|
23
|
-
from safetensors.torch import safe_open
|
|
22
|
+
from safetensors.torch import _getdtype, safe_open
|
|
24
23
|
from torch.multiprocessing.reductions import reduce_tensor
|
|
25
24
|
|
|
25
|
+
from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid
|
|
26
|
+
|
|
26
27
|
|
|
27
28
|
if TYPE_CHECKING:
|
|
28
29
|
from typing import TypeVar
|
|
@@ -92,6 +93,7 @@ class ParameterMeta(BaseModel):
|
|
|
92
93
|
name: str
|
|
93
94
|
dtype: _TorchDtype
|
|
94
95
|
shape: _TorchSize
|
|
96
|
+
aligned_size: int
|
|
95
97
|
|
|
96
98
|
|
|
97
99
|
class BucketRange(NamedTuple):
|
|
@@ -140,7 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
|
140
142
|
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
|
|
141
143
|
ret = []
|
|
142
144
|
for meta in metas:
|
|
143
|
-
size =
|
|
145
|
+
size = meta.aligned_size
|
|
144
146
|
ret.append(
|
|
145
147
|
{
|
|
146
148
|
"name": meta.name,
|
|
@@ -254,28 +256,16 @@ def _concat_tp_weights(
|
|
|
254
256
|
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
255
257
|
|
|
256
258
|
|
|
257
|
-
def _get_physical_gpu_id(device_index: int | None = None) -> str:
|
|
259
|
+
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
|
|
258
260
|
try:
|
|
259
|
-
|
|
261
|
+
if device_manager.device_type == "npu":
|
|
262
|
+
return f"NPU-{npu_generate_uuid()}"
|
|
263
|
+
else:
|
|
264
|
+
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
|
|
260
265
|
except AssertionError as e:
|
|
261
266
|
raise ValueError(f"fail to get physical gpu id {device_index}") from e
|
|
262
267
|
|
|
263
268
|
|
|
264
|
-
@lru_cache(maxsize=1)
|
|
265
|
-
def _get_ip() -> str:
|
|
266
|
-
try:
|
|
267
|
-
# try to get ip from network interface
|
|
268
|
-
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
|
269
|
-
s.connect(("8.8.8.8", 80))
|
|
270
|
-
return s.getsockname()[0]
|
|
271
|
-
except Exception as e: # noqa: BLE001
|
|
272
|
-
# fallback to get ip from hostname
|
|
273
|
-
logger.warning(
|
|
274
|
-
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
|
|
275
|
-
)
|
|
276
|
-
return socket.gethostbyname(socket.gethostname())
|
|
277
|
-
|
|
278
|
-
|
|
279
269
|
def _ibv_get_device_list() -> list[str]:
|
|
280
270
|
lib = ctypes.CDLL("libibverbs.so.1")
|
|
281
271
|
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
|
|
@@ -317,13 +307,21 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
|
|
|
317
307
|
"""
|
|
318
308
|
if not devices:
|
|
319
309
|
raise RuntimeError("no rdma devices found")
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
310
|
+
try:
|
|
311
|
+
assert len(devices) <= gpu_count, (
|
|
312
|
+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
|
|
313
|
+
)
|
|
314
|
+
assert gpu_count % len(devices) == 0, (
|
|
315
|
+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
|
|
316
|
+
)
|
|
317
|
+
return devices[local_rank // (gpu_count // len(devices))]
|
|
318
|
+
except AssertionError:
|
|
319
|
+
logger.error(
|
|
320
|
+
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
|
|
321
|
+
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
|
|
322
|
+
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
|
|
323
|
+
)
|
|
324
|
+
raise
|
|
327
325
|
|
|
328
326
|
|
|
329
327
|
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
|
|
@@ -426,6 +424,7 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
426
424
|
name=parameter_name,
|
|
427
425
|
shape=meta["shape"],
|
|
428
426
|
dtype=meta["dtype"],
|
|
427
|
+
aligned_size=_align_size(meta["dtype"], meta["shape"]),
|
|
429
428
|
)
|
|
430
429
|
tp_meta = tp_metas[parameter_name]
|
|
431
430
|
if tp_meta.concat_dim != -1:
|
|
@@ -435,7 +434,10 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
435
434
|
shape = list(parameter_metas[name].shape)
|
|
436
435
|
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
437
436
|
parameter_metas[name] = ParameterMeta(
|
|
438
|
-
name=name,
|
|
437
|
+
name=name,
|
|
438
|
+
shape=torch.Size(shape),
|
|
439
|
+
dtype=parameter_metas[name].dtype,
|
|
440
|
+
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
|
|
439
441
|
)
|
|
440
442
|
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
441
443
|
# TODO: here concat is serial, which may be slow
|
|
@@ -453,17 +455,85 @@ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
|
453
455
|
return parameters
|
|
454
456
|
|
|
455
457
|
|
|
456
|
-
def
|
|
457
|
-
|
|
458
|
+
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
459
|
+
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
460
|
+
"""
|
|
461
|
+
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
462
|
+
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
|
|
463
|
+
The actual tensor data is in the remaining bytes and is naturally aligned.
|
|
464
|
+
We pin the remaining bytes as the buffer, making pinning faster.
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
def _pin(t: torch.Tensor):
|
|
468
|
+
"""
|
|
469
|
+
Pin the memory of tensor in-place.
|
|
470
|
+
See: https://github.com/pytorch/pytorch/issues/32167
|
|
471
|
+
"""
|
|
472
|
+
cudart = torch.cuda.cudart()
|
|
473
|
+
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
474
|
+
assert r == 0, f"pin memory error, error code: {r}"
|
|
475
|
+
|
|
476
|
+
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
477
|
+
size = os.stat(file_path).st_size
|
|
478
|
+
flag_size = 8
|
|
479
|
+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
|
|
480
|
+
assert t.nbytes > flag_size, (
|
|
481
|
+
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
|
|
482
|
+
)
|
|
483
|
+
start_pos = (
|
|
484
|
+
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
|
|
485
|
+
+ flag_size
|
|
486
|
+
)
|
|
487
|
+
header_tensor = t[flag_size:start_pos]
|
|
488
|
+
header = json.loads(header_tensor.numpy().tobytes())
|
|
489
|
+
if "__metadata__" in header:
|
|
490
|
+
header.pop("__metadata__")
|
|
491
|
+
|
|
492
|
+
metas: list[ParameterMeta] = []
|
|
493
|
+
offset = 0
|
|
494
|
+
try:
|
|
495
|
+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
|
|
496
|
+
start, end = meta["data_offsets"]
|
|
497
|
+
# safetensors format ensures offsets are aligned
|
|
498
|
+
assert offset == start, f"offset {offset} should be equal to start {start}"
|
|
499
|
+
metas.append(
|
|
500
|
+
ParameterMeta(
|
|
501
|
+
name=name,
|
|
502
|
+
dtype=_getdtype(meta["dtype"]),
|
|
503
|
+
shape=torch.Size(meta["shape"]),
|
|
504
|
+
aligned_size=end - start,
|
|
505
|
+
)
|
|
506
|
+
)
|
|
507
|
+
offset = end
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
|
|
510
|
+
raise
|
|
511
|
+
|
|
512
|
+
buffer = t[start_pos:]
|
|
513
|
+
assert offset == buffer.nbytes, (
|
|
514
|
+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
|
|
515
|
+
)
|
|
516
|
+
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
517
|
+
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
518
|
+
os.remove(file_path)
|
|
519
|
+
_pin(buffer)
|
|
520
|
+
logger.info(
|
|
521
|
+
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
522
|
+
)
|
|
523
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
|
|
524
|
+
|
|
525
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
526
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
527
|
+
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
|
|
528
|
+
return memory_buffers
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
def _normal_pin_memory(
|
|
458
532
|
files: list[str],
|
|
459
533
|
named_tensors: dict[str, torch.Tensor],
|
|
460
534
|
rank: int | None = None,
|
|
535
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
461
536
|
) -> list[MemoryBuffer]:
|
|
462
|
-
logger.info(
|
|
463
|
-
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
464
|
-
)
|
|
465
|
-
if not files and not named_tensors:
|
|
466
|
-
return []
|
|
467
537
|
parameters = _load_checkpoint(files)
|
|
468
538
|
if named_tensors:
|
|
469
539
|
parameters.update(named_tensors)
|
|
@@ -473,13 +543,16 @@ def _register_checkpoint(
|
|
|
473
543
|
size: int
|
|
474
544
|
metas: list[ParameterMeta]
|
|
475
545
|
|
|
476
|
-
buckets: list[MemoryBucket] = [
|
|
546
|
+
buckets: list[MemoryBucket] = []
|
|
547
|
+
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
477
548
|
for name, tensor in sorted(parameters.items()):
|
|
478
549
|
size = _align_size(tensor.dtype, tensor.shape)
|
|
479
550
|
if buckets[-1].size + size > bucket_size:
|
|
480
551
|
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
|
|
481
552
|
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
482
|
-
buckets[-1].metas.append(
|
|
553
|
+
buckets[-1].metas.append(
|
|
554
|
+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
|
|
555
|
+
)
|
|
483
556
|
buckets[-1].size += size
|
|
484
557
|
|
|
485
558
|
memory_buffers = [
|
|
@@ -487,16 +560,34 @@ def _register_checkpoint(
|
|
|
487
560
|
for bucket in buckets
|
|
488
561
|
]
|
|
489
562
|
|
|
490
|
-
def register_pin_memory(
|
|
491
|
-
|
|
492
|
-
|
|
563
|
+
def register_pin_memory(
|
|
564
|
+
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
|
|
565
|
+
) -> tuple[int, torch.Tensor]:
|
|
566
|
+
if shared_pin_memory:
|
|
567
|
+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
|
|
568
|
+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
|
|
569
|
+
assert idx < len(shared_pin_memory), (
|
|
570
|
+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
|
|
571
|
+
)
|
|
572
|
+
assert shared_pin_memory[idx].size == size, (
|
|
573
|
+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
|
|
574
|
+
)
|
|
575
|
+
return idx, shared_pin_memory[idx].buffer
|
|
576
|
+
else:
|
|
577
|
+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
578
|
+
return idx, buffer
|
|
493
579
|
|
|
494
580
|
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
|
|
495
581
|
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
496
582
|
|
|
497
583
|
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
498
584
|
futures = [
|
|
499
|
-
executor.submit(
|
|
585
|
+
executor.submit(
|
|
586
|
+
register_pin_memory,
|
|
587
|
+
idx,
|
|
588
|
+
bucket.size,
|
|
589
|
+
shared_pin_memory,
|
|
590
|
+
)
|
|
500
591
|
for idx, bucket in enumerate(buckets)
|
|
501
592
|
]
|
|
502
593
|
new_futures = []
|
|
@@ -522,6 +613,39 @@ def _register_checkpoint(
|
|
|
522
613
|
offset += size
|
|
523
614
|
for future in concurrent.futures.as_completed(new_futures):
|
|
524
615
|
future.result()
|
|
616
|
+
return memory_buffers
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def _register_checkpoint(
|
|
620
|
+
*,
|
|
621
|
+
files: list[str],
|
|
622
|
+
named_tensors: dict[str, torch.Tensor],
|
|
623
|
+
rank: int | None = None,
|
|
624
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
625
|
+
) -> list[MemoryBuffer]:
|
|
626
|
+
logger.info(
|
|
627
|
+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
628
|
+
)
|
|
629
|
+
if not files and not named_tensors:
|
|
630
|
+
return []
|
|
631
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
632
|
+
files_to_inplace_pin = [
|
|
633
|
+
file
|
|
634
|
+
for file in files
|
|
635
|
+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
636
|
+
]
|
|
637
|
+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
638
|
+
if files_to_normal_pin or named_tensors:
|
|
639
|
+
memory_buffers.extend(
|
|
640
|
+
_normal_pin_memory(
|
|
641
|
+
files=files_to_normal_pin,
|
|
642
|
+
named_tensors=named_tensors,
|
|
643
|
+
rank=rank,
|
|
644
|
+
shared_pin_memory=shared_pin_memory,
|
|
645
|
+
)
|
|
646
|
+
)
|
|
647
|
+
if files_to_inplace_pin:
|
|
648
|
+
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
|
|
525
649
|
return memory_buffers
|
|
526
650
|
|
|
527
651
|
|
|
@@ -570,7 +694,7 @@ def _gen_h2d_buckets(
|
|
|
570
694
|
for idx, metas in enumerate(items.memory_buffer_metas_list):
|
|
571
695
|
start_offset, offset = 0, 0
|
|
572
696
|
for meta in metas.metas:
|
|
573
|
-
s =
|
|
697
|
+
s = meta.aligned_size
|
|
574
698
|
if buckets[-1][1].size + s > bucket_size:
|
|
575
699
|
if offset - start_offset > 0:
|
|
576
700
|
buckets[-1][1].ranges.append(
|
|
@@ -677,20 +801,29 @@ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, i
|
|
|
677
801
|
|
|
678
802
|
|
|
679
803
|
class P2PStore:
|
|
680
|
-
def __init__(self):
|
|
804
|
+
def __init__(self, device_manager: DeviceManager):
|
|
681
805
|
from mooncake.engine import TransferEngine
|
|
682
806
|
|
|
683
807
|
self.rank = int(os.getenv("RANK"))
|
|
684
|
-
gpu_count =
|
|
808
|
+
gpu_count = device_manager.device_module.device_count()
|
|
685
809
|
local_rank = self.rank % gpu_count
|
|
686
|
-
|
|
687
|
-
|
|
810
|
+
device_type = device_manager.device_type
|
|
811
|
+
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
|
|
812
|
+
self.device = ""
|
|
813
|
+
else:
|
|
814
|
+
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
|
|
815
|
+
self.ip = get_ip()
|
|
688
816
|
|
|
689
817
|
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
|
|
690
818
|
retry_count = 8
|
|
691
819
|
for i in range(retry_count):
|
|
692
820
|
self.engine = TransferEngine()
|
|
693
|
-
ret = self.engine.initialize(
|
|
821
|
+
ret = self.engine.initialize(
|
|
822
|
+
self.ip,
|
|
823
|
+
"P2PHANDSHAKE",
|
|
824
|
+
"ascend_direct" if device_type == "npu" else "rdma",
|
|
825
|
+
self.device,
|
|
826
|
+
)
|
|
694
827
|
if ret == 0:
|
|
695
828
|
break
|
|
696
829
|
# sleep 0.5 ~ 2.0s, to avoid port conflicts when two processes retry at the same time
|
|
@@ -742,6 +875,8 @@ class P2PStore:
|
|
|
742
875
|
|
|
743
876
|
|
|
744
877
|
class ParameterServer:
|
|
878
|
+
shared_memory_pool_name = "__shared_memory_pool__"
|
|
879
|
+
|
|
745
880
|
def __init__(
|
|
746
881
|
self,
|
|
747
882
|
*,
|
|
@@ -757,11 +892,12 @@ class ParameterServer:
|
|
|
757
892
|
Args:
|
|
758
893
|
auto_pg: Whether to automatically initialize the process group.
|
|
759
894
|
Notice that if auto_pg is True, will destroy the process group after update.
|
|
760
|
-
mem_fraction: The proportion (as a fraction) of the current free
|
|
895
|
+
mem_fraction: The proportion (as a fraction) of the current free device memory for allocation.
|
|
761
896
|
"""
|
|
762
897
|
self._rank = rank or int(os.environ.get("RANK", None))
|
|
763
898
|
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
|
|
764
|
-
self.
|
|
899
|
+
self.device_manager = DeviceManager()
|
|
900
|
+
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
|
|
765
901
|
self._local_rank = self._rank % self._gpu_count
|
|
766
902
|
self._auto_pg = auto_pg
|
|
767
903
|
self._all_hosts = []
|
|
@@ -775,7 +911,7 @@ class ParameterServer:
|
|
|
775
911
|
assert (
|
|
776
912
|
self._gpu_count is not None
|
|
777
913
|
and self._gpu_count > 0
|
|
778
|
-
and self._gpu_count <=
|
|
914
|
+
and self._gpu_count <= self.device_manager.device_module.device_count()
|
|
779
915
|
), self._gpu_count
|
|
780
916
|
assert (
|
|
781
917
|
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
|
|
@@ -784,20 +920,35 @@ class ParameterServer:
|
|
|
784
920
|
self._zmq_ctx = zmq.Context()
|
|
785
921
|
self._zmq_addr_counter = 0
|
|
786
922
|
|
|
923
|
+
# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
|
|
924
|
+
self._current_shared_memory_pool_user: str = ""
|
|
787
925
|
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
|
|
926
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
788
927
|
# dict key is owner_rank, value is a bucket metas list in owner_rank
|
|
789
928
|
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
|
|
929
|
+
# NPU transfer engine initialization requires prior set_device.
|
|
930
|
+
device_index = self._local_rank
|
|
931
|
+
self.device_manager.device_module.set_device(device_index)
|
|
790
932
|
try:
|
|
791
|
-
self._p2p_store = P2PStore()
|
|
933
|
+
self._p2p_store = P2PStore(self.device_manager)
|
|
792
934
|
except ImportError as e:
|
|
793
935
|
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
|
|
794
936
|
self._p2p_store = None
|
|
795
937
|
|
|
796
|
-
|
|
797
|
-
torch.cuda.set_device(device_index)
|
|
798
|
-
self._device_uuid = _get_physical_gpu_id(device_index)
|
|
938
|
+
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
|
|
799
939
|
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device
|
|
800
940
|
|
|
941
|
+
def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
|
|
942
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
943
|
+
assert self._memory_pool[self.shared_memory_pool_name], (
|
|
944
|
+
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
|
|
945
|
+
)
|
|
946
|
+
return self._memory_pool[self.shared_memory_pool_name]
|
|
947
|
+
elif checkpoint_name in self._memory_pool:
|
|
948
|
+
return self._memory_pool[checkpoint_name]
|
|
949
|
+
else:
|
|
950
|
+
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")
|
|
951
|
+
|
|
801
952
|
def _logger_rank0(self, msg: str):
|
|
802
953
|
if self._local_rank == 0:
|
|
803
954
|
logger.info(msg)
|
|
@@ -821,46 +972,97 @@ class ParameterServer:
|
|
|
821
972
|
*,
|
|
822
973
|
files: list[str] | None = None,
|
|
823
974
|
named_tensors: dict[str, torch.Tensor] | None = None,
|
|
975
|
+
use_shared_memory_pool: bool = False,
|
|
824
976
|
) -> None:
|
|
825
977
|
"""
|
|
826
978
|
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
|
|
979
|
+
Warning: .safetensors files in /dev/shm/ will be pinned in-place, and the files will be REMOVED after pinning.
|
|
980
|
+
Please make sure to copy the files to disks if you need to keep them.
|
|
827
981
|
|
|
828
982
|
Args:
|
|
829
983
|
checkpoint_name: The name of the checkpoint.
|
|
830
984
|
files: The safetensors files to register.
|
|
831
985
|
named_tensors: The named tensors to register.
|
|
986
|
+
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
|
|
987
|
+
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
|
|
988
|
+
cannot accommodate checkpoints with different memory requirements.
|
|
989
|
+
To free the actual memory of the shared pool or to modify its shape,
|
|
990
|
+
please unregister the current user of the shared memory pool using `unregister_checkpoint` with `force=True`.
|
|
832
991
|
"""
|
|
833
992
|
try:
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
993
|
+
if use_shared_memory_pool:
|
|
994
|
+
logger.info(
|
|
995
|
+
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
|
|
996
|
+
)
|
|
997
|
+
assert self._current_shared_memory_pool_user == "", (
|
|
998
|
+
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
|
|
999
|
+
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
|
|
1000
|
+
f"This registration may cause unexpected conflicts."
|
|
1001
|
+
)
|
|
1002
|
+
# Since we set the uninitialized shared memory pool to empty list,
|
|
1003
|
+
# we can check whether this is the first time to use shared memory pool
|
|
1004
|
+
_is_first_time = not self._memory_pool[self.shared_memory_pool_name]
|
|
1005
|
+
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
|
|
1006
|
+
files=files or [],
|
|
1007
|
+
named_tensors=named_tensors or {},
|
|
1008
|
+
rank=self._rank,
|
|
1009
|
+
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
|
|
1010
|
+
)
|
|
1011
|
+
self._current_shared_memory_pool_user = checkpoint_name
|
|
1012
|
+
if self._p2p_store is not None and _is_first_time:
|
|
1013
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
1014
|
+
else:
|
|
1015
|
+
assert checkpoint_name not in self._memory_pool, (
|
|
1016
|
+
f"checkpoint {checkpoint_name} already registered"
|
|
1017
|
+
)
|
|
1018
|
+
self._memory_pool[checkpoint_name] = _register_checkpoint(
|
|
1019
|
+
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
|
|
1020
|
+
)
|
|
1021
|
+
if self._p2p_store is not None:
|
|
1022
|
+
self._register_parameters_to_p2p_store(checkpoint_name)
|
|
842
1023
|
except Exception:
|
|
843
1024
|
logger.exception(
|
|
844
1025
|
f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}"
|
|
845
1026
|
)
|
|
846
|
-
if self._p2p_store is not None:
|
|
1027
|
+
if self._p2p_store is not None and not use_shared_memory_pool:
|
|
847
1028
|
self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
848
1029
|
self.unregister_checkpoint(checkpoint_name)
|
|
849
1030
|
raise
|
|
850
1031
|
|
|
851
|
-
def unregister_checkpoint(self, checkpoint_name: str):
|
|
1032
|
+
def unregister_checkpoint(self, checkpoint_name: str, force: bool = False) -> None:
|
|
852
1033
|
"""
|
|
853
1034
|
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
|
|
854
1035
|
from p2p store if p2p store is initialized.
|
|
1036
|
+
Args:
|
|
1037
|
+
checkpoint_name: The name of the checkpoint.
|
|
1038
|
+
force: This flag is designed for shared memory pool user. If True, the memory for shared memory pool itself will be freed.
|
|
1039
|
+
If False, only the checkpoint name will be unregistered, and the shared memory pool will be kept for future use.
|
|
855
1040
|
"""
|
|
856
|
-
if
|
|
1041
|
+
if (
|
|
1042
|
+
checkpoint_name not in self._memory_pool
|
|
1043
|
+
and checkpoint_name != self._current_shared_memory_pool_user
|
|
1044
|
+
):
|
|
1045
|
+
logger.warning(
|
|
1046
|
+
f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found"
|
|
1047
|
+
)
|
|
857
1048
|
return
|
|
1049
|
+
|
|
1050
|
+
if checkpoint_name == self._current_shared_memory_pool_user and not force:
|
|
1051
|
+
self._current_shared_memory_pool_user = ""
|
|
1052
|
+
return
|
|
1053
|
+
|
|
858
1054
|
if self._p2p_store is not None:
|
|
859
1055
|
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
|
|
860
1056
|
logger.info(
|
|
861
1057
|
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
|
|
862
1058
|
)
|
|
863
|
-
|
|
1059
|
+
|
|
1060
|
+
if checkpoint_name == self._current_shared_memory_pool_user:
|
|
1061
|
+
self._current_shared_memory_pool_user = ""
|
|
1062
|
+
del self._memory_pool[self.shared_memory_pool_name]
|
|
1063
|
+
self._memory_pool[self.shared_memory_pool_name] = []
|
|
1064
|
+
else:
|
|
1065
|
+
del self._memory_pool[checkpoint_name]
|
|
864
1066
|
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
|
|
865
1067
|
# this works by using torch>=2.5.0
|
|
866
1068
|
torch._C._host_emptyCache()
|
|
@@ -875,6 +1077,10 @@ class ParameterServer:
|
|
|
875
1077
|
self.init_process_group()
|
|
876
1078
|
assert dist.is_initialized(), "process group is not initialized"
|
|
877
1079
|
metas_lst: list[DataToGather | None] = [None for _ in range(self._world_size)] # type: ignore
|
|
1080
|
+
try:
|
|
1081
|
+
memory_pool = self._get_memory_pool(checkpoint_name)
|
|
1082
|
+
except RuntimeError:
|
|
1083
|
+
memory_pool = []
|
|
878
1084
|
metas = DataToGather(
|
|
879
1085
|
memory_buffer_metas_list=[
|
|
880
1086
|
MemoryBufferMetas(
|
|
@@ -882,16 +1088,18 @@ class ParameterServer:
|
|
|
882
1088
|
ptr=x.buffer.data_ptr(),
|
|
883
1089
|
size=x.size,
|
|
884
1090
|
)
|
|
885
|
-
for x in
|
|
1091
|
+
for x in memory_pool
|
|
886
1092
|
],
|
|
887
1093
|
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
|
|
888
|
-
host_ip=
|
|
1094
|
+
host_ip=get_ip(),
|
|
889
1095
|
device_uuid=self._device_uuid,
|
|
890
1096
|
rdma_device=self._rdma_device or "",
|
|
891
1097
|
)
|
|
892
1098
|
|
|
893
1099
|
dist.all_gather_object(metas_lst, metas)
|
|
894
1100
|
|
|
1101
|
+
self._current_global_parameter_metas = {}
|
|
1102
|
+
|
|
895
1103
|
num_parameters = 0
|
|
896
1104
|
all_hosts: list[str] = []
|
|
897
1105
|
global_device_uuids: list[str] = []
|
|
@@ -948,7 +1156,7 @@ class ParameterServer:
|
|
|
948
1156
|
is_master=self._rank == 0,
|
|
949
1157
|
)
|
|
950
1158
|
dist.init_process_group(
|
|
951
|
-
backend=
|
|
1159
|
+
backend=self.device_manager.backend,
|
|
952
1160
|
world_size=self._world_size,
|
|
953
1161
|
rank=self._rank,
|
|
954
1162
|
timeout=timeout,
|
|
@@ -991,21 +1199,22 @@ class ParameterServer:
|
|
|
991
1199
|
if self._rank not in ranks:
|
|
992
1200
|
return
|
|
993
1201
|
self._update_per_bucket(checkpoint_name, req_func, ranks)
|
|
994
|
-
if self._auto_pg:
|
|
995
|
-
dist.destroy_process_group()
|
|
996
|
-
|
|
997
|
-
torch.cuda.empty_cache()
|
|
998
1202
|
|
|
999
|
-
logger.info(
|
|
1000
|
-
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
1001
|
-
f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
|
|
1002
|
-
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
|
|
1003
|
-
)
|
|
1004
1203
|
except Exception as e:
|
|
1005
1204
|
logger.exception(
|
|
1006
1205
|
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
|
|
1007
1206
|
)
|
|
1008
1207
|
raise
|
|
1208
|
+
finally:
|
|
1209
|
+
if self._auto_pg and (not ranks or self._rank in ranks):
|
|
1210
|
+
dist.destroy_process_group()
|
|
1211
|
+
|
|
1212
|
+
self.device_manager.device_module.empty_cache()
|
|
1213
|
+
logger.info(
|
|
1214
|
+
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
|
|
1215
|
+
f"Current device allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
|
|
1216
|
+
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
|
|
1217
|
+
)
|
|
1009
1218
|
|
|
1010
1219
|
def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
|
|
1011
1220
|
def zmq_handle(device_uuid: str) -> str:
|
|
@@ -1022,14 +1231,16 @@ class ParameterServer:
|
|
|
1022
1231
|
# auto detect bucket size
|
|
1023
1232
|
tensor = torch.tensor(
|
|
1024
1233
|
[
|
|
1025
|
-
# proportion of current
|
|
1026
|
-
int(
|
|
1234
|
+
# proportion of current device free memory bytes
|
|
1235
|
+
int(
|
|
1236
|
+
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
|
|
1237
|
+
),
|
|
1027
1238
|
# we use negative value to reuse allreduce min operation
|
|
1028
1239
|
# for getting the max value of zmq_addr_counter in all ranks
|
|
1029
1240
|
-self._zmq_addr_counter,
|
|
1030
1241
|
],
|
|
1031
1242
|
dtype=torch.int64,
|
|
1032
|
-
device=
|
|
1243
|
+
device=self.device_manager.device_type,
|
|
1033
1244
|
)
|
|
1034
1245
|
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
|
|
1035
1246
|
tensor = tensor.cpu()
|
|
@@ -1038,7 +1249,7 @@ class ParameterServer:
|
|
|
1038
1249
|
for items in self._current_global_parameter_metas.values():
|
|
1039
1250
|
for metas_list in items.memory_buffer_metas_list:
|
|
1040
1251
|
for meta in metas_list.metas:
|
|
1041
|
-
max_tensor_bytes = max(max_tensor_bytes,
|
|
1252
|
+
max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
|
|
1042
1253
|
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
|
|
1043
1254
|
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
|
|
1044
1255
|
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")
|
|
@@ -1083,7 +1294,7 @@ class ParameterServer:
|
|
|
1083
1294
|
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
|
|
1084
1295
|
lens.append(b.size)
|
|
1085
1296
|
else:
|
|
1086
|
-
pool = self.
|
|
1297
|
+
pool = self._get_memory_pool(checkpoint_name)[b.idx]
|
|
1087
1298
|
buffer[offset : offset + b.size].data.copy_(
|
|
1088
1299
|
pool.buffer[b.offset : b.offset + b.size],
|
|
1089
1300
|
non_blocking=True,
|
|
@@ -1092,7 +1303,7 @@ class ParameterServer:
|
|
|
1092
1303
|
assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
|
|
1093
1304
|
if owner_rank is not None:
|
|
1094
1305
|
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
|
|
1095
|
-
|
|
1306
|
+
self.device_manager.device_module.synchronize()
|
|
1096
1307
|
|
|
1097
1308
|
def init_process_group_for_ranks(
|
|
1098
1309
|
self,
|
|
@@ -1132,7 +1343,11 @@ class ParameterServer:
|
|
|
1132
1343
|
master_addr, master_port, len(ranks), is_master=rank == 0, timeout=timeout
|
|
1133
1344
|
)
|
|
1134
1345
|
dist.init_process_group(
|
|
1135
|
-
backend=
|
|
1346
|
+
backend=self.device_manager.backend,
|
|
1347
|
+
world_size=len(ranks),
|
|
1348
|
+
rank=rank,
|
|
1349
|
+
timeout=timeout,
|
|
1350
|
+
store=store,
|
|
1136
1351
|
)
|
|
1137
1352
|
|
|
1138
1353
|
def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:
|
|
@@ -1142,22 +1357,32 @@ class ParameterServer:
|
|
|
1142
1357
|
|
|
1143
1358
|
def _register_parameters_to_p2p_store(self, checkpoint_name: str):
|
|
1144
1359
|
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1145
|
-
pool = self.
|
|
1360
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1146
1361
|
if len(pool) == 0:
|
|
1147
1362
|
return
|
|
1148
1363
|
named_tensors, tensor_ptrs = {}, []
|
|
1364
|
+
register_name = (
|
|
1365
|
+
checkpoint_name
|
|
1366
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1367
|
+
else self.shared_memory_pool_name
|
|
1368
|
+
)
|
|
1149
1369
|
for idx, memory_buffer in enumerate(pool):
|
|
1150
|
-
named_tensors[f"memory_pool_{
|
|
1370
|
+
named_tensors[f"memory_pool_{register_name}_{idx}"] = memory_buffer.buffer
|
|
1151
1371
|
tensor_ptrs.append((memory_buffer.buffer.data_ptr(), memory_buffer.size))
|
|
1152
1372
|
self._p2p_store.register_named_tensors(named_tensors)
|
|
1153
1373
|
|
|
1154
1374
|
def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
|
|
1155
1375
|
assert self._p2p_store is not None, "p2p store is not initialized"
|
|
1156
|
-
pool = self.
|
|
1376
|
+
pool = self._get_memory_pool(checkpoint_name)
|
|
1157
1377
|
if len(pool) == 0:
|
|
1158
1378
|
return 0
|
|
1379
|
+
unregister_name = (
|
|
1380
|
+
checkpoint_name
|
|
1381
|
+
if checkpoint_name != self._current_shared_memory_pool_user
|
|
1382
|
+
else self.shared_memory_pool_name
|
|
1383
|
+
)
|
|
1159
1384
|
return self._p2p_store.unregister_named_tensors(
|
|
1160
|
-
[f"memory_pool_{
|
|
1385
|
+
[f"memory_pool_{unregister_name}_{idx}" for idx, _ in enumerate(pool)]
|
|
1161
1386
|
)
|
|
1162
1387
|
|
|
1163
1388
|
def _update_per_bucket(
|
|
@@ -1184,7 +1409,7 @@ class ParameterServer:
|
|
|
1184
1409
|
|
|
1185
1410
|
if not need_update:
|
|
1186
1411
|
return
|
|
1187
|
-
# first execute a barrier to avoid subsequent
|
|
1412
|
+
# first execute a barrier to avoid subsequent device oom
|
|
1188
1413
|
dist.barrier()
|
|
1189
1414
|
|
|
1190
1415
|
bucket_size, disable_h2d_buffer = self._detect_bucket_size()
|
|
@@ -1199,7 +1424,7 @@ class ParameterServer:
|
|
|
1199
1424
|
h2d_buffer: torch.Tensor | None = (
|
|
1200
1425
|
None
|
|
1201
1426
|
if disable_h2d_buffer
|
|
1202
|
-
else torch.empty(bucket_size, dtype=torch.uint8, device=
|
|
1427
|
+
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
|
|
1203
1428
|
)
|
|
1204
1429
|
# p2p store need to register h2d_buffer to let other ranks read
|
|
1205
1430
|
if ranks:
|
|
@@ -1212,7 +1437,9 @@ class ParameterServer:
|
|
|
1212
1437
|
continue
|
|
1213
1438
|
receiver_rank_buckets.append((owner_rank, bucket))
|
|
1214
1439
|
|
|
1215
|
-
buffer = torch.empty(
|
|
1440
|
+
buffer = torch.empty(
|
|
1441
|
+
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
|
|
1442
|
+
)
|
|
1216
1443
|
handle = reduce_tensor(buffer)
|
|
1217
1444
|
|
|
1218
1445
|
buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
|
|
@@ -1231,52 +1458,66 @@ class ParameterServer:
|
|
|
1231
1458
|
socket.send_pyobj(handle)
|
|
1232
1459
|
|
|
1233
1460
|
gidx = 0
|
|
1461
|
+
ret_code = torch.zeros((), device=self.device_manager.device_type, dtype=torch.int64)
|
|
1234
1462
|
bcast_rank_map = _get_bcast_rank_map(self._world_size, ranks)
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
if
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1463
|
+
try:
|
|
1464
|
+
for i in range(max_len):
|
|
1465
|
+
if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
|
|
1466
|
+
self._copy_to_buffer(
|
|
1467
|
+
checkpoint_name,
|
|
1468
|
+
receiver_rank_buckets[i][1],
|
|
1469
|
+
h2d_buffer,
|
|
1470
|
+
receiver_rank_buckets[i][0] if ranks else None,
|
|
1471
|
+
)
|
|
1472
|
+
for receiver_rank, _buckets in buckets_by_receiver_rank.items():
|
|
1473
|
+
if i >= len(_buckets):
|
|
1474
|
+
continue
|
|
1475
|
+
bucket = _buckets[i]
|
|
1476
|
+
alloc, reserved = (
|
|
1477
|
+
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
|
|
1478
|
+
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
|
|
1479
|
+
)
|
|
1480
|
+
self._logger_rank0(
|
|
1481
|
+
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
|
|
1482
|
+
f"Current device allocated {alloc:.2f} MB, "
|
|
1483
|
+
f"reserved {reserved:.2f} MB."
|
|
1484
|
+
)
|
|
1485
|
+
start = gidx % 2 * bucket_size
|
|
1486
|
+
buffer_b: torch.Tensor = buffer[start : start + bucket.size]
|
|
1487
|
+
if receiver_rank == self._rank:
|
|
1488
|
+
if disable_h2d_buffer:
|
|
1489
|
+
self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
|
|
1490
|
+
else:
|
|
1491
|
+
buffer_b.data.copy_(h2d_buffer[: bucket.size])
|
|
1492
|
+
brank = bcast_rank_map[receiver_rank]
|
|
1493
|
+
dist.broadcast(buffer_b, src=brank)
|
|
1494
|
+
resp = socket.recv()
|
|
1495
|
+
if resp != b"":
|
|
1496
|
+
msg = resp.decode("utf-8")
|
|
1497
|
+
logger.error(
|
|
1498
|
+
f"[rank{self._rank}] receive error response from rank {receiver_rank} for bucket {gidx} in checkpoint {checkpoint_name}: {msg}"
|
|
1499
|
+
)
|
|
1500
|
+
ret_code.fill_(1)
|
|
1501
|
+
dist.all_reduce(ret_code, op=dist.ReduceOp.SUM)
|
|
1502
|
+
self.device_manager.device_module.synchronize()
|
|
1503
|
+
if ret_code.item() != 0:
|
|
1504
|
+
# quit early if any rank failed
|
|
1505
|
+
socket.send_pyobj(RuntimeError("Some workers failed to update weights"))
|
|
1506
|
+
raise RuntimeError("Failed to update weights due to remote errors")
|
|
1507
|
+
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
|
|
1508
|
+
gidx += 1
|
|
1509
|
+
|
|
1510
|
+
socket.recv()
|
|
1511
|
+
socket.send_pyobj(None)
|
|
1512
|
+
socket.recv()
|
|
1513
|
+
finally:
|
|
1514
|
+
req_thread.join()
|
|
1515
|
+
dist.barrier()
|
|
1516
|
+
socket.close()
|
|
1517
|
+
if ranks and h2d_buffer is not None:
|
|
1518
|
+
self._p2p_store.unregister_named_tensors([h2d_buffer_name])
|
|
1519
|
+
|
|
1520
|
+
self.device_manager.device_module.empty_cache()
|
|
1280
1521
|
|
|
1281
1522
|
|
|
1282
1523
|
def _init_api(ps: ParameterServer) -> Any:
|
|
@@ -1294,6 +1535,7 @@ def _init_api(ps: ParameterServer) -> Any:
|
|
|
1294
1535
|
update_url: str | None = None
|
|
1295
1536
|
inference_group_ranks: list[int] = []
|
|
1296
1537
|
timeout: float = 300.0
|
|
1538
|
+
uds: str | None = None
|
|
1297
1539
|
|
|
1298
1540
|
def wrap_exception(func: Callable[[], None]) -> Response:
|
|
1299
1541
|
try:
|
|
@@ -1326,7 +1568,9 @@ def _init_api(ps: ParameterServer) -> Any:
|
|
|
1326
1568
|
return
|
|
1327
1569
|
if req.inference_group_ranks:
|
|
1328
1570
|
socket_paths = [socket_paths[i] for i in req.inference_group_ranks]
|
|
1329
|
-
request_inference_to_update(
|
|
1571
|
+
request_inference_to_update(
|
|
1572
|
+
req.update_url, dict(socket_paths), timeout=req.timeout, uds=req.uds
|
|
1573
|
+
)
|
|
1330
1574
|
|
|
1331
1575
|
return wrap_exception(lambda: ps.update(checkpoint_name, update_func, ranks=req.ranks))
|
|
1332
1576
|
|
checkpoint_engine/worker.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
import traceback
|
|
2
3
|
from collections.abc import Callable
|
|
3
4
|
from typing import TypedDict
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
import zmq
|
|
7
8
|
|
|
9
|
+
from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid
|
|
10
|
+
|
|
8
11
|
|
|
9
12
|
def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
|
|
10
13
|
func, args = handle
|
|
@@ -53,51 +56,107 @@ def update_weights_from_ipc(
|
|
|
53
56
|
socket = zmq_ctx.socket(zmq.REP)
|
|
54
57
|
socket.connect(zmq_handle)
|
|
55
58
|
buffer: torch.Tensor | None = None
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
torch.cuda.synchronize()
|
|
63
|
-
socket.send(b"")
|
|
64
|
-
break
|
|
65
|
-
if isinstance(payload, tuple):
|
|
66
|
-
# an ipc handle that vLLM can use `func, args = handle`
|
|
67
|
-
# and `func(*args)` to rebuild GPU tensor.
|
|
68
|
-
buffer = _rebuild_ipc(payload, device_id)
|
|
69
|
-
assert buffer.dtype == torch.uint8
|
|
70
|
-
socket.send(b"")
|
|
71
|
-
continue
|
|
72
|
-
assert isinstance(payload, list)
|
|
73
|
-
run(_extract_weights(payload, buffer))
|
|
74
|
-
torch.cuda.synchronize()
|
|
59
|
+
device_manager = DeviceManager()
|
|
60
|
+
try:
|
|
61
|
+
ipc_handle: tuple[Callable, tuple] = socket.recv_pyobj()
|
|
62
|
+
assert isinstance(ipc_handle, tuple)
|
|
63
|
+
buffer = _rebuild_ipc(ipc_handle, device_id)
|
|
64
|
+
assert buffer.dtype == torch.uint8
|
|
75
65
|
socket.send(b"")
|
|
66
|
+
except Exception as e:
|
|
67
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
68
|
+
socket.send_string(msg)
|
|
69
|
+
socket.recv() # wait for ack
|
|
70
|
+
raise
|
|
71
|
+
try:
|
|
72
|
+
while True:
|
|
73
|
+
payload: list[FlattenedTensorMetadata] | Exception | None = socket.recv_pyobj()
|
|
74
|
+
if payload is None: # done signal
|
|
75
|
+
if post_hook is not None:
|
|
76
|
+
post_hook()
|
|
77
|
+
device_manager.device_module.synchronize()
|
|
78
|
+
socket.send(b"")
|
|
79
|
+
break
|
|
80
|
+
if isinstance(payload, list): # still updating weights
|
|
81
|
+
try:
|
|
82
|
+
run(_extract_weights(payload, buffer))
|
|
83
|
+
device_manager.device_module.synchronize()
|
|
84
|
+
socket.send(b"")
|
|
85
|
+
except Exception as e: # noqa: BLE001
|
|
86
|
+
# Send exception back to Parameter Server.
|
|
87
|
+
# Don't raise here. Because all workers should quit in the same way by receiving the exception from PS
|
|
88
|
+
msg = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
89
|
+
socket.send_string(msg)
|
|
90
|
+
elif isinstance(
|
|
91
|
+
payload, Exception
|
|
92
|
+
): # error occurred, got force quit signal from Parameter Server
|
|
93
|
+
raise payload
|
|
94
|
+
else:
|
|
95
|
+
raise TypeError(f"Unexpected payload type: {type(payload)}")
|
|
76
96
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
97
|
+
finally:
|
|
98
|
+
socket.close()
|
|
99
|
+
del buffer
|
|
100
|
+
gc.collect()
|
|
101
|
+
device_manager.device_module.empty_cache()
|
|
81
102
|
|
|
82
103
|
|
|
83
104
|
class VllmColocateWorkerExtension:
|
|
84
105
|
"""
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
106
|
+
Worker extension for vLLM to update weights from checkpoint-engine.
|
|
107
|
+
|
|
108
|
+
This class provides a worker extension mechanism that allows vLLM workers to receive
|
|
109
|
+
and apply weight updates from the checkpoint-engine via IPC (Inter-Process Communication).
|
|
110
|
+
The methods in this worker extension will be injected into the vLLM worker class and
|
|
111
|
+
are callable from the `collective_rpc` API, enabling seamless weight updates for both
|
|
112
|
+
vLLM V0 and V1 versions.
|
|
113
|
+
|
|
114
|
+
Note:
|
|
115
|
+
This class is defined in a separate module. The fully qualified name
|
|
116
|
+
`checkpoint_engine.worker.VllmColocateWorkerExtension` should be passed as the
|
|
117
|
+
`worker_extension_cls` argument when initializing the vLLM worker.
|
|
91
118
|
"""
|
|
92
119
|
|
|
93
120
|
def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
|
|
121
|
+
"""
|
|
122
|
+
Update model weights from checkpoint-engine via IPC communication.
|
|
123
|
+
|
|
124
|
+
This method establishes a ZMQ connection to the checkpoint-engine and receives
|
|
125
|
+
weight updates through a shared memory buffer. The update process includes:
|
|
126
|
+
1. Receiving IPC handles to reconstruct shared memory tensors
|
|
127
|
+
2. Extracting flattened metadata describing tensor weights in the shared memory tensor
|
|
128
|
+
3. Loading weights into the model
|
|
129
|
+
4. Post-processing weights after loading
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
zmq_handles: A dictionary mapping device UUIDs to ZMQ socket handles.
|
|
133
|
+
The device UUID is platform-specific:
|
|
134
|
+
- For CUDA: UUID from `current_platform.get_device_uuid()`
|
|
135
|
+
- For NPU: Format "NPU-{generated_uuid}"
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If the device type is not supported (not CUDA or NPU).
|
|
139
|
+
AssertionError: If the device is not properly initialized.
|
|
140
|
+
|
|
141
|
+
Note:
|
|
142
|
+
This method is called by vLLM's collective RPC mechanism. The ZMQ context
|
|
143
|
+
is lazily initialized on first call and reused for subsequent updates.
|
|
144
|
+
"""
|
|
94
145
|
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
|
95
146
|
from vllm.platforms import current_platform
|
|
96
147
|
|
|
148
|
+
# vllm-ascend not init device
|
|
149
|
+
if current_platform.device_type == "npu" and self.device is None:
|
|
150
|
+
self.device = torch.device(f"npu:{self.local_rank}")
|
|
97
151
|
assert self.device is not None
|
|
98
152
|
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
|
|
99
153
|
self._zmq_ctx = zmq.Context()
|
|
100
|
-
|
|
154
|
+
if current_platform.device_type == "cuda":
|
|
155
|
+
device_uuid = current_platform.get_device_uuid(self.device.index)
|
|
156
|
+
elif current_platform.device_type == "npu":
|
|
157
|
+
device_uuid = f"NPU-{npu_generate_uuid()}"
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError(f"Unsupported device type: {current_platform.device_type}")
|
|
101
160
|
update_weights_from_ipc(
|
|
102
161
|
self._zmq_ctx,
|
|
103
162
|
zmq_handles[device_uuid],
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: checkpoint-engine
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: checkpoint-engine is a lightweight, decoupling and efficient weight update middleware
|
|
5
5
|
Project-URL: Homepage, https://github.com/MoonshotAI/checkpoint-engine
|
|
6
6
|
Project-URL: Repository, https://github.com/MoonshotAI/checkpoint-engine
|
|
@@ -99,17 +99,15 @@ Use the flexible P2P implementation, notice this will install `mooncake-transfer
|
|
|
99
99
|
pip install 'checkpoint-engine[p2p]'
|
|
100
100
|
```
|
|
101
101
|
|
|
102
|
-
If set `NCCL_IB_HCA` env, checkpoint-engine will use it to auto select net devices for different ranks. Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If not set, it will read all RDMA devices and try to divide them into each rank.
|
|
103
|
-
|
|
104
102
|
## Getting Started
|
|
105
103
|
|
|
106
|
-
Prepare an H800 or H20 machine with 8 GPUs with
|
|
104
|
+
Prepare an H800 or H20 machine with 8 GPUs with vLLM. Be sure to include [/collective_rpc API endpoint](https://github.com/vllm-project/vllm/commit/f7cf5b512ee41f36613deb2471a44de5f304f70d) commit (available in main branch) since checkpoint-engine will use this endpoint to update weights. vLLM version `v0.10.2` is fully tested and recommended.
|
|
107
105
|
|
|
108
106
|
```Bash
|
|
109
|
-
|
|
107
|
+
mkdir -p /opt/vLLM && cd /opt/vLLM
|
|
110
108
|
uv venv --python 3.12 --seed
|
|
111
109
|
source .venv/bin/activate
|
|
112
|
-
|
|
110
|
+
uv pip install vllm==0.10.2
|
|
113
111
|
```
|
|
114
112
|
|
|
115
113
|
Install checkpoint-engine
|
|
@@ -169,13 +167,68 @@ A [PR](https://github.com/vllm-project/vllm/pull/24488) is opened to the vLLM pr
|
|
|
169
167
|
Run a simple correctness test for checkpoint_engine
|
|
170
168
|
|
|
171
169
|
```bash
|
|
172
|
-
|
|
170
|
+
pytest tests/test_update.py
|
|
173
171
|
```
|
|
174
172
|
|
|
175
|
-
|
|
173
|
+
`test_update.py` are only designed to run with `pytest`. Please don't run it directly with `torchrun`.
|
|
174
|
+
|
|
175
|
+
Other unit tests can also be done with pytest. Only test_update.py requires GPUs, other tests can be run on CPUs. Only to run CPU tests, use:
|
|
176
|
+
|
|
177
|
+
```bash
|
|
178
|
+
pytest tests/ -m "not gpu"
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
### Environment Variables
|
|
182
|
+
- `PS_MAX_BUCKET_SIZE_GB`: An integer is used to set the maximum bucket size for checkpoint-engine. If not set, 8GB is used as default.
|
|
183
|
+
- `PS_P2P_STORE_RDMA_DEVICES`: Comma-separated RDMA devices' names for P2P transfer. If not set, checkpoint-engine will fall back to use `NCCL_IB_HCA` to detect RDMA devices.
|
|
184
|
+
- `NCCL_IB_HCA`: Available patterns can be found from [NCCL documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8). If also not set, all RDMA devices will be used and divided evenly among the ranks.
|
|
185
|
+
|
|
186
|
+
## SGLang Integration
|
|
187
|
+
|
|
188
|
+
Checkpoint Engine provides efficient distributed checkpoint loading for SGLang inference servers, significantly reducing model loading time for large models and multi-node setups.
|
|
189
|
+
|
|
190
|
+
### Quick Start
|
|
191
|
+
|
|
192
|
+
**1. Install checkpoint-engine:**
|
|
193
|
+
```bash
|
|
194
|
+
pip install 'checkpoint-engine[p2p]'
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
**2. Launch SGLang server:**
|
|
198
|
+
```bash
|
|
199
|
+
python -m sglang.launch_server \
|
|
200
|
+
--model-path $MODEL_PATH \
|
|
201
|
+
--tp 8 \
|
|
202
|
+
--load-format dummy \
|
|
203
|
+
--wait-for-initial-weights
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
**3. Run checkpoint engine:**
|
|
207
|
+
```bash
|
|
208
|
+
python -m sglang.srt.checkpoint_engine.update \
|
|
209
|
+
--update-method broadcast \
|
|
210
|
+
--checkpoint-path $MODEL_PATH \
|
|
211
|
+
--inference-parallel-size 8
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
### Multi-Node Setup
|
|
215
|
+
|
|
216
|
+
For 2-node setup, run the same commands on both nodes with appropriate `--host` and distributed training parameters.
|
|
217
|
+
|
|
218
|
+
### Key Options
|
|
219
|
+
|
|
220
|
+
**SGLang Server:**
|
|
221
|
+
- `--wait-for-initial-weights`: Wait for checkpoint engine before becoming ready
|
|
222
|
+
- `--load-format dummy`: Enable overlapping initialization tasks
|
|
223
|
+
|
|
224
|
+
**Checkpoint Engine:**
|
|
225
|
+
- `--update-method`: Choose `broadcast`, `p2p`, or `all`
|
|
226
|
+
- `--inference-parallel-size`: Number of parallel processes
|
|
227
|
+
- `--checkpoint-path`: Model checkpoint directory
|
|
228
|
+
|
|
176
229
|
## Limitations and Future Work
|
|
177
230
|
|
|
178
|
-
- This project is currently
|
|
231
|
+
- This project is currently tested with vLLM and SGLang. Integration with other frameworks is planned for future releases.
|
|
179
232
|
- The perfect three-stage pipeline mentioned in our paper is currently not implemented. This could be useful for architectures where H2D and broadcast do not conflict in PCIE.
|
|
180
233
|
|
|
181
234
|
## Acknowledgments
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
+
checkpoint_engine/_version.py,sha256=o3ZTescp-19Z9cvBGq9dQnbppljgzdUYUf98Nov0spY,704
|
|
3
|
+
checkpoint_engine/device_utils.py,sha256=iKrof60j3CY3fStRTq3DRTt_kE1vYoEWHhAeyh0lByA,3020
|
|
4
|
+
checkpoint_engine/ps.py,sha256=cu8Qp5daY1iL30iN69jXP4grlHoAKILblngcKQPA5Bg,67692
|
|
5
|
+
checkpoint_engine/worker.py,sha256=f6kS1ushIXxkRCEHXM5wVofUer9OxRiVY03vmKYLzgo,6757
|
|
6
|
+
checkpoint_engine-0.2.2.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
7
|
+
checkpoint_engine-0.2.2.dist-info/METADATA,sha256=_bBxy27d0GMc7KzuIBAdw-Lno3-UrVLUhH63YDbY1YA,11559
|
|
8
|
+
checkpoint_engine-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
checkpoint_engine-0.2.2.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
10
|
+
checkpoint_engine-0.2.2.dist-info/RECORD,,
|
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
checkpoint_engine/__init__.py,sha256=Zj4I008kn9R6fYr0lVBzcQSnvckLpX2s1ljCOOqV1c8,87
|
|
2
|
-
checkpoint_engine/_version.py,sha256=Dg8AmJomLVpjKL6prJylOONZAPRtB86LOce7dorQS_A,704
|
|
3
|
-
checkpoint_engine/ps.py,sha256=OpGocqJv0TfGgVC1cPKARfz6qehfCLMzQ5KpDQNxb0o,55291
|
|
4
|
-
checkpoint_engine/worker.py,sha256=ZmJTHeNPbnE8sPInfrghj9jeRDkMUSQO906o1UoJv-E,3748
|
|
5
|
-
checkpoint_engine-0.2.0.dist-info/licenses/LICENCE,sha256=D3gPmHKpGtF1yxYNhqjtBtZY_brZjDotJTzpnmClzlY,1067
|
|
6
|
-
checkpoint_engine-0.2.0.dist-info/METADATA,sha256=tbAq45YlRvRAfQHDB0XV8w4ZP0zmVJ3RMTAx_wTm154,9896
|
|
7
|
-
checkpoint_engine-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
8
|
-
checkpoint_engine-0.2.0.dist-info/top_level.txt,sha256=66sik_1eLakLYmcllOEJzFaNbSfjsueuP0tHYEzhMSs,18
|
|
9
|
-
checkpoint_engine-0.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|