modelexpress 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelexpress/__init__.py +60 -0
- modelexpress/client.py +168 -0
- modelexpress/gds_loader.py +332 -0
- modelexpress/gds_transfer.py +238 -0
- modelexpress/heartbeat.py +142 -0
- modelexpress/nixl_transfer.py +652 -0
- modelexpress/p2p_pb2.py +79 -0
- modelexpress/p2p_pb2_grpc.py +341 -0
- modelexpress/transfer_safety.py +228 -0
- modelexpress/types.py +46 -0
- modelexpress/vllm_loader.py +1176 -0
- modelexpress/vllm_worker.py +27 -0
- modelexpress/worker_server.py +108 -0
- modelexpress-0.3.0.dist-info/METADATA +152 -0
- modelexpress-0.3.0.dist-info/RECORD +18 -0
- modelexpress-0.3.0.dist-info/WHEEL +5 -0
- modelexpress-0.3.0.dist-info/entry_points.txt +2 -0
- modelexpress-0.3.0.dist-info/top_level.txt +1 -0
modelexpress/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
ModelExpress - High-performance GPU-to-GPU model weight transfers.
|
|
6
|
+
|
|
7
|
+
This package provides:
|
|
8
|
+
- NIXL-based RDMA transfers for GPU tensors
|
|
9
|
+
- GPUDirect Storage (GDS) for direct file-to-GPU loading
|
|
10
|
+
- vLLM worker extension for serving model weights
|
|
11
|
+
- Custom model loaders for FP8 model support (DeepSeek-V3, etc.)
|
|
12
|
+
|
|
13
|
+
Quick Start:
|
|
14
|
+
# For FP8 models (DeepSeek-V3), use custom loaders:
|
|
15
|
+
from modelexpress import register_modelexpress_loaders
|
|
16
|
+
register_modelexpress_loaders()
|
|
17
|
+
|
|
18
|
+
# vllm serve model --load-format mx
|
|
19
|
+
# Auto-detects: RDMA -> GDS -> disk
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import logging
|
|
23
|
+
|
|
24
|
+
_logger = logging.getLogger(__name__)
|
|
25
|
+
_loaders_registered = False
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def register_modelexpress_loaders():
|
|
29
|
+
"""
|
|
30
|
+
Register ModelExpress loaders with vLLM.
|
|
31
|
+
|
|
32
|
+
This function ensures loaders are registered exactly once. It can be called
|
|
33
|
+
multiple times safely (idempotent).
|
|
34
|
+
|
|
35
|
+
Enables:
|
|
36
|
+
--load-format mx (auto-detect: RDMA -> GDS -> disk)
|
|
37
|
+
"""
|
|
38
|
+
global _loaders_registered
|
|
39
|
+
if _loaders_registered:
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
# Import triggers @register_model_loader decorators on the classes
|
|
43
|
+
from . import vllm_loader # noqa: F401
|
|
44
|
+
|
|
45
|
+
_loaders_registered = True
|
|
46
|
+
_logger.debug("ModelExpress loader registered: mx")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
from .client import MxClient # noqa: F401
|
|
50
|
+
from .gds_loader import MxGdsLoader # noqa: F401
|
|
51
|
+
from .gds_transfer import GdsTransferManager # noqa: F401
|
|
52
|
+
from .heartbeat import HeartbeatThread # noqa: F401
|
|
53
|
+
|
|
54
|
+
__all__ = [
|
|
55
|
+
"GdsTransferManager",
|
|
56
|
+
"HeartbeatThread",
|
|
57
|
+
"MxClient",
|
|
58
|
+
"MxGdsLoader",
|
|
59
|
+
"register_modelexpress_loaders",
|
|
60
|
+
]
|
modelexpress/client.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
ModelExpress Client for P2P GPU Weight Transfers.
|
|
6
|
+
|
|
7
|
+
Orchestrates NIXL/RDMA transfers between vLLM workers. The client fetches
|
|
8
|
+
NIXL metadata from workers via ZMQ, queries the ModelExpress server for
|
|
9
|
+
existing sources, and instructs workers to receive weights if found.
|
|
10
|
+
|
|
11
|
+
NIXL agents live in vLLM workers (not here) because GPU memory must be
|
|
12
|
+
registered by the owning process for GPUDirect RDMA.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
|
|
18
|
+
import grpc
|
|
19
|
+
|
|
20
|
+
from . import p2p_pb2
|
|
21
|
+
from . import p2p_pb2_grpc
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger("modelexpress.client")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _parse_server_address(address: str) -> str:
|
|
27
|
+
"""Strip http:// or https:// prefix from server address for gRPC."""
|
|
28
|
+
if address.startswith("http://"):
|
|
29
|
+
return address[7:]
|
|
30
|
+
elif address.startswith("https://"):
|
|
31
|
+
return address[8:]
|
|
32
|
+
return address
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _get_server_url(explicit_url: str | None = None) -> str:
|
|
36
|
+
"""
|
|
37
|
+
Resolve the ModelExpress server URL.
|
|
38
|
+
|
|
39
|
+
Priority:
|
|
40
|
+
1. Explicit ``server_url`` argument
|
|
41
|
+
2. ``MODEL_EXPRESS_URL`` env var (Dynamo-consistent)
|
|
42
|
+
3. ``MX_SERVER_ADDRESS`` env var (backward compat)
|
|
43
|
+
4. Default ``localhost:8001``
|
|
44
|
+
"""
|
|
45
|
+
if explicit_url:
|
|
46
|
+
return _parse_server_address(explicit_url)
|
|
47
|
+
url = os.environ.get(
|
|
48
|
+
"MODEL_EXPRESS_URL",
|
|
49
|
+
os.environ.get("MX_SERVER_ADDRESS", "localhost:8001"),
|
|
50
|
+
)
|
|
51
|
+
return _parse_server_address(url)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class MxClient:
|
|
55
|
+
"""
|
|
56
|
+
Lightweight gRPC client for ModelExpress server communication.
|
|
57
|
+
|
|
58
|
+
Provides typed methods for every P2P RPC (``PublishMetadata``,
|
|
59
|
+
``ListSources``, ``GetMetadata``, ``UpdateStatus``) so that callers
|
|
60
|
+
(loaders, coordinators) never need to create gRPC channels or
|
|
61
|
+
stubs directly.
|
|
62
|
+
|
|
63
|
+
The connection is created lazily on first use.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
server_url: Explicit server address (``host:port``). When
|
|
67
|
+
*None* the address is resolved via ``MODEL_EXPRESS_URL``
|
|
68
|
+
or ``MX_SERVER_ADDRESS`` env vars, falling back to
|
|
69
|
+
``localhost:8001``.
|
|
70
|
+
max_message_size: Max send/receive message size in bytes.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
server_url: str | None = None,
|
|
76
|
+
max_message_size: int = 100 * 1024 * 1024, # 100 MB
|
|
77
|
+
):
|
|
78
|
+
self.server_url = _get_server_url(server_url)
|
|
79
|
+
self._max_message_size = max_message_size
|
|
80
|
+
self._channel: grpc.Channel | None = None
|
|
81
|
+
self._stub: p2p_pb2_grpc.P2pServiceStub | None = None
|
|
82
|
+
|
|
83
|
+
# -- connection management ------------------------------------------------
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def stub(self) -> p2p_pb2_grpc.P2pServiceStub:
|
|
87
|
+
"""Return (and lazily create) the gRPC stub."""
|
|
88
|
+
if self._channel is None:
|
|
89
|
+
options = [
|
|
90
|
+
("grpc.max_send_message_length", self._max_message_size),
|
|
91
|
+
("grpc.max_receive_message_length", self._max_message_size),
|
|
92
|
+
]
|
|
93
|
+
self._channel = grpc.insecure_channel(self.server_url, options=options)
|
|
94
|
+
self._stub = p2p_pb2_grpc.P2pServiceStub(self._channel)
|
|
95
|
+
logger.debug("MxClient connected to %s", self.server_url)
|
|
96
|
+
return self._stub
|
|
97
|
+
|
|
98
|
+
def close(self) -> None:
|
|
99
|
+
"""Close the underlying gRPC channel."""
|
|
100
|
+
if self._channel is not None:
|
|
101
|
+
self._channel.close()
|
|
102
|
+
self._channel = None
|
|
103
|
+
self._stub = None
|
|
104
|
+
|
|
105
|
+
# -- RPC wrappers ---------------------------------------------------------
|
|
106
|
+
|
|
107
|
+
def publish_metadata(
|
|
108
|
+
self,
|
|
109
|
+
identity: "p2p_pb2.SourceIdentity",
|
|
110
|
+
worker: "p2p_pb2.WorkerMetadata",
|
|
111
|
+
worker_id: str,
|
|
112
|
+
) -> str:
|
|
113
|
+
"""Publish metadata for one worker so targets can discover this source.
|
|
114
|
+
|
|
115
|
+
Returns the *mx_source_id* (16-char hex) on success, raises on failure.
|
|
116
|
+
"""
|
|
117
|
+
request = p2p_pb2.PublishMetadataRequest(
|
|
118
|
+
identity=identity,
|
|
119
|
+
worker=worker,
|
|
120
|
+
worker_id=worker_id,
|
|
121
|
+
)
|
|
122
|
+
response = self.stub.PublishMetadata(request, timeout=30)
|
|
123
|
+
if not response.success:
|
|
124
|
+
raise RuntimeError(f"PublishMetadata failed: {response.message}")
|
|
125
|
+
return response.mx_source_id
|
|
126
|
+
|
|
127
|
+
def list_sources(
|
|
128
|
+
self,
|
|
129
|
+
identity: "p2p_pb2.SourceIdentity | None" = None,
|
|
130
|
+
status_filter: "p2p_pb2.SourceStatus | None" = None,
|
|
131
|
+
) -> "p2p_pb2.ListSourcesResponse":
|
|
132
|
+
"""List available source workers, optionally filtered by identity and status."""
|
|
133
|
+
request = p2p_pb2.ListSourcesRequest(
|
|
134
|
+
identity=identity,
|
|
135
|
+
status_filter=status_filter,
|
|
136
|
+
)
|
|
137
|
+
return self.stub.ListSources(request, timeout=30)
|
|
138
|
+
|
|
139
|
+
def get_metadata(
|
|
140
|
+
self,
|
|
141
|
+
mx_source_id: str,
|
|
142
|
+
worker_id: str,
|
|
143
|
+
) -> "p2p_pb2.GetMetadataResponse":
|
|
144
|
+
"""Fetch full tensor metadata for one specific worker."""
|
|
145
|
+
request = p2p_pb2.GetMetadataRequest(
|
|
146
|
+
mx_source_id=mx_source_id,
|
|
147
|
+
worker_id=worker_id,
|
|
148
|
+
)
|
|
149
|
+
return self.stub.GetMetadata(request, timeout=30)
|
|
150
|
+
|
|
151
|
+
def update_status(
|
|
152
|
+
self,
|
|
153
|
+
mx_source_id: str,
|
|
154
|
+
worker_id: str,
|
|
155
|
+
worker_rank: int,
|
|
156
|
+
status: "p2p_pb2.SourceStatus",
|
|
157
|
+
) -> bool:
|
|
158
|
+
"""Update worker status. Returns *True* on success."""
|
|
159
|
+
request = p2p_pb2.UpdateStatusRequest(
|
|
160
|
+
mx_source_id=mx_source_id,
|
|
161
|
+
worker_id=worker_id,
|
|
162
|
+
worker_rank=worker_rank,
|
|
163
|
+
status=status,
|
|
164
|
+
)
|
|
165
|
+
response = self.stub.UpdateStatus(request, timeout=30)
|
|
166
|
+
if not response.success:
|
|
167
|
+
logger.error("UpdateStatus failed: %s", response.message)
|
|
168
|
+
return response.success
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Framework-agnostic GDS model loader.
|
|
6
|
+
|
|
7
|
+
Loads model weights from safetensors files directly to GPU memory via NIXL's
|
|
8
|
+
GDS (GPUDirect Storage) backend, bypassing CPU bounce buffers entirely.
|
|
9
|
+
|
|
10
|
+
The target GPU is determined from torch.cuda.current_device(), matching
|
|
11
|
+
the behavior of vLLM/sglang default loaders.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
import struct
|
|
20
|
+
import time
|
|
21
|
+
import uuid
|
|
22
|
+
from collections import defaultdict
|
|
23
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Iterator
|
|
26
|
+
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
from .gds_transfer import GdsTransferManager, is_gds_available
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger("modelexpress.gds_loader")
|
|
32
|
+
|
|
33
|
+
# Complete dtype mapping from the safetensors spec:
|
|
34
|
+
# https://huggingface.co/docs/safetensors/metadata_parsing#accepted-dtypes
|
|
35
|
+
SAFETENSORS_DTYPE_MAP: dict[str, torch.dtype] = {
|
|
36
|
+
"F64": torch.float64,
|
|
37
|
+
"F32": torch.float32,
|
|
38
|
+
"F16": torch.float16,
|
|
39
|
+
"BF16": torch.bfloat16,
|
|
40
|
+
"F8_E4M3": torch.float8_e4m3fn,
|
|
41
|
+
"F8_E5M2": torch.float8_e5m2,
|
|
42
|
+
"I64": torch.int64,
|
|
43
|
+
"I32": torch.int32,
|
|
44
|
+
"I16": torch.int16,
|
|
45
|
+
"I8": torch.int8,
|
|
46
|
+
"U8": torch.uint8,
|
|
47
|
+
"BOOL": torch.bool,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class MxGdsLoader:
|
|
52
|
+
"""
|
|
53
|
+
Load model weights from safetensors files directly to GPU via GDS.
|
|
54
|
+
|
|
55
|
+
Framework-agnostic. Can be used from vLLM, sglang, or standalone.
|
|
56
|
+
|
|
57
|
+
Usage::
|
|
58
|
+
|
|
59
|
+
loader = MxGdsLoader()
|
|
60
|
+
tensors = loader.load("/path/to/model")
|
|
61
|
+
|
|
62
|
+
# Or stream per-file:
|
|
63
|
+
for name, tensor in loader.load_iter("/path/to/model"):
|
|
64
|
+
process(name, tensor)
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self):
|
|
68
|
+
self._gds_manager: GdsTransferManager | None = None
|
|
69
|
+
self._device_id: int | None = None
|
|
70
|
+
|
|
71
|
+
# ------------------------------------------------------------------
|
|
72
|
+
# Public API
|
|
73
|
+
# ------------------------------------------------------------------
|
|
74
|
+
|
|
75
|
+
def load(self, model_path: str) -> dict[str, torch.Tensor]:
|
|
76
|
+
"""Load all tensors from model_path to GPU."""
|
|
77
|
+
result: dict[str, torch.Tensor] = {}
|
|
78
|
+
for name, tensor in self.load_iter(model_path):
|
|
79
|
+
result[name] = tensor
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
def load_iter(
|
|
83
|
+
self,
|
|
84
|
+
model_path: str,
|
|
85
|
+
*,
|
|
86
|
+
use_tqdm: bool = True,
|
|
87
|
+
revision: str | None = None,
|
|
88
|
+
) -> Iterator[tuple[str, torch.Tensor]]:
|
|
89
|
+
"""
|
|
90
|
+
Yield (tensor_name, gpu_tensor) pairs loaded via GDS.
|
|
91
|
+
|
|
92
|
+
Each safetensors file is batch-loaded through a single GDS
|
|
93
|
+
transfer, then its tensors are yielded one by one.
|
|
94
|
+
"""
|
|
95
|
+
load_start = time.perf_counter()
|
|
96
|
+
model_path = self._resolve_model_path(model_path, revision=revision)
|
|
97
|
+
|
|
98
|
+
if not is_gds_available():
|
|
99
|
+
raise RuntimeError(
|
|
100
|
+
"GDS is not available. Check nvidia_fs module and libcufile."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
self._device_id = torch.cuda.current_device()
|
|
104
|
+
self._ensure_gds_manager()
|
|
105
|
+
|
|
106
|
+
file_tensor_map = self._resolve_safetensors_files(model_path)
|
|
107
|
+
|
|
108
|
+
file_jobs = []
|
|
109
|
+
for file_path, tensor_names in file_tensor_map.items():
|
|
110
|
+
header_info = self._parse_safetensors_header(file_path)
|
|
111
|
+
file_tensors = {
|
|
112
|
+
name: header_info[name]
|
|
113
|
+
for name in tensor_names
|
|
114
|
+
if name in header_info
|
|
115
|
+
}
|
|
116
|
+
if file_tensors:
|
|
117
|
+
file_jobs.append((file_path, file_tensors))
|
|
118
|
+
|
|
119
|
+
if not file_jobs:
|
|
120
|
+
return
|
|
121
|
+
|
|
122
|
+
# Prefetch pipeline: load file[i+1] while yielding file[i]
|
|
123
|
+
total_files = len(file_jobs)
|
|
124
|
+
pbar = None
|
|
125
|
+
if use_tqdm:
|
|
126
|
+
from tqdm import tqdm
|
|
127
|
+
pbar = tqdm(
|
|
128
|
+
total=total_files,
|
|
129
|
+
desc="Loading safetensors via GDS",
|
|
130
|
+
unit="file",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
pool = ThreadPoolExecutor(max_workers=1)
|
|
134
|
+
try:
|
|
135
|
+
pending = pool.submit(self._load_file_tensors, *file_jobs[0])
|
|
136
|
+
|
|
137
|
+
for i in range(total_files):
|
|
138
|
+
loaded = pending.result()
|
|
139
|
+
if pbar is not None:
|
|
140
|
+
pbar.update(1)
|
|
141
|
+
|
|
142
|
+
if i + 1 < total_files:
|
|
143
|
+
pending = pool.submit(
|
|
144
|
+
self._load_file_tensors, *file_jobs[i + 1]
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
for name, tensor in loaded.items():
|
|
148
|
+
yield name, tensor
|
|
149
|
+
|
|
150
|
+
logger.info("GDS load complete in %.2fs", time.perf_counter() - load_start)
|
|
151
|
+
finally:
|
|
152
|
+
if pbar is not None:
|
|
153
|
+
pbar.close()
|
|
154
|
+
pool.shutdown(wait=True)
|
|
155
|
+
|
|
156
|
+
# ------------------------------------------------------------------
|
|
157
|
+
# Internal helpers
|
|
158
|
+
# ------------------------------------------------------------------
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def _resolve_model_path(
|
|
162
|
+
model_path: str, revision: str | None = None
|
|
163
|
+
) -> str:
|
|
164
|
+
"""Resolve model_path to a local directory."""
|
|
165
|
+
p = Path(model_path)
|
|
166
|
+
if p.is_dir():
|
|
167
|
+
return str(p.resolve())
|
|
168
|
+
|
|
169
|
+
from huggingface_hub import snapshot_download
|
|
170
|
+
local_dir = snapshot_download(model_path, revision=revision)
|
|
171
|
+
logger.info("Resolved HF model '%s' -> %s", model_path, local_dir)
|
|
172
|
+
return local_dir
|
|
173
|
+
|
|
174
|
+
def _ensure_gds_manager(self) -> None:
|
|
175
|
+
"""Lazily create and initialize the GDS transfer manager."""
|
|
176
|
+
if self._gds_manager is not None:
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
agent_name = f"mx-gds-{self._device_id}-{uuid.uuid4().hex[:8]}"
|
|
180
|
+
self._gds_manager = GdsTransferManager(agent_name=agent_name)
|
|
181
|
+
self._gds_manager.initialize()
|
|
182
|
+
logger.info("GDS manager initialized for device %d", self._device_id)
|
|
183
|
+
|
|
184
|
+
def _resolve_safetensors_files(
|
|
185
|
+
self, model_path: str
|
|
186
|
+
) -> dict[str, list[str]]:
|
|
187
|
+
"""
|
|
188
|
+
Discover safetensors files and map each to its tensor names.
|
|
189
|
+
|
|
190
|
+
Supports sharded (index.json) and single-file layouts.
|
|
191
|
+
"""
|
|
192
|
+
model_dir = Path(model_path)
|
|
193
|
+
|
|
194
|
+
# Try sharded index first
|
|
195
|
+
index_path = model_dir / "model.safetensors.index.json"
|
|
196
|
+
if index_path.exists():
|
|
197
|
+
with open(index_path, "r") as f:
|
|
198
|
+
index = json.load(f)
|
|
199
|
+
|
|
200
|
+
weight_map: dict[str, str] = index.get("weight_map", {})
|
|
201
|
+
if not weight_map:
|
|
202
|
+
raise RuntimeError(f"Empty weight_map in {index_path}")
|
|
203
|
+
|
|
204
|
+
file_tensors: dict[str, list[str]] = defaultdict(list)
|
|
205
|
+
for tensor_name, filename in weight_map.items():
|
|
206
|
+
abs_path = str(model_dir / filename)
|
|
207
|
+
file_tensors[abs_path].append(tensor_name)
|
|
208
|
+
|
|
209
|
+
logger.info(
|
|
210
|
+
"Found sharded model: %d files, %d tensors",
|
|
211
|
+
len(file_tensors), len(weight_map),
|
|
212
|
+
)
|
|
213
|
+
return dict(file_tensors)
|
|
214
|
+
|
|
215
|
+
# Try single file
|
|
216
|
+
single_path = model_dir / "model.safetensors"
|
|
217
|
+
if single_path.exists():
|
|
218
|
+
header_info = self._parse_safetensors_header(str(single_path))
|
|
219
|
+
tensor_names = list(header_info.keys())
|
|
220
|
+
logger.info("Found single safetensors file: %d tensors", len(tensor_names))
|
|
221
|
+
return {str(single_path): tensor_names}
|
|
222
|
+
|
|
223
|
+
# Fallback: glob
|
|
224
|
+
st_files = sorted(model_dir.glob("*.safetensors"))
|
|
225
|
+
if not st_files:
|
|
226
|
+
raise FileNotFoundError(f"No .safetensors files found in {model_path}")
|
|
227
|
+
|
|
228
|
+
file_tensors_map: dict[str, list[str]] = {}
|
|
229
|
+
for st_file in st_files:
|
|
230
|
+
header_info = self._parse_safetensors_header(str(st_file))
|
|
231
|
+
file_tensors_map[str(st_file)] = list(header_info.keys())
|
|
232
|
+
|
|
233
|
+
total = sum(len(v) for v in file_tensors_map.values())
|
|
234
|
+
logger.info(
|
|
235
|
+
"Found %d safetensors files via glob: %d tensors",
|
|
236
|
+
len(file_tensors_map), total,
|
|
237
|
+
)
|
|
238
|
+
return file_tensors_map
|
|
239
|
+
|
|
240
|
+
def _parse_safetensors_header(self, file_path: str) -> dict[str, dict]:
|
|
241
|
+
"""
|
|
242
|
+
Parse a safetensors file header without loading tensor data.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
{tensor_name: {"file_offset": int, "size": int, "dtype": str, "shape": list}}
|
|
246
|
+
"""
|
|
247
|
+
with open(file_path, "rb") as f:
|
|
248
|
+
raw = f.read(8)
|
|
249
|
+
if len(raw) < 8:
|
|
250
|
+
raise RuntimeError(f"Invalid safetensors file: {file_path}")
|
|
251
|
+
|
|
252
|
+
header_size = struct.unpack("<Q", raw)[0]
|
|
253
|
+
|
|
254
|
+
if header_size > 100 * 1024 * 1024:
|
|
255
|
+
raise RuntimeError(
|
|
256
|
+
f"Safetensors header too large ({header_size} bytes): {file_path}"
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
header_bytes = f.read(header_size)
|
|
260
|
+
|
|
261
|
+
header = json.loads(header_bytes)
|
|
262
|
+
data_start = 8 + header_size
|
|
263
|
+
|
|
264
|
+
result: dict[str, dict] = {}
|
|
265
|
+
for name, info in header.items():
|
|
266
|
+
if name == "__metadata__":
|
|
267
|
+
continue
|
|
268
|
+
|
|
269
|
+
offsets = info["data_offsets"]
|
|
270
|
+
result[name] = {
|
|
271
|
+
"file_offset": data_start + offsets[0],
|
|
272
|
+
"size": offsets[1] - offsets[0],
|
|
273
|
+
"dtype": info["dtype"],
|
|
274
|
+
"shape": info["shape"],
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
return result
|
|
278
|
+
|
|
279
|
+
def _load_file_tensors(
|
|
280
|
+
self,
|
|
281
|
+
file_path: str,
|
|
282
|
+
tensor_infos: dict[str, dict],
|
|
283
|
+
) -> dict[str, torch.Tensor]:
|
|
284
|
+
"""
|
|
285
|
+
Load all tensors from one safetensors file via GDS.
|
|
286
|
+
|
|
287
|
+
All tensors are submitted in a single batch so GDS_MT threads
|
|
288
|
+
work in parallel. Reads go directly into result tensors.
|
|
289
|
+
"""
|
|
290
|
+
device = torch.device("cuda", self._device_id)
|
|
291
|
+
|
|
292
|
+
sorted_names = sorted(
|
|
293
|
+
tensor_infos.keys(),
|
|
294
|
+
key=lambda n: tensor_infos[n]["file_offset"],
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
tensor_list = []
|
|
298
|
+
tensor_meta = []
|
|
299
|
+
for name in sorted_names:
|
|
300
|
+
info = tensor_infos[name]
|
|
301
|
+
st_dtype = info["dtype"]
|
|
302
|
+
torch_dtype = SAFETENSORS_DTYPE_MAP.get(st_dtype)
|
|
303
|
+
if torch_dtype is None:
|
|
304
|
+
raise RuntimeError(
|
|
305
|
+
f"Unsupported safetensors dtype '{st_dtype}' "
|
|
306
|
+
f"for tensor '{name}'"
|
|
307
|
+
)
|
|
308
|
+
tensor_list.append((info["file_offset"], info["size"]))
|
|
309
|
+
tensor_meta.append((name, torch_dtype, info["shape"]))
|
|
310
|
+
|
|
311
|
+
fd = os.open(file_path, os.O_RDONLY)
|
|
312
|
+
file_size = os.fstat(fd).st_size
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
raw_tensors = self._gds_manager.batch_load_file(
|
|
316
|
+
fd, file_size, tensor_list, device,
|
|
317
|
+
)
|
|
318
|
+
finally:
|
|
319
|
+
os.close(fd)
|
|
320
|
+
|
|
321
|
+
result: dict[str, torch.Tensor] = {}
|
|
322
|
+
for raw, (name, torch_dtype, shape) in zip(raw_tensors, tensor_meta, strict=True):
|
|
323
|
+
result[name] = raw.view(torch_dtype).reshape(shape)
|
|
324
|
+
|
|
325
|
+
logger.info("Loaded %s", Path(file_path).name)
|
|
326
|
+
return result
|
|
327
|
+
|
|
328
|
+
def shutdown(self) -> None:
|
|
329
|
+
"""Release GDS resources."""
|
|
330
|
+
if self._gds_manager is not None:
|
|
331
|
+
self._gds_manager.shutdown()
|
|
332
|
+
self._gds_manager = None
|