modelexpress 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. modelexpress-0.4.0/PKG-INFO +157 -0
  2. modelexpress-0.4.0/README.md +119 -0
  3. modelexpress-0.4.0/modelexpress/__init__.py +84 -0
  4. modelexpress-0.4.0/modelexpress/adapter.py +176 -0
  5. modelexpress-0.4.0/modelexpress/client.py +225 -0
  6. modelexpress-0.4.0/modelexpress/engines/__init__.py +4 -0
  7. modelexpress-0.4.0/modelexpress/engines/sglang/__init__.py +18 -0
  8. modelexpress-0.4.0/modelexpress/engines/sglang/adapter.py +344 -0
  9. modelexpress-0.4.0/modelexpress/engines/sglang/loader.py +407 -0
  10. modelexpress-0.4.0/modelexpress/engines/vllm/__init__.py +29 -0
  11. modelexpress-0.4.0/modelexpress/engines/vllm/adapter.py +249 -0
  12. modelexpress-0.4.0/modelexpress/engines/vllm/loader.py +152 -0
  13. modelexpress-0.4.0/modelexpress/engines/vllm/registration.py +69 -0
  14. modelexpress-0.4.0/modelexpress/gds_loader.py +332 -0
  15. modelexpress-0.4.0/modelexpress/gds_transfer.py +238 -0
  16. modelexpress-0.4.0/modelexpress/lifecycle.py +98 -0
  17. modelexpress-0.4.0/modelexpress/load_strategy/__init__.py +132 -0
  18. modelexpress-0.4.0/modelexpress/load_strategy/base.py +269 -0
  19. modelexpress-0.4.0/modelexpress/load_strategy/context.py +71 -0
  20. modelexpress-0.4.0/modelexpress/load_strategy/default_strategy.py +38 -0
  21. modelexpress-0.4.0/modelexpress/load_strategy/gds_strategy.py +64 -0
  22. modelexpress-0.4.0/modelexpress/load_strategy/model_streamer_strategy.py +96 -0
  23. modelexpress-0.4.0/modelexpress/load_strategy/rdma_strategy.py +307 -0
  24. modelexpress-0.4.0/modelexpress/metadata/__init__.py +4 -0
  25. modelexpress-0.4.0/modelexpress/metadata/client_factory.py +61 -0
  26. modelexpress-0.4.0/modelexpress/metadata/heartbeat.py +142 -0
  27. modelexpress-0.4.0/modelexpress/metadata/k8s_service_client.py +318 -0
  28. modelexpress-0.4.0/modelexpress/metadata/publish.py +295 -0
  29. modelexpress-0.4.0/modelexpress/metadata/source_id.py +67 -0
  30. modelexpress-0.4.0/modelexpress/metadata/worker_server.py +141 -0
  31. modelexpress-0.4.0/modelexpress/nixl_transfer.py +612 -0
  32. modelexpress-0.4.0/modelexpress/p2p_pb2.py +79 -0
  33. modelexpress-0.4.0/modelexpress/p2p_pb2_grpc.py +341 -0
  34. modelexpress-0.4.0/modelexpress/rank_utils.py +30 -0
  35. modelexpress-0.4.0/modelexpress/tensor_utils.py +299 -0
  36. modelexpress-0.4.0/modelexpress/tracing.py +18 -0
  37. modelexpress-0.4.0/modelexpress/transfer_safety.py +202 -0
  38. modelexpress-0.4.0/modelexpress/trtllm_live_transfer.py +626 -0
  39. modelexpress-0.4.0/modelexpress/types.py +46 -0
  40. modelexpress-0.4.0/modelexpress/ucx_utils.py +426 -0
  41. modelexpress-0.4.0/modelexpress/vllm_loader.py +8 -0
  42. modelexpress-0.4.0/modelexpress/vllm_worker.py +26 -0
  43. modelexpress-0.4.0/modelexpress/vmm/__init__.py +58 -0
  44. modelexpress-0.4.0/modelexpress/vmm/_alloc_ext.cpp +710 -0
  45. modelexpress-0.4.0/modelexpress/vmm/arena.py +378 -0
  46. modelexpress-0.4.0/modelexpress/vmm/backend.py +299 -0
  47. modelexpress-0.4.0/modelexpress/vmm/hook.py +209 -0
  48. modelexpress-0.4.0/modelexpress/vmm/runtime.py +250 -0
  49. modelexpress-0.4.0/modelexpress.egg-info/PKG-INFO +157 -0
  50. modelexpress-0.4.0/modelexpress.egg-info/SOURCES.txt +74 -0
  51. modelexpress-0.4.0/modelexpress.egg-info/dependency_links.txt +1 -0
  52. modelexpress-0.4.0/modelexpress.egg-info/entry_points.txt +2 -0
  53. modelexpress-0.4.0/modelexpress.egg-info/requires.txt +23 -0
  54. modelexpress-0.4.0/modelexpress.egg-info/top_level.txt +1 -0
  55. modelexpress-0.4.0/pyproject.toml +68 -0
  56. modelexpress-0.4.0/setup.cfg +4 -0
  57. modelexpress-0.4.0/setup.py +96 -0
  58. modelexpress-0.4.0/tests/test_adapter.py +53 -0
  59. modelexpress-0.4.0/tests/test_gds_loader.py +324 -0
  60. modelexpress-0.4.0/tests/test_heartbeat.py +185 -0
  61. modelexpress-0.4.0/tests/test_k8s_service_client.py +414 -0
  62. modelexpress-0.4.0/tests/test_lifecycle.py +167 -0
  63. modelexpress-0.4.0/tests/test_model_streamer_strategy.py +464 -0
  64. modelexpress-0.4.0/tests/test_nixl_backend.py +62 -0
  65. modelexpress-0.4.0/tests/test_pool_registration.py +248 -0
  66. modelexpress-0.4.0/tests/test_sglang_loader.py +519 -0
  67. modelexpress-0.4.0/tests/test_source_id.py +100 -0
  68. modelexpress-0.4.0/tests/test_tensor_utils.py +336 -0
  69. modelexpress-0.4.0/tests/test_tracing.py +78 -0
  70. modelexpress-0.4.0/tests/test_transfer_safety.py +191 -0
  71. modelexpress-0.4.0/tests/test_types.py +126 -0
  72. modelexpress-0.4.0/tests/test_vllm_adapter.py +150 -0
  73. modelexpress-0.4.0/tests/test_vllm_loader.py +1483 -0
  74. modelexpress-0.4.0/tests/test_vmm_arena.py +672 -0
  75. modelexpress-0.4.0/tests/test_vmm_backend.py +279 -0
  76. modelexpress-0.4.0/tests/test_vmm_hook.py +459 -0
