autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +84 -0
- autogluon/timeseries/dataset/ts_dataframe.py +339 -186
- autogluon/timeseries/learner.py +192 -60
- autogluon/timeseries/metrics/__init__.py +55 -11
- autogluon/timeseries/metrics/abstract.py +96 -25
- autogluon/timeseries/metrics/point.py +186 -39
- autogluon/timeseries/metrics/quantile.py +47 -20
- autogluon/timeseries/metrics/utils.py +6 -6
- autogluon/timeseries/models/__init__.py +13 -7
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
- autogluon/timeseries/models/abstract/model_trial.py +10 -10
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
- autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
- autogluon/timeseries/models/chronos/__init__.py +4 -0
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +738 -0
- autogluon/timeseries/models/chronos/utils.py +369 -0
- autogluon/timeseries/models/ensemble/__init__.py +35 -2
- autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
- autogluon/timeseries/models/gluonts/__init__.py +3 -1
- autogluon/timeseries/models/gluonts/abstract.py +583 -0
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
- autogluon/timeseries/models/local/__init__.py +1 -10
- autogluon/timeseries/models/local/abstract_local_model.py +150 -97
- autogluon/timeseries/models/local/naive.py +31 -23
- autogluon/timeseries/models/local/npts.py +6 -2
- autogluon/timeseries/models/local/statsforecast.py +99 -112
- autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +826 -305
- autogluon/timeseries/regressor.py +253 -0
- autogluon/timeseries/splitter.py +10 -31
- autogluon/timeseries/trainer/__init__.py +2 -3
- autogluon/timeseries/trainer/ensemble_composer.py +439 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/trainer/trainer.py +1298 -0
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -0
- autogluon/timeseries/transforms/covariate_scaler.py +164 -0
- autogluon/timeseries/transforms/target_scaler.py +149 -0
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/base.py +38 -20
- autogluon/timeseries/utils/datetime/lags.py +18 -16
- autogluon/timeseries/utils/datetime/seasonality.py +14 -14
- autogluon/timeseries/utils/datetime/time_features.py +17 -14
- autogluon/timeseries/utils/features.py +317 -53
- autogluon/timeseries/utils/forecast.py +31 -17
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +44 -6
- autogluon/timeseries/version.py +2 -1
- autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
- autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -11
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -325
- autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
- autogluon/timeseries/trainer/auto_trainer.py +0 -74
- autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
- autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
|
@@ -1,1144 +0,0 @@
|
|
|
1
|
-
import copy
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import time
|
|
5
|
-
import traceback
|
|
6
|
-
from collections import defaultdict
|
|
7
|
-
from pathlib import Path
|
|
8
|
-
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
9
|
-
|
|
10
|
-
import networkx as nx
|
|
11
|
-
import numpy as np
|
|
12
|
-
import pandas as pd
|
|
13
|
-
from tqdm import tqdm
|
|
14
|
-
|
|
15
|
-
from autogluon.common.utils.utils import hash_pandas_df, seed_everything
|
|
16
|
-
from autogluon.core.models import AbstractModel
|
|
17
|
-
from autogluon.core.utils.exceptions import TimeLimitExceeded
|
|
18
|
-
from autogluon.core.utils.loaders import load_pkl
|
|
19
|
-
from autogluon.core.utils.savers import save_json, save_pkl
|
|
20
|
-
from autogluon.timeseries import TimeSeriesDataFrame
|
|
21
|
-
from autogluon.timeseries.metrics import TimeSeriesScorer, check_get_evaluation_metric
|
|
22
|
-
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
23
|
-
from autogluon.timeseries.models.ensemble import AbstractTimeSeriesEnsembleModel, TimeSeriesGreedyEnsemble
|
|
24
|
-
from autogluon.timeseries.models.presets import contains_searchspace
|
|
25
|
-
from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
|
|
26
|
-
from autogluon.timeseries.utils.features import CovariateMetadata
|
|
27
|
-
from autogluon.timeseries.utils.warning_filters import disable_tqdm
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger("autogluon.timeseries.trainer")
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
# TODO: This class is meant to be moved to `core`, where it will likely
|
|
33
|
-
# TODO: be renamed `AbstractTrainer` and the current `AbstractTrainer`
|
|
34
|
-
# TODO: will inherit from this class.
|
|
35
|
-
# TODO: add documentation for abstract methods
|
|
36
|
-
class SimpleAbstractTrainer:
|
|
37
|
-
trainer_file_name = "trainer.pkl"
|
|
38
|
-
trainer_info_name = "info.pkl"
|
|
39
|
-
trainer_info_json_name = "info.json"
|
|
40
|
-
|
|
41
|
-
def __init__(self, path: str, low_memory: bool, save_data: bool, *args, **kwargs):
|
|
42
|
-
self.path = path
|
|
43
|
-
self.reset_paths = False
|
|
44
|
-
|
|
45
|
-
self.low_memory = low_memory
|
|
46
|
-
self.save_data = save_data
|
|
47
|
-
|
|
48
|
-
self.models = {}
|
|
49
|
-
self.model_graph = nx.DiGraph()
|
|
50
|
-
self.model_best = None
|
|
51
|
-
|
|
52
|
-
self._extra_banned_names = set()
|
|
53
|
-
|
|
54
|
-
def get_model_names(self, **kwargs) -> List[str]:
|
|
55
|
-
"""Get all model names that are registered in the model graph"""
|
|
56
|
-
return list(self.model_graph.nodes)
|
|
57
|
-
|
|
58
|
-
def _get_banned_model_names(self) -> List[str]:
|
|
59
|
-
"""Gets all model names which would cause model files to be overwritten if a new model
|
|
60
|
-
was trained with the name
|
|
61
|
-
"""
|
|
62
|
-
return self.get_model_names() + list(self._extra_banned_names)
|
|
63
|
-
|
|
64
|
-
def get_models_attribute_dict(self, attribute: str, models: List[str] = None) -> Dict[str, Any]:
|
|
65
|
-
"""Get an attribute from the `model_graph` for each of the model names
|
|
66
|
-
specified. If `models` is none, the attribute will be returned for all models"""
|
|
67
|
-
results = {}
|
|
68
|
-
if models is None:
|
|
69
|
-
models = self.get_model_names()
|
|
70
|
-
for model in models:
|
|
71
|
-
results[model] = self.model_graph.nodes[model][attribute]
|
|
72
|
-
return results
|
|
73
|
-
|
|
74
|
-
def get_model_best(self) -> str:
|
|
75
|
-
"""Return the name of the best model by model performance on the validation set."""
|
|
76
|
-
models = self.get_model_names()
|
|
77
|
-
if not models:
|
|
78
|
-
raise ValueError("Trainer has no fit models that can predict.")
|
|
79
|
-
model_performances = self.get_models_attribute_dict(attribute="val_score")
|
|
80
|
-
performances_list = [(m, model_performances[m]) for m in models if model_performances[m] is not None]
|
|
81
|
-
|
|
82
|
-
if not performances_list:
|
|
83
|
-
raise ValueError("No fitted models have validation scores computed.")
|
|
84
|
-
|
|
85
|
-
return max(performances_list, key=lambda i: i[1])[0]
|
|
86
|
-
|
|
87
|
-
def get_model_attribute(self, model: Union[str, AbstractModel], attribute: str):
|
|
88
|
-
"""Get a member attribute for given model from the `model_graph`."""
|
|
89
|
-
if not isinstance(model, str):
|
|
90
|
-
model = model.name
|
|
91
|
-
if attribute == "path":
|
|
92
|
-
return os.path.join(*self.model_graph.nodes[model][attribute])
|
|
93
|
-
return self.model_graph.nodes[model][attribute]
|
|
94
|
-
|
|
95
|
-
def set_model_attribute(self, model: Union[str, AbstractModel], attribute: str, val):
|
|
96
|
-
"""Set a member attribute for given model in the `model_graph`."""
|
|
97
|
-
if not isinstance(model, str):
|
|
98
|
-
model = model.name
|
|
99
|
-
self.model_graph.nodes[model][attribute] = val
|
|
100
|
-
|
|
101
|
-
@property
|
|
102
|
-
def path_root(self) -> str:
|
|
103
|
-
return os.path.dirname(self.path)
|
|
104
|
-
|
|
105
|
-
@property
|
|
106
|
-
def path_utils(self) -> str:
|
|
107
|
-
return os.path.join(self.path_root, "utils")
|
|
108
|
-
|
|
109
|
-
@property
|
|
110
|
-
def path_data(self) -> str:
|
|
111
|
-
return os.path.join(self.path_utils, "data")
|
|
112
|
-
|
|
113
|
-
@property
|
|
114
|
-
def path_pkl(self) -> str:
|
|
115
|
-
return os.path.join(self.path, self.trainer_file_name)
|
|
116
|
-
|
|
117
|
-
def set_contexts(self, path_context: str) -> None:
|
|
118
|
-
self.path = self.create_contexts(path_context)
|
|
119
|
-
|
|
120
|
-
def create_contexts(self, path_context: str) -> str:
|
|
121
|
-
path = path_context
|
|
122
|
-
|
|
123
|
-
return path
|
|
124
|
-
|
|
125
|
-
def save(self) -> None:
|
|
126
|
-
# todo: remove / revise low_memory logic
|
|
127
|
-
models = self.models
|
|
128
|
-
if self.low_memory:
|
|
129
|
-
self.models = {}
|
|
130
|
-
try:
|
|
131
|
-
save_pkl.save(path=self.path_pkl, object=self)
|
|
132
|
-
except: # noqa
|
|
133
|
-
self.models = {}
|
|
134
|
-
save_pkl.save(path=self.path_pkl, object=self)
|
|
135
|
-
if not self.models:
|
|
136
|
-
self.models = models
|
|
137
|
-
|
|
138
|
-
@classmethod
|
|
139
|
-
def load(cls, path: str, reset_paths: bool = False) -> "SimpleAbstractTrainer":
|
|
140
|
-
load_path = os.path.join(path, cls.trainer_file_name)
|
|
141
|
-
if not reset_paths:
|
|
142
|
-
return load_pkl.load(path=load_path)
|
|
143
|
-
else:
|
|
144
|
-
obj = load_pkl.load(path=load_path)
|
|
145
|
-
obj.set_contexts(path)
|
|
146
|
-
obj.reset_paths = reset_paths
|
|
147
|
-
return obj
|
|
148
|
-
|
|
149
|
-
def save_model(self, model: AbstractModel, **kwargs) -> None: # noqa: F841
|
|
150
|
-
model.save()
|
|
151
|
-
if not self.low_memory:
|
|
152
|
-
self.models[model.name] = model
|
|
153
|
-
|
|
154
|
-
def load_model(
|
|
155
|
-
self,
|
|
156
|
-
model_name: Union[str, AbstractModel],
|
|
157
|
-
path: Optional[str] = None,
|
|
158
|
-
model_type: Optional[Type[AbstractModel]] = None,
|
|
159
|
-
) -> AbstractTimeSeriesModel:
|
|
160
|
-
if isinstance(model_name, AbstractModel):
|
|
161
|
-
return model_name
|
|
162
|
-
if model_name in self.models.keys():
|
|
163
|
-
return self.models[model_name]
|
|
164
|
-
|
|
165
|
-
if path is None:
|
|
166
|
-
path = self.get_model_attribute(model=model_name, attribute="path")
|
|
167
|
-
if model_type is None:
|
|
168
|
-
model_type = self.get_model_attribute(model=model_name, attribute="type")
|
|
169
|
-
return model_type.load(path=os.path.join(self.path, path), reset_paths=self.reset_paths)
|
|
170
|
-
|
|
171
|
-
def construct_model_templates(self, hyperparameters: Union[str, Dict[str, Any]], **kwargs):
|
|
172
|
-
raise NotImplementedError
|
|
173
|
-
|
|
174
|
-
# FIXME: Copy pasted from Tabular
|
|
175
|
-
def get_minimum_model_set(self, model: Union[str, AbstractTimeSeriesModel], include_self: bool = True) -> list:
|
|
176
|
-
"""Gets the minimum set of models that the provided model depends on, including itself.
|
|
177
|
-
Returns a list of model names"""
|
|
178
|
-
if not isinstance(model, str):
|
|
179
|
-
model = model.name
|
|
180
|
-
minimum_model_set = list(nx.bfs_tree(self.model_graph, model, reverse=True))
|
|
181
|
-
if not include_self:
|
|
182
|
-
minimum_model_set = [m for m in minimum_model_set if m != model]
|
|
183
|
-
return minimum_model_set
|
|
184
|
-
|
|
185
|
-
def get_models_info(self, models: List[str] = None) -> Dict[str, Any]:
|
|
186
|
-
if models is None:
|
|
187
|
-
models = self.get_model_names()
|
|
188
|
-
model_info_dict = dict()
|
|
189
|
-
for model in models:
|
|
190
|
-
if isinstance(model, str):
|
|
191
|
-
if model in self.models.keys():
|
|
192
|
-
model = self.models[model]
|
|
193
|
-
if isinstance(model, str):
|
|
194
|
-
model_type = self.get_model_attribute(model=model, attribute="type")
|
|
195
|
-
model_path = os.path.join(self.path, self.get_model_attribute(model=model, attribute="path"))
|
|
196
|
-
model_info_dict[model] = model_type.load_info(path=model_path)
|
|
197
|
-
else:
|
|
198
|
-
model_info_dict[model.name] = model.get_info()
|
|
199
|
-
return model_info_dict
|
|
200
|
-
|
|
201
|
-
@classmethod
|
|
202
|
-
def load_info(cls, path, reset_paths=False, load_model_if_required=True) -> Dict[str, Any]:
|
|
203
|
-
load_path = os.path.join(path, cls.trainer_info_name)
|
|
204
|
-
try:
|
|
205
|
-
return load_pkl.load(path=load_path)
|
|
206
|
-
except: # noqa
|
|
207
|
-
if load_model_if_required:
|
|
208
|
-
trainer = cls.load(path=path, reset_paths=reset_paths)
|
|
209
|
-
return trainer.get_info()
|
|
210
|
-
else:
|
|
211
|
-
raise
|
|
212
|
-
|
|
213
|
-
def save_info(self, include_model_info: bool = False):
|
|
214
|
-
info = self.get_info(include_model_info=include_model_info)
|
|
215
|
-
|
|
216
|
-
save_pkl.save(path=os.path.join(self.path, self.trainer_info_name), object=info)
|
|
217
|
-
save_json.save(path=os.path.join(self.path, self.trainer_info_json_name), obj=info)
|
|
218
|
-
return info
|
|
219
|
-
|
|
220
|
-
def get_info(self, include_model_info: bool = False) -> Dict[str, Any]:
|
|
221
|
-
num_models_trained = len(self.get_model_names())
|
|
222
|
-
if self.model_best is not None:
|
|
223
|
-
best_model = self.model_best
|
|
224
|
-
else:
|
|
225
|
-
try:
|
|
226
|
-
best_model = self.get_model_best()
|
|
227
|
-
except AssertionError:
|
|
228
|
-
best_model = None
|
|
229
|
-
if best_model is not None:
|
|
230
|
-
best_model_score_val = self.get_model_attribute(model=best_model, attribute="val_score")
|
|
231
|
-
else:
|
|
232
|
-
best_model_score_val = None
|
|
233
|
-
|
|
234
|
-
info = {
|
|
235
|
-
"best_model": best_model,
|
|
236
|
-
"best_model_score_val": best_model_score_val,
|
|
237
|
-
"num_models_trained": num_models_trained,
|
|
238
|
-
}
|
|
239
|
-
|
|
240
|
-
if include_model_info:
|
|
241
|
-
info["model_info"] = self.get_models_info()
|
|
242
|
-
|
|
243
|
-
return info
|
|
244
|
-
|
|
245
|
-
def predict(self, *args, **kwargs):
|
|
246
|
-
raise NotImplementedError
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
class AbstractTimeSeriesTrainer(SimpleAbstractTrainer):
|
|
250
|
-
_cached_predictions_filename = "cached_predictions.pkl"
|
|
251
|
-
|
|
252
|
-
def __init__(
|
|
253
|
-
self,
|
|
254
|
-
path: str,
|
|
255
|
-
prediction_length: Optional[int] = 1,
|
|
256
|
-
eval_metric: Union[str, TimeSeriesScorer, None] = None,
|
|
257
|
-
eval_metric_seasonal_period: Optional[int] = None,
|
|
258
|
-
save_data: bool = True,
|
|
259
|
-
enable_ensemble: bool = True,
|
|
260
|
-
verbosity: int = 2,
|
|
261
|
-
val_splitter: Optional[AbstractWindowSplitter] = None,
|
|
262
|
-
refit_every_n_windows: Optional[int] = 1,
|
|
263
|
-
cache_predictions: bool = True,
|
|
264
|
-
**kwargs,
|
|
265
|
-
):
|
|
266
|
-
super().__init__(path=path, save_data=save_data, low_memory=True, **kwargs)
|
|
267
|
-
|
|
268
|
-
self.prediction_length = prediction_length
|
|
269
|
-
self.quantile_levels = kwargs.get("quantile_levels", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
|
|
270
|
-
self.target = kwargs.get("target", "target")
|
|
271
|
-
self.metadata = kwargs.get("metadata", CovariateMetadata())
|
|
272
|
-
self.is_data_saved = False
|
|
273
|
-
self.enable_ensemble = enable_ensemble
|
|
274
|
-
self.ensemble_model_type = TimeSeriesGreedyEnsemble
|
|
275
|
-
|
|
276
|
-
self.verbosity = verbosity
|
|
277
|
-
|
|
278
|
-
# Dict of normal model -> FULL model. FULL models are produced by
|
|
279
|
-
# self.refit_single_full() and self.refit_full().
|
|
280
|
-
self.model_refit_map = {}
|
|
281
|
-
|
|
282
|
-
self.eval_metric: TimeSeriesScorer = check_get_evaluation_metric(eval_metric)
|
|
283
|
-
self.eval_metric_seasonal_period = eval_metric_seasonal_period
|
|
284
|
-
if val_splitter is None:
|
|
285
|
-
val_splitter = ExpandingWindowSplitter(prediction_length=self.prediction_length)
|
|
286
|
-
assert isinstance(val_splitter, AbstractWindowSplitter), "val_splitter must be of type AbstractWindowSplitter"
|
|
287
|
-
self.val_splitter = val_splitter
|
|
288
|
-
self.refit_every_n_windows = refit_every_n_windows
|
|
289
|
-
self.cache_predictions = cache_predictions
|
|
290
|
-
self.hpo_results = {}
|
|
291
|
-
|
|
292
|
-
def save_train_data(self, data: TimeSeriesDataFrame, verbose: bool = True) -> None:
|
|
293
|
-
path = os.path.join(self.path_data, "train.pkl")
|
|
294
|
-
save_pkl.save(path=path, object=data, verbose=verbose)
|
|
295
|
-
|
|
296
|
-
def save_val_data(self, data: TimeSeriesDataFrame, verbose: bool = True) -> None:
|
|
297
|
-
path = os.path.join(self.path_data, "val.pkl")
|
|
298
|
-
save_pkl.save(path=path, object=data, verbose=verbose)
|
|
299
|
-
|
|
300
|
-
def load_train_data(self) -> TimeSeriesDataFrame:
|
|
301
|
-
path = os.path.join(self.path_data, "train.pkl")
|
|
302
|
-
return load_pkl.load(path=path)
|
|
303
|
-
|
|
304
|
-
def load_val_data(self) -> Optional[TimeSeriesDataFrame]:
|
|
305
|
-
path = os.path.join(self.path_data, "val.pkl")
|
|
306
|
-
if os.path.exists(path):
|
|
307
|
-
return load_pkl.load(path=path)
|
|
308
|
-
else:
|
|
309
|
-
return None
|
|
310
|
-
|
|
311
|
-
def load_data(self) -> Tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
|
|
312
|
-
train_data = self.load_train_data()
|
|
313
|
-
val_data = self.load_val_data()
|
|
314
|
-
return train_data, val_data
|
|
315
|
-
|
|
316
|
-
def save(self) -> None:
|
|
317
|
-
models = self.models
|
|
318
|
-
self.models = {}
|
|
319
|
-
|
|
320
|
-
save_pkl.save(path=self.path_pkl, object=self)
|
|
321
|
-
for model in self.models.values():
|
|
322
|
-
model.save()
|
|
323
|
-
|
|
324
|
-
self.models = models
|
|
325
|
-
|
|
326
|
-
def _get_model_oof_predictions(self, model_name: str) -> List[TimeSeriesDataFrame]:
|
|
327
|
-
model_path = os.path.join(self.path, self.get_model_attribute(model=model_name, attribute="path"))
|
|
328
|
-
model_type = self.get_model_attribute(model=model_name, attribute="type")
|
|
329
|
-
return model_type.load_oof_predictions(path=model_path)
|
|
330
|
-
|
|
331
|
-
def _add_model(
|
|
332
|
-
self,
|
|
333
|
-
model: AbstractTimeSeriesModel,
|
|
334
|
-
base_models: List[str] = None,
|
|
335
|
-
):
|
|
336
|
-
"""Add a model to the model graph of the trainer. If the model is an ensemble, also add
|
|
337
|
-
information about dependencies to the model graph (list of models specified via ``base_models``).
|
|
338
|
-
|
|
339
|
-
Parameters
|
|
340
|
-
----------
|
|
341
|
-
model : AbstractTimeSeriesModel
|
|
342
|
-
The model to be added to the model graph.
|
|
343
|
-
base_models : List[str], optional, default None
|
|
344
|
-
If the model is an ensemble, the list of base model names that are included in the ensemble.
|
|
345
|
-
Expected only when ``model`` is a ``AbstractTimeSeriesEnsembleModel``.
|
|
346
|
-
|
|
347
|
-
Raises
|
|
348
|
-
------
|
|
349
|
-
AssertionError
|
|
350
|
-
If ``base_models`` are provided and ``model`` is not a ``AbstractTimeSeriesEnsembleModel``.
|
|
351
|
-
"""
|
|
352
|
-
node_attrs = dict(
|
|
353
|
-
path=os.path.relpath(model.path, self.path).split(os.sep),
|
|
354
|
-
type=type(model),
|
|
355
|
-
fit_time=model.fit_time,
|
|
356
|
-
predict_time=model.predict_time,
|
|
357
|
-
val_score=model.val_score,
|
|
358
|
-
)
|
|
359
|
-
self.model_graph.add_node(model.name, **node_attrs)
|
|
360
|
-
|
|
361
|
-
if base_models:
|
|
362
|
-
assert isinstance(model, AbstractTimeSeriesEnsembleModel)
|
|
363
|
-
for base_model in base_models:
|
|
364
|
-
self.model_graph.add_edge(base_model, model.name)
|
|
365
|
-
|
|
366
|
-
def _get_model_levels(self) -> Dict[str, int]:
|
|
367
|
-
"""Get a dictionary mapping each model to their level in the model graph"""
|
|
368
|
-
|
|
369
|
-
# get nodes without a parent
|
|
370
|
-
rootset = set(self.model_graph.nodes)
|
|
371
|
-
for e in self.model_graph.edges():
|
|
372
|
-
rootset.discard(e[1])
|
|
373
|
-
|
|
374
|
-
# get shortest paths
|
|
375
|
-
paths_from = defaultdict(dict)
|
|
376
|
-
for source_node, paths_to in nx.shortest_path_length(self.model_graph):
|
|
377
|
-
for dest_node in paths_to:
|
|
378
|
-
paths_from[dest_node][source_node] = paths_to[dest_node]
|
|
379
|
-
|
|
380
|
-
# determine levels
|
|
381
|
-
levels = {}
|
|
382
|
-
for n in paths_from:
|
|
383
|
-
levels[n] = max(paths_from[n].get(src, 0) for src in rootset)
|
|
384
|
-
|
|
385
|
-
return levels
|
|
386
|
-
|
|
387
|
-
def get_model_names(self, level: Optional[int] = None, **kwargs) -> List[str]:
|
|
388
|
-
"""Get model names that are registered in the model graph"""
|
|
389
|
-
if level is not None:
|
|
390
|
-
return list(node for node, l in self._get_model_levels().items() if l == level) # noqa: E741
|
|
391
|
-
return list(self.model_graph.nodes)
|
|
392
|
-
|
|
393
|
-
def _train_single(
|
|
394
|
-
self,
|
|
395
|
-
train_data: TimeSeriesDataFrame,
|
|
396
|
-
model: AbstractTimeSeriesModel,
|
|
397
|
-
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
398
|
-
time_limit: Optional[float] = None,
|
|
399
|
-
) -> AbstractTimeSeriesModel:
|
|
400
|
-
"""Train the single model and return the model object that was fitted. This method
|
|
401
|
-
does not save the resulting model."""
|
|
402
|
-
model.fit(
|
|
403
|
-
train_data=train_data,
|
|
404
|
-
val_data=val_data,
|
|
405
|
-
time_limit=time_limit,
|
|
406
|
-
verbosity=self.verbosity,
|
|
407
|
-
val_splitter=self.val_splitter,
|
|
408
|
-
refit_every_n_windows=self.refit_every_n_windows,
|
|
409
|
-
)
|
|
410
|
-
return model
|
|
411
|
-
|
|
412
|
-
def tune_model_hyperparameters(
|
|
413
|
-
self,
|
|
414
|
-
model: AbstractTimeSeriesModel,
|
|
415
|
-
train_data: TimeSeriesDataFrame,
|
|
416
|
-
time_limit: Optional[float] = None,
|
|
417
|
-
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
418
|
-
hyperparameter_tune_kwargs: Union[str, dict] = "auto",
|
|
419
|
-
):
|
|
420
|
-
default_num_trials = None
|
|
421
|
-
if time_limit is None and (
|
|
422
|
-
"num_samples" not in hyperparameter_tune_kwargs or isinstance(hyperparameter_tune_kwargs, str)
|
|
423
|
-
):
|
|
424
|
-
default_num_trials = 10
|
|
425
|
-
|
|
426
|
-
tuning_start_time = time.time()
|
|
427
|
-
with disable_tqdm():
|
|
428
|
-
hpo_models, _ = model.hyperparameter_tune(
|
|
429
|
-
train_data=train_data,
|
|
430
|
-
val_data=val_data,
|
|
431
|
-
hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
|
|
432
|
-
time_limit=time_limit,
|
|
433
|
-
default_num_trials=default_num_trials,
|
|
434
|
-
val_splitter=self.val_splitter,
|
|
435
|
-
refit_every_n_windows=self.refit_every_n_windows,
|
|
436
|
-
)
|
|
437
|
-
total_tuning_time = time.time() - tuning_start_time
|
|
438
|
-
|
|
439
|
-
self.hpo_results[model.name] = hpo_models
|
|
440
|
-
model_names_trained = []
|
|
441
|
-
# add each of the trained HPO configurations to the trained models
|
|
442
|
-
for model_hpo_name, model_info in hpo_models.items():
|
|
443
|
-
model_path = os.path.join(self.path, model_info["path"])
|
|
444
|
-
# Only load model configurations that didn't fail
|
|
445
|
-
if Path(model_path).exists():
|
|
446
|
-
model_hpo = self.load_model(model_hpo_name, path=model_path, model_type=type(model))
|
|
447
|
-
self._add_model(model_hpo)
|
|
448
|
-
model_names_trained.append(model_hpo.name)
|
|
449
|
-
|
|
450
|
-
logger.info(f"\tTrained {len(model_names_trained)} models while tuning {model.name}.")
|
|
451
|
-
|
|
452
|
-
if len(model_names_trained) > 0:
|
|
453
|
-
trained_model_results = [hpo_models[model_name] for model_name in model_names_trained]
|
|
454
|
-
best_model_result = max(trained_model_results, key=lambda x: x["val_score"])
|
|
455
|
-
|
|
456
|
-
logger.info(
|
|
457
|
-
f"\t{best_model_result['val_score']:<7.4f}".ljust(15)
|
|
458
|
-
+ f"= Validation score ({self.eval_metric.name_with_sign})"
|
|
459
|
-
)
|
|
460
|
-
logger.info(f"\t{total_tuning_time:<7.2f} s".ljust(15) + "= Total tuning time")
|
|
461
|
-
logger.debug(f"\tBest hyperparameter configuration: {best_model_result['hyperparameters']}")
|
|
462
|
-
|
|
463
|
-
return model_names_trained
|
|
464
|
-
|
|
465
|
-
def _train_and_save(
|
|
466
|
-
self,
|
|
467
|
-
train_data: TimeSeriesDataFrame,
|
|
468
|
-
model: AbstractTimeSeriesModel,
|
|
469
|
-
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
470
|
-
time_limit: Optional[float] = None,
|
|
471
|
-
) -> List[str]:
|
|
472
|
-
"""Fit and save the given model on given training and validation data and save the trained model.
|
|
473
|
-
|
|
474
|
-
Returns
|
|
475
|
-
-------
|
|
476
|
-
model_names_trained: the list of model names that were successfully trained
|
|
477
|
-
"""
|
|
478
|
-
fit_start_time = time.time()
|
|
479
|
-
model_names_trained = []
|
|
480
|
-
try:
|
|
481
|
-
if time_limit is not None:
|
|
482
|
-
if time_limit <= 0:
|
|
483
|
-
logger.info(f"\tSkipping {model.name} due to lack of time remaining.")
|
|
484
|
-
return model_names_trained
|
|
485
|
-
|
|
486
|
-
model = self._train_single(train_data, model, val_data=val_data, time_limit=time_limit)
|
|
487
|
-
fit_end_time = time.time()
|
|
488
|
-
model.fit_time = model.fit_time or (fit_end_time - fit_start_time)
|
|
489
|
-
|
|
490
|
-
if val_data is not None:
|
|
491
|
-
model.score_and_cache_oof(val_data, store_val_score=True, store_predict_time=True)
|
|
492
|
-
|
|
493
|
-
self._log_scores_and_times(model.val_score, model.fit_time, model.predict_time)
|
|
494
|
-
|
|
495
|
-
self.save_model(model=model)
|
|
496
|
-
except TimeLimitExceeded:
|
|
497
|
-
logger.error(f"\tTime limit exceeded... Skipping {model.name}.")
|
|
498
|
-
except (Exception, MemoryError) as err:
|
|
499
|
-
logger.error(f"\tWarning: Exception caused {model.name} to fail during training... Skipping this model.")
|
|
500
|
-
logger.error(f"\t{err}")
|
|
501
|
-
logger.debug(traceback.format_exc())
|
|
502
|
-
else:
|
|
503
|
-
self._add_model(model=model) # noqa: F821
|
|
504
|
-
model_names_trained.append(model.name) # noqa: F821
|
|
505
|
-
finally:
|
|
506
|
-
del model
|
|
507
|
-
|
|
508
|
-
return model_names_trained
|
|
509
|
-
|
|
510
|
-
def _log_scores_and_times(
|
|
511
|
-
self,
|
|
512
|
-
val_score: Optional[float] = None,
|
|
513
|
-
fit_time: Optional[float] = None,
|
|
514
|
-
predict_time: Optional[float] = None,
|
|
515
|
-
):
|
|
516
|
-
if val_score is not None:
|
|
517
|
-
logger.info(f"\t{val_score:<7.4f}".ljust(15) + f"= Validation score ({self.eval_metric.name_with_sign})")
|
|
518
|
-
if fit_time is not None:
|
|
519
|
-
logger.info(f"\t{fit_time:<7.2f} s".ljust(15) + "= Training runtime")
|
|
520
|
-
if predict_time is not None:
|
|
521
|
-
logger.info(f"\t{predict_time:<7.2f} s".ljust(15) + "= Validation (prediction) runtime")
|
|
522
|
-
|
|
523
|
-
def _train_multi(
|
|
524
|
-
self,
|
|
525
|
-
train_data: TimeSeriesDataFrame,
|
|
526
|
-
hyperparameters: Optional[Union[str, Dict]] = None,
|
|
527
|
-
models: Optional[List[AbstractTimeSeriesModel]] = None,
|
|
528
|
-
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
529
|
-
hyperparameter_tune_kwargs: Optional[Union[str, dict]] = None,
|
|
530
|
-
excluded_model_types: Optional[List[str]] = None,
|
|
531
|
-
time_limit: Optional[float] = None,
|
|
532
|
-
random_seed: Optional[int] = None,
|
|
533
|
-
) -> List[str]:
|
|
534
|
-
logger.info(f"\nStarting training. Start time is {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
535
|
-
|
|
536
|
-
time_start = time.time()
|
|
537
|
-
if hyperparameters is not None:
|
|
538
|
-
hyperparameters = copy.deepcopy(hyperparameters)
|
|
539
|
-
else:
|
|
540
|
-
if models is None:
|
|
541
|
-
raise ValueError("Either models or hyperparameters should be provided")
|
|
542
|
-
|
|
543
|
-
if self.save_data and not self.is_data_saved:
|
|
544
|
-
self.save_train_data(train_data)
|
|
545
|
-
if val_data is not None:
|
|
546
|
-
self.save_val_data(val_data)
|
|
547
|
-
self.is_data_saved = True
|
|
548
|
-
|
|
549
|
-
if models is None:
|
|
550
|
-
models = self.construct_model_templates(
|
|
551
|
-
hyperparameters=hyperparameters,
|
|
552
|
-
hyperparameter_tune=hyperparameter_tune_kwargs is not None, # TODO: remove hyperparameter_tune
|
|
553
|
-
freq=train_data.freq,
|
|
554
|
-
multi_window=self.val_splitter.num_val_windows > 0,
|
|
555
|
-
excluded_model_types=excluded_model_types,
|
|
556
|
-
)
|
|
557
|
-
|
|
558
|
-
logger.info(f"Models that will be trained: {list(m.name for m in models)}")
|
|
559
|
-
|
|
560
|
-
num_base_models = len(models)
|
|
561
|
-
model_names_trained = []
|
|
562
|
-
for i, model in enumerate(models):
|
|
563
|
-
if time_limit is None:
|
|
564
|
-
time_left = None
|
|
565
|
-
time_left_for_model = None
|
|
566
|
-
else:
|
|
567
|
-
time_left = time_limit - (time.time() - time_start)
|
|
568
|
-
if num_base_models > 1 and self.enable_ensemble:
|
|
569
|
-
time_reserved_for_ensemble = min(600.0, time_left / (num_base_models - i + 1))
|
|
570
|
-
logger.debug(f"Reserving {time_reserved_for_ensemble:.1f}s for ensemble")
|
|
571
|
-
else:
|
|
572
|
-
time_reserved_for_ensemble = 0.0
|
|
573
|
-
time_left_for_model = (time_left - time_reserved_for_ensemble) / (num_base_models - i)
|
|
574
|
-
if time_left <= 0:
|
|
575
|
-
logger.info(f"Stopping training due to lack of time remaining. Time left: {time_left:.1f} seconds")
|
|
576
|
-
break
|
|
577
|
-
|
|
578
|
-
if random_seed is not None:
|
|
579
|
-
seed_everything(random_seed)
|
|
580
|
-
|
|
581
|
-
if contains_searchspace(model.get_user_params()):
|
|
582
|
-
fit_log_message = f"Hyperparameter tuning model {model.name}. "
|
|
583
|
-
if time_left is not None:
|
|
584
|
-
fit_log_message += (
|
|
585
|
-
f"Tuning model for up to {time_left_for_model:.1f}s of the {time_left:.1f}s remaining."
|
|
586
|
-
)
|
|
587
|
-
logger.info(fit_log_message)
|
|
588
|
-
with tqdm.external_write_mode():
|
|
589
|
-
model_names_trained += self.tune_model_hyperparameters(
|
|
590
|
-
model,
|
|
591
|
-
time_limit=time_left_for_model,
|
|
592
|
-
train_data=train_data,
|
|
593
|
-
val_data=val_data,
|
|
594
|
-
hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
|
|
595
|
-
)
|
|
596
|
-
else:
|
|
597
|
-
fit_log_message = f"Training timeseries model {model.name}. "
|
|
598
|
-
if time_left is not None:
|
|
599
|
-
fit_log_message += (
|
|
600
|
-
f"Training for up to {time_left_for_model:.1f}s of the {time_left:.1f}s of remaining time."
|
|
601
|
-
)
|
|
602
|
-
logger.info(fit_log_message)
|
|
603
|
-
model_names_trained += self._train_and_save(
|
|
604
|
-
train_data, model=model, val_data=val_data, time_limit=time_left_for_model
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
if self.enable_ensemble:
|
|
608
|
-
models_available_for_ensemble = self.get_model_names(level=0)
|
|
609
|
-
|
|
610
|
-
time_left_for_ensemble = None
|
|
611
|
-
if time_limit is not None:
|
|
612
|
-
time_left_for_ensemble = time_limit - (time.time() - time_start)
|
|
613
|
-
|
|
614
|
-
if time_left_for_ensemble is not None and time_left_for_ensemble <= 0:
|
|
615
|
-
logger.info(
|
|
616
|
-
"Not fitting ensemble due to lack of time remaining. "
|
|
617
|
-
f"Time left: {time_left_for_ensemble:.1f} seconds"
|
|
618
|
-
)
|
|
619
|
-
elif len(models_available_for_ensemble) <= 1:
|
|
620
|
-
logger.info(
|
|
621
|
-
"Not fitting ensemble as "
|
|
622
|
-
+ (
|
|
623
|
-
"no models were successfully trained."
|
|
624
|
-
if not models_available_for_ensemble
|
|
625
|
-
else "only 1 model was trained."
|
|
626
|
-
)
|
|
627
|
-
)
|
|
628
|
-
else:
|
|
629
|
-
try:
|
|
630
|
-
model_names_trained.append(
|
|
631
|
-
self.fit_ensemble(
|
|
632
|
-
data_per_window=self._get_ensemble_oof_data(train_data=train_data, val_data=val_data),
|
|
633
|
-
model_names=models_available_for_ensemble,
|
|
634
|
-
time_limit=time_left_for_ensemble,
|
|
635
|
-
)
|
|
636
|
-
)
|
|
637
|
-
except Exception as err: # noqa
|
|
638
|
-
logger.error(
|
|
639
|
-
"\tWarning: Exception caused ensemble to fail during training... Skipping this model."
|
|
640
|
-
)
|
|
641
|
-
logger.error(f"\t{err}")
|
|
642
|
-
logger.debug(traceback.format_exc())
|
|
643
|
-
|
|
644
|
-
logger.info(f"Training complete. Models trained: {model_names_trained}")
|
|
645
|
-
logger.info(f"Total runtime: {time.time() - time_start:.2f} s")
|
|
646
|
-
try:
|
|
647
|
-
best_model = self.get_model_best()
|
|
648
|
-
logger.info(f"Best model: {best_model}")
|
|
649
|
-
logger.info(f"Best model score: {self.get_model_attribute(best_model, 'val_score'):.4f}")
|
|
650
|
-
except ValueError as e:
|
|
651
|
-
logger.error(str(e))
|
|
652
|
-
|
|
653
|
-
return model_names_trained
|
|
654
|
-
|
|
655
|
-
def _get_ensemble_oof_data(
|
|
656
|
-
self, train_data: TimeSeriesDataFrame, val_data: Optional[TimeSeriesDataFrame]
|
|
657
|
-
) -> List[TimeSeriesDataFrame]:
|
|
658
|
-
if val_data is None:
|
|
659
|
-
return [val_fold for _, val_fold in self.val_splitter.split(train_data)]
|
|
660
|
-
else:
|
|
661
|
-
return [val_data]
|
|
662
|
-
|
|
663
|
-
def _get_ensemble_model_name(self) -> str:
|
|
664
|
-
"""Ensure we don't have name collisions in the ensemble model name"""
|
|
665
|
-
ensemble_name = "WeightedEnsemble"
|
|
666
|
-
increment = 1
|
|
667
|
-
while ensemble_name in self._get_banned_model_names():
|
|
668
|
-
increment += 1
|
|
669
|
-
ensemble_name = f"WeightedEnsemble_{increment}"
|
|
670
|
-
return ensemble_name
|
|
671
|
-
|
|
672
|
-
def fit_ensemble(
|
|
673
|
-
self, data_per_window: List[TimeSeriesDataFrame], model_names: List[str], time_limit: Optional[float] = None
|
|
674
|
-
) -> str:
|
|
675
|
-
logger.info("Fitting simple weighted ensemble.")
|
|
676
|
-
|
|
677
|
-
model_preds: Dict[str, List[TimeSeriesDataFrame]] = {}
|
|
678
|
-
for model_name in model_names:
|
|
679
|
-
model_preds[model_name] = self._get_model_oof_predictions(model_name=model_name)
|
|
680
|
-
|
|
681
|
-
time_start = time.time()
|
|
682
|
-
ensemble = self.ensemble_model_type(
|
|
683
|
-
name=self._get_ensemble_model_name(),
|
|
684
|
-
eval_metric=self.eval_metric,
|
|
685
|
-
eval_metric_seasonal_period=self.eval_metric_seasonal_period,
|
|
686
|
-
target=self.target,
|
|
687
|
-
prediction_length=self.prediction_length,
|
|
688
|
-
path=self.path,
|
|
689
|
-
freq=data_per_window[0].freq,
|
|
690
|
-
quantile_levels=self.quantile_levels,
|
|
691
|
-
metadata=self.metadata,
|
|
692
|
-
)
|
|
693
|
-
ensemble.fit_ensemble(model_preds, data_per_window=data_per_window, time_limit=time_limit)
|
|
694
|
-
ensemble.fit_time = time.time() - time_start
|
|
695
|
-
|
|
696
|
-
predict_time = 0
|
|
697
|
-
for m in ensemble.model_names:
|
|
698
|
-
predict_time += self.get_model_attribute(model=m, attribute="predict_time")
|
|
699
|
-
ensemble.predict_time = predict_time
|
|
700
|
-
|
|
701
|
-
score_per_fold = []
|
|
702
|
-
for window_idx, data in enumerate(data_per_window):
|
|
703
|
-
predictions = ensemble.predict({n: model_preds[n][window_idx] for n in ensemble.model_names})
|
|
704
|
-
score_per_fold.append(self._score_with_predictions(data, predictions))
|
|
705
|
-
ensemble.val_score = np.mean(score_per_fold)
|
|
706
|
-
|
|
707
|
-
self._log_scores_and_times(
|
|
708
|
-
val_score=ensemble.val_score,
|
|
709
|
-
fit_time=ensemble.fit_time,
|
|
710
|
-
predict_time=ensemble.predict_time,
|
|
711
|
-
)
|
|
712
|
-
self._add_model(model=ensemble, base_models=ensemble.model_names)
|
|
713
|
-
self.save_model(model=ensemble)
|
|
714
|
-
return ensemble.name
|
|
715
|
-
|
|
716
|
-
def leaderboard(self, data: Optional[TimeSeriesDataFrame] = None, use_cache: bool = True) -> pd.DataFrame:
|
|
717
|
-
logger.debug("Generating leaderboard for all models trained")
|
|
718
|
-
|
|
719
|
-
model_names = self.get_model_names()
|
|
720
|
-
if len(model_names) == 0:
|
|
721
|
-
logger.warning("Warning: No models were trained during fit. Resulting leaderboard will be empty.")
|
|
722
|
-
|
|
723
|
-
model_info = {}
|
|
724
|
-
for ix, model_name in enumerate(model_names):
|
|
725
|
-
model_info[model_name] = {
|
|
726
|
-
"model": model_name,
|
|
727
|
-
"fit_order": ix + 1,
|
|
728
|
-
"score_val": self.get_model_attribute(model_name, "val_score"),
|
|
729
|
-
"fit_time_marginal": self.get_model_attribute(model_name, "fit_time"),
|
|
730
|
-
"pred_time_val": self.get_model_attribute(model_name, "predict_time"),
|
|
731
|
-
}
|
|
732
|
-
|
|
733
|
-
if data is not None:
|
|
734
|
-
past_data, known_covariates = data.get_model_inputs_for_scoring(
|
|
735
|
-
prediction_length=self.prediction_length, known_covariates_names=self.metadata.known_covariates_real
|
|
736
|
-
)
|
|
737
|
-
logger.info(
|
|
738
|
-
"Additional data provided, testing on additional data. Resulting leaderboard "
|
|
739
|
-
"will be sorted according to test score (`score_test`)."
|
|
740
|
-
)
|
|
741
|
-
model_predictions, pred_time_dict = self.get_model_pred_dict(
|
|
742
|
-
model_names=model_names,
|
|
743
|
-
data=past_data,
|
|
744
|
-
known_covariates=known_covariates,
|
|
745
|
-
record_pred_time=True,
|
|
746
|
-
raise_exception_if_failed=False,
|
|
747
|
-
use_cache=use_cache,
|
|
748
|
-
)
|
|
749
|
-
|
|
750
|
-
for model_name in model_names:
|
|
751
|
-
model_preds = model_predictions[model_name]
|
|
752
|
-
if model_preds is None:
|
|
753
|
-
# Model failed at prediction time
|
|
754
|
-
model_info[model_name]["score_test"] = float("nan")
|
|
755
|
-
model_info[model_name]["pred_time_test"] = float("nan")
|
|
756
|
-
else:
|
|
757
|
-
model_info[model_name]["score_test"] = self._score_with_predictions(data, model_preds)
|
|
758
|
-
model_info[model_name]["pred_time_test"] = pred_time_dict[model_name]
|
|
759
|
-
|
|
760
|
-
explicit_column_order = [
|
|
761
|
-
"model",
|
|
762
|
-
"score_test",
|
|
763
|
-
"score_val",
|
|
764
|
-
"pred_time_test",
|
|
765
|
-
"pred_time_val",
|
|
766
|
-
"fit_time_marginal",
|
|
767
|
-
"fit_order",
|
|
768
|
-
]
|
|
769
|
-
|
|
770
|
-
df = pd.DataFrame(model_info.values(), columns=explicit_column_order)
|
|
771
|
-
if data is None:
|
|
772
|
-
explicit_column_order.remove("score_test")
|
|
773
|
-
explicit_column_order.remove("pred_time_test")
|
|
774
|
-
sort_column = "score_val"
|
|
775
|
-
else:
|
|
776
|
-
sort_column = "score_test"
|
|
777
|
-
|
|
778
|
-
df.sort_values(by=[sort_column, "model"], ascending=[False, False], inplace=True)
|
|
779
|
-
df.reset_index(drop=True, inplace=True)
|
|
780
|
-
|
|
781
|
-
return df[explicit_column_order]
|
|
782
|
-
|
|
783
|
-
def _get_model_for_prediction(self, model: Optional[Union[str, AbstractTimeSeriesModel]] = None) -> str:
|
|
784
|
-
"""Given an optional identifier or model object, return the name of the model with which to predict.
|
|
785
|
-
|
|
786
|
-
If the model is not provided, this method will default to the best model according to the validation score.
|
|
787
|
-
"""
|
|
788
|
-
if model is None:
|
|
789
|
-
if self.model_best is None:
|
|
790
|
-
best_model_name: str = self.get_model_best()
|
|
791
|
-
self.model_best = best_model_name
|
|
792
|
-
logger.info(
|
|
793
|
-
f"Model not specified in predict, will default to the model with the "
|
|
794
|
-
f"best validation score: {self.model_best}",
|
|
795
|
-
)
|
|
796
|
-
return self.model_best
|
|
797
|
-
else:
|
|
798
|
-
if isinstance(model, AbstractTimeSeriesModel):
|
|
799
|
-
return model.name
|
|
800
|
-
else:
|
|
801
|
-
return model
|
|
802
|
-
|
|
803
|
-
def predict(
|
|
804
|
-
self,
|
|
805
|
-
data: TimeSeriesDataFrame,
|
|
806
|
-
known_covariates: Optional[TimeSeriesDataFrame] = None,
|
|
807
|
-
model: Optional[Union[str, AbstractTimeSeriesModel]] = None,
|
|
808
|
-
use_cache: bool = True,
|
|
809
|
-
random_seed: Optional[int] = None,
|
|
810
|
-
**kwargs,
|
|
811
|
-
) -> TimeSeriesDataFrame:
|
|
812
|
-
model_name = self._get_model_for_prediction(model)
|
|
813
|
-
model_pred_dict = self.get_model_pred_dict(
|
|
814
|
-
model_names=[model_name],
|
|
815
|
-
data=data,
|
|
816
|
-
known_covariates=known_covariates,
|
|
817
|
-
use_cache=use_cache,
|
|
818
|
-
random_seed=random_seed,
|
|
819
|
-
)
|
|
820
|
-
return model_pred_dict[model_name]
|
|
821
|
-
|
|
822
|
-
def _score_with_predictions(
|
|
823
|
-
self,
|
|
824
|
-
data: TimeSeriesDataFrame,
|
|
825
|
-
predictions: TimeSeriesDataFrame,
|
|
826
|
-
metric: Union[str, TimeSeriesScorer, None] = None,
|
|
827
|
-
) -> float:
|
|
828
|
-
"""Compute the score measuring how well the predictions align with the data."""
|
|
829
|
-
eval_metric = self.eval_metric if metric is None else check_get_evaluation_metric(metric)
|
|
830
|
-
return eval_metric.score(
|
|
831
|
-
data=data,
|
|
832
|
-
predictions=predictions,
|
|
833
|
-
prediction_length=self.prediction_length,
|
|
834
|
-
target=self.target,
|
|
835
|
-
seasonal_period=self.eval_metric_seasonal_period,
|
|
836
|
-
)
|
|
837
|
-
|
|
838
|
-
def score(
|
|
839
|
-
self,
|
|
840
|
-
data: TimeSeriesDataFrame,
|
|
841
|
-
model: Optional[Union[str, AbstractTimeSeriesModel]] = None,
|
|
842
|
-
metric: Union[str, TimeSeriesScorer, None] = None,
|
|
843
|
-
use_cache: bool = True,
|
|
844
|
-
) -> float:
|
|
845
|
-
eval_metric = self.eval_metric if metric is None else check_get_evaluation_metric(metric)
|
|
846
|
-
scores_dict = self.evaluate(data=data, model=model, metrics=[eval_metric], use_cache=use_cache)
|
|
847
|
-
return scores_dict[eval_metric.name]
|
|
848
|
-
|
|
849
|
-
def evaluate(
|
|
850
|
-
self,
|
|
851
|
-
data: TimeSeriesDataFrame,
|
|
852
|
-
model: Optional[Union[str, AbstractTimeSeriesModel]] = None,
|
|
853
|
-
metrics: Optional[Union[str, TimeSeriesScorer, List[Union[str, TimeSeriesScorer]]]] = None,
|
|
854
|
-
use_cache: bool = True,
|
|
855
|
-
) -> Dict[str, float]:
|
|
856
|
-
past_data, known_covariates = data.get_model_inputs_for_scoring(
|
|
857
|
-
prediction_length=self.prediction_length, known_covariates_names=self.metadata.known_covariates_real
|
|
858
|
-
)
|
|
859
|
-
predictions = self.predict(data=past_data, known_covariates=known_covariates, model=model, use_cache=use_cache)
|
|
860
|
-
if not isinstance(metrics, list): # a single metric is provided
|
|
861
|
-
metrics = [metrics]
|
|
862
|
-
scores_dict = {}
|
|
863
|
-
for metric in metrics:
|
|
864
|
-
eval_metric = self.eval_metric if metric is None else check_get_evaluation_metric(metric)
|
|
865
|
-
scores_dict[eval_metric.name] = self._score_with_predictions(
|
|
866
|
-
data=data, predictions=predictions, metric=eval_metric
|
|
867
|
-
)
|
|
868
|
-
return scores_dict
|
|
869
|
-
|
|
870
|
-
def _predict_model(
|
|
871
|
-
self,
|
|
872
|
-
model: Union[str, AbstractTimeSeriesModel],
|
|
873
|
-
data: TimeSeriesDataFrame,
|
|
874
|
-
model_pred_dict: Dict[str, TimeSeriesDataFrame],
|
|
875
|
-
known_covariates: Optional[TimeSeriesDataFrame] = None,
|
|
876
|
-
) -> TimeSeriesDataFrame:
|
|
877
|
-
"""Generate predictions using the given model.
|
|
878
|
-
|
|
879
|
-
This method assumes that model_pred_dict contains the predictions of all base models, if model is an ensemble.
|
|
880
|
-
"""
|
|
881
|
-
if isinstance(model, str):
|
|
882
|
-
model = self.load_model(model)
|
|
883
|
-
data = self._get_inputs_to_model(model=model, data=data, model_pred_dict=model_pred_dict)
|
|
884
|
-
return model.predict(data, known_covariates=known_covariates)
|
|
885
|
-
|
|
886
|
-
def _get_inputs_to_model(
|
|
887
|
-
self,
|
|
888
|
-
model: str,
|
|
889
|
-
data: TimeSeriesDataFrame,
|
|
890
|
-
model_pred_dict: Dict[str, TimeSeriesDataFrame],
|
|
891
|
-
) -> Union[TimeSeriesDataFrame, Dict[str, TimeSeriesDataFrame]]:
|
|
892
|
-
"""Get the first argument that should be passed to model.predict.
|
|
893
|
-
|
|
894
|
-
This method assumes that model_pred_dict contains the predictions of all base models, if model is an ensemble.
|
|
895
|
-
"""
|
|
896
|
-
model_set = self.get_minimum_model_set(model, include_self=False)
|
|
897
|
-
if model_set:
|
|
898
|
-
for m in model_set:
|
|
899
|
-
if m not in model_pred_dict:
|
|
900
|
-
raise AssertionError(f"Prediction for base model {m} not found in model_pred_dict")
|
|
901
|
-
return {m: model_pred_dict[m] for m in model_set}
|
|
902
|
-
else:
|
|
903
|
-
return data
|
|
904
|
-
|
|
905
|
-
def get_model_pred_dict(
|
|
906
|
-
self,
|
|
907
|
-
model_names: List[str],
|
|
908
|
-
data: TimeSeriesDataFrame,
|
|
909
|
-
known_covariates: Optional[TimeSeriesDataFrame] = None,
|
|
910
|
-
record_pred_time: bool = False,
|
|
911
|
-
raise_exception_if_failed: bool = True,
|
|
912
|
-
use_cache: bool = True,
|
|
913
|
-
random_seed: Optional[int] = None,
|
|
914
|
-
) -> Union[Dict[str, TimeSeriesDataFrame], Tuple[Dict[str, TimeSeriesDataFrame], Dict[str, float]]]:
|
|
915
|
-
"""Return a dictionary with predictions of all models for the given dataset.
|
|
916
|
-
|
|
917
|
-
Parameters
|
|
918
|
-
----------
|
|
919
|
-
model_names
|
|
920
|
-
Names of the model for which the predictions should be produced.
|
|
921
|
-
data
|
|
922
|
-
Time series data to forecast with.
|
|
923
|
-
known_covariates
|
|
924
|
-
Future values of the known covariates.
|
|
925
|
-
record_pred_time
|
|
926
|
-
If True, will additionally return the total prediction times for all models (including the prediction time
|
|
927
|
-
for base models). If False, will only return the model predictions.
|
|
928
|
-
raise_exception_if_failed
|
|
929
|
-
If True, the method will raise an exception if any model crashes during prediction.
|
|
930
|
-
If False, error will be logged and predictions for failed models will contain None.
|
|
931
|
-
use_cache
|
|
932
|
-
If False, will ignore the cache even if it's available.
|
|
933
|
-
"""
|
|
934
|
-
# TODO: Unify design of the method with Tabular
|
|
935
|
-
if self.cache_predictions and use_cache:
|
|
936
|
-
dataset_hash = self._compute_dataset_hash(data=data, known_covariates=known_covariates)
|
|
937
|
-
model_pred_dict, pred_time_dict_marginal = self._get_cached_pred_dicts(dataset_hash)
|
|
938
|
-
else:
|
|
939
|
-
model_pred_dict = {}
|
|
940
|
-
pred_time_dict_marginal = {}
|
|
941
|
-
|
|
942
|
-
model_set = set()
|
|
943
|
-
for model_name in model_names:
|
|
944
|
-
model_set.update(self.get_minimum_model_set(model_name))
|
|
945
|
-
if len(model_set) > 1:
|
|
946
|
-
model_to_level = self._get_model_levels()
|
|
947
|
-
model_set = sorted(model_set, key=model_to_level.get)
|
|
948
|
-
logger.debug(f"Prediction order: {model_set}")
|
|
949
|
-
|
|
950
|
-
failed_models = []
|
|
951
|
-
for model_name in model_set:
|
|
952
|
-
if model_name not in model_pred_dict:
|
|
953
|
-
if random_seed is not None:
|
|
954
|
-
seed_everything(random_seed)
|
|
955
|
-
try:
|
|
956
|
-
predict_start_time = time.time()
|
|
957
|
-
model_pred_dict[model_name] = self._predict_model(
|
|
958
|
-
model=model_name,
|
|
959
|
-
data=data,
|
|
960
|
-
known_covariates=known_covariates,
|
|
961
|
-
model_pred_dict=model_pred_dict,
|
|
962
|
-
)
|
|
963
|
-
pred_time_dict_marginal[model_name] = time.time() - predict_start_time
|
|
964
|
-
except Exception:
|
|
965
|
-
failed_models.append(model_name)
|
|
966
|
-
logger.error(f"Model {model_name} failed to predict with the following exception:")
|
|
967
|
-
logger.error(traceback.format_exc())
|
|
968
|
-
model_pred_dict[model_name] = None
|
|
969
|
-
pred_time_dict_marginal[model_name] = None
|
|
970
|
-
|
|
971
|
-
if len(failed_models) > 0 and raise_exception_if_failed:
|
|
972
|
-
raise RuntimeError(f"Following models failed to predict: {failed_models}")
|
|
973
|
-
if self.cache_predictions and use_cache:
|
|
974
|
-
self._save_cached_pred_dicts(
|
|
975
|
-
dataset_hash, model_pred_dict=model_pred_dict, pred_time_dict=pred_time_dict_marginal
|
|
976
|
-
)
|
|
977
|
-
pred_time_dict_total = self._get_total_pred_time_from_marginal(pred_time_dict_marginal)
|
|
978
|
-
|
|
979
|
-
final_model_pred_dict = {model_name: model_pred_dict[model_name] for model_name in model_names}
|
|
980
|
-
final_pred_time_dict_total = {model_name: pred_time_dict_total[model_name] for model_name in model_names}
|
|
981
|
-
if record_pred_time:
|
|
982
|
-
return final_model_pred_dict, final_pred_time_dict_total
|
|
983
|
-
else:
|
|
984
|
-
return final_model_pred_dict
|
|
985
|
-
|
|
986
|
-
def _get_total_pred_time_from_marginal(self, pred_time_dict_marginal: Dict[str, float]) -> Dict[str, float]:
|
|
987
|
-
pred_time_dict_total = defaultdict(float)
|
|
988
|
-
for model_name in pred_time_dict_marginal.keys():
|
|
989
|
-
for base_model in self.get_minimum_model_set(model_name):
|
|
990
|
-
if pred_time_dict_marginal[base_model] is not None:
|
|
991
|
-
pred_time_dict_total[model_name] += pred_time_dict_marginal[base_model]
|
|
992
|
-
return dict(pred_time_dict_total)
|
|
993
|
-
|
|
994
|
-
@property
|
|
995
|
-
def _cached_predictions_path(self) -> Path:
|
|
996
|
-
return Path(self.path) / self._cached_predictions_filename
|
|
997
|
-
|
|
998
|
-
@staticmethod
|
|
999
|
-
def _compute_dataset_hash(
|
|
1000
|
-
data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame] = None
|
|
1001
|
-
) -> str:
|
|
1002
|
-
"""Compute a unique string that identifies the time series dataset."""
|
|
1003
|
-
combined_hash = hash_pandas_df(data) + hash_pandas_df(known_covariates) + hash_pandas_df(data.static_features)
|
|
1004
|
-
return combined_hash
|
|
1005
|
-
|
|
1006
|
-
def _get_cached_pred_dicts(self, dataset_hash: str) -> Tuple[Dict[str, TimeSeriesDataFrame], Dict[str, float]]:
|
|
1007
|
-
"""Load cached predictions for given dataset_hash from disk, if possible. Otherwise returns empty dicts."""
|
|
1008
|
-
if self._cached_predictions_path.exists():
|
|
1009
|
-
cached_predictions = load_pkl.load(str(self._cached_predictions_path))
|
|
1010
|
-
if dataset_hash in cached_predictions:
|
|
1011
|
-
model_pred_dict = cached_predictions[dataset_hash]["model_pred_dict"]
|
|
1012
|
-
pred_time_dict = cached_predictions[dataset_hash]["pred_time_dict"]
|
|
1013
|
-
if model_pred_dict.keys() == pred_time_dict.keys():
|
|
1014
|
-
logger.debug(f"Loaded cached predictions for models {list(model_pred_dict.keys())}")
|
|
1015
|
-
return model_pred_dict, pred_time_dict
|
|
1016
|
-
else:
|
|
1017
|
-
logger.warning(f"Found corrupted cached predictions in {self._cached_predictions_path}")
|
|
1018
|
-
logger.debug("Found no cached predictions")
|
|
1019
|
-
return {}, {}
|
|
1020
|
-
|
|
1021
|
-
def _save_cached_pred_dicts(
|
|
1022
|
-
self, dataset_hash: str, model_pred_dict: Dict[str, TimeSeriesDataFrame], pred_time_dict: Dict[str, float]
|
|
1023
|
-
) -> None:
|
|
1024
|
-
# TODO: Save separate file for each dataset if _cached_predictions file grows large?
|
|
1025
|
-
if self._cached_predictions_path.exists():
|
|
1026
|
-
logger.debug("Extending existing cached predictions")
|
|
1027
|
-
cached_predictions = load_pkl.load(str(self._cached_predictions_path))
|
|
1028
|
-
else:
|
|
1029
|
-
cached_predictions = {}
|
|
1030
|
-
# Do not save results for models that failed
|
|
1031
|
-
cached_predictions[dataset_hash] = {
|
|
1032
|
-
"model_pred_dict": {k: v for k, v in model_pred_dict.items() if v is not None},
|
|
1033
|
-
"pred_time_dict": {k: v for k, v in pred_time_dict.items() if v is not None},
|
|
1034
|
-
}
|
|
1035
|
-
save_pkl.save(str(self._cached_predictions_path), object=cached_predictions)
|
|
1036
|
-
logger.debug(f"Cached predictions saved to {self._cached_predictions_path}")
|
|
1037
|
-
|
|
1038
|
-
def _merge_refit_full_data(
|
|
1039
|
-
self, train_data: TimeSeriesDataFrame, val_data: Optional[TimeSeriesDataFrame]
|
|
1040
|
-
) -> TimeSeriesDataFrame:
|
|
1041
|
-
if val_data is None:
|
|
1042
|
-
return train_data
|
|
1043
|
-
else:
|
|
1044
|
-
# TODO: Implement merging of arbitrary tuning_data with train_data
|
|
1045
|
-
raise NotImplementedError("refit_full is not supported if custom val_data is provided.")
|
|
1046
|
-
|
|
1047
|
-
def refit_single_full(
|
|
1048
|
-
self,
|
|
1049
|
-
train_data: Optional[TimeSeriesDataFrame] = None,
|
|
1050
|
-
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
1051
|
-
models: List[str] = None,
|
|
1052
|
-
) -> List[str]:
|
|
1053
|
-
train_data = train_data or self.load_train_data()
|
|
1054
|
-
val_data = val_data or self.load_val_data()
|
|
1055
|
-
refit_full_data = self._merge_refit_full_data(train_data, val_data)
|
|
1056
|
-
|
|
1057
|
-
if models is None:
|
|
1058
|
-
models = self.get_model_names()
|
|
1059
|
-
|
|
1060
|
-
model_to_level = self._get_model_levels()
|
|
1061
|
-
models_sorted_by_level = sorted(models, key=model_to_level.get)
|
|
1062
|
-
|
|
1063
|
-
model_refit_map = {}
|
|
1064
|
-
models_trained_full = []
|
|
1065
|
-
for model in models_sorted_by_level:
|
|
1066
|
-
model = self.load_model(model)
|
|
1067
|
-
model_name = model.name
|
|
1068
|
-
if model._get_tags()["can_refit_full"]:
|
|
1069
|
-
model_full = model.convert_to_refit_full_template()
|
|
1070
|
-
logger.info(f"Fitting model: {model_full.name}")
|
|
1071
|
-
models_trained = self._train_and_save(
|
|
1072
|
-
train_data=refit_full_data,
|
|
1073
|
-
val_data=None,
|
|
1074
|
-
model=model_full,
|
|
1075
|
-
)
|
|
1076
|
-
else:
|
|
1077
|
-
model_full = model.convert_to_refit_full_via_copy()
|
|
1078
|
-
logger.info(f"Fitting model: {model_full.name} | Skipping fit via cloning parent ...")
|
|
1079
|
-
models_trained = [model_full.name]
|
|
1080
|
-
if isinstance(model_full, AbstractTimeSeriesEnsembleModel):
|
|
1081
|
-
model_full.remap_base_models(model_refit_map)
|
|
1082
|
-
self._add_model(model_full, base_models=model_full.model_names)
|
|
1083
|
-
else:
|
|
1084
|
-
self._add_model(model_full)
|
|
1085
|
-
self.save_model(model_full)
|
|
1086
|
-
|
|
1087
|
-
if len(models_trained) == 1:
|
|
1088
|
-
model_refit_map[model_name] = models_trained[0]
|
|
1089
|
-
models_trained_full += models_trained
|
|
1090
|
-
|
|
1091
|
-
self.model_refit_map.update(model_refit_map)
|
|
1092
|
-
self.save()
|
|
1093
|
-
return models_trained_full
|
|
1094
|
-
|
|
1095
|
-
def refit_full(self, model: str = "all") -> Dict[str, str]:
|
|
1096
|
-
time_start = time.time()
|
|
1097
|
-
existing_models = self.get_model_names()
|
|
1098
|
-
if model == "all":
|
|
1099
|
-
model_names = existing_models
|
|
1100
|
-
elif model == "best":
|
|
1101
|
-
model_names = self.get_minimum_model_set(self.get_model_best())
|
|
1102
|
-
else:
|
|
1103
|
-
model_names = self.get_minimum_model_set(model)
|
|
1104
|
-
|
|
1105
|
-
valid_model_set = []
|
|
1106
|
-
for name in model_names:
|
|
1107
|
-
if name in self.model_refit_map and self.model_refit_map[name] in existing_models:
|
|
1108
|
-
logger.info(
|
|
1109
|
-
f"Model '{name}' already has a refit _FULL model: "
|
|
1110
|
-
f"'{self.model_refit_map[name]}', skipping refit..."
|
|
1111
|
-
)
|
|
1112
|
-
elif name in self.model_refit_map.values():
|
|
1113
|
-
logger.debug(f"Model '{name}' is a refit _FULL model, skipping refit...")
|
|
1114
|
-
else:
|
|
1115
|
-
valid_model_set.append(name)
|
|
1116
|
-
|
|
1117
|
-
if valid_model_set:
|
|
1118
|
-
models_trained_full = self.refit_single_full(models=valid_model_set)
|
|
1119
|
-
else:
|
|
1120
|
-
models_trained_full = []
|
|
1121
|
-
|
|
1122
|
-
self.save()
|
|
1123
|
-
logger.info(f"Refit complete. Models trained: {models_trained_full}")
|
|
1124
|
-
logger.info(f"Total runtime: {time.time() - time_start:.2f} s")
|
|
1125
|
-
return copy.deepcopy(self.model_refit_map)
|
|
1126
|
-
|
|
1127
|
-
def construct_model_templates(
|
|
1128
|
-
self, hyperparameters: Union[str, Dict[str, Any]], multi_window: bool = False, **kwargs
|
|
1129
|
-
) -> List[AbstractTimeSeriesModel]:
|
|
1130
|
-
"""Constructs a list of unfit models based on the hyperparameters dict."""
|
|
1131
|
-
raise NotImplementedError
|
|
1132
|
-
|
|
1133
|
-
def fit(
|
|
1134
|
-
self,
|
|
1135
|
-
train_data: TimeSeriesDataFrame,
|
|
1136
|
-
hyperparameters: Dict[str, Any],
|
|
1137
|
-
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
1138
|
-
**kwargs,
|
|
1139
|
-
) -> None:
|
|
1140
|
-
raise NotImplementedError
|
|
1141
|
-
|
|
1142
|
-
# TODO: def _filter_base_models_via_infer_limit
|
|
1143
|
-
|
|
1144
|
-
# TODO: persist and unpersist models
|