openstef 3.4.77__py3-none-any.whl → 3.4.79__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstef/data_classes/prediction_job.py +4 -0
- openstef/feature_engineering/holiday_features.py +34 -26
- openstef/model/regressors/median.py +5 -3
- openstef/model/serializer.py +23 -7
- openstef/pipeline/create_forecast.py +1 -2
- openstef/pipeline/train_model.py +4 -1
- {openstef-3.4.77.dist-info → openstef-3.4.79.dist-info}/METADATA +1 -1
- {openstef-3.4.77.dist-info → openstef-3.4.79.dist-info}/RECORD +11 -11
- {openstef-3.4.77.dist-info → openstef-3.4.79.dist-info}/WHEEL +0 -0
- {openstef-3.4.77.dist-info → openstef-3.4.79.dist-info}/licenses/LICENSE +0 -0
- {openstef-3.4.77.dist-info → openstef-3.4.79.dist-info}/top_level.txt +0 -0
@@ -140,6 +140,10 @@ class PredictionJobDataClass(BaseModel):
|
|
140
140
|
data_prep_class: Optional[DataPrepDataClass] = Field(
|
141
141
|
None, description="The import string for the custom data prep class"
|
142
142
|
)
|
143
|
+
model_run_id: Optional[str] = Field(
|
144
|
+
None,
|
145
|
+
description="The specific model run number that should be used for the forecast. If not set, the latest model run will be used.",
|
146
|
+
)
|
143
147
|
|
144
148
|
fallback_strategy: Optional[FallbackStrategy] = Field(
|
145
149
|
FallbackStrategy.EXTREME_DAY,
|
@@ -1,8 +1,8 @@
|
|
1
1
|
# SPDX-FileCopyrightText: 2017-2023 Contributors to the OpenSTEF project <korte.termijn.prognoses@alliander.com> # noqa E501>
|
2
2
|
#
|
3
3
|
# SPDX-License-Identifier: MPL-2.0
|
4
|
-
"""This module contains all holiday related features."""
|
5
4
|
from datetime import datetime, timedelta
|
5
|
+
import collections
|
6
6
|
|
7
7
|
import holidays
|
8
8
|
import numpy as np
|
@@ -26,7 +26,6 @@ def generate_holiday_feature_functions(
|
|
26
26
|
2022-12-24 - 2023-01-08 is the 'Kerstvakantie'
|
27
27
|
2022-10-15 - 2022-10-23 is the 'HerfstvakantieNoord'
|
28
28
|
|
29
|
-
|
30
29
|
The holidays are based on a manually generated csv file.
|
31
30
|
The information is collected using:
|
32
31
|
https://www.schoolvakanties-nederland.nl/ and the python holiday function
|
@@ -44,7 +43,6 @@ def generate_holiday_feature_functions(
|
|
44
43
|
- Pinksteren
|
45
44
|
- Kerst
|
46
45
|
|
47
|
-
|
48
46
|
The 'Brugdagen' are updated untill dec 2020. (Generated using agenda)
|
49
47
|
|
50
48
|
Args:
|
@@ -83,23 +81,34 @@ def generate_holiday_feature_functions(
|
|
83
81
|
)
|
84
82
|
}
|
85
83
|
)
|
84
|
+
|
86
85
|
# Define empty list to keep track of bridgedays
|
87
86
|
bridge_days = []
|
88
|
-
|
87
|
+
|
88
|
+
# Group holiday dates by name
|
89
|
+
holiday_dates_by_name = collections.defaultdict(list)
|
89
90
|
for date, holiday_name in sorted(country_holidays.items()):
|
90
|
-
|
91
|
-
def make_holiday_func(requested_date):
|
92
|
-
return lambda x: np.isin(x.index.date, np.array([requested_date]))
|
91
|
+
holiday_dates_by_name[holiday_name].append(date)
|
93
92
|
|
94
|
-
|
93
|
+
# Create one function per holiday name that checks all dates for that holiday
|
94
|
+
for holiday_name, dates in holiday_dates_by_name.items():
|
95
|
+
# Use a default argument to capture the dates at definition time
|
95
96
|
holiday_functions.update(
|
96
|
-
{
|
97
|
+
{
|
98
|
+
"is_"
|
99
|
+
+ holiday_name.replace(
|
100
|
+
" ", "_"
|
101
|
+
).lower(): lambda x, dates_local=dates: np.isin(
|
102
|
+
x.index.date, np.array(dates_local)
|
103
|
+
)
|
104
|
+
}
|
97
105
|
)
|
98
106
|
|
99
|
-
# Check for bridge
|
100
|
-
|
101
|
-
|
102
|
-
|
107
|
+
# Check for bridge days for each date of this holiday
|
108
|
+
for date in dates:
|
109
|
+
holiday_functions, bridge_days = check_for_bridge_day(
|
110
|
+
date, holiday_name, country_code, years, holiday_functions, bridge_days
|
111
|
+
)
|
103
112
|
|
104
113
|
# Add feature function that includes all bridgedays
|
105
114
|
holiday_functions.update(
|
@@ -108,7 +117,7 @@ def generate_holiday_feature_functions(
|
|
108
117
|
|
109
118
|
# Add school holidays if country is NL
|
110
119
|
if country_code == "NL":
|
111
|
-
#
|
120
|
+
# Manually generated csv including all dutch schoolholidays for different regions
|
112
121
|
df_holidays = pd.read_csv(path_to_school_holidays_csv, index_col=None)
|
113
122
|
df_holidays["datum"] = pd.to_datetime(df_holidays.datum).apply(
|
114
123
|
lambda x: x.date()
|
@@ -125,19 +134,17 @@ def generate_holiday_feature_functions(
|
|
125
134
|
|
126
135
|
# Loop over list of holidays names
|
127
136
|
for holiday_name in list(set(df_holidays.name)):
|
128
|
-
#
|
129
|
-
def make_holiday_func(holidayname=holiday_name):
|
130
|
-
return lambda x: np.isin(
|
131
|
-
x.index.date,
|
132
|
-
df_holidays.datum[df_holidays.name == holidayname].values,
|
133
|
-
)
|
134
|
-
|
135
|
-
# Create lag function for each holiday
|
137
|
+
# Use the holidayname as a default argument to capture it at definition time
|
136
138
|
holiday_functions.update(
|
137
139
|
{
|
138
140
|
"is_"
|
139
|
-
+ holiday_name.replace(
|
140
|
-
|
141
|
+
+ holiday_name.replace(
|
142
|
+
" ", "_"
|
143
|
+
).lower(): lambda x, holiday_name_local=holiday_name: np.isin(
|
144
|
+
x.index.date,
|
145
|
+
df_holidays.datum[
|
146
|
+
df_holidays.name == holiday_name_local
|
147
|
+
].values,
|
141
148
|
)
|
142
149
|
}
|
143
150
|
)
|
@@ -178,9 +185,10 @@ def check_for_bridge_day(
|
|
178
185
|
if date in country_holidays:
|
179
186
|
return holiday_functions, bridge_days
|
180
187
|
|
181
|
-
# Define function
|
188
|
+
# Define function explicitly to mitigate 'late binding' problem
|
189
|
+
# Use a default argument to capture the date at definition time
|
182
190
|
def make_holiday_func(requested_date):
|
183
|
-
return lambda x: np.isin(x.index.date, np.array([
|
191
|
+
return lambda x, dt=requested_date: np.isin(x.index.date, np.array([dt]))
|
184
192
|
|
185
193
|
# Looking forward: If day after tomorow is a national holiday or
|
186
194
|
# a saturday check if tomorow is not a national holiday
|
@@ -304,9 +304,11 @@ class MedianRegressor(OpenstfRegressor, RegressorMixin):
|
|
304
304
|
|
305
305
|
Which lag features are used is determined by the feature engineering step.
|
306
306
|
"""
|
307
|
-
|
308
|
-
|
309
|
-
|
307
|
+
(
|
308
|
+
feature_names,
|
309
|
+
frequency,
|
310
|
+
feature_to_lags_in_min,
|
311
|
+
) = self._extract_and_validate_lags(x)
|
310
312
|
|
311
313
|
self.feature_names_ = list(feature_names)
|
312
314
|
self.frequency_ = frequency
|
openstef/model/serializer.py
CHANGED
@@ -18,6 +18,7 @@ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repo
|
|
18
18
|
from xgboost import XGBModel # Temporary for backward compatibility
|
19
19
|
|
20
20
|
from openstef.data_classes.model_specifications import ModelSpecificationDataClass
|
21
|
+
from openstef.data_classes.prediction_job import PredictionJobDataClass
|
21
22
|
from openstef.logging.logger_factory import get_logger
|
22
23
|
from openstef.metrics.reporter import Report
|
23
24
|
from openstef.model.regressors.regressor import OpenstfRegressor
|
@@ -143,20 +144,30 @@ class MLflowSerializer:
|
|
143
144
|
def load_model(
|
144
145
|
self,
|
145
146
|
experiment_name: str,
|
147
|
+
model_run_id: Optional[str] = None,
|
146
148
|
) -> tuple[OpenstfRegressor, ModelSpecificationDataClass]:
|
147
|
-
"""Load sklearn
|
149
|
+
""" Load an sklearn-compatible model from MLflow.
|
148
150
|
|
151
|
+
This method retrieves a trained model and its specifications from MLflow
|
152
|
+
based on the provided PredictionJobDataClass instance. It supports loading
|
153
|
+
a specific model run if a run number is provided.
|
154
|
+
|
149
155
|
Args:
|
150
|
-
|
156
|
+
experiment_name (str): Name of the experiment, often the id of the predition job.
|
157
|
+
model_run_id (Optional[str]): The specific model run number that should be used for the forecast.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
tuple[OpenstfRegressor, ModelSpecificationDataClass]: A tuple containing
|
161
|
+
the loaded model and its specifications.
|
151
162
|
|
152
|
-
|
153
|
-
|
163
|
+
LookupError: If the model is not found in MLflow or if an error occurs
|
164
|
+
during the loading process.
|
154
165
|
|
155
166
|
"""
|
156
167
|
try:
|
157
168
|
models_df = self._find_models(
|
158
|
-
self.experiment_name_prefix + experiment_name, max_results=1
|
159
|
-
|
169
|
+
self.experiment_name_prefix + experiment_name, max_results=1, model_run_id=model_run_id)
|
170
|
+
# return the latest finished run of the model
|
160
171
|
if not models_df.empty:
|
161
172
|
latest_run = models_df.iloc[0] # Use .iloc[0] to only get latest run
|
162
173
|
else:
|
@@ -172,7 +183,7 @@ class MLflowSerializer:
|
|
172
183
|
) # Path without file:///
|
173
184
|
self.logger.info("Model successfully loaded with MLflow")
|
174
185
|
return loaded_model, model_specs
|
175
|
-
except (AttributeError, MlflowException, OSError) as exception:
|
186
|
+
except (AttributeError, MlflowException, OSError) as exception:
|
176
187
|
raise LookupError("Model not found. First train a model!") from exception
|
177
188
|
|
178
189
|
def get_model_age(
|
@@ -205,8 +216,13 @@ class MLflowSerializer:
|
|
205
216
|
experiment_name: str,
|
206
217
|
max_results: Optional[int] = 100,
|
207
218
|
filter_string: str = "attribute.status = 'FINISHED'",
|
219
|
+
model_run_id: Optional[int] = None,
|
208
220
|
) -> pd.DataFrame:
|
209
221
|
"""Finds trained models for specific experiment_name sorted by age in descending order."""
|
222
|
+
|
223
|
+
if model_run_id is not None:
|
224
|
+
filter_string += f" AND attributes.run_id = '{model_run_id}'"
|
225
|
+
|
210
226
|
models_df = mlflow.search_runs(
|
211
227
|
experiment_names=[experiment_name],
|
212
228
|
max_results=max_results,
|
@@ -52,11 +52,10 @@ def create_forecast_pipeline(
|
|
52
52
|
# Use the alternative forecast model if it's specify in the pj
|
53
53
|
if pj.alternative_forecast_model_pid:
|
54
54
|
prediction_model_pid = pj.alternative_forecast_model_pid
|
55
|
-
|
56
55
|
# Load most recent model for the given pid
|
57
56
|
model, model_specs = MLflowSerializer(
|
58
57
|
mlflow_tracking_uri=mlflow_tracking_uri
|
59
|
-
).load_model(experiment_name=str(prediction_model_pid))
|
58
|
+
).load_model(experiment_name=str(prediction_model_pid), model_run_id=pj.get("model_run_id"))
|
60
59
|
return create_forecast_pipeline_core(pj, input_data, model, model_specs)
|
61
60
|
|
62
61
|
|
openstef/pipeline/train_model.py
CHANGED
@@ -52,6 +52,7 @@ def train_model_pipeline(
|
|
52
52
|
check_old_model_age: Check if training should be skipped because the model is too young
|
53
53
|
mlflow_tracking_uri: Tracking URI for MLFlow
|
54
54
|
artifact_folder: Path where artifacts, such as trained models, are stored
|
55
|
+
ignore_existing_models: If True, a new model is trained as if no old model exists.
|
55
56
|
|
56
57
|
Returns:
|
57
58
|
If pj.save_train_forecasts is False, None is returned
|
@@ -168,6 +169,7 @@ def train_model_pipeline_core(
|
|
168
169
|
input_data: Input data
|
169
170
|
old_model: Old model to compare to. Defaults to None.
|
170
171
|
horizons: Horizons to train on in hours, relevant for feature engineering.
|
172
|
+
ignore_existing_models: If True, all existing models, including, hyperparameters are ignored and defsault values are used.
|
171
173
|
|
172
174
|
Raises:
|
173
175
|
InputDataInsufficientError: when input data is insufficient.
|
@@ -319,8 +321,9 @@ def train_pipeline_step_load_model(
|
|
319
321
|
old_model: Optional[OpenstfRegressor]
|
320
322
|
|
321
323
|
if not ignore_existing_models:
|
324
|
+
model_run_id = pj.get("model_run_id", None)
|
322
325
|
try:
|
323
|
-
old_model, model_specs = serializer.load_model(
|
326
|
+
old_model, model_specs = serializer.load_model(str(pj.id), model_run_id=model_run_id)
|
324
327
|
old_model_age = old_model.age # Age attribute is openstef specific
|
325
328
|
return old_model, model_specs, old_model_age
|
326
329
|
except (AttributeError, FileNotFoundError, LookupError):
|
@@ -17,7 +17,7 @@ openstef/data/dazls_model_3.4.24/dazls_stored_3.4.24_model_card.md.license,sha25
|
|
17
17
|
openstef/data_classes/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
18
18
|
openstef/data_classes/data_prep.py,sha256=sANgFjfwmSWhLCfmLjfqXQnczuvVZfk2765jZd7LwuE,3691
|
19
19
|
openstef/data_classes/model_specifications.py,sha256=PZeBLfH_MrP9-QorL1r0Hklp0befE8Nw05vNhTX9Y20,1338
|
20
|
-
openstef/data_classes/prediction_job.py,sha256=
|
20
|
+
openstef/data_classes/prediction_job.py,sha256=794joix2ynvCYvm-MbiA5eagT46CArr_n_K5UrVoFBs,7166
|
21
21
|
openstef/data_classes/split_function.py,sha256=K8y1dsQC5exeIDh37f7UwJ11tV71_uVSNbnKmwXpnOM,3435
|
22
22
|
openstef/feature_engineering/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
23
23
|
openstef/feature_engineering/apply_features.py,sha256=pro4eUmOFexX_9g9kJtDcbrQ1hWKzXjVpiJBmmBi89o,5326
|
@@ -27,7 +27,7 @@ openstef/feature_engineering/data_preparation.py,sha256=TXAPTtSmBRC_LZP7o5Jlmj7J
|
|
27
27
|
openstef/feature_engineering/feature_adder.py,sha256=aSqDl_gUrB3H2TD3cNvU5JniY_KOb4u4a2A6J7zB2BQ,6835
|
28
28
|
openstef/feature_engineering/feature_applicator.py,sha256=bU1Pu5V1fxMCQCwh6HG66nmctBjrNa7gHUYqOqPmLTU,7501
|
29
29
|
openstef/feature_engineering/general.py,sha256=PdvnDqkze31FggUuWHQ1ysroh_uDOa1hZ7NftMYH2_U,4130
|
30
|
-
openstef/feature_engineering/holiday_features.py,sha256=
|
30
|
+
openstef/feature_engineering/holiday_features.py,sha256=g3VBj9oU3wmp82iKcknX41S_7Z4tGIjlvgbZOcFqQaw,8572
|
31
31
|
openstef/feature_engineering/lag_features.py,sha256=Dr6qS8UhdgEHPZZSe-w6ibtjl_lcbcQohhqdZN9fqEU,5652
|
32
32
|
openstef/feature_engineering/missing_values_transformer.py,sha256=U8pdA61k8CRosO3yR2IsCy5C4Ka3c8BWCimDLIB4LCQ,5010
|
33
33
|
openstef/feature_engineering/rolling_features.py,sha256=V-UulqWKuSksFQAASyVSQim1stEA4TmtHNULCrrdgjo,2160
|
@@ -49,7 +49,7 @@ openstef/model/fallback.py,sha256=x60GVyl1c5DpebzkjJEMToZpMTD1c4FrhM-tBN9uizk,31
|
|
49
49
|
openstef/model/model_creator.py,sha256=fnhcVGUHskbuAys5kjlJ4GXKxbi9Eq5eAA19ex11Vv0,6658
|
50
50
|
openstef/model/objective.py,sha256=0PZUbPzuyaYlpWEH_qPavO6ll7zwqTTUTfIrUzzFMbs,15585
|
51
51
|
openstef/model/objective_creator.py,sha256=3jJgcmY1sm-Yoe3SfjKrJukrsqtYyloUFaPbBWqswhQ,2208
|
52
|
-
openstef/model/serializer.py,sha256=
|
52
|
+
openstef/model/serializer.py,sha256=8vESYq2TmtEzEViBR7qbJ3rjm68LZkbiET2cUPGvFMs,17925
|
53
53
|
openstef/model/standard_deviation_generator.py,sha256=OorRvX2wRScU7f4SIBoiT24yJeeM50sETP3xC6m5IG4,2865
|
54
54
|
openstef/model/metamodels/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
55
55
|
openstef/model/metamodels/feature_clipper.py,sha256=DNsyYdjUT7ZNimJJIyTvv1nmwTwDUk5fX9EDgV9FbUQ,2862
|
@@ -64,7 +64,7 @@ openstef/model/regressors/gblinear_quantile.py,sha256=PKQL_TAXa3Kw9oZrKC6Uvo_n2N
|
|
64
64
|
openstef/model/regressors/lgbm.py,sha256=zCdn1euEdSFxYJzH8XqQFFnb6R4JVUnmineKjX_Gy-g,800
|
65
65
|
openstef/model/regressors/linear.py,sha256=uOvZMLGZH_9nXfmS5honCMfyVeyGXP1Cza9A_BdXlVw,3665
|
66
66
|
openstef/model/regressors/linear_quantile.py,sha256=zIpGo9deMeTZdwFWoZ3FstX74mYdlAhfg-YOsPRFl0k,10534
|
67
|
-
openstef/model/regressors/median.py,sha256=
|
67
|
+
openstef/model/regressors/median.py,sha256=f_yZWuJXAUbGbHAIMqpIAFSaUi0GnEe55DgFWGo7S5U,14157
|
68
68
|
openstef/model/regressors/regressor.py,sha256=0um575rTEkzYb1E5IAOuTlsZDhmb7eI5byu5e062NRs,3469
|
69
69
|
openstef/model/regressors/xgb.py,sha256=uhV9Wm90aOkjByTm-O2xpt2kpANRxAqQvv5mA0H1uBc,1294
|
70
70
|
openstef/model/regressors/xgb_multioutput_quantile.py,sha256=xWzA7tymC_o-F1OS3I7vUKf9zP6RR1ZglEeY4NAgjU0,9146
|
@@ -77,10 +77,10 @@ openstef/monitoring/teams.py,sha256=klN7Ge-0VktJbZ_I-K8MJIc3LWgdNy0MGL8b2TdoUR8,
|
|
77
77
|
openstef/pipeline/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
78
78
|
openstef/pipeline/create_basecase_forecast.py,sha256=7IShIjEmjkzpNzWzQVKmYQvy0q_uwCGO-E0mSRmGdhw,4397
|
79
79
|
openstef/pipeline/create_component_forecast.py,sha256=40fYKajdj4F9K7fzmL3euyvwTr0v-oO_5cXpya87A0c,5839
|
80
|
-
openstef/pipeline/create_forecast.py,sha256=
|
80
|
+
openstef/pipeline/create_forecast.py,sha256=z18MrnMW6f85mLjH9XKLniuCQ9oziWCqfgA5YdEgROM,5676
|
81
81
|
openstef/pipeline/optimize_hyperparameters.py,sha256=w5LpZhW3KVklCJzaogNzyHfpMJfNqeRAnvyV4vi35wg,10953
|
82
82
|
openstef/pipeline/train_create_forecast_backtest.py,sha256=hBJPxfDkbrmFSSGZrRH1vTiIVqJP-SWe0ibVpHT_8Qg,6048
|
83
|
-
openstef/pipeline/train_model.py,sha256=
|
83
|
+
openstef/pipeline/train_model.py,sha256=4mtNXosLxxLNDtyIBd58youAHx5zWIW7PoSeZdtDoXY,20234
|
84
84
|
openstef/pipeline/utils.py,sha256=23mB31p19FoGWelLJzxNmqlzGwEr3fCDBEA37V2kpYY,2167
|
85
85
|
openstef/plotting/__init__.py,sha256=KQjXzyafCt1bE7XDrSeV4TDUIO7MkwN_Br4ASOcNI2g,163
|
86
86
|
openstef/plotting/load_forecast_plotter.py,sha256=GWHVmUB2YosNj7TnSrMnxYAfM2Z1mNg5oRV9A_lJmQY,8129
|
@@ -104,8 +104,8 @@ openstef/tasks/utils/predictionjobloop.py,sha256=Ysy3zF5lzPMz_asYDKeF5m0qgVT3tCt
|
|
104
104
|
openstef/tasks/utils/taskcontext.py,sha256=O-LZ_wHEl5vbT8oB7EYtOeMkvk6EqCnI1-KiyER7Eu4,5407
|
105
105
|
openstef/validation/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
106
106
|
openstef/validation/validation.py,sha256=r6UqkdH5TMjsGfn8Ta07K1jkqmrVmwcPGfyQvMmZyO4,11459
|
107
|
-
openstef-3.4.
|
108
|
-
openstef-3.4.
|
109
|
-
openstef-3.4.
|
110
|
-
openstef-3.4.
|
111
|
-
openstef-3.4.
|
107
|
+
openstef-3.4.79.dist-info/licenses/LICENSE,sha256=7Pm2fWFFHHUG5lDHed1vl5CjzxObIXQglnYsEdtjo_k,14907
|
108
|
+
openstef-3.4.79.dist-info/METADATA,sha256=zfFVPR_RhCyKZ50LSCxuA46CI8L8d2tIJH02ryc9bUk,8834
|
109
|
+
openstef-3.4.79.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
110
|
+
openstef-3.4.79.dist-info/top_level.txt,sha256=kD0H4PqrQoncZ957FvqwfBxa89kTrun4Z_RAPs_HhLs,9
|
111
|
+
openstef-3.4.79.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|