@@ -0,0 +1,157 @@
1
+ Metadata-Version: 2.4
2
+ Name: modelexpress
3
+ Version: 0.4.0
4
+ Summary: Python client for ModelExpress P2P GPU transfer service
5
+ Author-email: NVIDIA Corporation <sw-dl-dynamo@nvidia.com>
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/ai-dynamo/modelexpress
8
+ Project-URL: Repository, https://github.com/ai-dynamo/modelexpress.git
9
+ Project-URL: Documentation, https://github.com/ai-dynamo/modelexpress/tree/main/docs
10
+ Keywords: llm,gpu,transfer,rdma,nixl,vllm
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Requires-Python: >=3.10
19
+ Description-Content-Type: text/markdown
20
+ Requires-Dist: grpcio>=1.66.2
21
+ Requires-Dist: huggingface_hub>=0.20.0
22
+ Requires-Dist: nixl[cu12]; sys_platform == "linux"
23
+ Requires-Dist: numpy>=1.24.0
24
+ Requires-Dist: protobuf<6.0.0,>=5.27.0
25
+ Requires-Dist: pydantic>=2.0.0
26
+ Requires-Dist: torch>=2.6.0
27
+ Requires-Dist: runai-model-streamer[azure,gcs,s3]; sys_platform == "linux"
28
+ Provides-Extra: dev
29
+ Requires-Dist: grpcio-tools<=1.66.2,>=1.60.0; extra == "dev"
30
+ Requires-Dist: opentelemetry-api>=1.41.1; extra == "dev"
31
+ Requires-Dist: opentelemetry-sdk>=1.41.1; extra == "dev"
32
+ Requires-Dist: pytest>=7.0.0; extra == "dev"
33
+ Requires-Dist: pytest-asyncio>=0.21.0; extra == "dev"
34
+ Provides-Extra: otel
35
+ Requires-Dist: opentelemetry-api>=1.20.0; extra == "otel"
36
+ Provides-Extra: vmm
37
+ Requires-Dist: cuda-python>=12.0; extra == "vmm"
38
+
39
+ # ModelExpress Python Client
40
+
41
+ Python client for ModelExpress -- high-performance GPU-to-GPU model weight transfers using NVIDIA NIXL over RDMA/InfiniBand.
42
+
43
+ Instead of each vLLM instance loading model weights from storage, one "source" instance loads the model and transfers weights directly to "target" instances via GPUDirect RDMA, bypassing the CPU entirely.
44
+
45
+ ## Installation
46
+
47
+ ```bash
48
+ # From PyPI (coming soon)
49
+ pip install modelexpress
50
+
51
+ # Editable install from source
52
+ pip install -e .
53
+
54
+ # With dev dependencies (pytest, grpcio-tools)
55
+ pip install -e ".[dev]"
56
+ ```
57
+
58
+ ### Requirements
59
+
60
+ - Python >= 3.10
61
+ - NVIDIA GPUs with RDMA/InfiniBand support
62
+ - [NIXL](https://github.com/ai-dynamo/nixl) (NVIDIA Interconnect eXchange Library)
63
+ - A running [ModelExpress server](https://github.com/ai-dynamo/modelexpress/tree/main/modelexpress_server) (Rust gRPC service backed by Redis)
64
+
65
+ ## Quick Start with vLLM
66
+
67
+ ModelExpress integrates with vLLM via custom model loaders. vLLM can discover the package through its `vllm.general_plugins` entrypoint; set `VLLM_PLUGINS=modelexpress` if your vLLM deployment requires explicit plugin selection. For manual registration, call `register_modelexpress_loaders()` in your code.
68
+
69
+ ```bash
70
+ export MX_SERVER_ADDRESS="modelexpress-server:8001"
71
+
72
+ vllm serve deepseek-ai/DeepSeek-V3 \
73
+ --load-format modelexpress \
74
+ --tensor-parallel-size 8
75
+ ```
76
+
77
+ Starting the vLLM engine with the `modelexpress` load format on the source worker will load the weights from disk and register/publish the NIXL and tensor metadata to the MX server. The `mx` load format is kept as a backward-compatible alias.
78
+ And on the target worker, it will retrieve these metadata from MX serverand stream weights over RDMA from GPU to GPU.
79
+
80
+ ## Programmatic Usage
81
+
82
+ ### MxClient
83
+
84
+ `MxClient` is a lightweight gRPC client for communicating with the ModelExpress server:
85
+
86
+ ```python
87
+ from modelexpress import MxClient
88
+
89
+ client = MxClient(server_url="modelexpress-server:8001")
90
+
91
+ # Query for a source model
92
+ response = client.get_metadata("deepseek-ai/DeepSeek-V3")
93
+ if response.found:
94
+ for worker in response.workers:
95
+ print(f"Worker rank {worker.worker_rank}: {len(worker.tensors)} tensors")
96
+
97
+ # Wait for source readiness (blocks until ready or timeout)
98
+ success, session_id, metadata_hash = client.wait_for_ready(
99
+ model_name="deepseek-ai/DeepSeek-V3",
100
+ worker_id=0,
101
+ timeout_seconds=7200,
102
+ )
103
+
104
+ client.close()
105
+ ```
106
+
107
+ ### Registering Loaders Manually
108
+
109
+ ```python
110
+ from modelexpress import register_modelexpress_loaders
111
+
112
+ register_modelexpress_loaders()
113
+ # Now vLLM recognizes --load-format modelexpress and mx
114
+ ```
115
+
116
+ ## Environment Variables
117
+
118
+ | Variable | Default | Description |
119
+ |----------|---------|-------------|
120
+ | `MX_SERVER_ADDRESS` | `localhost:8001` | ModelExpress gRPC server address (recommended) |
121
+ | `MODEL_EXPRESS_URL` | `localhost:8001` | Deprecated, pending removal in a future release. Still read by all client paths and takes precedence when both are set; keep setting it during the transition. |
122
+ | `MX_EXPECTED_WORKERS` | Auto-detected from TP size | Number of GPU workers to coordinate |
123
+ | `MX_SYNC_PUBLISH` | `0` | Source: wait for all workers before publishing metadata |
124
+ | `MX_SYNC_START` | `1` | Target: wait for all source workers before transferring |
125
+ | `MX_POOL_REG` | `0` | Allocation-level NIXL registration (registers cudaMalloc blocks instead of individual tensors) |
126
+
127
+ ### UCX/NIXL Tuning
128
+
129
+ | Variable | Recommended | Description |
130
+ |----------|-------------|-------------|
131
+ | `UCX_RNDV_SCHEME` | `get_zcopy` | Zero-copy RDMA reads |
132
+ | `UCX_RNDV_THRESH` | `0` | Force rendezvous for all transfers |
133
+ | `NIXL_LOG_LEVEL` | `INFO` | NIXL logging level |
134
+
135
+ ## Package Structure
136
+
137
+ | Module | Description |
138
+ |--------|-------------|
139
+ | `modelexpress.client` | `MxClient` -- gRPC client for the ModelExpress server |
140
+ | `modelexpress.metadata` | Metadata clients, source identity, heartbeat, and worker manifest serving |
141
+ | `modelexpress.engines.vllm.loader` | `MxModelLoader` -- vLLM integration |
142
+ | `modelexpress.vllm_loader` | Compatibility shim for the vLLM loader |
143
+ | `modelexpress.nixl_transfer` | `NixlTransferManager` -- NIXL agent lifecycle and RDMA transfers |
144
+ | `modelexpress.types` | `TensorDescriptor`, `WorkerMetadata` -- core data types |
145
+ | `modelexpress.vllm_worker` | Compatibility worker extension for older manual-registration workflows |
146
+
147
+ ## How It Works
148
+
149
+ 1. **Source** loads weights from disk, registers raw tensors with NIXL *before* FP8 processing, and publishes metadata to the ModelExpress server.
150
+ 2. **Target** creates dummy weights, waits for the source ready flag, then pulls raw tensors via RDMA read.
151
+ 3. Both source and target run `process_weights_after_loading()` independently, producing identical FP8-transformed weights.
152
+
153
+ This pre-processing transfer strategy is critical for FP8 models (e.g., DeepSeek-V3) where `weight_scale_inv` tensors are renamed and transformed during processing.
154
+
155
+ ## License
156
+
157
+ Apache-2.0
@@ -0,0 +1,119 @@
1
+ # ModelExpress Python Client
2
+
3
+ Python client for ModelExpress -- high-performance GPU-to-GPU model weight transfers using NVIDIA NIXL over RDMA/InfiniBand.
4
+
5
+ Instead of each vLLM instance loading model weights from storage, one "source" instance loads the model and transfers weights directly to "target" instances via GPUDirect RDMA, bypassing the CPU entirely.
6
+
7
+ ## Installation
8
+
9
+ ```bash
10
+ # From PyPI (coming soon)
11
+ pip install modelexpress
12
+
13
+ # Editable install from source
14
+ pip install -e .
15
+
16
+ # With dev dependencies (pytest, grpcio-tools)
17
+ pip install -e ".[dev]"
18
+ ```
19
+
20
+ ### Requirements
21
+
22
+ - Python >= 3.10
23
+ - NVIDIA GPUs with RDMA/InfiniBand support
24
+ - [NIXL](https://github.com/ai-dynamo/nixl) (NVIDIA Interconnect eXchange Library)
25
+ - A running [ModelExpress server](https://github.com/ai-dynamo/modelexpress/tree/main/modelexpress_server) (Rust gRPC service backed by Redis)
26
+
27
+ ## Quick Start with vLLM
28
+
29
+ ModelExpress integrates with vLLM via custom model loaders. vLLM can discover the package through its `vllm.general_plugins` entrypoint; set `VLLM_PLUGINS=modelexpress` if your vLLM deployment requires explicit plugin selection. For manual registration, call `register_modelexpress_loaders()` in your code.
30
+
31
+ ```bash
32
+ export MX_SERVER_ADDRESS="modelexpress-server:8001"
33
+
34
+ vllm serve deepseek-ai/DeepSeek-V3 \
35
+ --load-format modelexpress \
36
+ --tensor-parallel-size 8
37
+ ```
38
+
39
+ Starting the vLLM engine with the `modelexpress` load format on the source worker will load the weights from disk and register/publish the NIXL and tensor metadata to the MX server. The `mx` load format is kept as a backward-compatible alias.
40
+ And on the target worker, it will retrieve these metadata from MX serverand stream weights over RDMA from GPU to GPU.
41
+
42
+ ## Programmatic Usage
43
+
44
+ ### MxClient
45
+
46
+ `MxClient` is a lightweight gRPC client for communicating with the ModelExpress server:
47
+
48
+ ```python
49
+ from modelexpress import MxClient
50
+
51
+ client = MxClient(server_url="modelexpress-server:8001")
52
+
53
+ # Query for a source model
54
+ response = client.get_metadata("deepseek-ai/DeepSeek-V3")
55
+ if response.found:
56
+ for worker in response.workers:
57
+ print(f"Worker rank {worker.worker_rank}: {len(worker.tensors)} tensors")
58
+
59
+ # Wait for source readiness (blocks until ready or timeout)
60
+ success, session_id, metadata_hash = client.wait_for_ready(
61
+ model_name="deepseek-ai/DeepSeek-V3",
62
+ worker_id=0,
63
+ timeout_seconds=7200,
64
+ )
65
+
66
+ client.close()
67
+ ```
68
+
69
+ ### Registering Loaders Manually
70
+
71
+ ```python
72
+ from modelexpress import register_modelexpress_loaders
73
+
74
+ register_modelexpress_loaders()
75
+ # Now vLLM recognizes --load-format modelexpress and mx
76
+ ```
77
+
78
+ ## Environment Variables
79
+
80
+ | Variable | Default | Description |
81
+ |----------|---------|-------------|
82
+ | `MX_SERVER_ADDRESS` | `localhost:8001` | ModelExpress gRPC server address (recommended) |
83
+ | `MODEL_EXPRESS_URL` | `localhost:8001` | Deprecated, pending removal in a future release. Still read by all client paths and takes precedence when both are set; keep setting it during the transition. |
84
+ | `MX_EXPECTED_WORKERS` | Auto-detected from TP size | Number of GPU workers to coordinate |
85
+ | `MX_SYNC_PUBLISH` | `0` | Source: wait for all workers before publishing metadata |
86
+ | `MX_SYNC_START` | `1` | Target: wait for all source workers before transferring |
87
+ | `MX_POOL_REG` | `0` | Allocation-level NIXL registration (registers cudaMalloc blocks instead of individual tensors) |
88
+
89
+ ### UCX/NIXL Tuning
90
+
91
+ | Variable | Recommended | Description |
92
+ |----------|-------------|-------------|
93
+ | `UCX_RNDV_SCHEME` | `get_zcopy` | Zero-copy RDMA reads |
94
+ | `UCX_RNDV_THRESH` | `0` | Force rendezvous for all transfers |
95
+ | `NIXL_LOG_LEVEL` | `INFO` | NIXL logging level |
96
+
97
+ ## Package Structure
98
+
99
+ | Module | Description |
100
+ |--------|-------------|
101
+ | `modelexpress.client` | `MxClient` -- gRPC client for the ModelExpress server |
102
+ | `modelexpress.metadata` | Metadata clients, source identity, heartbeat, and worker manifest serving |
103
+ | `modelexpress.engines.vllm.loader` | `MxModelLoader` -- vLLM integration |
104
+ | `modelexpress.vllm_loader` | Compatibility shim for the vLLM loader |
105
+ | `modelexpress.nixl_transfer` | `NixlTransferManager` -- NIXL agent lifecycle and RDMA transfers |
106
+ | `modelexpress.types` | `TensorDescriptor`, `WorkerMetadata` -- core data types |
107
+ | `modelexpress.vllm_worker` | Compatibility worker extension for older manual-registration workflows |
108
+
109
+ ## How It Works
110
+
111
+ 1. **Source** loads weights from disk, registers raw tensors with NIXL *before* FP8 processing, and publishes metadata to the ModelExpress server.
112
+ 2. **Target** creates dummy weights, waits for the source ready flag, then pulls raw tensors via RDMA read.
113
+ 3. Both source and target run `process_weights_after_loading()` independently, producing identical FP8-transformed weights.
114
+
115
+ This pre-processing transfer strategy is critical for FP8 models (e.g., DeepSeek-V3) where `weight_scale_inv` tensors are renamed and transformed during processing.
116
+
117
+ ## License
118
+
119
+ Apache-2.0
@@ -0,0 +1,84 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ ModelExpress - High-performance GPU-to-GPU model weight transfers.
6
+
7
+ This package provides:
8
+ - NIXL-based RDMA transfers for GPU tensors
9
+ - GPUDirect Storage (GDS) for direct file-to-GPU loading
10
+ - vLLM worker extension for serving model weights
11
+ - Custom model loaders for FP8 model support (DeepSeek-V3, etc.)
12
+
13
+ Quick Start (vLLM):
14
+ from modelexpress import register_modelexpress_loaders
15
+ register_modelexpress_loaders()
16
+
17
+ # vllm serve model --load-format modelexpress
18
+ # Auto-detects: RDMA -> GDS -> disk
19
+ """
20
+
21
+ import logging
22
+ import os
23
+
24
+ _logger = logging.getLogger(__name__)
25
+ _loaders_registered = False
26
+
27
+
28
+ def configure_vllm_logging():
29
+ """Ensure modelexpress loggers are visible in vLLM worker subprocesses.
30
+
31
+ vLLM only attaches log handlers to the "vllm" namespace. Without this,
32
+ all "modelexpress.*" output is silently dropped in EngineCore worker
33
+ processes. Copies vLLM's handlers onto the "modelexpress" parent logger
34
+ so every child inherits them via propagation. Idempotent.
35
+ """
36
+ mx_root = logging.getLogger("modelexpress")
37
+ if mx_root.handlers:
38
+ return
39
+ vllm_logger = logging.getLogger("vllm")
40
+ for handler in vllm_logger.handlers:
41
+ mx_root.addHandler(handler)
42
+ mx_level = os.environ.get("MODEL_EXPRESS_LOG_LEVEL", "").upper()
43
+ if mx_level and hasattr(logging, mx_level):
44
+ mx_root.setLevel(getattr(logging, mx_level))
45
+ elif vllm_logger.level != logging.NOTSET:
46
+ mx_root.setLevel(vllm_logger.level)
47
+
48
+
49
+ def register_modelexpress_loaders():
50
+ """
51
+ Register ModelExpress loaders with vLLM.
52
+
53
+ This function ensures loaders are registered exactly once. It can be called
54
+ multiple times safely (idempotent).
55
+
56
+ Enables:
57
+ --load-format modelexpress (auto-detect: RDMA -> GDS -> disk)
58
+ --load-format mx (backward-compatible alias)
59
+ """
60
+ global _loaders_registered
61
+ if _loaders_registered:
62
+ return
63
+
64
+ from .engines.vllm import register_modelexpress_loaders as register_vllm_loaders
65
+
66
+ register_vllm_loaders()
67
+
68
+ _loaders_registered = True
69
+ _logger.debug("ModelExpress loaders registered")
70
+
71
+
72
+ from .client import MxClient # noqa: F401
73
+ from .gds_loader import MxGdsLoader # noqa: F401
74
+ from .gds_transfer import GdsTransferManager # noqa: F401
75
+ from .metadata.heartbeat import HeartbeatThread # noqa: F401
76
+
77
+ __all__ = [
78
+ "GdsTransferManager",
79
+ "HeartbeatThread",
80
+ "MxClient",
81
+ "MxGdsLoader",
82
+ "configure_vllm_logging",
83
+ "register_modelexpress_loaders",
84
+ ]
@@ -0,0 +1,176 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Engine adapter contract for ModelExpress loading strategies."""
5
+
6
+ from __future__ import annotations
7
+
8
+ import functools
9
+ import os
10
+ from typing import TYPE_CHECKING, Iterator
11
+
12
+ import torch
13
+
14
+ from . import p2p_pb2
15
+
16
+ if TYPE_CHECKING:
17
+ from .load_strategy.context import LoadResult
18
+
19
+
20
+ class UnsupportedCapability(NotImplementedError):
21
+ """Raised by default bodies for adapter capabilities."""
22
+
23
+
24
+ class StrategyFailed(RuntimeError):
25
+ """Expected strategy miss that lets the chain try the next strategy.
26
+
27
+ Set mutated=True when the current strategy may have changed model weights
28
+ or model structure before failing. The chain will run rollback() for
29
+ strategy-owned cleanup, then ask the adapter to re-initialize the model.
30
+ """
31
+
32
+ def __init__(self, message: str, *, mutated: bool = False):
33
+ super().__init__(message)
34
+ self.mutated = mutated
35
+
36
+
37
+ def gated_capability(method):
38
+ """Create an optional adapter method that engines must override to support it.
39
+
40
+ Strategies compare their required EngineAdapter methods by object identity.
41
+ If an engine inherits this default method, the strategy is not eligible; if
42
+ the engine overrides it, the strategy may call the engine implementation.
43
+ Calling the inherited default still raises UnsupportedCapability.
44
+ """
45
+
46
+ @functools.wraps(method)
47
+ def default(self, *args, **kwargs):
48
+ raise UnsupportedCapability(method.__name__)
49
+
50
+ default.__gated_capability__ = True
51
+ return default
52
+
53
+
54
+ class EngineAdapter:
55
+ """Engine-specific boundary used by shared loading strategies.
56
+
57
+ Load strategies own ModelExpress policy: fallback order, metadata
58
+ publishing, RDMA transfer, and retry decisions. Engine adapters own the
59
+ operations that depend on a framework's model object, device mapping, and
60
+ post-load processing rules.
61
+
62
+ Methods decorated with @gated_capability are optional capabilities. A
63
+ strategy lists the exact adapter methods it needs in its `requires` tuple;
64
+ the strategy is eligible only when the concrete engine overrides those
65
+ methods. Plain lifecycle hooks below are no-op by default because they are
66
+ safe extension points around an already-selected strategy.
67
+ """
68
+
69
+ @gated_capability
70
+ def build_identity(self) -> p2p_pb2.SourceIdentity:
71
+ """Return the stable identity used to match compatible source workers."""
72
+ ...
73
+
74
+ @gated_capability
75
+ def get_worker_rank(self) -> int:
76
+ """Return the model-shard key used to pair source and target workers.
77
+
78
+ Data-parallel replicas should normally return the same key because
79
+ their weights are interchangeable. Tensor, pipeline, and expert
80
+ parallel shards should return distinct keys when they own distinct
81
+ model weights.
82
+ """
83
+ ...
84
+
85
+ @gated_capability
86
+ def get_global_rank(self) -> int:
87
+ """Return the engine's global distributed rank for logging and metadata."""
88
+ ...
89
+
90
+ @gated_capability
91
+ def get_device_id(self) -> int:
92
+ """Return the local CUDA device id owned by this adapter instance."""
93
+ ...
94
+
95
+ @gated_capability
96
+ def discover_tensors(self, result: LoadResult) -> dict[str, torch.Tensor]:
97
+ """Return publishable tensors from the loaded engine model.
98
+
99
+ Strategies call this after weights are ready. Implementations may run
100
+ engine-specific tensor adoption or normalization before collecting the
101
+ tensors, but should not change model weights.
102
+ """
103
+ ...
104
+
105
+ @gated_capability
106
+ def apply_weight_iter(
107
+ self,
108
+ result: LoadResult,
109
+ weights_iter: Iterator[tuple[str, torch.Tensor]],
110
+ ) -> LoadResult:
111
+ """Apply a stream of named tensors to the engine model.
112
+
113
+ This hook may mutate model weights. If a strategy catches an exception
114
+ from this method, it should raise StrategyFailed(mutated=True) so the
115
+ chain reinitializes the model before trying another strategy.
116
+ """
117
+ ...
118
+
119
+ @gated_capability
120
+ def build_model_streamer_weight_iter(
121
+ self,
122
+ model_uri: str,
123
+ model: torch.nn.Module | None = None,
124
+ ) -> Iterator[tuple[str, torch.Tensor]]:
125
+ """Return the engine-native ModelStreamer weight iterator.
126
+
127
+ The shared strategy owns fallback and registration policy. Engines own
128
+ the concrete ModelStreamer integration because native loader APIs and
129
+ distributed-device selection are framework-specific. Some engine-native
130
+ loaders need the initialized model to discover secondary weight sources.
131
+ """
132
+ ...
133
+
134
+ @gated_capability
135
+ def load_via_native(self, result: LoadResult) -> LoadResult:
136
+ """Load the model using the engine's native disk/checkpoint loader."""
137
+ ...
138
+
139
+ @gated_capability
140
+ def reinit_for_retry(self, result: LoadResult) -> LoadResult:
141
+ """Replace a possibly-mutated model with a fresh engine model instance."""
142
+ ...
143
+
144
+ def get_unique_id(self) -> str:
145
+ """Return a best-effort unique id for engines without custom identity."""
146
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
147
+ return f"rank-{torch.distributed.get_rank()}"
148
+ return f"pid-{os.getpid()}"
149
+
150
+ def get_target_device(self) -> torch.device:
151
+ """Return the torch device where target model state should live."""
152
+ return torch.device(f"cuda:{self.get_device_id()}")
153
+
154
+ def is_cuda_alike(self) -> bool:
155
+ """Return whether this engine is running on a CUDA-like platform."""
156
+ return False
157
+
158
+ def prepare_rdma_target(self, result: LoadResult) -> LoadResult:
159
+ """Prepare target-side model storage before receiving RDMA weights."""
160
+ return result
161
+
162
+ def before_rdma_receive(self, result: LoadResult) -> LoadResult:
163
+ """Run engine post-processing needed before RDMA writes into tensors."""
164
+ return result
165
+
166
+ def after_rdma_receive(self, result: LoadResult) -> LoadResult:
167
+ """Run engine post-processing after RDMA weights have been received."""
168
+ return result
169
+
170
+ def after_weight_iter_load(self, result: LoadResult) -> LoadResult:
171
+ """Run engine post-processing after apply_weight_iter() succeeds."""
172
+ return result
173
+
174
+ def after_native_load(self, result: LoadResult) -> LoadResult:
175
+ """Run engine post-processing after load_via_native() succeeds."""
176
+ return result