openstef 3.4.65__py3-none-any.whl → 3.4.66__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstef/pipeline/train_model.py +24 -14
- openstef/tasks/train_model.py +11 -4
- {openstef-3.4.65.dist-info → openstef-3.4.66.dist-info}/METADATA +1 -1
- {openstef-3.4.65.dist-info → openstef-3.4.66.dist-info}/RECORD +7 -7
- {openstef-3.4.65.dist-info → openstef-3.4.66.dist-info}/LICENSE +0 -0
- {openstef-3.4.65.dist-info → openstef-3.4.66.dist-info}/WHEEL +0 -0
- {openstef-3.4.65.dist-info → openstef-3.4.66.dist-info}/top_level.txt +0 -0
openstef/pipeline/train_model.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
# SPDX-License-Identifier: MPL-2.0
|
4
4
|
import logging
|
5
5
|
import os
|
6
|
-
from typing import Optional,
|
6
|
+
from typing import Optional, Tuple, Union
|
7
7
|
|
8
8
|
import pandas as pd
|
9
9
|
import structlog
|
@@ -46,6 +46,7 @@ def train_model_pipeline(
|
|
46
46
|
check_old_model_age: bool,
|
47
47
|
mlflow_tracking_uri: str,
|
48
48
|
artifact_folder: str,
|
49
|
+
ignore_existing_models: bool = False,
|
49
50
|
) -> Optional[tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]]:
|
50
51
|
"""Middle level pipeline that takes care of all persistent storage dependencies.
|
51
52
|
|
@@ -79,7 +80,7 @@ def train_model_pipeline(
|
|
79
80
|
|
80
81
|
# Get old model and age
|
81
82
|
old_model, model_specs, old_model_age = train_pipeline_step_load_model(
|
82
|
-
pj, serializer
|
83
|
+
pj, serializer, ignore_existing_models
|
83
84
|
)
|
84
85
|
|
85
86
|
# Check old model age and continue yes/no
|
@@ -106,6 +107,7 @@ def train_model_pipeline(
|
|
106
107
|
input_data,
|
107
108
|
old_model,
|
108
109
|
horizons=horizons,
|
110
|
+
ignore_existing_models=ignore_existing_models,
|
109
111
|
)
|
110
112
|
except OldModelHigherScoreError as OMHSE:
|
111
113
|
logger.error("Old model is better than new model", pid=pj["id"], exc_info=OMHSE)
|
@@ -155,6 +157,7 @@ def train_model_pipeline_core(
|
|
155
157
|
input_data: pd.DataFrame,
|
156
158
|
old_model: OpenstfRegressor = None,
|
157
159
|
horizons: list[float] = DEFAULT_TRAIN_HORIZONS_HOURS,
|
160
|
+
ignore_existing_models: bool = False,
|
158
161
|
) -> Tuple[
|
159
162
|
OpenstfRegressor,
|
160
163
|
Report,
|
@@ -203,7 +206,7 @@ def train_model_pipeline_core(
|
|
203
206
|
model_specs.feature_names = list(train_data.columns)
|
204
207
|
|
205
208
|
# Check if new model is better than old model
|
206
|
-
if old_model:
|
209
|
+
if old_model and not ignore_existing_models:
|
207
210
|
combined = pd.concat([train_data, validation_data])
|
208
211
|
# skip the forecast column added at the end of dataframes
|
209
212
|
if pj.save_train_forecasts:
|
@@ -220,6 +223,7 @@ def train_model_pipeline_core(
|
|
220
223
|
# Try to compare new model to old model.
|
221
224
|
# If this does not success, for example since the feature names of the
|
222
225
|
# old model differ from the new model, the new model is considered better
|
226
|
+
|
223
227
|
try:
|
224
228
|
score_old_model = old_model.score(x_data, y_data)
|
225
229
|
|
@@ -315,25 +319,31 @@ def train_pipeline_common(
|
|
315
319
|
|
316
320
|
|
317
321
|
def train_pipeline_step_load_model(
|
318
|
-
pj: PredictionJobDataClass,
|
322
|
+
pj: PredictionJobDataClass,
|
323
|
+
serializer: MLflowSerializer,
|
324
|
+
ignore_existing_models: bool = False,
|
319
325
|
) -> Tuple[OpenstfRegressor, ModelSpecificationDataClass, Union[int, float]]:
|
320
326
|
old_model: Optional[OpenstfRegressor]
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
327
|
+
|
328
|
+
if not ignore_existing_models:
|
329
|
+
try:
|
330
|
+
old_model, model_specs = serializer.load_model(experiment_name=str(pj.id))
|
331
|
+
old_model_age = old_model.age # Age attribute is openstef specific
|
332
|
+
return old_model, model_specs, old_model_age
|
333
|
+
except (AttributeError, FileNotFoundError, LookupError):
|
334
|
+
logger.warning("No old model found, training new model", pid=pj.id)
|
335
|
+
except Exception:
|
336
|
+
logger.exception(
|
337
|
+
"Old model could not be loaded, training new model", pid=pj.id
|
338
|
+
)
|
339
|
+
|
329
340
|
old_model = None
|
330
341
|
old_model_age = float("inf")
|
331
342
|
if pj["default_modelspecs"] is not None:
|
332
343
|
model_specs = pj["default_modelspecs"]
|
333
344
|
if model_specs.id != pj.id:
|
334
345
|
raise RuntimeError(
|
335
|
-
"The id of the prediction job and its default model_specs do not"
|
336
|
-
" match."
|
346
|
+
"The id of the prediction job and its default model_specs do not match."
|
337
347
|
)
|
338
348
|
else:
|
339
349
|
# create basic model_specs
|
openstef/tasks/train_model.py
CHANGED
@@ -19,8 +19,10 @@ Example:
|
|
19
19
|
$ python model_train.py
|
20
20
|
|
21
21
|
"""
|
22
|
-
|
22
|
+
|
23
|
+
from datetime import UTC, datetime, timedelta
|
23
24
|
from pathlib import Path
|
25
|
+
from typing import Optional
|
24
26
|
|
25
27
|
import pandas as pd
|
26
28
|
|
@@ -41,14 +43,16 @@ from openstef.tasks.utils.taskcontext import TaskContext
|
|
41
43
|
|
42
44
|
TRAINING_PERIOD_DAYS: int = 120
|
43
45
|
DEFAULT_CHECK_MODEL_AGE: bool = True
|
46
|
+
DEFAULT_IGNORE_EXISTING_MODELS: bool = False
|
44
47
|
|
45
48
|
|
46
49
|
def train_model_task(
|
47
50
|
pj: PredictionJobDataClass,
|
48
51
|
context: TaskContext,
|
49
52
|
check_old_model_age: bool = DEFAULT_CHECK_MODEL_AGE,
|
50
|
-
datetime_start: datetime = None,
|
51
|
-
datetime_end: datetime = None,
|
53
|
+
datetime_start: Optional[datetime] = None,
|
54
|
+
datetime_end: Optional[datetime] = None,
|
55
|
+
ignore_existing_models: bool = DEFAULT_IGNORE_EXISTING_MODELS,
|
52
56
|
) -> None:
|
53
57
|
"""Train model task.
|
54
58
|
|
@@ -104,7 +108,9 @@ def train_model_task(
|
|
104
108
|
serializer = MLflowSerializer(mlflow_tracking_uri=mlflow_tracking_uri)
|
105
109
|
|
106
110
|
# Get old model and age
|
107
|
-
_, _, old_model_age = train_pipeline_step_load_model(
|
111
|
+
_, _, old_model_age = train_pipeline_step_load_model(
|
112
|
+
pj, serializer, ignore_existing_models
|
113
|
+
)
|
108
114
|
|
109
115
|
# Check old model age and continue yes/no
|
110
116
|
if (old_model_age < MAXIMUM_MODEL_AGE) and check_old_model_age:
|
@@ -168,6 +174,7 @@ def train_model_task(
|
|
168
174
|
check_old_model_age=check_old_model_age,
|
169
175
|
mlflow_tracking_uri=mlflow_tracking_uri,
|
170
176
|
artifact_folder=artifact_folder,
|
177
|
+
ignore_existing_models=ignore_existing_models,
|
171
178
|
)
|
172
179
|
|
173
180
|
if data_sets:
|
@@ -73,7 +73,7 @@ openstef/pipeline/create_component_forecast.py,sha256=U2v_R-FSOXWVbWeknsJbkulN1Y
|
|
73
73
|
openstef/pipeline/create_forecast.py,sha256=uvp5mQqGSOx-ANY-9o5reiBYNNby0npm-0lt4w9EQ18,5763
|
74
74
|
openstef/pipeline/optimize_hyperparameters.py,sha256=uwXkzRA_fTSFt0yBuvvEoY5-4dMv42FPdS4hZocL-N8,11114
|
75
75
|
openstef/pipeline/train_create_forecast_backtest.py,sha256=hBJPxfDkbrmFSSGZrRH1vTiIVqJP-SWe0ibVpHT_8Qg,6048
|
76
|
-
openstef/pipeline/train_model.py,sha256=
|
76
|
+
openstef/pipeline/train_model.py,sha256=8tqJcfqjT9gsXoOSBJxf3i-N_3BPmxbUqt_Ygd7Oao0,20134
|
77
77
|
openstef/pipeline/utils.py,sha256=23mB31p19FoGWelLJzxNmqlzGwEr3fCDBEA37V2kpYY,2167
|
78
78
|
openstef/plotting/__init__.py,sha256=KQjXzyafCt1bE7XDrSeV4TDUIO7MkwN_Br4ASOcNI2g,163
|
79
79
|
openstef/plotting/load_forecast_plotter.py,sha256=n-dB2dQnqjWCvV3kBjnOZYQ03J-9jSIHVovJy3nGSnQ,8129
|
@@ -90,15 +90,15 @@ openstef/tasks/create_solar_forecast.py,sha256=HDrJrvTPCM8GS7EQwNr9uJNamf-nH2pu0
|
|
90
90
|
openstef/tasks/create_wind_forecast.py,sha256=RhshkmNSyFWx4Y6yQn02GzHjWTREbN5A5GAeWv0JpcE,2907
|
91
91
|
openstef/tasks/optimize_hyperparameters.py,sha256=3NT0KFgim8wAzWPJ0S-GULM3zoshyj63Ivp-g1_oPDw,4765
|
92
92
|
openstef/tasks/split_forecast.py,sha256=X1D3MnnMdAb9wzDWubAJwfMkWpNGdRUPDvPAbJApNhg,9277
|
93
|
-
openstef/tasks/train_model.py,sha256
|
93
|
+
openstef/tasks/train_model.py,sha256=-d1VewDAaZV2B_JAnwl02Y3hONq7cPZrpH6X87_IOKA,8772
|
94
94
|
openstef/tasks/utils/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
95
95
|
openstef/tasks/utils/dependencies.py,sha256=Jy9dtV_G7lTEa5Cdy--wvMxJuAb0adb3R0X4QDjVteM,3077
|
96
96
|
openstef/tasks/utils/predictionjobloop.py,sha256=Ysy3zF5lzPMz_asYDKeF5m0qgVT3tCtwSPihqMjnI5Q,9580
|
97
97
|
openstef/tasks/utils/taskcontext.py,sha256=L9K14ycwgVxbIVUjH2DIn_QWbnu-OfxcGtQ1K9T6sus,5630
|
98
98
|
openstef/validation/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
99
99
|
openstef/validation/validation.py,sha256=24GEzLyjVqaE2a-MppbFS-YQT5n739BxD7fH3LK5LEE,12133
|
100
|
-
openstef-3.4.
|
101
|
-
openstef-3.4.
|
102
|
-
openstef-3.4.
|
103
|
-
openstef-3.4.
|
104
|
-
openstef-3.4.
|
100
|
+
openstef-3.4.66.dist-info/LICENSE,sha256=7Pm2fWFFHHUG5lDHed1vl5CjzxObIXQglnYsEdtjo_k,14907
|
101
|
+
openstef-3.4.66.dist-info/METADATA,sha256=L8J4MBiz55-LU8iettxkpAP4Nj5UF4kR7wi8WBlFvtY,8816
|
102
|
+
openstef-3.4.66.dist-info/WHEEL,sha256=52BFRY2Up02UkjOa29eZOS2VxUrpPORXg1pkohGGUS8,91
|
103
|
+
openstef-3.4.66.dist-info/top_level.txt,sha256=kD0H4PqrQoncZ957FvqwfBxa89kTrun4Z_RAPs_HhLs,9
|
104
|
+
openstef-3.4.66.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|