openstef-foundation-models 4.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstef_foundation_models/__init__.py +12 -0
- openstef_foundation_models/inference/__init__.py +32 -0
- openstef_foundation_models/inference/backend.py +43 -0
- openstef_foundation_models/inference/onnx_backend.py +237 -0
- openstef_foundation_models/inference/provider_selection.py +142 -0
- openstef_foundation_models/inference/providers.py +158 -0
- openstef_foundation_models/integrations/__init__.py +5 -0
- openstef_foundation_models/integrations/beam.py +203 -0
- openstef_foundation_models/models/__init__.py +17 -0
- openstef_foundation_models/models/catalog.py +105 -0
- openstef_foundation_models/models/checkpoint.py +212 -0
- openstef_foundation_models/models/forecasting/__init__.py +21 -0
- openstef_foundation_models/models/forecasting/chronos2_forecaster.py +283 -0
- openstef_foundation_models/presets/__init__.py +23 -0
- openstef_foundation_models/presets/forecasting_workflow.py +205 -0
- openstef_foundation_models-4.1.0.dist-info/METADATA +89 -0
- openstef_foundation_models-4.1.0.dist-info/RECORD +18 -0
- openstef_foundation_models-4.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: MPL-2.0
|
|
4
|
+
"""Foundation model support for OpenSTEF."""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
root_logger = logging.getLogger(name=__name__)
|
|
9
|
+
if not root_logger.handlers:
|
|
10
|
+
root_logger.addHandler(logging.NullHandler())
|
|
11
|
+
|
|
12
|
+
__all__: list[str] = []
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: MPL-2.0
|
|
4
|
+
|
|
5
|
+
"""Inference backends for foundation-model forecasters.
|
|
6
|
+
|
|
7
|
+
An :class:`InferenceBackend` isolates how a checkpoint is executed behind a
|
|
8
|
+
single named-tensor contract; forecasters compose a backend rather than
|
|
9
|
+
inheriting execution behaviour. Only the dependency-free surface (the protocol
|
|
10
|
+
and the execution-provider configs) is re-exported here; a concrete backend
|
|
11
|
+
lives in its own submodule and imports its heavy runtime at module top level.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from openstef_foundation_models.inference.backend import InferenceBackend
|
|
15
|
+
from openstef_foundation_models.inference.providers import (
|
|
16
|
+
CoreMLProvider,
|
|
17
|
+
CpuProvider,
|
|
18
|
+
CudaProvider,
|
|
19
|
+
ExecutionProvider,
|
|
20
|
+
SessionOptionsConfig,
|
|
21
|
+
TensorRTProvider,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"CoreMLProvider",
|
|
26
|
+
"CpuProvider",
|
|
27
|
+
"CudaProvider",
|
|
28
|
+
"ExecutionProvider",
|
|
29
|
+
"InferenceBackend",
|
|
30
|
+
"SessionOptionsConfig",
|
|
31
|
+
"TensorRTProvider",
|
|
32
|
+
]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: MPL-2.0
|
|
4
|
+
|
|
5
|
+
"""The :class:`InferenceBackend` contract shared by all execution backends."""
|
|
6
|
+
|
|
7
|
+
from collections.abc import Mapping
|
|
8
|
+
from typing import Protocol, runtime_checkable
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from openstef_foundation_models.models.checkpoint import CheckpointMetadata
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class InferenceBackend(Protocol):
|
|
17
|
+
"""A model-agnostic execution backend.
|
|
18
|
+
|
|
19
|
+
A backend takes a mapping of named input tensors to a mapping of named
|
|
20
|
+
output tensors. It owns whatever runtime resources are needed (e.g. an ONNX
|
|
21
|
+
Runtime session) and is loaded once, then reused across an entire backtest.
|
|
22
|
+
Model-family specifics live in :attr:`metadata`, not in the backend itself.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def metadata(self) -> CheckpointMetadata:
|
|
27
|
+
"""Metadata describing the checkpoint this backend executes."""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
def run(self, inputs: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
|
|
31
|
+
"""Execute the model on a batch of named input tensors.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
inputs: Named input tensors. Keys must match ``metadata.input_names``.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Named output tensors, including ``metadata.output_name``.
|
|
38
|
+
"""
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
def close(self) -> None:
|
|
42
|
+
"""Release any runtime resources held by the backend."""
|
|
43
|
+
...
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: MPL-2.0
|
|
4
|
+
|
|
5
|
+
"""ONNX Runtime execution backend.
|
|
6
|
+
|
|
7
|
+
Importing this module requires ONNX Runtime (the ``[cpu]`` or ``[gpu]`` extra)
|
|
8
|
+
and raises :class:`MissingExtraError` if it is missing.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Mapping, Sequence
|
|
13
|
+
from typing import Self
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
from openstef_core.exceptions import MissingExtraError
|
|
18
|
+
from openstef_foundation_models.inference.provider_selection import (
|
|
19
|
+
DefaultProviderPolicy,
|
|
20
|
+
HostCapabilities,
|
|
21
|
+
ProviderPolicy,
|
|
22
|
+
)
|
|
23
|
+
from openstef_foundation_models.inference.providers import (
|
|
24
|
+
ExecutionProvider,
|
|
25
|
+
SessionOptionsConfig,
|
|
26
|
+
)
|
|
27
|
+
from openstef_foundation_models.models.checkpoint import CheckpointMetadata, ResolvedCheckpoint
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import onnxruntime as ort
|
|
31
|
+
except ImportError as e:
|
|
32
|
+
raise MissingExtraError("onnxruntime", "openstef-foundation-models", install_extra="cpu") from e
|
|
33
|
+
|
|
34
|
+
# onnxruntime-gpu ships the CUDA execution-provider plugin but loads the CUDA/cuDNN
|
|
35
|
+
# runtime (the nvidia-*-cu12 wheels the [gpu] extra pulls) lazily at session creation.
|
|
36
|
+
# preload_dlls() loads them from the nvidia site-packages so the CUDA provider can be
|
|
37
|
+
# realized without a system CUDA install or LD_LIBRARY_PATH. It is a no-op on the CPU
|
|
38
|
+
# runtime; the guard covers onnxruntime < 1.21, which predates the API.
|
|
39
|
+
if hasattr(ort, "preload_dlls"):
|
|
40
|
+
ort.preload_dlls()
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OnnxBackend:
|
|
46
|
+
"""An :class:`~openstef_foundation_models.inference.backend.InferenceBackend` backed by ONNX Runtime.
|
|
47
|
+
|
|
48
|
+
The session is built once on construction and reused for every
|
|
49
|
+
:meth:`run` call, so a single backend instance can be shared across an
|
|
50
|
+
entire backtest. Users may either let the backend build a session from a
|
|
51
|
+
resolved checkpoint and provider chain, or pass a pre-built session they own.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
metadata: CheckpointMetadata,
|
|
57
|
+
session: ort.InferenceSession,
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Wrap a pre-built ONNX Runtime session.
|
|
60
|
+
|
|
61
|
+
Prefer :meth:`from_checkpoint` unless you need to own the session
|
|
62
|
+
lifecycle yourself.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
metadata: Metadata describing the checkpoint the session executes.
|
|
66
|
+
session: A pre-built ONNX Runtime inference session.
|
|
67
|
+
"""
|
|
68
|
+
self._metadata = metadata
|
|
69
|
+
self._session: ort.InferenceSession | None = session
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def from_checkpoint(
|
|
73
|
+
cls,
|
|
74
|
+
checkpoint: ResolvedCheckpoint,
|
|
75
|
+
providers: Sequence[ExecutionProvider] | None = None,
|
|
76
|
+
session_options: SessionOptionsConfig | None = None,
|
|
77
|
+
*,
|
|
78
|
+
policy: ProviderPolicy | None = None,
|
|
79
|
+
) -> Self:
|
|
80
|
+
"""Build a backend by loading a checkpoint into a new ONNX Runtime session.
|
|
81
|
+
|
|
82
|
+
With ``providers=None`` the *policy* selects a chain from the checkpoint
|
|
83
|
+
and host; an explicit ``providers`` list is used as given and *policy* is
|
|
84
|
+
ignored. See :class:`~openstef_foundation_models.inference.provider_selection.ProviderPolicy`
|
|
85
|
+
for how a chain is chosen and how strictly its realization is enforced.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
checkpoint: The resolved checkpoint (weights + metadata) to load.
|
|
89
|
+
providers: Ordered execution providers to try. ``None`` lets the policy
|
|
90
|
+
pick a host-appropriate chain from the checkpoint metadata.
|
|
91
|
+
session_options: Optional ONNX Runtime session options.
|
|
92
|
+
policy: Selection policy used when ``providers is None``. Defaults to
|
|
93
|
+
:class:`DefaultProviderPolicy`.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
A backend wrapping the newly built session.
|
|
97
|
+
"""
|
|
98
|
+
metadata = checkpoint.metadata
|
|
99
|
+
if providers is not None:
|
|
100
|
+
provider_configs = list(providers)
|
|
101
|
+
explicit = True
|
|
102
|
+
logger.info(
|
|
103
|
+
"Using explicit execution-provider chain %s for checkpoint '%s'.",
|
|
104
|
+
[config.to_ort()[0] for config in provider_configs],
|
|
105
|
+
metadata.model_family,
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
selector = policy or DefaultProviderPolicy()
|
|
109
|
+
host = HostCapabilities.detect()
|
|
110
|
+
provider_configs = selector.select(metadata, host)
|
|
111
|
+
explicit = False
|
|
112
|
+
logger.debug(
|
|
113
|
+
"Detected host: platform=%s, available_providers=%s",
|
|
114
|
+
host.platform,
|
|
115
|
+
sorted(host.available_providers),
|
|
116
|
+
)
|
|
117
|
+
logger.info(
|
|
118
|
+
"%s selected execution-provider chain %s for checkpoint '%s' (precision=%s, static_shapes=%s) on %s.",
|
|
119
|
+
type(selector).__name__,
|
|
120
|
+
[config.to_ort()[0] for config in provider_configs],
|
|
121
|
+
metadata.model_family,
|
|
122
|
+
metadata.precision,
|
|
123
|
+
metadata.static_shapes,
|
|
124
|
+
host.platform,
|
|
125
|
+
)
|
|
126
|
+
ort_providers = [config.to_ort() for config in provider_configs]
|
|
127
|
+
so = _build_session_options(session_options) if session_options else None
|
|
128
|
+
|
|
129
|
+
session = ort.InferenceSession(
|
|
130
|
+
str(checkpoint.weights_path),
|
|
131
|
+
sess_options=so,
|
|
132
|
+
providers=ort_providers,
|
|
133
|
+
)
|
|
134
|
+
logger.info("ONNX Runtime session built; realized providers: %s.", session.get_providers())
|
|
135
|
+
_check_provider_fallback(
|
|
136
|
+
requested=provider_configs,
|
|
137
|
+
realized=session.get_providers(),
|
|
138
|
+
strict=explicit,
|
|
139
|
+
)
|
|
140
|
+
return cls(metadata=checkpoint.metadata, session=session)
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def metadata(self) -> CheckpointMetadata:
|
|
144
|
+
"""Metadata describing the checkpoint this backend executes."""
|
|
145
|
+
return self._metadata
|
|
146
|
+
|
|
147
|
+
def run(self, inputs: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
|
|
148
|
+
"""Execute the ONNX model on a batch of named input tensors.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
inputs: Named input tensors. Keys must match ``metadata.input_names``.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Named output tensors keyed by the model's output names.
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
RuntimeError: If the backend has been closed.
|
|
158
|
+
"""
|
|
159
|
+
if self._session is None:
|
|
160
|
+
msg = "OnnxBackend has been closed."
|
|
161
|
+
raise RuntimeError(msg)
|
|
162
|
+
output_names = [out.name for out in self._session.get_outputs()]
|
|
163
|
+
results = self._session.run(output_names, dict(inputs))
|
|
164
|
+
return {name: np.asarray(result) for name, result in zip(output_names, results, strict=True)}
|
|
165
|
+
|
|
166
|
+
def close(self) -> None:
|
|
167
|
+
"""Release the underlying ONNX Runtime session.
|
|
168
|
+
|
|
169
|
+
ONNX Runtime frees native resources on garbage collection, so dropping
|
|
170
|
+
the reference is the supported way to release them.
|
|
171
|
+
"""
|
|
172
|
+
self._session = None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _build_session_options(config: SessionOptionsConfig) -> ort.SessionOptions:
|
|
176
|
+
"""Translate a :class:`SessionOptionsConfig` into ONNX Runtime session options.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
config: The typed session-options configuration.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
The corresponding ONNX Runtime ``SessionOptions``.
|
|
183
|
+
"""
|
|
184
|
+
so = ort.SessionOptions()
|
|
185
|
+
so.graph_optimization_level = getattr(
|
|
186
|
+
ort.GraphOptimizationLevel,
|
|
187
|
+
f"ORT_{config.graph_optimization_level}",
|
|
188
|
+
)
|
|
189
|
+
if config.intra_op_num_threads is not None:
|
|
190
|
+
so.intra_op_num_threads = config.intra_op_num_threads
|
|
191
|
+
if config.inter_op_num_threads is not None:
|
|
192
|
+
so.inter_op_num_threads = config.inter_op_num_threads
|
|
193
|
+
return so
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _check_provider_fallback(
|
|
197
|
+
requested: Sequence[ExecutionProvider],
|
|
198
|
+
realized: Sequence[str],
|
|
199
|
+
*,
|
|
200
|
+
strict: bool,
|
|
201
|
+
) -> None:
|
|
202
|
+
"""Detect and report a silent fallback to the CPU execution provider.
|
|
203
|
+
|
|
204
|
+
Compares the requested chain against what ONNX Runtime actually realized. See
|
|
205
|
+
:class:`~openstef_foundation_models.inference.provider_selection.ProviderPolicy`
|
|
206
|
+
for the strict-vs-graceful contract this enforces.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
requested: The execution providers that were requested.
|
|
210
|
+
realized: The provider names ONNX Runtime actually loaded.
|
|
211
|
+
strict: When ``True``, raise on any missing accelerator; otherwise warn
|
|
212
|
+
only on a full fallback to CPU.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
RuntimeError: If ``strict`` is set and a requested accelerator is missing.
|
|
216
|
+
"""
|
|
217
|
+
requested_names = {config.to_ort()[0] for config in requested}
|
|
218
|
+
accelerators = requested_names - {"CPUExecutionProvider"}
|
|
219
|
+
if not accelerators:
|
|
220
|
+
return
|
|
221
|
+
realized_set = set(realized)
|
|
222
|
+
missing = accelerators - realized_set
|
|
223
|
+
if strict:
|
|
224
|
+
if missing:
|
|
225
|
+
msg = (
|
|
226
|
+
f"Requested execution provider(s) {sorted(missing)} were not realized; "
|
|
227
|
+
f"ONNX Runtime fell back to {realized}."
|
|
228
|
+
)
|
|
229
|
+
raise RuntimeError(msg)
|
|
230
|
+
return
|
|
231
|
+
if accelerators & realized_set:
|
|
232
|
+
return
|
|
233
|
+
logger.warning(
|
|
234
|
+
"No requested accelerator (%s) was realized; ONNX Runtime fell back to %s. Inference will run on CPU.",
|
|
235
|
+
sorted(accelerators),
|
|
236
|
+
realized,
|
|
237
|
+
)
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: MPL-2.0
|
|
4
|
+
|
|
5
|
+
"""Metadata-driven execution-provider selection.
|
|
6
|
+
|
|
7
|
+
Which ONNX Runtime execution provider is fastest — and even which is *usable* —
|
|
8
|
+
depends on both the host (Apple CoreML vs NVIDIA CUDA/TensorRT vs CPU) and on
|
|
9
|
+
properties of the checkpoint itself (precision, whether its graph has static
|
|
10
|
+
shapes). This module keeps that knowledge in **one replaceable component** rather
|
|
11
|
+
than scattered platform ``if``-ladders:
|
|
12
|
+
|
|
13
|
+
* :class:`HostCapabilities` carries the host facts as injectable data, with a
|
|
14
|
+
single impure :meth:`HostCapabilities.detect` classmethod.
|
|
15
|
+
* :class:`ProviderPolicy` is the port; :class:`DefaultProviderPolicy` is the
|
|
16
|
+
adapter that maps ``(checkpoint, host)`` to an ordered provider chain. Users
|
|
17
|
+
with exotic hardware implement their own policy.
|
|
18
|
+
|
|
19
|
+
Importing this module requires ONNX Runtime (the ``[cpu]`` or ``[gpu]`` extra)
|
|
20
|
+
and raises :class:`MissingExtraError` if it is missing.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import platform
|
|
24
|
+
from typing import Literal, Protocol, Self
|
|
25
|
+
|
|
26
|
+
from pydantic import Field
|
|
27
|
+
|
|
28
|
+
from openstef_core.base_model import BaseConfig
|
|
29
|
+
from openstef_core.exceptions import MissingExtraError
|
|
30
|
+
from openstef_foundation_models.inference.providers import (
|
|
31
|
+
CoreMLProvider,
|
|
32
|
+
CpuProvider,
|
|
33
|
+
CudaProvider,
|
|
34
|
+
ExecutionProvider,
|
|
35
|
+
)
|
|
36
|
+
from openstef_foundation_models.models.checkpoint import CheckpointMetadata
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
import onnxruntime as ort
|
|
40
|
+
except ImportError as e:
|
|
41
|
+
raise MissingExtraError("onnxruntime", "openstef-foundation-models", install_extra="cpu") from e
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class HostCapabilities(BaseConfig):
|
|
45
|
+
"""Execution-relevant facts about the host, captured as injectable data.
|
|
46
|
+
|
|
47
|
+
Passing host facts into a policy (rather than having the policy call
|
|
48
|
+
``platform.system()`` itself) keeps selection a pure function of its inputs,
|
|
49
|
+
so it can be unit-tested by constructing a fake host.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
model_config = BaseConfig.model_config | {"frozen": True}
|
|
53
|
+
|
|
54
|
+
platform: str = Field(
|
|
55
|
+
description="OS identifier, lower-cased (e.g. 'darwin', 'linux', 'windows').",
|
|
56
|
+
)
|
|
57
|
+
available_providers: frozenset[str] = Field(
|
|
58
|
+
description="Execution provider names ONNX Runtime reports as available on this host.",
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def detect(cls) -> Self:
|
|
63
|
+
"""Detect the host's capabilities from the platform and ONNX Runtime.
|
|
64
|
+
|
|
65
|
+
This is the one impure call in the selection path; it is isolated here so
|
|
66
|
+
the policy stays a pure function of injected facts.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The detected host capabilities.
|
|
70
|
+
"""
|
|
71
|
+
return cls(
|
|
72
|
+
platform=platform.system().lower(),
|
|
73
|
+
available_providers=frozenset(ort.get_available_providers()),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class ProviderPolicy(Protocol):
|
|
78
|
+
"""Port mapping a checkpoint and host to an ordered execution-provider chain.
|
|
79
|
+
|
|
80
|
+
Implement this to encode selection rules for hardware the default policy does
|
|
81
|
+
not cover; pass the implementation to the backend or
|
|
82
|
+
:class:`~openstef_foundation_models.presets.forecasting_workflow.OnnxBackendConfig`.
|
|
83
|
+
|
|
84
|
+
A policy-selected chain is enforced *gracefully*: ONNX Runtime silently drops
|
|
85
|
+
accelerators it cannot initialize and falls back to CPU, and a policy chain
|
|
86
|
+
such as ``[CoreML, CPU]`` realizing CoreML is the intended outcome, so a
|
|
87
|
+
warning is logged only if it falls all the way to CPU. A chain the caller
|
|
88
|
+
passes explicitly is enforced *strictly* instead: any requested accelerator
|
|
89
|
+
that is not realized raises.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def select(self, metadata: CheckpointMetadata, host: HostCapabilities) -> list[ExecutionProvider]:
|
|
93
|
+
"""Return the ordered provider chain to try for *metadata* on *host*."""
|
|
94
|
+
...
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class DefaultProviderPolicy(BaseConfig):
|
|
98
|
+
"""Default policy mapping ``(checkpoint precision/shape, host)`` to a provider chain.
|
|
99
|
+
|
|
100
|
+
Each rule encodes a *measured* hardware conclusion; see the design doc
|
|
101
|
+
``design-docs/0001`` and the provider benchmark for the rationale. The chain
|
|
102
|
+
is ordered preferred-first with CPU as the final fallback.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
kind: Literal["default"] = Field(default="default", description="Discriminator tag for the policy type.")
|
|
106
|
+
|
|
107
|
+
def select( # instance method to satisfy the ProviderPolicy protocol, though stateless here
|
|
108
|
+
self, metadata: CheckpointMetadata, host: HostCapabilities
|
|
109
|
+
) -> list[ExecutionProvider]:
|
|
110
|
+
"""Select an ordered provider chain for *metadata* on *host*.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
metadata: The checkpoint's metadata (precision, static-shape-ness).
|
|
114
|
+
host: The detected host capabilities.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
An ordered execution-provider chain, preferred-first, CPU last.
|
|
118
|
+
"""
|
|
119
|
+
cuda_ok = "CUDAExecutionProvider" in host.available_providers
|
|
120
|
+
coreml_ok = "CoreMLExecutionProvider" in host.available_providers and metadata.static_shapes
|
|
121
|
+
|
|
122
|
+
# int8 (QDQ) runs fast on CPU; CoreML cannot accelerate the quantized ops,
|
|
123
|
+
# so it is skipped entirely. CUDA int8 is fine when a GPU is present.
|
|
124
|
+
if metadata.precision == "int8":
|
|
125
|
+
return [CudaProvider(), CpuProvider()] if cuda_ok else [CpuProvider()]
|
|
126
|
+
# macOS: a static-shape fp16/fp32 graph runs on CoreML, but only on the GPU
|
|
127
|
+
# (MLComputeUnits=ALL/ANE triggers a multi-minute Neural-Engine compile for
|
|
128
|
+
# no inference win — measured).
|
|
129
|
+
if host.platform == "darwin" and coreml_ok:
|
|
130
|
+
return [CoreMLProvider(compute_units="CPUAndGPU"), CpuProvider()]
|
|
131
|
+
# NVIDIA: CUDA with a CPU fallback. TensorRT stays opt-in (engine-build cost
|
|
132
|
+
# and fp16 caveats), so the default never selects it.
|
|
133
|
+
if cuda_ok:
|
|
134
|
+
return [CudaProvider(), CpuProvider()]
|
|
135
|
+
return [CpuProvider()]
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
__all__ = [
|
|
139
|
+
"DefaultProviderPolicy",
|
|
140
|
+
"HostCapabilities",
|
|
141
|
+
"ProviderPolicy",
|
|
142
|
+
]
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: MPL-2.0
|
|
4
|
+
|
|
5
|
+
"""Typed ONNX Runtime execution-provider configuration.
|
|
6
|
+
|
|
7
|
+
Each provider is a small pydantic config that compiles to an ONNX Runtime
|
|
8
|
+
``(name, options)`` tuple via :meth:`ExecutionProviderConfig.to_ort`. Keeping
|
|
9
|
+
providers as typed configs (rather than raw strings) lets users opt into
|
|
10
|
+
hardware acceleration — CUDA, TensorRT FP16, CoreML/ANE — without touching
|
|
11
|
+
model code, and keeps the options validated and discoverable.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Annotated, Literal
|
|
16
|
+
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
|
|
19
|
+
from openstef_core.base_model import BaseConfig
|
|
20
|
+
|
|
21
|
+
#: An ONNX Runtime provider specification: ``(provider_name, provider_options)``.
|
|
22
|
+
type OrtProvider = tuple[str, dict[str, object]]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CpuProvider(BaseConfig):
|
|
26
|
+
"""The default CPU execution provider."""
|
|
27
|
+
|
|
28
|
+
kind: Literal["cpu"] = Field(default="cpu", description="Discriminator tag for execution-provider type.")
|
|
29
|
+
|
|
30
|
+
def to_ort(self) -> OrtProvider:
|
|
31
|
+
"""Compile to an ONNX Runtime provider tuple.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The ``CPUExecutionProvider`` with no options.
|
|
35
|
+
"""
|
|
36
|
+
return ("CPUExecutionProvider", {})
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class CudaProvider(BaseConfig):
|
|
40
|
+
"""The CUDA (NVIDIA GPU) execution provider."""
|
|
41
|
+
|
|
42
|
+
kind: Literal["cuda"] = Field(default="cuda", description="Discriminator tag for execution-provider type.")
|
|
43
|
+
device_id: int = Field(default=0, ge=0, description="CUDA device index to run on.")
|
|
44
|
+
|
|
45
|
+
def to_ort(self) -> OrtProvider:
|
|
46
|
+
"""Compile to an ONNX Runtime provider tuple.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The ``CUDAExecutionProvider`` with the configured device id.
|
|
50
|
+
"""
|
|
51
|
+
return ("CUDAExecutionProvider", {"device_id": self.device_id})
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TensorRTProvider(BaseConfig):
|
|
55
|
+
"""The TensorRT execution provider (NVIDIA, ahead-of-time engine build).
|
|
56
|
+
|
|
57
|
+
FP16 with a persistent engine cache is the recommended production path on
|
|
58
|
+
NVIDIA hardware: the first run pays the engine-build cost, subsequent runs
|
|
59
|
+
load the cached engine.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
kind: Literal["tensorrt"] = Field(default="tensorrt", description="Discriminator tag for execution-provider type.")
|
|
63
|
+
device_id: int = Field(default=0, ge=0, description="CUDA device index to run on.")
|
|
64
|
+
fp16: bool = Field(default=True, description="Enable FP16 precision for faster inference.")
|
|
65
|
+
engine_cache_dir: Path | None = Field(
|
|
66
|
+
default=None,
|
|
67
|
+
description="Directory to persist built TensorRT engines. When set, engine caching is enabled.",
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def to_ort(self) -> OrtProvider:
|
|
71
|
+
"""Compile to an ONNX Runtime provider tuple.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The ``TensorrtExecutionProvider`` with precision and engine-cache options.
|
|
75
|
+
"""
|
|
76
|
+
options: dict[str, object] = {
|
|
77
|
+
"device_id": self.device_id,
|
|
78
|
+
"trt_fp16_enable": self.fp16,
|
|
79
|
+
}
|
|
80
|
+
if self.engine_cache_dir is not None:
|
|
81
|
+
options["trt_engine_cache_enable"] = True
|
|
82
|
+
options["trt_engine_cache_path"] = str(self.engine_cache_dir)
|
|
83
|
+
return ("TensorrtExecutionProvider", options)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class CoreMLProvider(BaseConfig):
|
|
87
|
+
"""The CoreML execution provider (Apple, GPU/Neural Engine).
|
|
88
|
+
|
|
89
|
+
``ModelFormat=MLProgram`` is required for modern CoreML: the legacy
|
|
90
|
+
``NeuralNetwork`` format fragments the graph and silently falls back to CPU
|
|
91
|
+
for many ops.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
kind: Literal["coreml"] = Field(default="coreml", description="Discriminator tag for execution-provider type.")
|
|
95
|
+
model_format: Literal["MLProgram", "NeuralNetwork"] = Field(
|
|
96
|
+
default="MLProgram",
|
|
97
|
+
description="CoreML model format. MLProgram is required for modern op coverage.",
|
|
98
|
+
)
|
|
99
|
+
compute_units: Literal["CPUOnly", "CPUAndGPU", "CPUAndNeuralEngine", "ALL"] = Field(
|
|
100
|
+
default="ALL",
|
|
101
|
+
description="Which compute units CoreML may dispatch to. Prefer 'CPUAndGPU' for large transformer "
|
|
102
|
+
"graphs: allowing the Neural Engine ('ALL'/'CPUAndNeuralEngine') can make CoreML's ahead-of-time "
|
|
103
|
+
"compile run for many minutes for no inference win.",
|
|
104
|
+
)
|
|
105
|
+
cache_dir: Path | None = Field(
|
|
106
|
+
default=None,
|
|
107
|
+
description="Directory for CoreML's compiled-model cache (ORT 'ModelCacheDirectory'). CoreML compiles "
|
|
108
|
+
"the graph ahead of time on session build, which is slow; caching it cuts a warm rebuild from tens of "
|
|
109
|
+
"seconds to a few. The cache is ORT-version/OS/hardware-specific — a local speedup, not a portable "
|
|
110
|
+
"artifact. Requires 'MLProgram' format.",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def to_ort(self) -> OrtProvider:
|
|
114
|
+
"""Compile to an ONNX Runtime provider tuple.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
The ``CoreMLExecutionProvider`` with format, compute-unit and (optional) cache options.
|
|
118
|
+
"""
|
|
119
|
+
options: dict[str, object] = {"ModelFormat": self.model_format, "MLComputeUnits": self.compute_units}
|
|
120
|
+
if self.cache_dir is not None:
|
|
121
|
+
options["ModelCacheDirectory"] = str(self.cache_dir)
|
|
122
|
+
return ("CoreMLExecutionProvider", options)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
#: An execution-provider config, discriminated by its ``kind`` tag.
|
|
126
|
+
ExecutionProvider = Annotated[
|
|
127
|
+
CpuProvider | CudaProvider | TensorRTProvider | CoreMLProvider,
|
|
128
|
+
Field(discriminator="kind"),
|
|
129
|
+
]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class SessionOptionsConfig(BaseConfig):
|
|
133
|
+
"""A subset of ONNX Runtime ``SessionOptions`` exposed as typed config."""
|
|
134
|
+
|
|
135
|
+
graph_optimization_level: Literal["DISABLE_ALL", "ENABLE_BASIC", "ENABLE_EXTENDED", "ENABLE_ALL"] = Field(
|
|
136
|
+
default="ENABLE_ALL",
|
|
137
|
+
description="Graph optimization level applied when loading the model.",
|
|
138
|
+
)
|
|
139
|
+
intra_op_num_threads: int | None = Field(
|
|
140
|
+
default=None,
|
|
141
|
+
ge=0,
|
|
142
|
+
description="Threads used within a single operator. None uses the ONNX Runtime default.",
|
|
143
|
+
)
|
|
144
|
+
inter_op_num_threads: int | None = Field(
|
|
145
|
+
default=None,
|
|
146
|
+
ge=0,
|
|
147
|
+
description="Threads used across operators. None uses the ONNX Runtime default.",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
__all__ = [
|
|
152
|
+
"CoreMLProvider",
|
|
153
|
+
"CpuProvider",
|
|
154
|
+
"CudaProvider",
|
|
155
|
+
"ExecutionProvider",
|
|
156
|
+
"SessionOptionsConfig",
|
|
157
|
+
"TensorRTProvider",
|
|
158
|
+
]
|