openstef 3.4.77__py3-none-any.whl → 3.4.78__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstef/data_classes/prediction_job.py +4 -0
- openstef/model/serializer.py +23 -7
- openstef/pipeline/create_forecast.py +1 -2
- openstef/pipeline/train_model.py +4 -1
- {openstef-3.4.77.dist-info → openstef-3.4.78.dist-info}/METADATA +1 -1
- {openstef-3.4.77.dist-info → openstef-3.4.78.dist-info}/RECORD +9 -9
- {openstef-3.4.77.dist-info → openstef-3.4.78.dist-info}/WHEEL +0 -0
- {openstef-3.4.77.dist-info → openstef-3.4.78.dist-info}/licenses/LICENSE +0 -0
- {openstef-3.4.77.dist-info → openstef-3.4.78.dist-info}/top_level.txt +0 -0
@@ -140,6 +140,10 @@ class PredictionJobDataClass(BaseModel):
|
|
140
140
|
data_prep_class: Optional[DataPrepDataClass] = Field(
|
141
141
|
None, description="The import string for the custom data prep class"
|
142
142
|
)
|
143
|
+
model_run_id: Optional[str] = Field(
|
144
|
+
None,
|
145
|
+
description="The specific model run number that should be used for the forecast. If not set, the latest model run will be used.",
|
146
|
+
)
|
143
147
|
|
144
148
|
fallback_strategy: Optional[FallbackStrategy] = Field(
|
145
149
|
FallbackStrategy.EXTREME_DAY,
|
openstef/model/serializer.py
CHANGED
@@ -18,6 +18,7 @@ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repo
|
|
18
18
|
from xgboost import XGBModel # Temporary for backward compatibility
|
19
19
|
|
20
20
|
from openstef.data_classes.model_specifications import ModelSpecificationDataClass
|
21
|
+
from openstef.data_classes.prediction_job import PredictionJobDataClass
|
21
22
|
from openstef.logging.logger_factory import get_logger
|
22
23
|
from openstef.metrics.reporter import Report
|
23
24
|
from openstef.model.regressors.regressor import OpenstfRegressor
|
@@ -143,20 +144,30 @@ class MLflowSerializer:
|
|
143
144
|
def load_model(
|
144
145
|
self,
|
145
146
|
experiment_name: str,
|
147
|
+
model_run_id: Optional[str] = None,
|
146
148
|
) -> tuple[OpenstfRegressor, ModelSpecificationDataClass]:
|
147
|
-
"""Load sklearn
|
149
|
+
""" Load an sklearn-compatible model from MLflow.
|
148
150
|
|
151
|
+
This method retrieves a trained model and its specifications from MLflow
|
152
|
+
based on the provided PredictionJobDataClass instance. It supports loading
|
153
|
+
a specific model run if a run number is provided.
|
154
|
+
|
149
155
|
Args:
|
150
|
-
|
156
|
+
experiment_name (str): Name of the experiment, often the id of the predition job.
|
157
|
+
model_run_id (Optional[str]): The specific model run number that should be used for the forecast.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
tuple[OpenstfRegressor, ModelSpecificationDataClass]: A tuple containing
|
161
|
+
the loaded model and its specifications.
|
151
162
|
|
152
|
-
|
153
|
-
|
163
|
+
LookupError: If the model is not found in MLflow or if an error occurs
|
164
|
+
during the loading process.
|
154
165
|
|
155
166
|
"""
|
156
167
|
try:
|
157
168
|
models_df = self._find_models(
|
158
|
-
self.experiment_name_prefix + experiment_name, max_results=1
|
159
|
-
|
169
|
+
self.experiment_name_prefix + experiment_name, max_results=1, model_run_id=model_run_id)
|
170
|
+
# return the latest finished run of the model
|
160
171
|
if not models_df.empty:
|
161
172
|
latest_run = models_df.iloc[0] # Use .iloc[0] to only get latest run
|
162
173
|
else:
|
@@ -172,7 +183,7 @@ class MLflowSerializer:
|
|
172
183
|
) # Path without file:///
|
173
184
|
self.logger.info("Model successfully loaded with MLflow")
|
174
185
|
return loaded_model, model_specs
|
175
|
-
except (AttributeError, MlflowException, OSError) as exception:
|
186
|
+
except (AttributeError, MlflowException, OSError) as exception:
|
176
187
|
raise LookupError("Model not found. First train a model!") from exception
|
177
188
|
|
178
189
|
def get_model_age(
|
@@ -205,8 +216,13 @@ class MLflowSerializer:
|
|
205
216
|
experiment_name: str,
|
206
217
|
max_results: Optional[int] = 100,
|
207
218
|
filter_string: str = "attribute.status = 'FINISHED'",
|
219
|
+
model_run_id: Optional[int] = None,
|
208
220
|
) -> pd.DataFrame:
|
209
221
|
"""Finds trained models for specific experiment_name sorted by age in descending order."""
|
222
|
+
|
223
|
+
if model_run_id is not None:
|
224
|
+
filter_string += f" AND attributes.run_id = '{model_run_id}'"
|
225
|
+
|
210
226
|
models_df = mlflow.search_runs(
|
211
227
|
experiment_names=[experiment_name],
|
212
228
|
max_results=max_results,
|
@@ -52,11 +52,10 @@ def create_forecast_pipeline(
|
|
52
52
|
# Use the alternative forecast model if it's specify in the pj
|
53
53
|
if pj.alternative_forecast_model_pid:
|
54
54
|
prediction_model_pid = pj.alternative_forecast_model_pid
|
55
|
-
|
56
55
|
# Load most recent model for the given pid
|
57
56
|
model, model_specs = MLflowSerializer(
|
58
57
|
mlflow_tracking_uri=mlflow_tracking_uri
|
59
|
-
).load_model(experiment_name=str(prediction_model_pid))
|
58
|
+
).load_model(experiment_name=str(prediction_model_pid), model_run_id=pj.get("model_run_id"))
|
60
59
|
return create_forecast_pipeline_core(pj, input_data, model, model_specs)
|
61
60
|
|
62
61
|
|
openstef/pipeline/train_model.py
CHANGED
@@ -52,6 +52,7 @@ def train_model_pipeline(
|
|
52
52
|
check_old_model_age: Check if training should be skipped because the model is too young
|
53
53
|
mlflow_tracking_uri: Tracking URI for MLFlow
|
54
54
|
artifact_folder: Path where artifacts, such as trained models, are stored
|
55
|
+
ignore_existing_models: If True, a new model is trained as if no old model exists.
|
55
56
|
|
56
57
|
Returns:
|
57
58
|
If pj.save_train_forecasts is False, None is returned
|
@@ -168,6 +169,7 @@ def train_model_pipeline_core(
|
|
168
169
|
input_data: Input data
|
169
170
|
old_model: Old model to compare to. Defaults to None.
|
170
171
|
horizons: Horizons to train on in hours, relevant for feature engineering.
|
172
|
+
ignore_existing_models: If True, all existing models, including, hyperparameters are ignored and defsault values are used.
|
171
173
|
|
172
174
|
Raises:
|
173
175
|
InputDataInsufficientError: when input data is insufficient.
|
@@ -319,8 +321,9 @@ def train_pipeline_step_load_model(
|
|
319
321
|
old_model: Optional[OpenstfRegressor]
|
320
322
|
|
321
323
|
if not ignore_existing_models:
|
324
|
+
model_run_id = pj.get("model_run_id", None)
|
322
325
|
try:
|
323
|
-
old_model, model_specs = serializer.load_model(
|
326
|
+
old_model, model_specs = serializer.load_model(str(pj.id), model_run_id=model_run_id)
|
324
327
|
old_model_age = old_model.age # Age attribute is openstef specific
|
325
328
|
return old_model, model_specs, old_model_age
|
326
329
|
except (AttributeError, FileNotFoundError, LookupError):
|
@@ -17,7 +17,7 @@ openstef/data/dazls_model_3.4.24/dazls_stored_3.4.24_model_card.md.license,sha25
|
|
17
17
|
openstef/data_classes/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
18
18
|
openstef/data_classes/data_prep.py,sha256=sANgFjfwmSWhLCfmLjfqXQnczuvVZfk2765jZd7LwuE,3691
|
19
19
|
openstef/data_classes/model_specifications.py,sha256=PZeBLfH_MrP9-QorL1r0Hklp0befE8Nw05vNhTX9Y20,1338
|
20
|
-
openstef/data_classes/prediction_job.py,sha256=
|
20
|
+
openstef/data_classes/prediction_job.py,sha256=794joix2ynvCYvm-MbiA5eagT46CArr_n_K5UrVoFBs,7166
|
21
21
|
openstef/data_classes/split_function.py,sha256=K8y1dsQC5exeIDh37f7UwJ11tV71_uVSNbnKmwXpnOM,3435
|
22
22
|
openstef/feature_engineering/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
23
23
|
openstef/feature_engineering/apply_features.py,sha256=pro4eUmOFexX_9g9kJtDcbrQ1hWKzXjVpiJBmmBi89o,5326
|
@@ -49,7 +49,7 @@ openstef/model/fallback.py,sha256=x60GVyl1c5DpebzkjJEMToZpMTD1c4FrhM-tBN9uizk,31
|
|
49
49
|
openstef/model/model_creator.py,sha256=fnhcVGUHskbuAys5kjlJ4GXKxbi9Eq5eAA19ex11Vv0,6658
|
50
50
|
openstef/model/objective.py,sha256=0PZUbPzuyaYlpWEH_qPavO6ll7zwqTTUTfIrUzzFMbs,15585
|
51
51
|
openstef/model/objective_creator.py,sha256=3jJgcmY1sm-Yoe3SfjKrJukrsqtYyloUFaPbBWqswhQ,2208
|
52
|
-
openstef/model/serializer.py,sha256=
|
52
|
+
openstef/model/serializer.py,sha256=8vESYq2TmtEzEViBR7qbJ3rjm68LZkbiET2cUPGvFMs,17925
|
53
53
|
openstef/model/standard_deviation_generator.py,sha256=OorRvX2wRScU7f4SIBoiT24yJeeM50sETP3xC6m5IG4,2865
|
54
54
|
openstef/model/metamodels/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
55
55
|
openstef/model/metamodels/feature_clipper.py,sha256=DNsyYdjUT7ZNimJJIyTvv1nmwTwDUk5fX9EDgV9FbUQ,2862
|
@@ -77,10 +77,10 @@ openstef/monitoring/teams.py,sha256=klN7Ge-0VktJbZ_I-K8MJIc3LWgdNy0MGL8b2TdoUR8,
|
|
77
77
|
openstef/pipeline/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
78
78
|
openstef/pipeline/create_basecase_forecast.py,sha256=7IShIjEmjkzpNzWzQVKmYQvy0q_uwCGO-E0mSRmGdhw,4397
|
79
79
|
openstef/pipeline/create_component_forecast.py,sha256=40fYKajdj4F9K7fzmL3euyvwTr0v-oO_5cXpya87A0c,5839
|
80
|
-
openstef/pipeline/create_forecast.py,sha256=
|
80
|
+
openstef/pipeline/create_forecast.py,sha256=z18MrnMW6f85mLjH9XKLniuCQ9oziWCqfgA5YdEgROM,5676
|
81
81
|
openstef/pipeline/optimize_hyperparameters.py,sha256=w5LpZhW3KVklCJzaogNzyHfpMJfNqeRAnvyV4vi35wg,10953
|
82
82
|
openstef/pipeline/train_create_forecast_backtest.py,sha256=hBJPxfDkbrmFSSGZrRH1vTiIVqJP-SWe0ibVpHT_8Qg,6048
|
83
|
-
openstef/pipeline/train_model.py,sha256=
|
83
|
+
openstef/pipeline/train_model.py,sha256=4mtNXosLxxLNDtyIBd58youAHx5zWIW7PoSeZdtDoXY,20234
|
84
84
|
openstef/pipeline/utils.py,sha256=23mB31p19FoGWelLJzxNmqlzGwEr3fCDBEA37V2kpYY,2167
|
85
85
|
openstef/plotting/__init__.py,sha256=KQjXzyafCt1bE7XDrSeV4TDUIO7MkwN_Br4ASOcNI2g,163
|
86
86
|
openstef/plotting/load_forecast_plotter.py,sha256=GWHVmUB2YosNj7TnSrMnxYAfM2Z1mNg5oRV9A_lJmQY,8129
|
@@ -104,8 +104,8 @@ openstef/tasks/utils/predictionjobloop.py,sha256=Ysy3zF5lzPMz_asYDKeF5m0qgVT3tCt
|
|
104
104
|
openstef/tasks/utils/taskcontext.py,sha256=O-LZ_wHEl5vbT8oB7EYtOeMkvk6EqCnI1-KiyER7Eu4,5407
|
105
105
|
openstef/validation/__init__.py,sha256=bIyGTSA4V5VoOLTwdaiJJAnozmpSzvQooVYlsf8H4eU,163
|
106
106
|
openstef/validation/validation.py,sha256=r6UqkdH5TMjsGfn8Ta07K1jkqmrVmwcPGfyQvMmZyO4,11459
|
107
|
-
openstef-3.4.
|
108
|
-
openstef-3.4.
|
109
|
-
openstef-3.4.
|
110
|
-
openstef-3.4.
|
111
|
-
openstef-3.4.
|
107
|
+
openstef-3.4.78.dist-info/licenses/LICENSE,sha256=7Pm2fWFFHHUG5lDHed1vl5CjzxObIXQglnYsEdtjo_k,14907
|
108
|
+
openstef-3.4.78.dist-info/METADATA,sha256=81sI3OOBkDqIOnQOVq9HCvFzCGnTEpAhg70_4j_2DxM,8834
|
109
|
+
openstef-3.4.78.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
110
|
+
openstef-3.4.78.dist-info/top_level.txt,sha256=kD0H4PqrQoncZ957FvqwfBxa89kTrun4Z_RAPs_HhLs,9
|
111
|
+
openstef-3.4.78.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|