modelexpress 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,60 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ ModelExpress - High-performance GPU-to-GPU model weight transfers.
6
+
7
+ This package provides:
8
+ - NIXL-based RDMA transfers for GPU tensors
9
+ - GPUDirect Storage (GDS) for direct file-to-GPU loading
10
+ - vLLM worker extension for serving model weights
11
+ - Custom model loaders for FP8 model support (DeepSeek-V3, etc.)
12
+
13
+ Quick Start:
14
+ # For FP8 models (DeepSeek-V3), use custom loaders:
15
+ from modelexpress import register_modelexpress_loaders
16
+ register_modelexpress_loaders()
17
+
18
+ # vllm serve model --load-format mx
19
+ # Auto-detects: RDMA -> GDS -> disk
20
+ """
21
+
22
+ import logging
23
+
24
+ _logger = logging.getLogger(__name__)
25
+ _loaders_registered = False
26
+
27
+
28
+ def register_modelexpress_loaders():
29
+ """
30
+ Register ModelExpress loaders with vLLM.
31
+
32
+ This function ensures loaders are registered exactly once. It can be called
33
+ multiple times safely (idempotent).
34
+
35
+ Enables:
36
+ --load-format mx (auto-detect: RDMA -> GDS -> disk)
37
+ """
38
+ global _loaders_registered
39
+ if _loaders_registered:
40
+ return
41
+
42
+ # Import triggers @register_model_loader decorators on the classes
43
+ from . import vllm_loader # noqa: F401
44
+
45
+ _loaders_registered = True
46
+ _logger.debug("ModelExpress loader registered: mx")
47
+
48
+
49
+ from .client import MxClient # noqa: F401
50
+ from .gds_loader import MxGdsLoader # noqa: F401
51
+ from .gds_transfer import GdsTransferManager # noqa: F401
52
+ from .heartbeat import HeartbeatThread # noqa: F401
53
+
54
+ __all__ = [
55
+ "GdsTransferManager",
56
+ "HeartbeatThread",
57
+ "MxClient",
58
+ "MxGdsLoader",
59
+ "register_modelexpress_loaders",
60
+ ]
modelexpress/client.py ADDED
@@ -0,0 +1,168 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ ModelExpress Client for P2P GPU Weight Transfers.
6
+
7
+ Orchestrates NIXL/RDMA transfers between vLLM workers. The client fetches
8
+ NIXL metadata from workers via ZMQ, queries the ModelExpress server for
9
+ existing sources, and instructs workers to receive weights if found.
10
+
11
+ NIXL agents live in vLLM workers (not here) because GPU memory must be
12
+ registered by the owning process for GPUDirect RDMA.
13
+ """
14
+
15
+ import logging
16
+ import os
17
+
18
+ import grpc
19
+
20
+ from . import p2p_pb2
21
+ from . import p2p_pb2_grpc
22
+
23
+ logger = logging.getLogger("modelexpress.client")
24
+
25
+
26
+ def _parse_server_address(address: str) -> str:
27
+ """Strip http:// or https:// prefix from server address for gRPC."""
28
+ if address.startswith("http://"):
29
+ return address[7:]
30
+ elif address.startswith("https://"):
31
+ return address[8:]
32
+ return address
33
+
34
+
35
+ def _get_server_url(explicit_url: str | None = None) -> str:
36
+ """
37
+ Resolve the ModelExpress server URL.
38
+
39
+ Priority:
40
+ 1. Explicit ``server_url`` argument
41
+ 2. ``MODEL_EXPRESS_URL`` env var (Dynamo-consistent)
42
+ 3. ``MX_SERVER_ADDRESS`` env var (backward compat)
43
+ 4. Default ``localhost:8001``
44
+ """
45
+ if explicit_url:
46
+ return _parse_server_address(explicit_url)
47
+ url = os.environ.get(
48
+ "MODEL_EXPRESS_URL",
49
+ os.environ.get("MX_SERVER_ADDRESS", "localhost:8001"),
50
+ )
51
+ return _parse_server_address(url)
52
+
53
+
54
+ class MxClient:
55
+ """
56
+ Lightweight gRPC client for ModelExpress server communication.
57
+
58
+ Provides typed methods for every P2P RPC (``PublishMetadata``,
59
+ ``ListSources``, ``GetMetadata``, ``UpdateStatus``) so that callers
60
+ (loaders, coordinators) never need to create gRPC channels or
61
+ stubs directly.
62
+
63
+ The connection is created lazily on first use.
64
+
65
+ Args:
66
+ server_url: Explicit server address (``host:port``). When
67
+ *None* the address is resolved via ``MODEL_EXPRESS_URL``
68
+ or ``MX_SERVER_ADDRESS`` env vars, falling back to
69
+ ``localhost:8001``.
70
+ max_message_size: Max send/receive message size in bytes.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ server_url: str | None = None,
76
+ max_message_size: int = 100 * 1024 * 1024, # 100 MB
77
+ ):
78
+ self.server_url = _get_server_url(server_url)
79
+ self._max_message_size = max_message_size
80
+ self._channel: grpc.Channel | None = None
81
+ self._stub: p2p_pb2_grpc.P2pServiceStub | None = None
82
+
83
+ # -- connection management ------------------------------------------------
84
+
85
+ @property
86
+ def stub(self) -> p2p_pb2_grpc.P2pServiceStub:
87
+ """Return (and lazily create) the gRPC stub."""
88
+ if self._channel is None:
89
+ options = [
90
+ ("grpc.max_send_message_length", self._max_message_size),
91
+ ("grpc.max_receive_message_length", self._max_message_size),
92
+ ]
93
+ self._channel = grpc.insecure_channel(self.server_url, options=options)
94
+ self._stub = p2p_pb2_grpc.P2pServiceStub(self._channel)
95
+ logger.debug("MxClient connected to %s", self.server_url)
96
+ return self._stub
97
+
98
+ def close(self) -> None:
99
+ """Close the underlying gRPC channel."""
100
+ if self._channel is not None:
101
+ self._channel.close()
102
+ self._channel = None
103
+ self._stub = None
104
+
105
+ # -- RPC wrappers ---------------------------------------------------------
106
+
107
+ def publish_metadata(
108
+ self,
109
+ identity: "p2p_pb2.SourceIdentity",
110
+ worker: "p2p_pb2.WorkerMetadata",
111
+ worker_id: str,
112
+ ) -> str:
113
+ """Publish metadata for one worker so targets can discover this source.
114
+
115
+ Returns the *mx_source_id* (16-char hex) on success, raises on failure.
116
+ """
117
+ request = p2p_pb2.PublishMetadataRequest(
118
+ identity=identity,
119
+ worker=worker,
120
+ worker_id=worker_id,
121
+ )
122
+ response = self.stub.PublishMetadata(request, timeout=30)
123
+ if not response.success:
124
+ raise RuntimeError(f"PublishMetadata failed: {response.message}")
125
+ return response.mx_source_id
126
+
127
+ def list_sources(
128
+ self,
129
+ identity: "p2p_pb2.SourceIdentity | None" = None,
130
+ status_filter: "p2p_pb2.SourceStatus | None" = None,
131
+ ) -> "p2p_pb2.ListSourcesResponse":
132
+ """List available source workers, optionally filtered by identity and status."""
133
+ request = p2p_pb2.ListSourcesRequest(
134
+ identity=identity,
135
+ status_filter=status_filter,
136
+ )
137
+ return self.stub.ListSources(request, timeout=30)
138
+
139
+ def get_metadata(
140
+ self,
141
+ mx_source_id: str,
142
+ worker_id: str,
143
+ ) -> "p2p_pb2.GetMetadataResponse":
144
+ """Fetch full tensor metadata for one specific worker."""
145
+ request = p2p_pb2.GetMetadataRequest(
146
+ mx_source_id=mx_source_id,
147
+ worker_id=worker_id,
148
+ )
149
+ return self.stub.GetMetadata(request, timeout=30)
150
+
151
+ def update_status(
152
+ self,
153
+ mx_source_id: str,
154
+ worker_id: str,
155
+ worker_rank: int,
156
+ status: "p2p_pb2.SourceStatus",
157
+ ) -> bool:
158
+ """Update worker status. Returns *True* on success."""
159
+ request = p2p_pb2.UpdateStatusRequest(
160
+ mx_source_id=mx_source_id,
161
+ worker_id=worker_id,
162
+ worker_rank=worker_rank,
163
+ status=status,
164
+ )
165
+ response = self.stub.UpdateStatus(request, timeout=30)
166
+ if not response.success:
167
+ logger.error("UpdateStatus failed: %s", response.message)
168
+ return response.success
@@ -0,0 +1,332 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Framework-agnostic GDS model loader.
6
+
7
+ Loads model weights from safetensors files directly to GPU memory via NIXL's
8
+ GDS (GPUDirect Storage) backend, bypassing CPU bounce buffers entirely.
9
+
10
+ The target GPU is determined from torch.cuda.current_device(), matching
11
+ the behavior of vLLM/sglang default loaders.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import logging
18
+ import os
19
+ import struct
20
+ import time
21
+ import uuid
22
+ from collections import defaultdict
23
+ from concurrent.futures import ThreadPoolExecutor
24
+ from pathlib import Path
25
+ from typing import Iterator
26
+
27
+ import torch
28
+
29
+ from .gds_transfer import GdsTransferManager, is_gds_available
30
+
31
+ logger = logging.getLogger("modelexpress.gds_loader")
32
+
33
+ # Complete dtype mapping from the safetensors spec:
34
+ # https://huggingface.co/docs/safetensors/metadata_parsing#accepted-dtypes
35
+ SAFETENSORS_DTYPE_MAP: dict[str, torch.dtype] = {
36
+ "F64": torch.float64,
37
+ "F32": torch.float32,
38
+ "F16": torch.float16,
39
+ "BF16": torch.bfloat16,
40
+ "F8_E4M3": torch.float8_e4m3fn,
41
+ "F8_E5M2": torch.float8_e5m2,
42
+ "I64": torch.int64,
43
+ "I32": torch.int32,
44
+ "I16": torch.int16,
45
+ "I8": torch.int8,
46
+ "U8": torch.uint8,
47
+ "BOOL": torch.bool,
48
+ }
49
+
50
+
51
+ class MxGdsLoader:
52
+ """
53
+ Load model weights from safetensors files directly to GPU via GDS.
54
+
55
+ Framework-agnostic. Can be used from vLLM, sglang, or standalone.
56
+
57
+ Usage::
58
+
59
+ loader = MxGdsLoader()
60
+ tensors = loader.load("/path/to/model")
61
+
62
+ # Or stream per-file:
63
+ for name, tensor in loader.load_iter("/path/to/model"):
64
+ process(name, tensor)
65
+ """
66
+
67
+ def __init__(self):
68
+ self._gds_manager: GdsTransferManager | None = None
69
+ self._device_id: int | None = None
70
+
71
+ # ------------------------------------------------------------------
72
+ # Public API
73
+ # ------------------------------------------------------------------
74
+
75
+ def load(self, model_path: str) -> dict[str, torch.Tensor]:
76
+ """Load all tensors from model_path to GPU."""
77
+ result: dict[str, torch.Tensor] = {}
78
+ for name, tensor in self.load_iter(model_path):
79
+ result[name] = tensor
80
+ return result
81
+
82
+ def load_iter(
83
+ self,
84
+ model_path: str,
85
+ *,
86
+ use_tqdm: bool = True,
87
+ revision: str | None = None,
88
+ ) -> Iterator[tuple[str, torch.Tensor]]:
89
+ """
90
+ Yield (tensor_name, gpu_tensor) pairs loaded via GDS.
91
+
92
+ Each safetensors file is batch-loaded through a single GDS
93
+ transfer, then its tensors are yielded one by one.
94
+ """
95
+ load_start = time.perf_counter()
96
+ model_path = self._resolve_model_path(model_path, revision=revision)
97
+
98
+ if not is_gds_available():
99
+ raise RuntimeError(
100
+ "GDS is not available. Check nvidia_fs module and libcufile."
101
+ )
102
+
103
+ self._device_id = torch.cuda.current_device()
104
+ self._ensure_gds_manager()
105
+
106
+ file_tensor_map = self._resolve_safetensors_files(model_path)
107
+
108
+ file_jobs = []
109
+ for file_path, tensor_names in file_tensor_map.items():
110
+ header_info = self._parse_safetensors_header(file_path)
111
+ file_tensors = {
112
+ name: header_info[name]
113
+ for name in tensor_names
114
+ if name in header_info
115
+ }
116
+ if file_tensors:
117
+ file_jobs.append((file_path, file_tensors))
118
+
119
+ if not file_jobs:
120
+ return
121
+
122
+ # Prefetch pipeline: load file[i+1] while yielding file[i]
123
+ total_files = len(file_jobs)
124
+ pbar = None
125
+ if use_tqdm:
126
+ from tqdm import tqdm
127
+ pbar = tqdm(
128
+ total=total_files,
129
+ desc="Loading safetensors via GDS",
130
+ unit="file",
131
+ )
132
+
133
+ pool = ThreadPoolExecutor(max_workers=1)
134
+ try:
135
+ pending = pool.submit(self._load_file_tensors, *file_jobs[0])
136
+
137
+ for i in range(total_files):
138
+ loaded = pending.result()
139
+ if pbar is not None:
140
+ pbar.update(1)
141
+
142
+ if i + 1 < total_files:
143
+ pending = pool.submit(
144
+ self._load_file_tensors, *file_jobs[i + 1]
145
+ )
146
+
147
+ for name, tensor in loaded.items():
148
+ yield name, tensor
149
+
150
+ logger.info("GDS load complete in %.2fs", time.perf_counter() - load_start)
151
+ finally:
152
+ if pbar is not None:
153
+ pbar.close()
154
+ pool.shutdown(wait=True)
155
+
156
+ # ------------------------------------------------------------------
157
+ # Internal helpers
158
+ # ------------------------------------------------------------------
159
+
160
+ @staticmethod
161
+ def _resolve_model_path(
162
+ model_path: str, revision: str | None = None
163
+ ) -> str:
164
+ """Resolve model_path to a local directory."""
165
+ p = Path(model_path)
166
+ if p.is_dir():
167
+ return str(p.resolve())
168
+
169
+ from huggingface_hub import snapshot_download
170
+ local_dir = snapshot_download(model_path, revision=revision)
171
+ logger.info("Resolved HF model '%s' -> %s", model_path, local_dir)
172
+ return local_dir
173
+
174
+ def _ensure_gds_manager(self) -> None:
175
+ """Lazily create and initialize the GDS transfer manager."""
176
+ if self._gds_manager is not None:
177
+ return
178
+
179
+ agent_name = f"mx-gds-{self._device_id}-{uuid.uuid4().hex[:8]}"
180
+ self._gds_manager = GdsTransferManager(agent_name=agent_name)
181
+ self._gds_manager.initialize()
182
+ logger.info("GDS manager initialized for device %d", self._device_id)
183
+
184
+ def _resolve_safetensors_files(
185
+ self, model_path: str
186
+ ) -> dict[str, list[str]]:
187
+ """
188
+ Discover safetensors files and map each to its tensor names.
189
+
190
+ Supports sharded (index.json) and single-file layouts.
191
+ """
192
+ model_dir = Path(model_path)
193
+
194
+ # Try sharded index first
195
+ index_path = model_dir / "model.safetensors.index.json"
196
+ if index_path.exists():
197
+ with open(index_path, "r") as f:
198
+ index = json.load(f)
199
+
200
+ weight_map: dict[str, str] = index.get("weight_map", {})
201
+ if not weight_map:
202
+ raise RuntimeError(f"Empty weight_map in {index_path}")
203
+
204
+ file_tensors: dict[str, list[str]] = defaultdict(list)
205
+ for tensor_name, filename in weight_map.items():
206
+ abs_path = str(model_dir / filename)
207
+ file_tensors[abs_path].append(tensor_name)
208
+
209
+ logger.info(
210
+ "Found sharded model: %d files, %d tensors",
211
+ len(file_tensors), len(weight_map),
212
+ )
213
+ return dict(file_tensors)
214
+
215
+ # Try single file
216
+ single_path = model_dir / "model.safetensors"
217
+ if single_path.exists():
218
+ header_info = self._parse_safetensors_header(str(single_path))
219
+ tensor_names = list(header_info.keys())
220
+ logger.info("Found single safetensors file: %d tensors", len(tensor_names))
221
+ return {str(single_path): tensor_names}
222
+
223
+ # Fallback: glob
224
+ st_files = sorted(model_dir.glob("*.safetensors"))
225
+ if not st_files:
226
+ raise FileNotFoundError(f"No .safetensors files found in {model_path}")
227
+
228
+ file_tensors_map: dict[str, list[str]] = {}
229
+ for st_file in st_files:
230
+ header_info = self._parse_safetensors_header(str(st_file))
231
+ file_tensors_map[str(st_file)] = list(header_info.keys())
232
+
233
+ total = sum(len(v) for v in file_tensors_map.values())
234
+ logger.info(
235
+ "Found %d safetensors files via glob: %d tensors",
236
+ len(file_tensors_map), total,
237
+ )
238
+ return file_tensors_map
239
+
240
+ def _parse_safetensors_header(self, file_path: str) -> dict[str, dict]:
241
+ """
242
+ Parse a safetensors file header without loading tensor data.
243
+
244
+ Returns:
245
+ {tensor_name: {"file_offset": int, "size": int, "dtype": str, "shape": list}}
246
+ """
247
+ with open(file_path, "rb") as f:
248
+ raw = f.read(8)
249
+ if len(raw) < 8:
250
+ raise RuntimeError(f"Invalid safetensors file: {file_path}")
251
+
252
+ header_size = struct.unpack("<Q", raw)[0]
253
+
254
+ if header_size > 100 * 1024 * 1024:
255
+ raise RuntimeError(
256
+ f"Safetensors header too large ({header_size} bytes): {file_path}"
257
+ )
258
+
259
+ header_bytes = f.read(header_size)
260
+
261
+ header = json.loads(header_bytes)
262
+ data_start = 8 + header_size
263
+
264
+ result: dict[str, dict] = {}
265
+ for name, info in header.items():
266
+ if name == "__metadata__":
267
+ continue
268
+
269
+ offsets = info["data_offsets"]
270
+ result[name] = {
271
+ "file_offset": data_start + offsets[0],
272
+ "size": offsets[1] - offsets[0],
273
+ "dtype": info["dtype"],
274
+ "shape": info["shape"],
275
+ }
276
+
277
+ return result
278
+
279
+ def _load_file_tensors(
280
+ self,
281
+ file_path: str,
282
+ tensor_infos: dict[str, dict],
283
+ ) -> dict[str, torch.Tensor]:
284
+ """
285
+ Load all tensors from one safetensors file via GDS.
286
+
287
+ All tensors are submitted in a single batch so GDS_MT threads
288
+ work in parallel. Reads go directly into result tensors.
289
+ """
290
+ device = torch.device("cuda", self._device_id)
291
+
292
+ sorted_names = sorted(
293
+ tensor_infos.keys(),
294
+ key=lambda n: tensor_infos[n]["file_offset"],
295
+ )
296
+
297
+ tensor_list = []
298
+ tensor_meta = []
299
+ for name in sorted_names:
300
+ info = tensor_infos[name]
301
+ st_dtype = info["dtype"]
302
+ torch_dtype = SAFETENSORS_DTYPE_MAP.get(st_dtype)
303
+ if torch_dtype is None:
304
+ raise RuntimeError(
305
+ f"Unsupported safetensors dtype '{st_dtype}' "
306
+ f"for tensor '{name}'"
307
+ )
308
+ tensor_list.append((info["file_offset"], info["size"]))
309
+ tensor_meta.append((name, torch_dtype, info["shape"]))
310
+
311
+ fd = os.open(file_path, os.O_RDONLY)
312
+ file_size = os.fstat(fd).st_size
313
+
314
+ try:
315
+ raw_tensors = self._gds_manager.batch_load_file(
316
+ fd, file_size, tensor_list, device,
317
+ )
318
+ finally:
319
+ os.close(fd)
320
+
321
+ result: dict[str, torch.Tensor] = {}
322
+ for raw, (name, torch_dtype, shape) in zip(raw_tensors, tensor_meta, strict=True):
323
+ result[name] = raw.view(torch_dtype).reshape(shape)
324
+
325
+ logger.info("Loaded %s", Path(file_path).name)
326
+ return result
327
+
328
+ def shutdown(self) -> None:
329
+ """Release GDS resources."""
330
+ if self._gds_manager is not None:
331
+ self._gds_manager.shutdown()
332
+ self._gds_manager = None