openstef-foundation-models 4.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. openstef_foundation_models-4.1.0/.gitignore +147 -0
  2. openstef_foundation_models-4.1.0/PKG-INFO +89 -0
  3. openstef_foundation_models-4.1.0/README.md +56 -0
  4. openstef_foundation_models-4.1.0/pyproject.toml +75 -0
  5. openstef_foundation_models-4.1.0/src/openstef_foundation_models/__init__.py +12 -0
  6. openstef_foundation_models-4.1.0/src/openstef_foundation_models/inference/__init__.py +32 -0
  7. openstef_foundation_models-4.1.0/src/openstef_foundation_models/inference/backend.py +43 -0
  8. openstef_foundation_models-4.1.0/src/openstef_foundation_models/inference/onnx_backend.py +237 -0
  9. openstef_foundation_models-4.1.0/src/openstef_foundation_models/inference/provider_selection.py +142 -0
  10. openstef_foundation_models-4.1.0/src/openstef_foundation_models/inference/providers.py +158 -0
  11. openstef_foundation_models-4.1.0/src/openstef_foundation_models/integrations/__init__.py +5 -0
  12. openstef_foundation_models-4.1.0/src/openstef_foundation_models/integrations/beam.py +203 -0
  13. openstef_foundation_models-4.1.0/src/openstef_foundation_models/models/__init__.py +17 -0
  14. openstef_foundation_models-4.1.0/src/openstef_foundation_models/models/catalog.py +105 -0
  15. openstef_foundation_models-4.1.0/src/openstef_foundation_models/models/checkpoint.py +212 -0
  16. openstef_foundation_models-4.1.0/src/openstef_foundation_models/models/forecasting/__init__.py +21 -0
  17. openstef_foundation_models-4.1.0/src/openstef_foundation_models/models/forecasting/chronos2_forecaster.py +283 -0
  18. openstef_foundation_models-4.1.0/src/openstef_foundation_models/presets/__init__.py +23 -0
  19. openstef_foundation_models-4.1.0/src/openstef_foundation_models/presets/forecasting_workflow.py +205 -0
  20. openstef_foundation_models-4.1.0/tests/__init__.py +3 -0
  21. openstef_foundation_models-4.1.0/tests/integration/__init__.py +3 -0
  22. openstef_foundation_models-4.1.0/tests/integration/conftest.py +60 -0
  23. openstef_foundation_models-4.1.0/tests/integration/test_beam_integration.py +77 -0
  24. openstef_foundation_models-4.1.0/tests/integration/test_chronos2_onnx.py +218 -0
  25. openstef_foundation_models-4.1.0/tests/unit/__init__.py +3 -0
  26. openstef_foundation_models-4.1.0/tests/unit/inference/__init__.py +3 -0
  27. openstef_foundation_models-4.1.0/tests/unit/inference/test_onnx_backend.py +105 -0
  28. openstef_foundation_models-4.1.0/tests/unit/inference/test_provider_selection.py +98 -0
  29. openstef_foundation_models-4.1.0/tests/unit/inference/test_providers.py +49 -0
  30. openstef_foundation_models-4.1.0/tests/unit/integrations/__init__.py +3 -0
  31. openstef_foundation_models-4.1.0/tests/unit/integrations/test_beam.py +442 -0
  32. openstef_foundation_models-4.1.0/tests/unit/models/__init__.py +3 -0
  33. openstef_foundation_models-4.1.0/tests/unit/models/forecasting/__init__.py +3 -0
  34. openstef_foundation_models-4.1.0/tests/unit/models/forecasting/test_chronos2_forecaster.py +521 -0
  35. openstef_foundation_models-4.1.0/tests/unit/models/test_catalog.py +38 -0
  36. openstef_foundation_models-4.1.0/tests/unit/presets/__init__.py +3 -0
  37. openstef_foundation_models-4.1.0/tests/unit/presets/test_forecasting_workflow.py +196 -0
