autogluon.timeseries 1.3.2b20250712__py3-none-any.whl → 1.4.1b20251116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +84 -0
- autogluon/timeseries/dataset/ts_dataframe.py +98 -72
- autogluon/timeseries/learner.py +19 -18
- autogluon/timeseries/metrics/__init__.py +5 -5
- autogluon/timeseries/metrics/abstract.py +17 -17
- autogluon/timeseries/metrics/point.py +1 -1
- autogluon/timeseries/metrics/quantile.py +2 -2
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +4 -0
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -75
- autogluon/timeseries/models/abstract/tunable.py +6 -6
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +72 -76
- autogluon/timeseries/models/autogluon_tabular/per_step.py +104 -46
- autogluon/timeseries/models/autogluon_tabular/transforms.py +9 -7
- autogluon/timeseries/models/chronos/model.py +115 -78
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +76 -44
- autogluon/timeseries/models/ensemble/__init__.py +29 -2
- autogluon/timeseries/models/ensemble/abstract.py +16 -52
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +247 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +50 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +10 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +87 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +133 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +141 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +41 -0
- autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +8 -18
- autogluon/timeseries/models/ensemble/{greedy.py → weighted/greedy.py} +13 -13
- autogluon/timeseries/models/gluonts/abstract.py +26 -26
- autogluon/timeseries/models/gluonts/dataset.py +4 -4
- autogluon/timeseries/models/gluonts/models.py +27 -12
- autogluon/timeseries/models/local/abstract_local_model.py +14 -14
- autogluon/timeseries/models/local/naive.py +4 -0
- autogluon/timeseries/models/local/npts.py +1 -0
- autogluon/timeseries/models/local/statsforecast.py +30 -14
- autogluon/timeseries/models/multi_window/multi_window_model.py +34 -23
- autogluon/timeseries/models/registry.py +65 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +94 -107
- autogluon/timeseries/regressor.py +31 -27
- autogluon/timeseries/splitter.py +7 -31
- autogluon/timeseries/trainer/__init__.py +3 -0
- autogluon/timeseries/trainer/ensemble_composer.py +250 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/{trainer.py → trainer/trainer.py} +182 -307
- autogluon/timeseries/trainer/utils.py +18 -0
- autogluon/timeseries/transforms/covariate_scaler.py +4 -4
- autogluon/timeseries/transforms/target_scaler.py +14 -14
- autogluon/timeseries/utils/datetime/lags.py +2 -2
- autogluon/timeseries/utils/datetime/time_features.py +2 -2
- autogluon/timeseries/utils/features.py +41 -37
- autogluon/timeseries/utils/forecast.py +5 -5
- autogluon/timeseries/utils/warning_filters.py +3 -1
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.4.1b20251116-py3.9-nspkg.pth +1 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/METADATA +32 -17
- autogluon_timeseries-1.4.1b20251116.dist-info/RECORD +96 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -79
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -530
- autogluon/timeseries/models/presets.py +0 -358
- autogluon.timeseries-1.3.2b20250712-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.3.2b20250712.dist-info/RECORD +0 -71
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
import re
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from typing import Any, Optional, Type, Union
|
|
6
|
+
|
|
7
|
+
from autogluon.common import space
|
|
8
|
+
from autogluon.core import constants
|
|
9
|
+
from autogluon.timeseries.configs import get_hyperparameter_presets
|
|
10
|
+
from autogluon.timeseries.metrics import TimeSeriesScorer
|
|
11
|
+
from autogluon.timeseries.models import ModelRegistry
|
|
12
|
+
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
13
|
+
from autogluon.timeseries.models.multi_window import MultiWindowBacktestingModel
|
|
14
|
+
from autogluon.timeseries.utils.features import CovariateMetadata
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
ModelKey = Union[str, Type[AbstractTimeSeriesModel]]
|
|
20
|
+
ModelHyperparameters = dict[str, Any]
|
|
21
|
+
TrainerHyperparameterSpec = dict[ModelKey, list[ModelHyperparameters]]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TrainableModelSetBuilder:
|
|
25
|
+
"""Responsible for building a list of model objects, in priority order, that will be trained by the
|
|
26
|
+
Trainer."""
|
|
27
|
+
|
|
28
|
+
VALID_AG_ARGS_KEYS = {
|
|
29
|
+
"name",
|
|
30
|
+
"name_prefix",
|
|
31
|
+
"name_suffix",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
path: str,
|
|
37
|
+
freq: Optional[str],
|
|
38
|
+
prediction_length: int,
|
|
39
|
+
eval_metric: TimeSeriesScorer,
|
|
40
|
+
target: str,
|
|
41
|
+
quantile_levels: list[float],
|
|
42
|
+
covariate_metadata: CovariateMetadata,
|
|
43
|
+
multi_window: bool,
|
|
44
|
+
):
|
|
45
|
+
self.path = path
|
|
46
|
+
self.freq = freq
|
|
47
|
+
self.prediction_length = prediction_length
|
|
48
|
+
self.eval_metric = eval_metric
|
|
49
|
+
self.target = target
|
|
50
|
+
self.quantile_levels = quantile_levels
|
|
51
|
+
self.covariate_metadata = covariate_metadata
|
|
52
|
+
self.multi_window = multi_window
|
|
53
|
+
|
|
54
|
+
def get_model_set(
|
|
55
|
+
self,
|
|
56
|
+
hyperparameters: Union[str, dict, None],
|
|
57
|
+
hyperparameter_tune: bool,
|
|
58
|
+
excluded_model_types: Optional[list[str]],
|
|
59
|
+
banned_model_names: Optional[list[str]] = None,
|
|
60
|
+
) -> list[AbstractTimeSeriesModel]:
|
|
61
|
+
"""Resolve hyperparameters and create the requested list of models"""
|
|
62
|
+
models = []
|
|
63
|
+
banned_model_names = [] if banned_model_names is None else banned_model_names.copy()
|
|
64
|
+
|
|
65
|
+
# resolve and normalize hyperparameters
|
|
66
|
+
model_hp_map: TrainerHyperparameterSpec = HyperparameterBuilder(
|
|
67
|
+
hyperparameters=hyperparameters,
|
|
68
|
+
hyperparameter_tune=hyperparameter_tune,
|
|
69
|
+
excluded_model_types=excluded_model_types,
|
|
70
|
+
).get_hyperparameters()
|
|
71
|
+
|
|
72
|
+
for k in model_hp_map.keys():
|
|
73
|
+
if isinstance(k, type) and not issubclass(k, AbstractTimeSeriesModel):
|
|
74
|
+
raise ValueError(f"Custom model type {k} must inherit from `AbstractTimeSeriesModel`.")
|
|
75
|
+
|
|
76
|
+
model_priority_list = sorted(
|
|
77
|
+
model_hp_map.keys(), key=lambda x: ModelRegistry.get_model_priority(x), reverse=True
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
for model_key in model_priority_list:
|
|
81
|
+
model_type = self._get_model_type(model_key)
|
|
82
|
+
|
|
83
|
+
for model_hps in model_hp_map[model_key]:
|
|
84
|
+
ag_args = model_hps.pop(constants.AG_ARGS, {})
|
|
85
|
+
|
|
86
|
+
for key in ag_args:
|
|
87
|
+
if key not in self.VALID_AG_ARGS_KEYS:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"Model {model_type} received unknown ag_args key: {key} (valid keys {self.VALID_AG_ARGS_KEYS})"
|
|
90
|
+
)
|
|
91
|
+
model_name_base = self._get_model_name(ag_args, model_type)
|
|
92
|
+
|
|
93
|
+
model_type_kwargs: dict[str, Any] = dict(
|
|
94
|
+
name=model_name_base,
|
|
95
|
+
hyperparameters=model_hps,
|
|
96
|
+
**self._get_default_model_init_kwargs(),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# add models while preventing name collisions
|
|
100
|
+
model = model_type(**model_type_kwargs)
|
|
101
|
+
model_type_kwargs.pop("name", None)
|
|
102
|
+
|
|
103
|
+
increment = 1
|
|
104
|
+
while model.name in banned_model_names:
|
|
105
|
+
increment += 1
|
|
106
|
+
model = model_type(name=f"{model_name_base}_{increment}", **model_type_kwargs)
|
|
107
|
+
|
|
108
|
+
if self.multi_window:
|
|
109
|
+
model = MultiWindowBacktestingModel(model_base=model, name=model.name, **model_type_kwargs) # type: ignore
|
|
110
|
+
|
|
111
|
+
banned_model_names.append(model.name)
|
|
112
|
+
models.append(model)
|
|
113
|
+
|
|
114
|
+
return models
|
|
115
|
+
|
|
116
|
+
def _get_model_type(self, model: ModelKey) -> Type[AbstractTimeSeriesModel]:
|
|
117
|
+
if isinstance(model, str):
|
|
118
|
+
model_type: Type[AbstractTimeSeriesModel] = ModelRegistry.get_model_class(model)
|
|
119
|
+
elif isinstance(model, type):
|
|
120
|
+
model_type = model
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Keys of the `hyperparameters` dictionary must be strings or types, received {type(model)}."
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return model_type
|
|
127
|
+
|
|
128
|
+
def _get_default_model_init_kwargs(self) -> dict[str, Any]:
|
|
129
|
+
return dict(
|
|
130
|
+
path=self.path,
|
|
131
|
+
freq=self.freq,
|
|
132
|
+
prediction_length=self.prediction_length,
|
|
133
|
+
eval_metric=self.eval_metric,
|
|
134
|
+
target=self.target,
|
|
135
|
+
quantile_levels=self.quantile_levels,
|
|
136
|
+
covariate_metadata=self.covariate_metadata,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def _get_model_name(self, ag_args: dict[str, Any], model_type: Type[AbstractTimeSeriesModel]) -> str:
|
|
140
|
+
name = ag_args.get("name")
|
|
141
|
+
if name is None:
|
|
142
|
+
name_stem = re.sub(r"Model$", "", model_type.__name__)
|
|
143
|
+
name_prefix = ag_args.get("name_prefix", "")
|
|
144
|
+
name_suffix = ag_args.get("name_suffix", "")
|
|
145
|
+
name = name_prefix + name_stem + name_suffix
|
|
146
|
+
return name
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class HyperparameterBuilder:
|
|
150
|
+
"""Given user hyperparameter specifications, this class resolves them against presets, removes
|
|
151
|
+
excluded model types and canonicalizes the hyperparameter specification.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
hyperparameters: Union[str, dict, None],
|
|
157
|
+
hyperparameter_tune: bool,
|
|
158
|
+
excluded_model_types: Optional[list[str]],
|
|
159
|
+
):
|
|
160
|
+
self.hyperparameters = hyperparameters
|
|
161
|
+
self.hyperparameter_tune = hyperparameter_tune
|
|
162
|
+
self.excluded_model_types = excluded_model_types
|
|
163
|
+
|
|
164
|
+
def get_hyperparameters(self) -> TrainerHyperparameterSpec:
|
|
165
|
+
hyperparameter_dict = {}
|
|
166
|
+
hp_presets = get_hyperparameter_presets()
|
|
167
|
+
|
|
168
|
+
if self.hyperparameters is None:
|
|
169
|
+
hyperparameter_dict = hp_presets["default"]
|
|
170
|
+
elif isinstance(self.hyperparameters, str):
|
|
171
|
+
try:
|
|
172
|
+
hyperparameter_dict = hp_presets[self.hyperparameters]
|
|
173
|
+
except KeyError:
|
|
174
|
+
raise ValueError(f"{self.hyperparameters} is not a valid preset.")
|
|
175
|
+
elif isinstance(self.hyperparameters, dict):
|
|
176
|
+
hyperparameter_dict = copy.deepcopy(self.hyperparameters)
|
|
177
|
+
else:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"hyperparameters must be a dict, a string or None (received {type(self.hyperparameters)}). "
|
|
180
|
+
f"Please see the documentation for TimeSeriesPredictor.fit"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return self._check_and_clean_hyperparameters(hyperparameter_dict) # type: ignore
|
|
184
|
+
|
|
185
|
+
def _check_and_clean_hyperparameters(
|
|
186
|
+
self,
|
|
187
|
+
hyperparameters: dict[ModelKey, Union[ModelHyperparameters, list[ModelHyperparameters]]],
|
|
188
|
+
) -> TrainerHyperparameterSpec:
|
|
189
|
+
"""Convert the hyperparameters dictionary to a unified format:
|
|
190
|
+
- Remove 'Model' suffix from model names, if present
|
|
191
|
+
- Make sure that each value in the hyperparameters dict is a list with model configurations
|
|
192
|
+
- Checks if hyperparameters contain searchspaces
|
|
193
|
+
"""
|
|
194
|
+
excluded_models = self._get_excluded_models()
|
|
195
|
+
hyperparameters_clean = defaultdict(list)
|
|
196
|
+
for model_name, model_hyperparameters in hyperparameters.items():
|
|
197
|
+
# Handle model names ending with "Model", e.g., "DeepARModel" is mapped to "DeepAR"
|
|
198
|
+
if isinstance(model_name, str):
|
|
199
|
+
model_name = self._normalize_model_type_name(model_name)
|
|
200
|
+
if model_name in excluded_models:
|
|
201
|
+
logger.info(
|
|
202
|
+
f"\tFound '{model_name}' model in `hyperparameters`, but '{model_name}' "
|
|
203
|
+
"is present in `excluded_model_types` and will be removed."
|
|
204
|
+
)
|
|
205
|
+
continue
|
|
206
|
+
if not isinstance(model_hyperparameters, list):
|
|
207
|
+
model_hyperparameters = [model_hyperparameters]
|
|
208
|
+
hyperparameters_clean[model_name].extend(model_hyperparameters)
|
|
209
|
+
|
|
210
|
+
self._verify_searchspaces(hyperparameters_clean)
|
|
211
|
+
|
|
212
|
+
return dict(hyperparameters_clean)
|
|
213
|
+
|
|
214
|
+
def _get_excluded_models(self) -> set[str]:
|
|
215
|
+
excluded_models = set()
|
|
216
|
+
if self.excluded_model_types is not None and len(self.excluded_model_types) > 0:
|
|
217
|
+
if not isinstance(self.excluded_model_types, list):
|
|
218
|
+
raise ValueError(f"`excluded_model_types` must be a list, received {type(self.excluded_model_types)}")
|
|
219
|
+
logger.info(f"Excluded model types: {self.excluded_model_types}")
|
|
220
|
+
for model in self.excluded_model_types:
|
|
221
|
+
if not isinstance(model, str):
|
|
222
|
+
raise ValueError(f"Each entry in `excluded_model_types` must be a string, received {type(model)}")
|
|
223
|
+
excluded_models.add(self._normalize_model_type_name(model))
|
|
224
|
+
return excluded_models
|
|
225
|
+
|
|
226
|
+
@staticmethod
|
|
227
|
+
def _normalize_model_type_name(model_name: str) -> str:
|
|
228
|
+
return model_name.removesuffix("Model")
|
|
229
|
+
|
|
230
|
+
def _verify_searchspaces(self, hyperparameters: dict[str, list[ModelHyperparameters]]):
|
|
231
|
+
if self.hyperparameter_tune:
|
|
232
|
+
for model, model_hps_list in hyperparameters.items():
|
|
233
|
+
for model_hps in model_hps_list:
|
|
234
|
+
if contains_searchspace(model_hps):
|
|
235
|
+
return
|
|
236
|
+
|
|
237
|
+
raise ValueError(
|
|
238
|
+
"Hyperparameter tuning specified, but no model contains a hyperparameter search space. "
|
|
239
|
+
"Please disable hyperparameter tuning with `hyperparameter_tune_kwargs=None` or provide a search space "
|
|
240
|
+
"for at least one model."
|
|
241
|
+
)
|
|
242
|
+
else:
|
|
243
|
+
for model, model_hps_list in hyperparameters.items():
|
|
244
|
+
for model_hps in model_hps_list:
|
|
245
|
+
if contains_searchspace(model_hps):
|
|
246
|
+
raise ValueError(
|
|
247
|
+
f"Hyperparameter tuning not specified, so hyperparameters must have fixed values. "
|
|
248
|
+
f"However, for model {model} hyperparameters {model_hps} contain a search space."
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def contains_searchspace(model_hyperparameters: ModelHyperparameters) -> bool:
|
|
253
|
+
for hp_value in model_hyperparameters.values():
|
|
254
|
+
if isinstance(hp_value, space.Space):
|
|
255
|
+
return True
|
|
256
|
+
return False
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
from autogluon.common.utils.utils import hash_pandas_df
|
|
7
|
+
from autogluon.core.utils.loaders import load_pkl
|
|
8
|
+
from autogluon.core.utils.savers import save_pkl
|
|
9
|
+
from autogluon.timeseries import TimeSeriesDataFrame
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PredictionCache(ABC):
|
|
15
|
+
"""A prediction cache is an abstract key-value store for time series predictions. The storage is keyed by
|
|
16
|
+
(data, known_covariates) pairs and stores (model_pred_dict, pred_time_dict) pair values. In this stored pair,
|
|
17
|
+
(model_pred_dict, pred_time_dict), both dictionaries are keyed by model names.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, root_path: str):
|
|
21
|
+
self.root_path = Path(root_path)
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def get(
|
|
25
|
+
self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame]
|
|
26
|
+
) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def put(
|
|
31
|
+
self,
|
|
32
|
+
data: TimeSeriesDataFrame,
|
|
33
|
+
known_covariates: Optional[TimeSeriesDataFrame],
|
|
34
|
+
model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
|
|
35
|
+
pred_time_dict: dict[str, float],
|
|
36
|
+
) -> None:
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def clear(self) -> None:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_prediction_cache(use_cache: bool, root_path: str) -> PredictionCache:
|
|
45
|
+
if use_cache:
|
|
46
|
+
return FileBasedPredictionCache(root_path=root_path)
|
|
47
|
+
else:
|
|
48
|
+
return NoOpPredictionCache(root_path=root_path)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def compute_dataset_hash(data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame] = None) -> str:
|
|
52
|
+
"""Compute a unique string that identifies the time series dataset."""
|
|
53
|
+
combined_hash = hash_pandas_df(data) + hash_pandas_df(known_covariates) + hash_pandas_df(data.static_features)
|
|
54
|
+
return combined_hash
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class NoOpPredictionCache(PredictionCache):
|
|
58
|
+
"""A dummy (no-op) prediction cache."""
|
|
59
|
+
|
|
60
|
+
def get(
|
|
61
|
+
self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame]
|
|
62
|
+
) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
|
|
63
|
+
return {}, {}
|
|
64
|
+
|
|
65
|
+
def put(
|
|
66
|
+
self,
|
|
67
|
+
data: TimeSeriesDataFrame,
|
|
68
|
+
known_covariates: Optional[TimeSeriesDataFrame],
|
|
69
|
+
model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
|
|
70
|
+
pred_time_dict: dict[str, float],
|
|
71
|
+
) -> None:
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
def clear(self) -> None:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class FileBasedPredictionCache(PredictionCache):
|
|
79
|
+
"""A file-backed cache of model predictions."""
|
|
80
|
+
|
|
81
|
+
_cached_predictions_filename = "cached_predictions.pkl"
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def path(self) -> Path:
|
|
85
|
+
return Path(self.root_path) / self._cached_predictions_filename
|
|
86
|
+
|
|
87
|
+
def get(
|
|
88
|
+
self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame]
|
|
89
|
+
) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
|
|
90
|
+
dataset_hash = compute_dataset_hash(data, known_covariates)
|
|
91
|
+
return self._get_cached_pred_dicts(dataset_hash)
|
|
92
|
+
|
|
93
|
+
def put(
|
|
94
|
+
self,
|
|
95
|
+
data: TimeSeriesDataFrame,
|
|
96
|
+
known_covariates: Optional[TimeSeriesDataFrame],
|
|
97
|
+
model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
|
|
98
|
+
pred_time_dict: dict[str, float],
|
|
99
|
+
) -> None:
|
|
100
|
+
dataset_hash = compute_dataset_hash(data, known_covariates)
|
|
101
|
+
self._save_cached_pred_dicts(dataset_hash, model_pred_dict, pred_time_dict)
|
|
102
|
+
|
|
103
|
+
def clear(self) -> None:
|
|
104
|
+
if self.path.exists():
|
|
105
|
+
logger.debug(f"Removing existing cached predictions file {self.path}")
|
|
106
|
+
self.path.unlink()
|
|
107
|
+
|
|
108
|
+
def _load_cached_predictions(self) -> dict[str, dict[str, dict[str, Any]]]:
|
|
109
|
+
if self.path.exists():
|
|
110
|
+
try:
|
|
111
|
+
cached_predictions = load_pkl.load(str(self.path))
|
|
112
|
+
except Exception:
|
|
113
|
+
cached_predictions = {}
|
|
114
|
+
else:
|
|
115
|
+
cached_predictions = {}
|
|
116
|
+
return cached_predictions
|
|
117
|
+
|
|
118
|
+
def _get_cached_pred_dicts(
|
|
119
|
+
self, dataset_hash: str
|
|
120
|
+
) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
|
|
121
|
+
"""Load cached predictions for given dataset_hash from disk, if possible.
|
|
122
|
+
|
|
123
|
+
If loading fails for any reason, empty dicts are returned.
|
|
124
|
+
"""
|
|
125
|
+
cached_predictions = self._load_cached_predictions()
|
|
126
|
+
if dataset_hash in cached_predictions:
|
|
127
|
+
try:
|
|
128
|
+
model_pred_dict = cached_predictions[dataset_hash]["model_pred_dict"]
|
|
129
|
+
pred_time_dict = cached_predictions[dataset_hash]["pred_time_dict"]
|
|
130
|
+
assert model_pred_dict.keys() == pred_time_dict.keys()
|
|
131
|
+
return model_pred_dict, pred_time_dict
|
|
132
|
+
except Exception:
|
|
133
|
+
logger.warning("Cached predictions are corrupted. Predictions will be made from scratch.")
|
|
134
|
+
return {}, {}
|
|
135
|
+
|
|
136
|
+
def _save_cached_pred_dicts(
|
|
137
|
+
self,
|
|
138
|
+
dataset_hash: str,
|
|
139
|
+
model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
|
|
140
|
+
pred_time_dict: dict[str, float],
|
|
141
|
+
) -> None:
|
|
142
|
+
cached_predictions = self._load_cached_predictions()
|
|
143
|
+
# Do not save results for models that failed
|
|
144
|
+
cached_predictions[dataset_hash] = {
|
|
145
|
+
"model_pred_dict": {k: v for k, v in model_pred_dict.items() if v is not None},
|
|
146
|
+
"pred_time_dict": {k: v for k, v in pred_time_dict.items() if v is not None},
|
|
147
|
+
}
|
|
148
|
+
save_pkl.save(str(self.path), object=cached_predictions)
|
|
149
|
+
logger.debug(f"Cached predictions saved to {self.path}")
|