openstef 3.2.74__py3-none-any.whl → 3.2.76__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstef/pipeline/optimize_hyperparameters.py +3 -6
- openstef/pipeline/train_model.py +1 -1
- openstef/tasks/train_model.py +25 -2
- {openstef-3.2.74.dist-info → openstef-3.2.76.dist-info}/METADATA +1 -1
- {openstef-3.2.74.dist-info → openstef-3.2.76.dist-info}/RECORD +8 -8
- {openstef-3.2.74.dist-info → openstef-3.2.76.dist-info}/WHEEL +1 -1
- {openstef-3.2.74.dist-info → openstef-3.2.76.dist-info}/LICENSE +0 -0
- {openstef-3.2.74.dist-info → openstef-3.2.76.dist-info}/top_level.txt +0 -0
@@ -188,7 +188,7 @@ def optimize_hyperparameters_pipeline_core(
|
|
188
188
|
)
|
189
189
|
|
190
190
|
best_hyperparams = study.best_params
|
191
|
-
best_model
|
191
|
+
# The best_model could be accessed via study.user_attrs["best_model"]
|
192
192
|
|
193
193
|
logger.info(
|
194
194
|
f"Finished hyperparameter optimization, error objective {study.best_value} "
|
@@ -209,17 +209,14 @@ def optimize_hyperparameters_pipeline_core(
|
|
209
209
|
# Train a model using the regular train pipeline.
|
210
210
|
# The train/validation/test split used in hyperparam optimisation
|
211
211
|
# is less suitable for an operational model.
|
212
|
-
|
212
|
+
model, report, model_specs, _ = train_model_pipeline_core(
|
213
213
|
pj=pj, input_data=input_data, model_specs=model_specs
|
214
214
|
)
|
215
215
|
|
216
|
-
# Save model and report. Report is always saved to MLFlow and optionally to disk
|
217
|
-
report = objective.create_report(model=best_model)
|
218
|
-
|
219
216
|
trials = objective.get_trial_track()
|
220
217
|
best_trial_number = study.best_trial.number
|
221
218
|
|
222
|
-
return
|
219
|
+
return model, model_specs, report, trials, best_trial_number, study.best_params
|
223
220
|
|
224
221
|
|
225
222
|
def optuna_optimization(
|
openstef/pipeline/train_model.py
CHANGED
openstef/tasks/train_model.py
CHANGED
@@ -26,10 +26,16 @@ from openstef.data_classes.prediction_job import PredictionJobDataClass
|
|
26
26
|
|
27
27
|
from openstef.enums import MLModelType, PipelineType
|
28
28
|
from openstef.exceptions import SkipSaveTrainingForecasts
|
29
|
-
from openstef.pipeline.train_model import
|
29
|
+
from openstef.pipeline.train_model import (
|
30
|
+
train_model_pipeline,
|
31
|
+
train_pipeline_step_load_model,
|
32
|
+
MAXIMUM_MODEL_AGE,
|
33
|
+
)
|
30
34
|
from openstef.tasks.utils.predictionjobloop import PredictionJobLoop
|
31
35
|
from openstef.tasks.utils.taskcontext import TaskContext
|
32
36
|
|
37
|
+
from openstef.model.serializer import MLflowSerializer
|
38
|
+
|
33
39
|
TRAINING_PERIOD_DAYS: int = 120
|
34
40
|
DEFAULT_CHECK_MODEL_AGE: bool = True
|
35
41
|
|
@@ -84,13 +90,30 @@ def train_model_task(
|
|
84
90
|
|
85
91
|
context.perf_meter.checkpoint("Added metadata to PredictionJob")
|
86
92
|
|
93
|
+
# Check the model age before retrieving the input data to speed up train job.
|
94
|
+
# (The exact same model age check is also part of the "train_model_pipeline".)
|
95
|
+
|
96
|
+
# Initialize serializer
|
97
|
+
serializer = MLflowSerializer(mlflow_tracking_uri=mlflow_tracking_uri)
|
98
|
+
|
99
|
+
# Get old model and age
|
100
|
+
_, _, old_model_age = train_pipeline_step_load_model(pj, serializer)
|
101
|
+
|
102
|
+
# Check old model age and continue yes/no
|
103
|
+
if (old_model_age < MAXIMUM_MODEL_AGE) and check_old_model_age:
|
104
|
+
context.perf_meter.checkpoint(
|
105
|
+
f"Old model is younger than {MAXIMUM_MODEL_AGE} days, skip training"
|
106
|
+
)
|
107
|
+
if pj.save_train_forecasts:
|
108
|
+
raise SkipSaveTrainingForecasts
|
109
|
+
return
|
110
|
+
|
87
111
|
# Define start and end of the training input data
|
88
112
|
if datetime_end is None:
|
89
113
|
datetime_end = datetime.utcnow()
|
90
114
|
if datetime_start is None:
|
91
115
|
datetime_start = datetime_end - timedelta(days=TRAINING_PERIOD_DAYS)
|
92
116
|
|
93
|
-
# todo: See if we can check model age before getting the data
|
94
117
|
# Get training input data from database
|
95
118
|
input_data = context.database.get_model_input(
|
96
119
|
pid=pj["id"],
|
@@ -60,9 +60,9 @@ openstef/pipeline/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU
|
|
60
60
|
openstef/pipeline/create_basecase_forecast.py,sha256=BPxf2MSvJyfbNCQGCr1Rol5ShqCUVBSDs_0Fjei-eOI,4284
|
61
61
|
openstef/pipeline/create_component_forecast.py,sha256=HgByae6ruVhy6TuGIJEuPyLyx7g4zSvJfk6Dynlqjl4,5030
|
62
62
|
openstef/pipeline/create_forecast.py,sha256=2vK2cH_VeRcoDWPXR06zFmwQ043FPA9uPvg5_OyxUfU,5008
|
63
|
-
openstef/pipeline/optimize_hyperparameters.py,sha256=
|
63
|
+
openstef/pipeline/optimize_hyperparameters.py,sha256=KL80enVMUAVEUwGhVygxWh3BluoUVpnBlDgbBz7iseY,10700
|
64
64
|
openstef/pipeline/train_create_forecast_backtest.py,sha256=upuoiE01vjjxUu_sY0tANPqdOtpGKrQQ3azhVDnBJdc,5512
|
65
|
-
openstef/pipeline/train_model.py,sha256=
|
65
|
+
openstef/pipeline/train_model.py,sha256=SzKZSKT5diajR2L8eB1JQeasMaaeBll4H21N49wm_r4,18556
|
66
66
|
openstef/pipeline/utils.py,sha256=fkc-oNirJ-JiyuOAL08RFrnPYPwudWal_N-BO6Cw980,2086
|
67
67
|
openstef/postprocessing/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
68
68
|
openstef/postprocessing/postprocessing.py,sha256=nehd0tDpkdIaWFJggQ-fDizIKdfmqJ3IOGfk0sDnrzk,8409
|
@@ -78,15 +78,15 @@ openstef/tasks/create_wind_forecast.py,sha256=RhshkmNSyFWx4Y6yQn02GzHjWTREbN5A5G
|
|
78
78
|
openstef/tasks/optimize_hyperparameters.py,sha256=s-z8YQJF6Lf3DdYgKHEpAdlbFJ3a-0Gj0Ahsqj1DErc,4758
|
79
79
|
openstef/tasks/run_tracy.py,sha256=sU1Aw6litLHw9XT2uqjtbrGUCaD6XRN9asUqtWJjkCg,5037
|
80
80
|
openstef/tasks/split_forecast.py,sha256=ilIwmUAEBZz8ksquLLiAxk4IiDqbg4oxPs-_ftrKRm8,9118
|
81
|
-
openstef/tasks/train_model.py,sha256=
|
81
|
+
openstef/tasks/train_model.py,sha256=BGRimvLN7AjUx97dcLrzjZGDR8q4V-I6KgP2gb7FWN0,6452
|
82
82
|
openstef/tasks/utils/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
83
83
|
openstef/tasks/utils/dependencies.py,sha256=Jy9dtV_G7lTEa5Cdy--wvMxJuAb0adb3R0X4QDjVteM,3077
|
84
84
|
openstef/tasks/utils/predictionjobloop.py,sha256=u4WQjvqBM6z9T7VFUZ-9JqgdepNJO0ZSr3DURMBus9E,9581
|
85
85
|
openstef/tasks/utils/taskcontext.py,sha256=yI6TntOkZcW8JiNVuw4uJIigEBL0_iIrkPklF4ZeCX4,5401
|
86
86
|
openstef/validation/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
87
87
|
openstef/validation/validation.py,sha256=AYQJBXwbFhpq34bqEhybw0lTIJ8Td4vr2-AbWxGxm3M,16917
|
88
|
-
openstef-3.2.
|
89
|
-
openstef-3.2.
|
90
|
-
openstef-3.2.
|
91
|
-
openstef-3.2.
|
92
|
-
openstef-3.2.
|
88
|
+
openstef-3.2.76.dist-info/LICENSE,sha256=7Pm2fWFFHHUG5lDHed1vl5CjzxObIXQglnYsEdtjo_k,14907
|
89
|
+
openstef-3.2.76.dist-info/METADATA,sha256=AUF-4s37eesqmP6APrVgT_5z61lbrR--MDCDE-YxhwI,6934
|
90
|
+
openstef-3.2.76.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
91
|
+
openstef-3.2.76.dist-info/top_level.txt,sha256=kD0H4PqrQoncZ957FvqwfBxa89kTrun4Z_RAPs_HhLs,9
|
92
|
+
openstef-3.2.76.dist-info/RECORD,,
|
File without changes
|
File without changes
|