@@ -0,0 +1,147 @@
1
+ # SPDX-FileCopyrightText: 2017-2025 Contributors to the OpenSTEF project <openstef@lfenergy.org> # noqa E501>
2
+ # SPDX-License-Identifier: MPL-2.0
3
+
4
+ # Core
5
+ config.user.yaml
6
+ git-template
7
+ .DS_Store
8
+ tmp
9
+
10
+ # Python bytecode
11
+ __pycache__/
12
+ *.py[cod]
13
+ *$py.class
14
+
15
+ # C extensions
16
+ *.so
17
+
18
+ # Packaging and build artifacts
19
+ .Python
20
+ build/
21
+ dist/
22
+ downloads/
23
+ wheels/
24
+ share/python-wheels/
25
+ sdist/
26
+ .eggs/
27
+ *.egg-info/
28
+ *.egg
29
+ .installed.cfg
30
+ MANIFEST
31
+
32
+ # uv
33
+ .uv/
34
+
35
+ # Ruff
36
+ .ruff_cache/
37
+
38
+ # Test, coverage, tox
39
+ .pytest_cache/
40
+ .coverage
41
+ .coverage.*
42
+ htmlcov/
43
+ .cover/
44
+ cover/
45
+ pytest-report.xml
46
+ .tox/
47
+ .nox/
48
+ coverage.xml
49
+
50
+ # Hypothesis
51
+ .hypothesis/
52
+
53
+ # Cython
54
+ cython_debug/
55
+
56
+ # Type checkers (misc)
57
+ .mypy_cache/
58
+ .dmypy.json
59
+ dmypy.json
60
+ .pytype/
61
+ .pyre/
62
+
63
+ # Sphinx
64
+ docs/_build/
65
+ docs/source/api/generated/
66
+ docs/source/tutorials/
67
+ docs/source/benchmarks/
68
+ # Community health files materialized from OpenSTEF/.github at build time
69
+ docs/source/contribute/_community/
70
+ docs/source/user_guide/**/quick_start_tutorial.py
71
+ docs/source/user_guide/**/feature_engineering_tutorial.py
72
+ docs/source/user_guide/**/datasets_tutorial.py
73
+ docs/source/user_guide/**/backtesting_tutorial.py
74
+
75
+ # docs/_doctrees/
76
+ # docs/_static_gen/
77
+
78
+ # Editors: JetBrains
79
+ .idea/
80
+
81
+ # Editors: VS Code (allow shared configs)
82
+ .vscode/*
83
+
84
+ # Jupyter / IPython
85
+ .ipynb_checkpoints
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # Spyder / Rope
90
+ .spyderproject
91
+ .spyproject
92
+ .ropeproject
93
+
94
+ # Environments
95
+ .env
96
+ .venv
97
+ env/
98
+ venv/
99
+ ENV/
100
+ env.bak/
101
+ venv.bak/
102
+
103
+ # PEP 582
104
+ __pypackages__/
105
+
106
+ # web frameworks / services
107
+ *.sqlite
108
+ *.sqlite3
109
+ *.log
110
+ instance/
111
+ .webassets-cache
112
+ celerybeat-schedule
113
+ celerybeat.pid
114
+ *.sage.py
115
+ tmp/
116
+
117
+ # PyInstaller
118
+ *.manifest
119
+ *.spec
120
+
121
+ # Project outputs
122
+ output/
123
+ prof/
124
+ certificates/
125
+
126
+ # Output artifacts
127
+ *.html
128
+ *.pkl
129
+
130
+ # Benchmark outputs
131
+ benchmark_results*/
132
+
133
+ # Local dataset files
134
+ liander_dataset/
135
+
136
+ # Deployment example run artifacts (MLflow store, forecasts, dataset, Celery/Airflow state)
137
+ openstef_deployment_runs/
138
+
139
+ # Mlflow
140
+ /mlflow
141
+ /mlflow_artifacts_local
142
+
143
+ .github/instructions
144
+
145
+ # Jupyter notebook cache (myst-nb execution outputs)
146
+ .jupyter_cache/
147
+ docs/build.zip
@@ -0,0 +1,89 @@
1
+ Metadata-Version: 2.4
2
+ Name: openstef-foundation-models
3
+ Version: 4.1.0
4
+ Summary: Foundation model support for OpenSTEF
5
+ Project-URL: Documentation, https://openstef.github.io/openstef/index.html
6
+ Project-URL: Homepage, https://lfenergy.org/projects/openstef/
7
+ Project-URL: Issues, https://github.com/OpenSTEF/openstef/issues
8
+ Project-URL: Repository, https://github.com/OpenSTEF/openstef
9
+ Author-email: "Alliander N.V" <openstef@lfenergy.org>
10
+ License-Expression: MPL-2.0
11
+ Keywords: energy,forecasting,foundationmodels,machinelearning,onnx
12
+ Classifier: Development Status :: 3 - Alpha
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Programming Language :: Python :: 3 :: Only
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Classifier: Programming Language :: Python :: 3.14
18
+ Requires-Python: <4.0,>=3.12
19
+ Requires-Dist: huggingface-hub>=1.2.2
20
+ Requires-Dist: openstef-beam<5,>=4
21
+ Requires-Dist: openstef-core<5,>=4
22
+ Requires-Dist: openstef-models<5,>=4
23
+ Provides-Extra: cpu
24
+ Requires-Dist: onnxruntime>=1.20; extra == 'cpu'
25
+ Provides-Extra: gpu
26
+ Requires-Dist: nvidia-cublas-cu12<13,>=12; extra == 'gpu'
27
+ Requires-Dist: nvidia-cuda-runtime-cu12<13,>=12; extra == 'gpu'
28
+ Requires-Dist: nvidia-cudnn-cu12<10,>=9; extra == 'gpu'
29
+ Requires-Dist: nvidia-cufft-cu12<12,>=11; extra == 'gpu'
30
+ Requires-Dist: nvidia-curand-cu12<11,>=10; extra == 'gpu'
31
+ Requires-Dist: onnxruntime-gpu>=1.21; extra == 'gpu'
32
+ Description-Content-Type: text/markdown
33
+
34
+ <!--
35
+ SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
36
+
37
+ SPDX-License-Identifier: MPL-2.0
38
+ -->
39
+
40
+ # openstef-foundation-models
41
+
42
+ Foundation model support for OpenSTEF — bringing pre-trained, ONNX-based forecasting models to the OpenSTEF ecosystem.
43
+
44
+ ## Installation
45
+
46
+ Pick exactly one ONNX runtime — `[cpu]` and `[gpu]` are mutually exclusive.
47
+
48
+ CPU (default — the meta-package `openstef` installs this for you):
49
+
50
+ ```bash
51
+ pip install "openstef-foundation-models[cpu]"
52
+ ```
53
+
54
+ GPU (CUDA):
55
+
56
+ ```bash
57
+ pip install "openstef-foundation-models[gpu]"
58
+ ```
59
+
60
+ > **Note:** Do **not** install both `[cpu]` and `[gpu]` in the same environment —
61
+ > `onnxruntime` and `onnxruntime-gpu` collide. They're declared as conflicting
62
+ > extras so uv enforces the choice; `pip` does not, so pick one yourself.
63
+
64
+ ## Selecting a checkpoint
65
+
66
+ OpenSTEF publishes its checkpoints to the HuggingFace Hub. Pick a model size and
67
+ variant from the catalog instead of hand-writing repo ids and filenames:
68
+
69
+ ```python
70
+ from openstef_foundation_models.models import Chronos2, CheckpointVariant
71
+ from openstef_foundation_models.presets import ForecastingWorkflowConfig
72
+
73
+ # Default: the base Chronos-2, dynamic shapes — runs on any provider.
74
+ config = ForecastingWorkflowConfig()
75
+
76
+ # The smaller model, static shapes — enables the CoreML provider on macOS.
77
+ config = ForecastingWorkflowConfig(
78
+ checkpoint=Chronos2.SMALL.checkpoint(CheckpointVariant.STATIC),
79
+ )
80
+
81
+ # Let the host decide: static on macOS, dynamic elsewhere.
82
+ config = ForecastingWorkflowConfig(
83
+ checkpoint=Chronos2.BASE.checkpoint(CheckpointVariant.recommended()),
84
+ )
85
+ ```
86
+
87
+ Available sizes are `Chronos2.BASE` (`chronos-2`) and `Chronos2.SMALL`
88
+ (`chronos-2-small`); variants are `DYNAMIC` (portable) and `STATIC` (frozen shapes,
89
+ the macOS/CoreML path in the default provider fallback chain).
@@ -0,0 +1,56 @@
1
+ <!--
2
+ SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
3
+
4
+ SPDX-License-Identifier: MPL-2.0
5
+ -->
6
+
7
+ # openstef-foundation-models
8
+
9
+ Foundation model support for OpenSTEF — bringing pre-trained, ONNX-based forecasting models to the OpenSTEF ecosystem.
10
+
11
+ ## Installation
12
+
13
+ Pick exactly one ONNX runtime — `[cpu]` and `[gpu]` are mutually exclusive.
14
+
15
+ CPU (default — the meta-package `openstef` installs this for you):
16
+
17
+ ```bash
18
+ pip install "openstef-foundation-models[cpu]"
19
+ ```
20
+
21
+ GPU (CUDA):
22
+
23
+ ```bash
24
+ pip install "openstef-foundation-models[gpu]"
25
+ ```
26
+
27
+ > **Note:** Do **not** install both `[cpu]` and `[gpu]` in the same environment —
28
+ > `onnxruntime` and `onnxruntime-gpu` collide. They're declared as conflicting
29
+ > extras so uv enforces the choice; `pip` does not, so pick one yourself.
30
+
31
+ ## Selecting a checkpoint
32
+
33
+ OpenSTEF publishes its checkpoints to the HuggingFace Hub. Pick a model size and
34
+ variant from the catalog instead of hand-writing repo ids and filenames:
35
+
36
+ ```python
37
+ from openstef_foundation_models.models import Chronos2, CheckpointVariant
38
+ from openstef_foundation_models.presets import ForecastingWorkflowConfig
39
+
40
+ # Default: the base Chronos-2, dynamic shapes — runs on any provider.
41
+ config = ForecastingWorkflowConfig()
42
+
43
+ # The smaller model, static shapes — enables the CoreML provider on macOS.
44
+ config = ForecastingWorkflowConfig(
45
+ checkpoint=Chronos2.SMALL.checkpoint(CheckpointVariant.STATIC),
46
+ )
47
+
48
+ # Let the host decide: static on macOS, dynamic elsewhere.
49
+ config = ForecastingWorkflowConfig(
50
+ checkpoint=Chronos2.BASE.checkpoint(CheckpointVariant.recommended()),
51
+ )
52
+ ```
53
+
54
+ Available sizes are `Chronos2.BASE` (`chronos-2`) and `Chronos2.SMALL`
55
+ (`chronos-2-small`); variants are `DYNAMIC` (portable) and `STATIC` (frozen shapes,
56
+ the macOS/CoreML path in the default provider fallback chain).
@@ -0,0 +1,75 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+ [build-system]
5
+ build-backend = "hatchling.build"
6
+ requires = [ "hatchling" ]
7
+
8
+ [project]
9
+ name = "openstef-foundation-models"
10
+ version = "4.1.0"
11
+ description = "Foundation model support for OpenSTEF"
12
+ readme = "README.md"
13
+ keywords = [ "energy", "forecasting", "foundationmodels", "machinelearning", "onnx" ]
14
+ license = "MPL-2.0"
15
+ authors = [
16
+ { name = "Alliander N.V", email = "openstef@lfenergy.org" },
17
+ ]
18
+ requires-python = ">=3.12,<4.0"
19
+ classifiers = [
20
+ "Development Status :: 3 - Alpha",
21
+ "Intended Audience :: Developers",
22
+ "Programming Language :: Python :: 3 :: Only",
23
+ "Programming Language :: Python :: 3.12",
24
+ "Programming Language :: Python :: 3.13",
25
+ "Programming Language :: Python :: 3.14",
26
+ ]
27
+ dependencies = [
28
+ # huggingface-hub resolves published checkpoints from the Hub — the default
29
+ # checkpoint source — and is pure Python, so it is a base dependency rather
30
+ # than an extra. openstef-beam (the backtesting integration target) is already
31
+ # a hard transitive dependency via openstef-models, so it is declared directly.
32
+ "huggingface-hub>=1.2.2",
33
+ "openstef-beam>=4,<5",
34
+ "openstef-core>=4,<5",
35
+ "openstef-models>=4,<5",
36
+ ]
37
+ # Exactly one of [cpu] or [gpu] must be installed — they are declared as
38
+ # conflicting extras (see [tool.uv]) so uv refuses to resolve both at once.
39
+ # The root `openstef` package (and the dev group) depend on
40
+ # `openstef-foundation-models[cpu]`, so the default install gets the CPU
41
+ # runtime without the caller having to choose.
42
+ optional-dependencies.cpu = [
43
+ "onnxruntime>=1.20",
44
+ ]
45
+ # onnxruntime-gpu ships only the CUDA execution-provider plugin, not the CUDA/cuDNN
46
+ # runtime it loads at session creation. We pull the matching nvidia CUDA-12 / cuDNN-9
47
+ # wheels explicitly (rather than via onnxruntime-gpu's own [cuda,cudnn] extras, whose
48
+ # nested-extra markers don't resolve through our conflicting cpu/gpu extras); >=1.21
49
+ # auto-preloads them from site-packages, so GPU works without a system CUDA install.
50
+ optional-dependencies.gpu = [
51
+ "nvidia-cublas-cu12>=12,<13",
52
+ "nvidia-cuda-runtime-cu12>=12,<13",
53
+ "nvidia-cudnn-cu12>=9,<10",
54
+ "nvidia-cufft-cu12>=11,<12",
55
+ "nvidia-curand-cu12>=10,<11",
56
+ "onnxruntime-gpu>=1.21",
57
+ ]
58
+ urls.Documentation = "https://openstef.github.io/openstef/index.html"
59
+ urls.Homepage = "https://lfenergy.org/projects/openstef/"
60
+ urls.Issues = "https://github.com/OpenSTEF/openstef/issues"
61
+ urls.Repository = "https://github.com/OpenSTEF/openstef"
62
+
63
+ [tool.hatch]
64
+ build.targets.wheel.packages = [ "src/openstef_foundation_models" ]
65
+
66
+ [tool.uv]
67
+ sources.openstef-beam = { workspace = true }
68
+ sources.openstef-core = { workspace = true }
69
+ sources.openstef-models = { workspace = true }
70
+ conflicts = [
71
+ [
72
+ { extra = "cpu" },
73
+ { extra = "gpu" },
74
+ ],
75
+ ]
@@ -0,0 +1,12 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+ """Foundation model support for OpenSTEF."""
5
+
6
+ import logging
7
+
8
+ root_logger = logging.getLogger(name=__name__)
9
+ if not root_logger.handlers:
10
+ root_logger.addHandler(logging.NullHandler())
11
+
12
+ __all__: list[str] = []
@@ -0,0 +1,32 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Inference backends for foundation-model forecasters.
6
+
7
+ An :class:`InferenceBackend` isolates how a checkpoint is executed behind a
8
+ single named-tensor contract; forecasters compose a backend rather than
9
+ inheriting execution behaviour. Only the dependency-free surface (the protocol
10
+ and the execution-provider configs) is re-exported here; a concrete backend
11
+ lives in its own submodule and imports its heavy runtime at module top level.
12
+ """
13
+
14
+ from openstef_foundation_models.inference.backend import InferenceBackend
15
+ from openstef_foundation_models.inference.providers import (
16
+ CoreMLProvider,
17
+ CpuProvider,
18
+ CudaProvider,
19
+ ExecutionProvider,
20
+ SessionOptionsConfig,
21
+ TensorRTProvider,
22
+ )
23
+
24
+ __all__ = [
25
+ "CoreMLProvider",
26
+ "CpuProvider",
27
+ "CudaProvider",
28
+ "ExecutionProvider",
29
+ "InferenceBackend",
30
+ "SessionOptionsConfig",
31
+ "TensorRTProvider",
32
+ ]
@@ -0,0 +1,43 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """The :class:`InferenceBackend` contract shared by all execution backends."""
6
+
7
+ from collections.abc import Mapping
8
+ from typing import Protocol, runtime_checkable
9
+
10
+ import numpy as np
11
+
12
+ from openstef_foundation_models.models.checkpoint import CheckpointMetadata
13
+
14
+
15
+ @runtime_checkable
16
+ class InferenceBackend(Protocol):
17
+ """A model-agnostic execution backend.
18
+
19
+ A backend takes a mapping of named input tensors to a mapping of named
20
+ output tensors. It owns whatever runtime resources are needed (e.g. an ONNX
21
+ Runtime session) and is loaded once, then reused across an entire backtest.
22
+ Model-family specifics live in :attr:`metadata`, not in the backend itself.
23
+ """
24
+
25
+ @property
26
+ def metadata(self) -> CheckpointMetadata:
27
+ """Metadata describing the checkpoint this backend executes."""
28
+ ...
29
+
30
+ def run(self, inputs: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
31
+ """Execute the model on a batch of named input tensors.
32
+
33
+ Args:
34
+ inputs: Named input tensors. Keys must match ``metadata.input_names``.
35
+
36
+ Returns:
37
+ Named output tensors, including ``metadata.output_name``.
38
+ """
39
+ ...
40
+
41
+ def close(self) -> None:
42
+ """Release any runtime resources held by the backend."""
43
+ ...
@@ -0,0 +1,237 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """ONNX Runtime execution backend.
6
+
7
+ Importing this module requires ONNX Runtime (the ``[cpu]`` or ``[gpu]`` extra)
8
+ and raises :class:`MissingExtraError` if it is missing.
9
+ """
10
+
11
+ import logging
12
+ from collections.abc import Mapping, Sequence
13
+ from typing import Self
14
+
15
+ import numpy as np
16
+
17
+ from openstef_core.exceptions import MissingExtraError
18
+ from openstef_foundation_models.inference.provider_selection import (
19
+ DefaultProviderPolicy,
20
+ HostCapabilities,
21
+ ProviderPolicy,
22
+ )
23
+ from openstef_foundation_models.inference.providers import (
24
+ ExecutionProvider,
25
+ SessionOptionsConfig,
26
+ )
27
+ from openstef_foundation_models.models.checkpoint import CheckpointMetadata, ResolvedCheckpoint
28
+
29
+ try:
30
+ import onnxruntime as ort
31
+ except ImportError as e:
32
+ raise MissingExtraError("onnxruntime", "openstef-foundation-models", install_extra="cpu") from e
33
+
34
+ # onnxruntime-gpu ships the CUDA execution-provider plugin but loads the CUDA/cuDNN
35
+ # runtime (the nvidia-*-cu12 wheels the [gpu] extra pulls) lazily at session creation.
36
+ # preload_dlls() loads them from the nvidia site-packages so the CUDA provider can be
37
+ # realized without a system CUDA install or LD_LIBRARY_PATH. It is a no-op on the CPU
38
+ # runtime; the guard covers onnxruntime < 1.21, which predates the API.
39
+ if hasattr(ort, "preload_dlls"):
40
+ ort.preload_dlls()
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class OnnxBackend:
46
+ """An :class:`~openstef_foundation_models.inference.backend.InferenceBackend` backed by ONNX Runtime.
47
+
48
+ The session is built once on construction and reused for every
49
+ :meth:`run` call, so a single backend instance can be shared across an
50
+ entire backtest. Users may either let the backend build a session from a
51
+ resolved checkpoint and provider chain, or pass a pre-built session they own.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ metadata: CheckpointMetadata,
57
+ session: ort.InferenceSession,
58
+ ) -> None:
59
+ """Wrap a pre-built ONNX Runtime session.
60
+
61
+ Prefer :meth:`from_checkpoint` unless you need to own the session
62
+ lifecycle yourself.
63
+
64
+ Args:
65
+ metadata: Metadata describing the checkpoint the session executes.
66
+ session: A pre-built ONNX Runtime inference session.
67
+ """
68
+ self._metadata = metadata
69
+ self._session: ort.InferenceSession | None = session
70
+
71
+ @classmethod
72
+ def from_checkpoint(
73
+ cls,
74
+ checkpoint: ResolvedCheckpoint,
75
+ providers: Sequence[ExecutionProvider] | None = None,
76
+ session_options: SessionOptionsConfig | None = None,
77
+ *,
78
+ policy: ProviderPolicy | None = None,
79
+ ) -> Self:
80
+ """Build a backend by loading a checkpoint into a new ONNX Runtime session.
81
+
82
+ With ``providers=None`` the *policy* selects a chain from the checkpoint
83
+ and host; an explicit ``providers`` list is used as given and *policy* is
84
+ ignored. See :class:`~openstef_foundation_models.inference.provider_selection.ProviderPolicy`
85
+ for how a chain is chosen and how strictly its realization is enforced.
86
+
87
+ Args:
88
+ checkpoint: The resolved checkpoint (weights + metadata) to load.
89
+ providers: Ordered execution providers to try. ``None`` lets the policy
90
+ pick a host-appropriate chain from the checkpoint metadata.
91
+ session_options: Optional ONNX Runtime session options.
92
+ policy: Selection policy used when ``providers is None``. Defaults to
93
+ :class:`DefaultProviderPolicy`.
94
+
95
+ Returns:
96
+ A backend wrapping the newly built session.
97
+ """
98
+ metadata = checkpoint.metadata
99
+ if providers is not None:
100
+ provider_configs = list(providers)
101
+ explicit = True
102
+ logger.info(
103
+ "Using explicit execution-provider chain %s for checkpoint '%s'.",
104
+ [config.to_ort()[0] for config in provider_configs],
105
+ metadata.model_family,
106
+ )
107
+ else:
108
+ selector = policy or DefaultProviderPolicy()
109
+ host = HostCapabilities.detect()
110
+ provider_configs = selector.select(metadata, host)
111
+ explicit = False
112
+ logger.debug(
113
+ "Detected host: platform=%s, available_providers=%s",
114
+ host.platform,
115
+ sorted(host.available_providers),
116
+ )
117
+ logger.info(
118
+ "%s selected execution-provider chain %s for checkpoint '%s' (precision=%s, static_shapes=%s) on %s.",
119
+ type(selector).__name__,
120
+ [config.to_ort()[0] for config in provider_configs],
121
+ metadata.model_family,
122
+ metadata.precision,
123
+ metadata.static_shapes,
124
+ host.platform,
125
+ )
126
+ ort_providers = [config.to_ort() for config in provider_configs]
127
+ so = _build_session_options(session_options) if session_options else None
128
+
129
+ session = ort.InferenceSession(
130
+ str(checkpoint.weights_path),
131
+ sess_options=so,
132
+ providers=ort_providers,
133
+ )
134
+ logger.info("ONNX Runtime session built; realized providers: %s.", session.get_providers())
135
+ _check_provider_fallback(
136
+ requested=provider_configs,
137
+ realized=session.get_providers(),
138
+ strict=explicit,
139
+ )
140
+ return cls(metadata=checkpoint.metadata, session=session)
141
+
142
+ @property
143
+ def metadata(self) -> CheckpointMetadata:
144
+ """Metadata describing the checkpoint this backend executes."""
145
+ return self._metadata
146
+
147
+ def run(self, inputs: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
148
+ """Execute the ONNX model on a batch of named input tensors.
149
+
150
+ Args:
151
+ inputs: Named input tensors. Keys must match ``metadata.input_names``.
152
+
153
+ Returns:
154
+ Named output tensors keyed by the model's output names.
155
+
156
+ Raises:
157
+ RuntimeError: If the backend has been closed.
158
+ """
159
+ if self._session is None:
160
+ msg = "OnnxBackend has been closed."
161
+ raise RuntimeError(msg)
162
+ output_names = [out.name for out in self._session.get_outputs()]
163
+ results = self._session.run(output_names, dict(inputs))
164
+ return {name: np.asarray(result) for name, result in zip(output_names, results, strict=True)}
165
+
166
+ def close(self) -> None:
167
+ """Release the underlying ONNX Runtime session.
168
+
169
+ ONNX Runtime frees native resources on garbage collection, so dropping
170
+ the reference is the supported way to release them.
171
+ """
172
+ self._session = None
173
+
174
+
175
+ def _build_session_options(config: SessionOptionsConfig) -> ort.SessionOptions:
176
+ """Translate a :class:`SessionOptionsConfig` into ONNX Runtime session options.
177
+
178
+ Args:
179
+ config: The typed session-options configuration.
180
+
181
+ Returns:
182
+ The corresponding ONNX Runtime ``SessionOptions``.
183
+ """
184
+ so = ort.SessionOptions()
185
+ so.graph_optimization_level = getattr(
186
+ ort.GraphOptimizationLevel,
187
+ f"ORT_{config.graph_optimization_level}",
188
+ )
189
+ if config.intra_op_num_threads is not None:
190
+ so.intra_op_num_threads = config.intra_op_num_threads
191
+ if config.inter_op_num_threads is not None:
192
+ so.inter_op_num_threads = config.inter_op_num_threads
193
+ return so
194
+
195
+
196
+ def _check_provider_fallback(
197
+ requested: Sequence[ExecutionProvider],
198
+ realized: Sequence[str],
199
+ *,
200
+ strict: bool,
201
+ ) -> None:
202
+ """Detect and report a silent fallback to the CPU execution provider.
203
+
204
+ Compares the requested chain against what ONNX Runtime actually realized. See
205
+ :class:`~openstef_foundation_models.inference.provider_selection.ProviderPolicy`
206
+ for the strict-vs-graceful contract this enforces.
207
+
208
+ Args:
209
+ requested: The execution providers that were requested.
210
+ realized: The provider names ONNX Runtime actually loaded.
211
+ strict: When ``True``, raise on any missing accelerator; otherwise warn
212
+ only on a full fallback to CPU.
213
+
214
+ Raises:
215
+ RuntimeError: If ``strict`` is set and a requested accelerator is missing.
216
+ """
217
+ requested_names = {config.to_ort()[0] for config in requested}
218
+ accelerators = requested_names - {"CPUExecutionProvider"}
219
+ if not accelerators:
220
+ return
221
+ realized_set = set(realized)
222
+ missing = accelerators - realized_set
223
+ if strict:
224
+ if missing:
225
+ msg = (
226
+ f"Requested execution provider(s) {sorted(missing)} were not realized; "
227
+ f"ONNX Runtime fell back to {realized}."
228
+ )
229
+ raise RuntimeError(msg)
230
+ return
231
+ if accelerators & realized_set:
232
+ return
233
+ logger.warning(
234
+ "No requested accelerator (%s) was realized; ONNX Runtime fell back to %s. Inference will run on CPU.",
235
+ sorted(accelerators),
236
+ realized,
237
+ )