openstef-models 4.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. openstef_models/__init__.py +13 -0
  2. openstef_models/explainability/__init__.py +18 -0
  3. openstef_models/explainability/mixins.py +126 -0
  4. openstef_models/explainability/plotters/__init__.py +17 -0
  5. openstef_models/explainability/plotters/contributions_plotter.py +228 -0
  6. openstef_models/explainability/plotters/feature_importance_plotter.py +59 -0
  7. openstef_models/integrations/__init__.py +12 -0
  8. openstef_models/integrations/joblib/__init__.py +15 -0
  9. openstef_models/integrations/joblib/joblib_model_serializer.py +68 -0
  10. openstef_models/integrations/mlflow/__init__.py +26 -0
  11. openstef_models/integrations/mlflow/mlflow_storage.py +332 -0
  12. openstef_models/integrations/mlflow/mlflow_storage_callback.py +383 -0
  13. openstef_models/integrations/optuna/__init__.py +23 -0
  14. openstef_models/integrations/optuna/tuner.py +354 -0
  15. openstef_models/mixins/__init__.py +14 -0
  16. openstef_models/mixins/callbacks.py +93 -0
  17. openstef_models/mixins/model_serializer.py +72 -0
  18. openstef_models/models/__init__.py +18 -0
  19. openstef_models/models/component_splitting/__init__.py +17 -0
  20. openstef_models/models/component_splitting/component_splitter.py +65 -0
  21. openstef_models/models/component_splitting/constant_component_splitter.py +152 -0
  22. openstef_models/models/component_splitting/linear_component_splitter.py +212 -0
  23. openstef_models/models/component_splitting/linear_component_splitter_model/linear_component_splitter_model.z +0 -0
  24. openstef_models/models/component_splitting/linear_component_splitter_model/linear_component_splitter_model.z.license +3 -0
  25. openstef_models/models/component_splitting_model.py +119 -0
  26. openstef_models/models/forecasting/__init__.py +9 -0
  27. openstef_models/models/forecasting/base_case_forecaster.py +177 -0
  28. openstef_models/models/forecasting/constant_quantile_forecaster.py +127 -0
  29. openstef_models/models/forecasting/flatliner_forecaster.py +111 -0
  30. openstef_models/models/forecasting/forecaster.py +140 -0
  31. openstef_models/models/forecasting/gblinear_forecaster.py +335 -0
  32. openstef_models/models/forecasting/lgbm_forecaster.py +342 -0
  33. openstef_models/models/forecasting/lgbmlinear_forecaster.py +344 -0
  34. openstef_models/models/forecasting/median_forecaster.py +320 -0
  35. openstef_models/models/forecasting/xgboost_forecaster.py +414 -0
  36. openstef_models/models/forecasting_model.py +609 -0
  37. openstef_models/presets/__init__.py +19 -0
  38. openstef_models/presets/forecasting_workflow.py +590 -0
  39. openstef_models/testing.py +98 -0
  40. openstef_models/transforms/__init__.py +26 -0
  41. openstef_models/transforms/energy_domain/__init__.py +14 -0
  42. openstef_models/transforms/energy_domain/wind_power_feature_adder.py +135 -0
  43. openstef_models/transforms/general/__init__.py +37 -0
  44. openstef_models/transforms/general/dimensionality_reducer.py +138 -0
  45. openstef_models/transforms/general/empty_feature_remover.py +120 -0
  46. openstef_models/transforms/general/flagger.py +97 -0
  47. openstef_models/transforms/general/imputer.py +278 -0
  48. openstef_models/transforms/general/nan_dropper.py +89 -0
  49. openstef_models/transforms/general/outlier_handler.py +178 -0
  50. openstef_models/transforms/general/sample_weighter.py +309 -0
  51. openstef_models/transforms/general/scaler.py +122 -0
  52. openstef_models/transforms/general/selector.py +83 -0
  53. openstef_models/transforms/general/shifter.py +130 -0
  54. openstef_models/transforms/postprocessing/__init__.py +16 -0
  55. openstef_models/transforms/postprocessing/confidence_interval_applicator.py +232 -0
  56. openstef_models/transforms/postprocessing/isotonic_quantile_calibrator.py +229 -0
  57. openstef_models/transforms/postprocessing/quantile_sorter.py +78 -0
  58. openstef_models/transforms/time_domain/__init__.py +29 -0
  59. openstef_models/transforms/time_domain/cyclic_features_adder.py +171 -0
  60. openstef_models/transforms/time_domain/datetime_features_adder.py +111 -0
  61. openstef_models/transforms/time_domain/holiday_features_adder.py +168 -0
  62. openstef_models/transforms/time_domain/lags_adder.py +388 -0
  63. openstef_models/transforms/time_domain/rolling_aggregates_adder.py +170 -0
  64. openstef_models/transforms/time_domain/versioned_lags_adder.py +139 -0
  65. openstef_models/transforms/validation/__init__.py +16 -0
  66. openstef_models/transforms/validation/completeness_checker.py +123 -0
  67. openstef_models/transforms/validation/flatline_checker.py +152 -0
  68. openstef_models/transforms/validation/input_consistency_checker.py +67 -0
  69. openstef_models/transforms/weather_domain/__init__.py +20 -0
  70. openstef_models/transforms/weather_domain/atmosphere_derived_features_adder.py +205 -0
  71. openstef_models/transforms/weather_domain/daylight_feature_adder.py +74 -0
  72. openstef_models/transforms/weather_domain/radiation_derived_features_adder.py +165 -0
  73. openstef_models/utils/__init__.py +9 -0
  74. openstef_models/utils/data_split.py +346 -0
  75. openstef_models/utils/evaluation_functions.py +30 -0
  76. openstef_models/utils/feature_selection.py +225 -0
  77. openstef_models/utils/loss_functions.py +256 -0
  78. openstef_models/utils/multi_quantile_regressor.py +157 -0
  79. openstef_models/utils/xgboost.py +45 -0
  80. openstef_models/workflows/__init__.py +16 -0
  81. openstef_models/workflows/callbacks/__init__.py +13 -0
  82. openstef_models/workflows/callbacks/data_save.py +108 -0
  83. openstef_models/workflows/callbacks/model_performance_callback.py +92 -0
  84. openstef_models/workflows/custom_component_split_workflow.py +154 -0
  85. openstef_models/workflows/custom_forecasting_workflow.py +206 -0
  86. openstef_models-4.0.0.dist-info/METADATA +44 -0
  87. openstef_models-4.0.0.dist-info/RECORD +88 -0
  88. openstef_models-4.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,13 @@
