openstef-foundation-models 4.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+ """Foundation model support for OpenSTEF."""
5
+
6
+ import logging
7
+
8
+ root_logger = logging.getLogger(name=__name__)
9
+ if not root_logger.handlers:
10
+ root_logger.addHandler(logging.NullHandler())
11
+
12
+ __all__: list[str] = []
@@ -0,0 +1,32 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Inference backends for foundation-model forecasters.
6
+
7
+ An :class:`InferenceBackend` isolates how a checkpoint is executed behind a
8
+ single named-tensor contract; forecasters compose a backend rather than
9
+ inheriting execution behaviour. Only the dependency-free surface (the protocol
10
+ and the execution-provider configs) is re-exported here; a concrete backend
11
+ lives in its own submodule and imports its heavy runtime at module top level.
12
+ """
13
+
14
+ from openstef_foundation_models.inference.backend import InferenceBackend
15
+ from openstef_foundation_models.inference.providers import (
16
+ CoreMLProvider,
17
+ CpuProvider,
18
+ CudaProvider,
19
+ ExecutionProvider,
20
+ SessionOptionsConfig,
21
+ TensorRTProvider,
22
+ )
23
+
24
+ __all__ = [
25
+ "CoreMLProvider",
26
+ "CpuProvider",
27
+ "CudaProvider",
28
+ "ExecutionProvider",
29
+ "InferenceBackend",
30
+ "SessionOptionsConfig",
31
+ "TensorRTProvider",
32
+ ]
@@ -0,0 +1,43 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """The :class:`InferenceBackend` contract shared by all execution backends."""
6
+
7
+ from collections.abc import Mapping
8
+ from typing import Protocol, runtime_checkable
9
+
10
+ import numpy as np
11
+
12
+ from openstef_foundation_models.models.checkpoint import CheckpointMetadata
13
+
14
+
15
+ @runtime_checkable
16
+ class InferenceBackend(Protocol):
17
+ """A model-agnostic execution backend.
18
+
19
+ A backend takes a mapping of named input tensors to a mapping of named
20
+ output tensors. It owns whatever runtime resources are needed (e.g. an ONNX
21
+ Runtime session) and is loaded once, then reused across an entire backtest.
22
+ Model-family specifics live in :attr:`metadata`, not in the backend itself.
23
+ """
24
+
25
+ @property
26
+ def metadata(self) -> CheckpointMetadata:
27
+ """Metadata describing the checkpoint this backend executes."""
28
+ ...
29
+
30
+ def run(self, inputs: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
31
+ """Execute the model on a batch of named input tensors.
32
+
33
+ Args:
34
+ inputs: Named input tensors. Keys must match ``metadata.input_names``.
35
+
36
+ Returns:
37
+ Named output tensors, including ``metadata.output_name``.
38
+ """
39
+ ...
40
+
41
+ def close(self) -> None:
42
+ """Release any runtime resources held by the backend."""
43
+ ...
@@ -0,0 +1,237 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """ONNX Runtime execution backend.
6
+
7
+ Importing this module requires ONNX Runtime (the ``[cpu]`` or ``[gpu]`` extra)
8
+ and raises :class:`MissingExtraError` if it is missing.
9
+ """
10
+
11
+ import logging
12
+ from collections.abc import Mapping, Sequence
13
+ from typing import Self
14
+
15
+ import numpy as np
16
+
17
+ from openstef_core.exceptions import MissingExtraError
18
+ from openstef_foundation_models.inference.provider_selection import (
19
+ DefaultProviderPolicy,
20
+ HostCapabilities,
21
+ ProviderPolicy,
22
+ )
23
+ from openstef_foundation_models.inference.providers import (
24
+ ExecutionProvider,
25
+ SessionOptionsConfig,
26
+ )
27
+ from openstef_foundation_models.models.checkpoint import CheckpointMetadata, ResolvedCheckpoint
28
+
29
+ try:
30
+ import onnxruntime as ort
31
+ except ImportError as e:
32
+ raise MissingExtraError("onnxruntime", "openstef-foundation-models", install_extra="cpu") from e
33
+
34
+ # onnxruntime-gpu ships the CUDA execution-provider plugin but loads the CUDA/cuDNN
35
+ # runtime (the nvidia-*-cu12 wheels the [gpu] extra pulls) lazily at session creation.
36
+ # preload_dlls() loads them from the nvidia site-packages so the CUDA provider can be
37
+ # realized without a system CUDA install or LD_LIBRARY_PATH. It is a no-op on the CPU
38
+ # runtime; the guard covers onnxruntime < 1.21, which predates the API.
39
+ if hasattr(ort, "preload_dlls"):
40
+ ort.preload_dlls()
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class OnnxBackend:
46
+ """An :class:`~openstef_foundation_models.inference.backend.InferenceBackend` backed by ONNX Runtime.
47
+
48
+ The session is built once on construction and reused for every
49
+ :meth:`run` call, so a single backend instance can be shared across an
50
+ entire backtest. Users may either let the backend build a session from a
51
+ resolved checkpoint and provider chain, or pass a pre-built session they own.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ metadata: CheckpointMetadata,
57
+ session: ort.InferenceSession,
58
+ ) -> None:
59
+ """Wrap a pre-built ONNX Runtime session.
60
+
61
+ Prefer :meth:`from_checkpoint` unless you need to own the session
62
+ lifecycle yourself.
63
+
64
+ Args:
65
+ metadata: Metadata describing the checkpoint the session executes.
66
+ session: A pre-built ONNX Runtime inference session.
67
+ """
68
+ self._metadata = metadata
69
+ self._session: ort.InferenceSession | None = session
70
+
71
+ @classmethod
72
+ def from_checkpoint(
73
+ cls,
74
+ checkpoint: ResolvedCheckpoint,
75
+ providers: Sequence[ExecutionProvider] | None = None,
76
+ session_options: SessionOptionsConfig | None = None,
77
+ *,
78
+ policy: ProviderPolicy | None = None,
79
+ ) -> Self:
80
+ """Build a backend by loading a checkpoint into a new ONNX Runtime session.
81
+
82
+ With ``providers=None`` the *policy* selects a chain from the checkpoint
83
+ and host; an explicit ``providers`` list is used as given and *policy* is
84
+ ignored. See :class:`~openstef_foundation_models.inference.provider_selection.ProviderPolicy`
85
+ for how a chain is chosen and how strictly its realization is enforced.
86
+
87
+ Args:
88
+ checkpoint: The resolved checkpoint (weights + metadata) to load.
89
+ providers: Ordered execution providers to try. ``None`` lets the policy
90
+ pick a host-appropriate chain from the checkpoint metadata.
91
+ session_options: Optional ONNX Runtime session options.
92
+ policy: Selection policy used when ``providers is None``. Defaults to
93
+ :class:`DefaultProviderPolicy`.
94
+
95
+ Returns:
96
+ A backend wrapping the newly built session.
97
+ """
98
+ metadata = checkpoint.metadata
99
+ if providers is not None:
100
+ provider_configs = list(providers)
101
+ explicit = True
102
+ logger.info(
103
+ "Using explicit execution-provider chain %s for checkpoint '%s'.",
104
+ [config.to_ort()[0] for config in provider_configs],
105
+ metadata.model_family,
106
+ )
107
+ else:
108
+ selector = policy or DefaultProviderPolicy()
109
+ host = HostCapabilities.detect()
110
+ provider_configs = selector.select(metadata, host)
111
+ explicit = False
112
+ logger.debug(
113
+ "Detected host: platform=%s, available_providers=%s",
114
+ host.platform,
115
+ sorted(host.available_providers),
116
+ )
117
+ logger.info(
118
+ "%s selected execution-provider chain %s for checkpoint '%s' (precision=%s, static_shapes=%s) on %s.",
119
+ type(selector).__name__,
120
+ [config.to_ort()[0] for config in provider_configs],
121
+ metadata.model_family,
122
+ metadata.precision,
123
+ metadata.static_shapes,
124
+ host.platform,
125
+ )
126
+ ort_providers = [config.to_ort() for config in provider_configs]
127
+ so = _build_session_options(session_options) if session_options else None
128
+
129
+ session = ort.InferenceSession(
130
+ str(checkpoint.weights_path),
131
+ sess_options=so,
132
+ providers=ort_providers,
133
+ )
134
+ logger.info("ONNX Runtime session built; realized providers: %s.", session.get_providers())
135
+ _check_provider_fallback(
136
+ requested=provider_configs,
137
+ realized=session.get_providers(),
138
+ strict=explicit,
139
+ )
140
+ return cls(metadata=checkpoint.metadata, session=session)
141
+
142
+ @property
143
+ def metadata(self) -> CheckpointMetadata:
144
+ """Metadata describing the checkpoint this backend executes."""
145
+ return self._metadata
146
+
147
+ def run(self, inputs: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
148
+ """Execute the ONNX model on a batch of named input tensors.
149
+
150
+ Args:
151
+ inputs: Named input tensors. Keys must match ``metadata.input_names``.
152
+
153
+ Returns:
154
+ Named output tensors keyed by the model's output names.
155
+
156
+ Raises:
157
+ RuntimeError: If the backend has been closed.
158
+ """
159
+ if self._session is None:
160
+ msg = "OnnxBackend has been closed."
161
+ raise RuntimeError(msg)
162
+ output_names = [out.name for out in self._session.get_outputs()]
163
+ results = self._session.run(output_names, dict(inputs))
164
+ return {name: np.asarray(result) for name, result in zip(output_names, results, strict=True)}
165
+
166
+ def close(self) -> None:
167
+ """Release the underlying ONNX Runtime session.
168
+
169
+ ONNX Runtime frees native resources on garbage collection, so dropping
170
+ the reference is the supported way to release them.
171
+ """
172
+ self._session = None
173
+
174
+
175
+ def _build_session_options(config: SessionOptionsConfig) -> ort.SessionOptions:
176
+ """Translate a :class:`SessionOptionsConfig` into ONNX Runtime session options.
177
+
178
+ Args:
179
+ config: The typed session-options configuration.
180
+
181
+ Returns:
182
+ The corresponding ONNX Runtime ``SessionOptions``.
183
+ """
184
+ so = ort.SessionOptions()
185
+ so.graph_optimization_level = getattr(
186
+ ort.GraphOptimizationLevel,
187
+ f"ORT_{config.graph_optimization_level}",
188
+ )
189
+ if config.intra_op_num_threads is not None:
190
+ so.intra_op_num_threads = config.intra_op_num_threads
191
+ if config.inter_op_num_threads is not None:
192
+ so.inter_op_num_threads = config.inter_op_num_threads
193
+ return so
194
+
195
+
196
+ def _check_provider_fallback(
197
+ requested: Sequence[ExecutionProvider],
198
+ realized: Sequence[str],
199
+ *,
200
+ strict: bool,
201
+ ) -> None:
202
+ """Detect and report a silent fallback to the CPU execution provider.
203
+
204
+ Compares the requested chain against what ONNX Runtime actually realized. See
205
+ :class:`~openstef_foundation_models.inference.provider_selection.ProviderPolicy`
206
+ for the strict-vs-graceful contract this enforces.
207
+
208
+ Args:
209
+ requested: The execution providers that were requested.
210
+ realized: The provider names ONNX Runtime actually loaded.
211
+ strict: When ``True``, raise on any missing accelerator; otherwise warn
212
+ only on a full fallback to CPU.
213
+
214
+ Raises:
215
+ RuntimeError: If ``strict`` is set and a requested accelerator is missing.
216
+ """
217
+ requested_names = {config.to_ort()[0] for config in requested}
218
+ accelerators = requested_names - {"CPUExecutionProvider"}
219
+ if not accelerators:
220
+ return
221
+ realized_set = set(realized)
222
+ missing = accelerators - realized_set
223
+ if strict:
224
+ if missing:
225
+ msg = (
226
+ f"Requested execution provider(s) {sorted(missing)} were not realized; "
227
+ f"ONNX Runtime fell back to {realized}."
228
+ )
229
+ raise RuntimeError(msg)
230
+ return
231
+ if accelerators & realized_set:
232
+ return
233
+ logger.warning(
234
+ "No requested accelerator (%s) was realized; ONNX Runtime fell back to %s. Inference will run on CPU.",
235
+ sorted(accelerators),
236
+ realized,
237
+ )
@@ -0,0 +1,142 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Metadata-driven execution-provider selection.
6
+
7
+ Which ONNX Runtime execution provider is fastest — and even which is *usable* —
8
+ depends on both the host (Apple CoreML vs NVIDIA CUDA/TensorRT vs CPU) and on
9
+ properties of the checkpoint itself (precision, whether its graph has static
10
+ shapes). This module keeps that knowledge in **one replaceable component** rather
11
+ than scattered platform ``if``-ladders:
12
+
13
+ * :class:`HostCapabilities` carries the host facts as injectable data, with a
14
+ single impure :meth:`HostCapabilities.detect` classmethod.
15
+ * :class:`ProviderPolicy` is the port; :class:`DefaultProviderPolicy` is the
16
+ adapter that maps ``(checkpoint, host)`` to an ordered provider chain. Users
17
+ with exotic hardware implement their own policy.
18
+
19
+ Importing this module requires ONNX Runtime (the ``[cpu]`` or ``[gpu]`` extra)
20
+ and raises :class:`MissingExtraError` if it is missing.
21
+ """
22
+
23
+ import platform
24
+ from typing import Literal, Protocol, Self
25
+
26
+ from pydantic import Field
27
+
28
+ from openstef_core.base_model import BaseConfig
29
+ from openstef_core.exceptions import MissingExtraError
30
+ from openstef_foundation_models.inference.providers import (
31
+ CoreMLProvider,
32
+ CpuProvider,
33
+ CudaProvider,
34
+ ExecutionProvider,
35
+ )
36
+ from openstef_foundation_models.models.checkpoint import CheckpointMetadata
37
+
38
+ try:
39
+ import onnxruntime as ort
40
+ except ImportError as e:
41
+ raise MissingExtraError("onnxruntime", "openstef-foundation-models", install_extra="cpu") from e
42
+
43
+
44
+ class HostCapabilities(BaseConfig):
45
+ """Execution-relevant facts about the host, captured as injectable data.
46
+
47
+ Passing host facts into a policy (rather than having the policy call
48
+ ``platform.system()`` itself) keeps selection a pure function of its inputs,
49
+ so it can be unit-tested by constructing a fake host.
50
+ """
51
+
52
+ model_config = BaseConfig.model_config | {"frozen": True}
53
+
54
+ platform: str = Field(
55
+ description="OS identifier, lower-cased (e.g. 'darwin', 'linux', 'windows').",
56
+ )
57
+ available_providers: frozenset[str] = Field(
58
+ description="Execution provider names ONNX Runtime reports as available on this host.",
59
+ )
60
+
61
+ @classmethod
62
+ def detect(cls) -> Self:
63
+ """Detect the host's capabilities from the platform and ONNX Runtime.
64
+
65
+ This is the one impure call in the selection path; it is isolated here so
66
+ the policy stays a pure function of injected facts.
67
+
68
+ Returns:
69
+ The detected host capabilities.
70
+ """
71
+ return cls(
72
+ platform=platform.system().lower(),
73
+ available_providers=frozenset(ort.get_available_providers()),
74
+ )
75
+
76
+
77
+ class ProviderPolicy(Protocol):
78
+ """Port mapping a checkpoint and host to an ordered execution-provider chain.
79
+
80
+ Implement this to encode selection rules for hardware the default policy does
81
+ not cover; pass the implementation to the backend or
82
+ :class:`~openstef_foundation_models.presets.forecasting_workflow.OnnxBackendConfig`.
83
+
84
+ A policy-selected chain is enforced *gracefully*: ONNX Runtime silently drops
85
+ accelerators it cannot initialize and falls back to CPU, and a policy chain
86
+ such as ``[CoreML, CPU]`` realizing CoreML is the intended outcome, so a
87
+ warning is logged only if it falls all the way to CPU. A chain the caller
88
+ passes explicitly is enforced *strictly* instead: any requested accelerator
89
+ that is not realized raises.
90
+ """
91
+
92
+ def select(self, metadata: CheckpointMetadata, host: HostCapabilities) -> list[ExecutionProvider]:
93
+ """Return the ordered provider chain to try for *metadata* on *host*."""
94
+ ...
95
+
96
+
97
+ class DefaultProviderPolicy(BaseConfig):
98
+ """Default policy mapping ``(checkpoint precision/shape, host)`` to a provider chain.
99
+
100
+ Each rule encodes a *measured* hardware conclusion; see the design doc
101
+ ``design-docs/0001`` and the provider benchmark for the rationale. The chain
102
+ is ordered preferred-first with CPU as the final fallback.
103
+ """
104
+
105
+ kind: Literal["default"] = Field(default="default", description="Discriminator tag for the policy type.")
106
+
107
+ def select( # instance method to satisfy the ProviderPolicy protocol, though stateless here
108
+ self, metadata: CheckpointMetadata, host: HostCapabilities
109
+ ) -> list[ExecutionProvider]:
110
+ """Select an ordered provider chain for *metadata* on *host*.
111
+
112
+ Args:
113
+ metadata: The checkpoint's metadata (precision, static-shape-ness).
114
+ host: The detected host capabilities.
115
+
116
+ Returns:
117
+ An ordered execution-provider chain, preferred-first, CPU last.
118
+ """
119
+ cuda_ok = "CUDAExecutionProvider" in host.available_providers
120
+ coreml_ok = "CoreMLExecutionProvider" in host.available_providers and metadata.static_shapes
121
+
122
+ # int8 (QDQ) runs fast on CPU; CoreML cannot accelerate the quantized ops,
123
+ # so it is skipped entirely. CUDA int8 is fine when a GPU is present.
124
+ if metadata.precision == "int8":
125
+ return [CudaProvider(), CpuProvider()] if cuda_ok else [CpuProvider()]
126
+ # macOS: a static-shape fp16/fp32 graph runs on CoreML, but only on the GPU
127
+ # (MLComputeUnits=ALL/ANE triggers a multi-minute Neural-Engine compile for
128
+ # no inference win — measured).
129
+ if host.platform == "darwin" and coreml_ok:
130
+ return [CoreMLProvider(compute_units="CPUAndGPU"), CpuProvider()]
131
+ # NVIDIA: CUDA with a CPU fallback. TensorRT stays opt-in (engine-build cost
132
+ # and fp16 caveats), so the default never selects it.
133
+ if cuda_ok:
134
+ return [CudaProvider(), CpuProvider()]
135
+ return [CpuProvider()]
136
+
137
+
138
+ __all__ = [
139
+ "DefaultProviderPolicy",
140
+ "HostCapabilities",
141
+ "ProviderPolicy",
142
+ ]
@@ -0,0 +1,158 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Typed ONNX Runtime execution-provider configuration.
6
+
7
+ Each provider is a small pydantic config that compiles to an ONNX Runtime
8
+ ``(name, options)`` tuple via :meth:`ExecutionProviderConfig.to_ort`. Keeping
9
+ providers as typed configs (rather than raw strings) lets users opt into
10
+ hardware acceleration — CUDA, TensorRT FP16, CoreML/ANE — without touching
11
+ model code, and keeps the options validated and discoverable.
12
+ """
13
+
14
+ from pathlib import Path
15
+ from typing import Annotated, Literal
16
+
17
+ from pydantic import Field
18
+
19
+ from openstef_core.base_model import BaseConfig
20
+
21
+ #: An ONNX Runtime provider specification: ``(provider_name, provider_options)``.
22
+ type OrtProvider = tuple[str, dict[str, object]]
23
+
24
+
25
+ class CpuProvider(BaseConfig):
26
+ """The default CPU execution provider."""
27
+
28
+ kind: Literal["cpu"] = Field(default="cpu", description="Discriminator tag for execution-provider type.")
29
+
30
+ def to_ort(self) -> OrtProvider:
31
+ """Compile to an ONNX Runtime provider tuple.
32
+
33
+ Returns:
34
+ The ``CPUExecutionProvider`` with no options.
35
+ """
36
+ return ("CPUExecutionProvider", {})
37
+
38
+
39
+ class CudaProvider(BaseConfig):
40
+ """The CUDA (NVIDIA GPU) execution provider."""
41
+
42
+ kind: Literal["cuda"] = Field(default="cuda", description="Discriminator tag for execution-provider type.")
43
+ device_id: int = Field(default=0, ge=0, description="CUDA device index to run on.")
44
+
45
+ def to_ort(self) -> OrtProvider:
46
+ """Compile to an ONNX Runtime provider tuple.
47
+
48
+ Returns:
49
+ The ``CUDAExecutionProvider`` with the configured device id.
50
+ """
51
+ return ("CUDAExecutionProvider", {"device_id": self.device_id})
52
+
53
+
54
+ class TensorRTProvider(BaseConfig):
55
+ """The TensorRT execution provider (NVIDIA, ahead-of-time engine build).
56
+
57
+ FP16 with a persistent engine cache is the recommended production path on
58
+ NVIDIA hardware: the first run pays the engine-build cost, subsequent runs
59
+ load the cached engine.
60
+ """
61
+
62
+ kind: Literal["tensorrt"] = Field(default="tensorrt", description="Discriminator tag for execution-provider type.")
63
+ device_id: int = Field(default=0, ge=0, description="CUDA device index to run on.")
64
+ fp16: bool = Field(default=True, description="Enable FP16 precision for faster inference.")
65
+ engine_cache_dir: Path | None = Field(
66
+ default=None,
67
+ description="Directory to persist built TensorRT engines. When set, engine caching is enabled.",
68
+ )
69
+
70
+ def to_ort(self) -> OrtProvider:
71
+ """Compile to an ONNX Runtime provider tuple.
72
+
73
+ Returns:
74
+ The ``TensorrtExecutionProvider`` with precision and engine-cache options.
75
+ """
76
+ options: dict[str, object] = {
77
+ "device_id": self.device_id,
78
+ "trt_fp16_enable": self.fp16,
79
+ }
80
+ if self.engine_cache_dir is not None:
81
+ options["trt_engine_cache_enable"] = True
82
+ options["trt_engine_cache_path"] = str(self.engine_cache_dir)
83
+ return ("TensorrtExecutionProvider", options)
84
+
85
+
86
+ class CoreMLProvider(BaseConfig):
87
+ """The CoreML execution provider (Apple, GPU/Neural Engine).
88
+
89
+ ``ModelFormat=MLProgram`` is required for modern CoreML: the legacy
90
+ ``NeuralNetwork`` format fragments the graph and silently falls back to CPU
91
+ for many ops.
92
+ """
93
+
94
+ kind: Literal["coreml"] = Field(default="coreml", description="Discriminator tag for execution-provider type.")
95
+ model_format: Literal["MLProgram", "NeuralNetwork"] = Field(
96
+ default="MLProgram",
97
+ description="CoreML model format. MLProgram is required for modern op coverage.",
98
+ )
99
+ compute_units: Literal["CPUOnly", "CPUAndGPU", "CPUAndNeuralEngine", "ALL"] = Field(
100
+ default="ALL",
101
+ description="Which compute units CoreML may dispatch to. Prefer 'CPUAndGPU' for large transformer "
102
+ "graphs: allowing the Neural Engine ('ALL'/'CPUAndNeuralEngine') can make CoreML's ahead-of-time "
103
+ "compile run for many minutes for no inference win.",
104
+ )
105
+ cache_dir: Path | None = Field(
106
+ default=None,
107
+ description="Directory for CoreML's compiled-model cache (ORT 'ModelCacheDirectory'). CoreML compiles "
108
+ "the graph ahead of time on session build, which is slow; caching it cuts a warm rebuild from tens of "
109
+ "seconds to a few. The cache is ORT-version/OS/hardware-specific — a local speedup, not a portable "
110
+ "artifact. Requires 'MLProgram' format.",
111
+ )
112
+
113
+ def to_ort(self) -> OrtProvider:
114
+ """Compile to an ONNX Runtime provider tuple.
115
+
116
+ Returns:
117
+ The ``CoreMLExecutionProvider`` with format, compute-unit and (optional) cache options.
118
+ """
119
+ options: dict[str, object] = {"ModelFormat": self.model_format, "MLComputeUnits": self.compute_units}
120
+ if self.cache_dir is not None:
121
+ options["ModelCacheDirectory"] = str(self.cache_dir)
122
+ return ("CoreMLExecutionProvider", options)
123
+
124
+
125
+ #: An execution-provider config, discriminated by its ``kind`` tag.
126
+ ExecutionProvider = Annotated[
127
+ CpuProvider | CudaProvider | TensorRTProvider | CoreMLProvider,
128
+ Field(discriminator="kind"),
129
+ ]
130
+
131
+
132
+ class SessionOptionsConfig(BaseConfig):
133
+ """A subset of ONNX Runtime ``SessionOptions`` exposed as typed config."""
134
+
135
+ graph_optimization_level: Literal["DISABLE_ALL", "ENABLE_BASIC", "ENABLE_EXTENDED", "ENABLE_ALL"] = Field(
136
+ default="ENABLE_ALL",
137
+ description="Graph optimization level applied when loading the model.",
138
+ )
139
+ intra_op_num_threads: int | None = Field(
140
+ default=None,
141
+ ge=0,
142
+ description="Threads used within a single operator. None uses the ONNX Runtime default.",
143
+ )
144
+ inter_op_num_threads: int | None = Field(
145
+ default=None,
146
+ ge=0,
147
+ description="Threads used across operators. None uses the ONNX Runtime default.",
148
+ )
149
+
150
+
151
+ __all__ = [
152
+ "CoreMLProvider",
153
+ "CpuProvider",
154
+ "CudaProvider",
155
+ "ExecutionProvider",
156
+ "SessionOptionsConfig",
157
+ "TensorRTProvider",
158
+ ]
@@ -0,0 +1,5 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Integrations with neighbouring OpenSTEF frameworks, one module per framework."""