1
+ # SPDX-FileCopyrightText: 2017-2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+ """Core models for OpenSTEF."""
5
+
6
+ import logging
7
+
8
+ # Set up logging configuration
9
+ root_logger = logging.getLogger(name=__name__)
10
+ if not root_logger.handlers:
11
+ root_logger.addHandler(logging.NullHandler())
12
+
13
+ __all__ = []
@@ -0,0 +1,18 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Explainability utilities for OpenSTEF.
6
+
7
+ Tools for feature importance, attribution and model interpretation.
8
+ """
9
+
10
+ from .mixins import ContributionsMixin, ExplainableForecaster
11
+ from .plotters import ContributionsPlotter, FeatureImportancePlotter
12
+
13
+ __all__ = [
14
+ "ContributionsMixin",
15
+ "ContributionsPlotter",
16
+ "ExplainableForecaster",
17
+ "FeatureImportancePlotter",
18
+ ]
@@ -0,0 +1,126 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Mixins for adding explainability features to forecasting models.
6
+
7
+ Provides base classes that enable models to expose feature importance scores
8
+ and generate visualization plots.
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from typing import Any, Literal
13
+
14
+ import pandas as pd
15
+ import plotly.graph_objects as go
16
+
17
+ from openstef_core.datasets import ForecastInputDataset, TimeSeriesDataset
18
+ from openstef_core.types import Q, Quantile
19
+ from openstef_models.explainability.plotters.contributions_plotter import ContributionsPlotter
20
+ from openstef_models.explainability.plotters.feature_importance_plotter import FeatureImportancePlotter
21
+
22
+
23
+ class ExplainableForecaster(ABC):
24
+ """Mixin for forecasters that can explain feature importance.
25
+
26
+ Provides a standardized interface for accessing and visualizing feature
27
+ importance scores across different forecasting models.
28
+ """
29
+
30
+ @property
31
+ @abstractmethod
32
+ def feature_importances(self) -> pd.DataFrame:
33
+ """Get feature importance scores for this model.
34
+
35
+ Returns DataFrame with feature names as index and quantiles as columns.
36
+ Each quantile represents the importance distribution across multiple
37
+ model training runs or folds.
38
+
39
+ Returns:
40
+ DataFrame with feature names as index and quantile columns.
41
+ Values represent normalized importance scores summing to 1.0.
42
+
43
+ Note:
44
+ The returned DataFrame must have feature names as index and quantile
45
+ columns in format 'quantile_PXX' (e.g., 'quantile_P50', 'quantile_P95').
46
+ All quantile values must be between 0 and 1.
47
+ """
48
+ raise NotImplementedError
49
+
50
+ def plot_feature_importances(self, quantile: Quantile = Q(0.5)) -> go.Figure:
51
+ """Create interactive treemap visualization of feature importances.
52
+
53
+ Args:
54
+ quantile: Which quantile of importance scores to display.
55
+ Defaults to median (0.5).
56
+
57
+ Returns:
58
+ Plotly Figure containing treemap with feature importance scores.
59
+ Color intensity indicates relative importance of each feature.
60
+ """
61
+ return FeatureImportancePlotter().plot(scores=self.feature_importances, quantile=quantile)
62
+
63
+
64
+ class ContributionsMixin(ABC):
65
+ """Mixin for forecasters that can explain per-sample feature contributions.
66
+
67
+ Unlike ``ExplainableForecaster`` which provides aggregate feature importance,
68
+ this mixin provides per-sample decomposition of predictions — i.e., how
69
+ much each feature contributed to the prediction for each individual sample.
70
+
71
+ For tree-based models (XGBoost), this corresponds to SHAP TreeExplainer values.
72
+ For linear models (GBLinear), this is the coefficient x feature value decomposition.
73
+ For ensembles, this shows each base model's contribution weight.
74
+ """
75
+
76
+ @abstractmethod
77
+ def predict_contributions(self, data: ForecastInputDataset) -> TimeSeriesDataset:
78
+ """Compute per-sample feature contributions for the given input data.
79
+
80
+ Returns a TimeSeriesDataset where columns are feature names (or model
81
+ names for ensemble contributions) and rows correspond to the same time
82
+ index as the input. Values represent the additive contribution of each
83
+ feature to the prediction at that timestep.
84
+
85
+ Args:
86
+ data: Preprocessed input data (same format as ``predict()`` takes).
87
+
88
+ Returns:
89
+ TimeSeriesDataset with feature contributions. Columns are features,
90
+ rows are timesteps. A ``bias`` column may be included for the
91
+ model intercept/base value.
92
+ """
93
+
94
+ def plot_contributions(
95
+ self,
96
+ data: ForecastInputDataset,
97
+ kind: Literal["heatmap", "waterfall", "bar"] = "heatmap",
98
+ **kwargs: Any,
99
+ ) -> go.Figure:
100
+ """Plot per-sample feature contributions.
101
+
102
+ Calls ``predict_contributions()`` and visualizes the result using the
103
+ requested chart type.
104
+
105
+ Args:
106
+ data: Preprocessed input data.
107
+ kind: Chart type — ``"heatmap"``, ``"waterfall"``, or ``"bar"``.
108
+ **kwargs: Forwarded to the corresponding plotter method
109
+ (e.g. ``top_n``, ``timestep``).
110
+
111
+ Returns:
112
+ Plotly Figure.
113
+
114
+ Raises:
115
+ ValueError: If *kind* is not one of the supported chart types.
116
+ """
117
+ contributions = self.predict_contributions(data)
118
+ plotters = {
119
+ "heatmap": ContributionsPlotter.plot_heatmap,
120
+ "waterfall": ContributionsPlotter.plot_waterfall,
121
+ "bar": ContributionsPlotter.plot_bar,
122
+ }
123
+ if kind not in plotters:
124
+ msg = f"Unknown plot kind {kind!r}. Choose from {list(plotters)}"
125
+ raise ValueError(msg)
126
+ return plotters[kind](contributions=contributions, **kwargs)
@@ -0,0 +1,17 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Visualization tools for model explainability.
6
+
7
+ Provides plotters for creating interactive visualizations of feature importance
8
+ scores and other model explanation outputs.
9
+ """
10
+
11
+ from .contributions_plotter import ContributionsPlotter
12
+ from .feature_importance_plotter import FeatureImportancePlotter
13
+
14
+ __all__ = [
15
+ "ContributionsPlotter",
16
+ "FeatureImportancePlotter",
17
+ ]
@@ -0,0 +1,228 @@
1
+ # SPDX-FileCopyrightText: 2026 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Visualizations for per-sample feature contributions (SHAP values)."""
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots # pyright: ignore[reportUnknownVariableType]
13
+
14
+ from openstef_core.datasets import TimeSeriesDataset # noqa: TC001 # runtime needed for pyright
15
+
16
+ if TYPE_CHECKING:
17
+ import pandas as pd
18
+
19
+
20
+ class ContributionsPlotter:
21
+ """Visualizations for per-timestep feature contributions."""
22
+
23
+ @staticmethod
24
+ def plot_heatmap(
25
+ contributions: TimeSeriesDataset,
26
+ top_n: int = 10,
27
+ target_column: str = "load",
28
+ bias_column: str = "bias",
29
+ *,
30
+ show_prediction: bool = True,
31
+ ) -> go.Figure:
32
+ """Create an interactive heatmap of feature contributions over time.
33
+
34
+ X-axis is the prediction datetime, Y-axis shows feature names ranked by mean absolute contribution
35
+ (most important at top). Color ranges from blue (negative) through white (zero) to red (positive).
36
+ When ``show_prediction`` is True a line plot of the model prediction (sum of contributions + bias)
37
+ is shown above the heatmap.
38
+
39
+ Args:
40
+ contributions: Output of ``predict_contributions()``.
41
+ top_n: Number of top features to show (ranked by mean absolute contribution).
42
+ target_column: Name of the target column to exclude. Default "load".
43
+ bias_column: Name of the bias column. Default "bias".
44
+ show_prediction: If True, add a prediction line subplot above the heatmap. Default True.
45
+
46
+ Returns:
47
+ Plotly Figure with a diverging heatmap centered at zero (and optional prediction line).
48
+ """
49
+ bias = contributions.data[bias_column] if bias_column in contributions.data.columns else None
50
+ cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns]
51
+ df = contributions.data.drop(columns=cols_to_drop)
52
+ ranked: list[str] = df.abs().mean().sort_values(ascending=False).head(top_n).index.tolist()
53
+
54
+ # Most-important feature at top of Y-axis
55
+ y_labels = list(reversed(ranked))
56
+
57
+ heatmap = go.Heatmap(
58
+ z=df[y_labels].T.values,
59
+ x=df.index,
60
+ y=y_labels,
61
+ colorscale="RdBu_r",
62
+ zmid=0,
63
+ colorbar={"title": "Contribution"},
64
+ showlegend=False,
65
+ )
66
+
67
+ if show_prediction:
68
+ prediction = df.sum(axis=1)
69
+ if bias is not None:
70
+ prediction += bias
71
+
72
+ fig = make_subplots(
73
+ rows=2,
74
+ cols=1,
75
+ shared_xaxes=True,
76
+ row_heights=[0.2, 0.8],
77
+ vertical_spacing=0.03,
78
+ )
79
+
80
+ fig.add_trace( # pyright: ignore[reportUnknownMemberType]
81
+ go.Scatter(
82
+ x=df.index,
83
+ y=prediction,
84
+ mode="lines",
85
+ name="Prediction",
86
+ line={"color": "black", "width": 1.5},
87
+ showlegend=False,
88
+ ),
89
+ row=1,
90
+ col=1,
91
+ )
92
+ fig.add_trace(heatmap, row=2, col=1) # pyright: ignore[reportUnknownMemberType]
93
+
94
+ fig.update_layout( # pyright: ignore[reportUnknownMemberType]
95
+ yaxis_title="Prediction",
96
+ yaxis2_title="Feature",
97
+ xaxis2_title="Time",
98
+ margin={"t": 30, "r": 10, "b": 40, "l": 120},
99
+ )
100
+ else:
101
+ fig = go.Figure(
102
+ data=heatmap,
103
+ layout={
104
+ "xaxis_title": "Time",
105
+ "yaxis_title": "Feature",
106
+ "margin": {"t": 30, "r": 10, "b": 40, "l": 120},
107
+ },
108
+ )
109
+
110
+ return fig
111
+
112
+ @staticmethod
113
+ def plot_waterfall(
114
+ contributions: TimeSeriesDataset,
115
+ timestep: int = 0,
116
+ top_n: int = 10,
117
+ target_column: str = "load",
118
+ bias_column: str = "bias",
119
+ ) -> go.Figure:
120
+ """Create a waterfall chart decomposing a single timestep's prediction.
121
+
122
+ Shows how the bias (base value) is pushed up or down by each feature's
123
+ contribution to arrive at the final prediction.
124
+
125
+ Args:
126
+ contributions: Output of ``predict_contributions()``.
127
+ timestep: Row index (0-based) of the timestep to explain.
128
+ top_n: Number of top features to show. Remaining features are
129
+ aggregated into an "other" bar.
130
+ target_column: Name of the target column to exclude. Default "load".
131
+ bias_column: Name of the bias column used as base value. Default "bias".
132
+
133
+ Returns:
134
+ Plotly Figure with waterfall chart.
135
+ """
136
+ bias = contributions.data[bias_column] if bias_column in contributions.data.columns else None
137
+ cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns]
138
+ df = contributions.data.drop(columns=cols_to_drop)
139
+ row = df.iloc[timestep]
140
+ base_value = float(bias.iloc[timestep]) if bias is not None else 0.0
141
+
142
+ # Rank by |contribution| for this specific timestep
143
+ abs_sorted = row.abs().sort_values(ascending=False)
144
+ top = abs_sorted.head(top_n).index.tolist()
145
+ remaining = [c for c in abs_sorted.index if c not in top]
146
+
147
+ names: list[str] = [bias_column]
148
+ values: list[float] = [base_value]
149
+ measures: list[str] = ["absolute"]
150
+
151
+ for feat in top:
152
+ names.append(feat)
153
+ values.append(float(row[feat])) # pyright: ignore[reportArgumentType]
154
+ measures.append("relative")
155
+
156
+ if len(remaining) > 0:
157
+ other_sum = float(row[remaining].sum())
158
+ names.append(f"other ({len(remaining)})")
159
+ values.append(other_sum)
160
+ measures.append("relative")
161
+
162
+ names.append("Prediction")
163
+ values.append(base_value + float(row.sum()))
164
+ measures.append("total")
165
+
166
+ timestamp = contributions.data.index[timestep]
167
+ return go.Figure(
168
+ go.Waterfall(
169
+ x=names,
170
+ y=values,
171
+ measure=measures,
172
+ connector={"line": {"color": "grey", "width": 0.5}},
173
+ increasing={"marker": {"color": "#ff4136"}},
174
+ decreasing={"marker": {"color": "#0074d9"}},
175
+ totals={"marker": {"color": "#2ecc40"}},
176
+ textposition="outside",
177
+ text=[f"{v:+.4f}" if m == "relative" else f"{v:.4f}" for v, m in zip(values, measures, strict=True)],
178
+ ),
179
+ layout={
180
+ "title": f"Contributions at {timestamp}",
181
+ "yaxis_title": "Contribution",
182
+ "margin": {"t": 50, "r": 10, "b": 40, "l": 60},
183
+ "showlegend": False,
184
+ },
185
+ )
186
+
187
+ @staticmethod
188
+ def plot_bar(
189
+ contributions: TimeSeriesDataset,
190
+ top_n: int = 10,
191
+ target_column: str = "load",
192
+ bias_column: str = "bias",
193
+ ) -> go.Figure:
194
+ """Create a horizontal bar chart of mean absolute contributions per feature.
195
+
196
+ Features are ranked from most to least important (top to bottom).
197
+
198
+ Args:
199
+ contributions: Output of ``predict_contributions()``.
200
+ top_n: Number of top features to show.
201
+ target_column: Name of the target column to exclude. Default "load".
202
+ bias_column: Name of the bias column to exclude. Default "bias".
203
+
204
+ Returns:
205
+ Plotly Figure with horizontal bar chart.
206
+ """
207
+ cols_to_drop = [c for c in [target_column, bias_column] if c in contributions.data.columns]
208
+ df = contributions.data.drop(columns=cols_to_drop)
209
+ mean_abs: pd.Series = df.abs().mean().sort_values(ascending=False).head(top_n)
210
+
211
+ # Reverse for plotly (bottom-to-top rendering)
212
+ mean_abs = mean_abs.iloc[::-1]
213
+
214
+ return go.Figure(
215
+ go.Bar(
216
+ x=mean_abs.values, # pyright: ignore[reportArgumentType]
217
+ y=mean_abs.index.tolist(),
218
+ orientation="h",
219
+ marker_color="#1f77b4",
220
+ hovertemplate="<b>%{y}</b><br>mean |SHAP|: %{x:.4f}<extra></extra>",
221
+ ),
222
+ layout={
223
+ "xaxis_title": "mean |SHAP value|",
224
+ "yaxis_title": "Feature",
225
+ "margin": {"t": 30, "r": 10, "b": 40, "l": 120},
226
+ "showlegend": False,
227
+ },
228
+ )
@@ -0,0 +1,59 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Interactive treemap visualization for feature importance scores.
6
+
7
+ Creates color-coded treemaps showing relative importance of features in
8
+ forecasting models.
9
+ """
10
+
11
+ import pandas as pd
12
+ import plotly.graph_objects as go
13
+
14
+ from openstef_core.base_model import BaseConfig
15
+ from openstef_core.datasets.validation import validate_required_columns
16
+ from openstef_core.types import Q, Quantile
17
+
18
+
19
+ class FeatureImportancePlotter(BaseConfig):
20
+ """Creates treemap visualizations of feature importance scores."""
21
+
22
+ @staticmethod
23
+ def plot(scores: pd.DataFrame, quantile: Quantile = Q(0.5)) -> go.Figure:
24
+ """Generate interactive treemap showing feature importance.
25
+
26
+ Creates a color-coded treemap where each box size and color intensity
27
+ represents the relative importance of a feature. Useful for quickly
28
+ identifying which features contribute most to model predictions.
29
+
30
+ Args:
31
+ scores: Feature importance scores with feature names as index and
32
+ quantiles as columns (e.g., 'q0.5', 'q0.95'). Values should be
33
+ normalized to sum to 1.0.
34
+ quantile: Which quantile column to visualize. Defaults to median (0.5).
35
+
36
+ Returns:
37
+ Plotly Figure containing interactive treemap with hover information.
38
+ Larger boxes and darker green colors indicate higher importance.
39
+ """
40
+ quantile_column = quantile.format()
41
+ validate_required_columns(scores, required_columns=[quantile_column])
42
+
43
+ return go.Figure(
44
+ go.Treemap(
45
+ labels=scores.index,
46
+ parents=pd.Series(data=["Feature importance"] * len(scores), index=scores.index),
47
+ values=scores[quantile_column],
48
+ marker={"colors": scores[quantile_column], "colorscale": "greens"},
49
+ hovertemplate=("<b>%{label}</b><br>importance: %{value:.1%}<extra></extra>"),
50
+ ),
51
+ layout={
52
+ "margin": {
53
+ "t": 0,
54
+ "r": 0,
55
+ "b": 0,
56
+ "l": 0,
57
+ }
58
+ },
59
+ )
@@ -0,0 +1,12 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """Integration components for extending OpenSTEF functionality.
6
+
7
+ Contains implementations for callbacks and storage systems that hook into and
8
+ extend OpenSTEF functionality by integrating with external systems such as
9
+ monitoring tools, databases, cloud storage, and custom processing pipelines.
10
+ """
11
+
12
+ __all__ = ["joblib", "mlflow", "optuna"] # noqa: F822 # pyright: ignore[reportUnsupportedDunderAll] # Sub-packages with optional deps; not imported to avoid missing-extra errors at import time
@@ -0,0 +1,15 @@
1
+ """Joblib-based model storage integration.
2
+
3
+ Provides local file-based model persistence using joblib for serialization.
4
+ This integration allows storing and loading ForecastingModel instances on
5
+ the local filesystem, making it suitable for development, testing, and
6
+ single-machine deployments.
7
+ """
8
+
9
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
10
+ #
11
+ # SPDX-License-Identifier: MPL-2.0
12
+
13
+ from .joblib_model_serializer import JoblibModelSerializer
14
+
15
+ __all__ = ["JoblibModelSerializer"]
@@ -0,0 +1,68 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+ """Local model storage implementation using joblib serialization.
5
+
6
+ Provides file-based persistence for ForecastingModel instances using joblib's
7
+ pickle-based serialization. This storage backend is suitable for development,
8
+ testing, and single-machine deployments where models need to be persisted
9
+ to the local filesystem.
10
+ """
11
+
12
+ from typing import BinaryIO, ClassVar, override
13
+
14
+ from openstef_core.exceptions import MissingExtraError
15
+ from openstef_models.mixins.model_serializer import ModelSerializer
16
+
17
+ try:
18
+ import joblib
19
+ except ImportError as e:
20
+ raise MissingExtraError("joblib", package="openstef-models") from e
21
+
22
+
23
+ class JoblibModelSerializer(ModelSerializer):
24
+ """File-based model storage using joblib serialization.
25
+
26
+ Provides persistent storage for ForecastingModel instances on the local
27
+ filesystem. Models are serialized using joblib and stored as pickle files
28
+ in the specified directory.
29
+
30
+ This storage implementation is suitable for development, testing, and
31
+ single-machine deployments where simple file-based persistence is sufficient.
32
+
33
+ Note:
34
+ joblib.dump() and joblib.load() are based on the Python pickle serialization model,
35
+ which means that arbitrary Python code can be executed when loading a serialized object
36
+ with joblib.load().
37
+
38
+ joblib.load() should therefore never be used to load objects from an untrusted source
39
+ or otherwise you will introduce a security vulnerability in your program.
40
+
41
+ Invariants:
42
+ - Models are stored as .pkl files in the configured storage directory
43
+ - Model files use the pattern: {model_id}.pkl
44
+ - Storage directory is created automatically if it doesn't exist
45
+ - Load operations fail with ModelNotFoundError if model file doesn't exist
46
+
47
+ Example:
48
+ Basic usage with model persistence
49
+
50
+ >>> from pathlib import Path
51
+ >>> from openstef_models.models.forecasting_model import ForecastingModel
52
+ >>> storage = LocalModelStorage(storage_dir=Path("./models")) # doctest: +SKIP
53
+ >>> storage.save_model("my_model", my_forecasting_model) # doctest: +SKIP
54
+ >>> loaded_model = storage.load_model("my_model") # doctest: +SKIP
55
+ """
56
+
57
+ extension: ClassVar[str] = "joblib"
58
+
59
+ @override
60
+ def serialize(self, model: object, file: BinaryIO) -> None:
61
+ joblib.dump(model, file) # type: ignore[reportUnknownMemberType]
62
+
63
+ @override
64
+ def deserialize(self, file: BinaryIO) -> object:
65
+ return joblib.load(file) # type: ignore[reportUnknownMemberType]
66
+
67
+
68
+ __all__ = ["JoblibModelSerializer"]
@@ -0,0 +1,26 @@
1
+ # SPDX-FileCopyrightText: 2025 Contributors to the OpenSTEF project <openstef@lfenergy.org>
2
+ #
3
+ # SPDX-License-Identifier: MPL-2.0
4
+
5
+ """MLflow integration for model tracking and storage.
6
+
7
+ Provides integration with MLflow for model lifecycle management, experiment
8
+ tracking, and model registry functionality. This package enables OpenSTEF
9
+ models to be stored, versioned, and tracked using MLflow's
10
+ model registry.
11
+
12
+ Note:
13
+ This package requires MLflow to be installed as an optional dependency.
14
+ MLflow integration is particularly useful for production deployments
15
+ requiring model versioning, experiment tracking, and centralized storage.
16
+ """
17
+
18
+ from .mlflow_storage import MLFlowStorage
19
+ from .mlflow_storage_callback import (
20
+ MLFlowStorageCallback,
21
+ )
22
+
23
+ __all__ = [
24
+ "MLFlowStorage",
25
+ "MLFlowStorageCallback",
26
+ ]