autogluon.timeseries 1.4.1b20250820__py3-none-any.whl → 1.4.1b20250901__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +9 -9
  5. autogluon/timeseries/learner.py +14 -14
  6. autogluon/timeseries/metrics/__init__.py +5 -5
  7. autogluon/timeseries/metrics/abstract.py +11 -12
  8. autogluon/timeseries/models/__init__.py +2 -0
  9. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +39 -41
  10. autogluon/timeseries/models/abstract/tunable.py +6 -6
  11. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +30 -30
  12. autogluon/timeseries/models/autogluon_tabular/per_step.py +12 -12
  13. autogluon/timeseries/models/chronos/model.py +10 -10
  14. autogluon/timeseries/models/chronos/pipeline/base.py +8 -8
  15. autogluon/timeseries/models/chronos/pipeline/chronos.py +12 -12
  16. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +12 -12
  17. autogluon/timeseries/models/chronos/pipeline/utils.py +12 -12
  18. autogluon/timeseries/models/ensemble/abstract.py +19 -19
  19. autogluon/timeseries/models/ensemble/basic.py +8 -8
  20. autogluon/timeseries/models/ensemble/greedy.py +13 -13
  21. autogluon/timeseries/models/gluonts/abstract.py +24 -24
  22. autogluon/timeseries/models/gluonts/dataset.py +2 -2
  23. autogluon/timeseries/models/gluonts/models.py +7 -7
  24. autogluon/timeseries/models/local/abstract_local_model.py +12 -12
  25. autogluon/timeseries/models/local/statsforecast.py +11 -11
  26. autogluon/timeseries/models/multi_window/multi_window_model.py +33 -22
  27. autogluon/timeseries/models/registry.py +3 -3
  28. autogluon/timeseries/predictor.py +37 -37
  29. autogluon/timeseries/regressor.py +13 -13
  30. autogluon/timeseries/splitter.py +6 -6
  31. autogluon/timeseries/trainer/__init__.py +3 -0
  32. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  33. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  34. autogluon/timeseries/{trainer.py → trainer/trainer.py} +72 -128
  35. autogluon/timeseries/transforms/covariate_scaler.py +3 -3
  36. autogluon/timeseries/transforms/target_scaler.py +7 -7
  37. autogluon/timeseries/utils/datetime/lags.py +2 -2
  38. autogluon/timeseries/utils/datetime/time_features.py +2 -2
  39. autogluon/timeseries/utils/features.py +32 -32
  40. autogluon/timeseries/version.py +1 -1
  41. {autogluon.timeseries-1.4.1b20250820.dist-info → autogluon.timeseries-1.4.1b20250901.dist-info}/METADATA +5 -5
  42. autogluon.timeseries-1.4.1b20250901.dist-info/RECORD +75 -0
  43. autogluon/timeseries/configs/presets_configs.py +0 -79
  44. autogluon/timeseries/models/presets.py +0 -280
  45. autogluon.timeseries-1.4.1b20250820.dist-info/RECORD +0 -72
  46. /autogluon.timeseries-1.4.1b20250820-py3.9-nspkg.pth → /autogluon.timeseries-1.4.1b20250901-py3.9-nspkg.pth +0 -0
  47. {autogluon.timeseries-1.4.1b20250820.dist-info → autogluon.timeseries-1.4.1b20250901.dist-info}/LICENSE +0 -0
  48. {autogluon.timeseries-1.4.1b20250820.dist-info → autogluon.timeseries-1.4.1b20250901.dist-info}/NOTICE +0 -0
  49. {autogluon.timeseries-1.4.1b20250820.dist-info → autogluon.timeseries-1.4.1b20250901.dist-info}/WHEEL +0 -0
  50. {autogluon.timeseries-1.4.1b20250820.dist-info → autogluon.timeseries-1.4.1b20250901.dist-info}/namespace_packages.txt +0 -0
  51. {autogluon.timeseries-1.4.1b20250820.dist-info → autogluon.timeseries-1.4.1b20250901.dist-info}/top_level.txt +0 -0
  52. {autogluon.timeseries-1.4.1b20250820.dist-info → autogluon.timeseries-1.4.1b20250901.dist-info}/zip-safe +0 -0
@@ -0,0 +1,149 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from pathlib import Path
4
+ from typing import Any, Optional
5
+
6
+ from autogluon.common.utils.utils import hash_pandas_df
7
+ from autogluon.core.utils.loaders import load_pkl
8
+ from autogluon.core.utils.savers import save_pkl
9
+ from autogluon.timeseries import TimeSeriesDataFrame
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class PredictionCache(ABC):
15
+ """A prediction cache is an abstract key-value store for time series predictions. The storage is keyed by
16
+ (data, known_covariates) pairs and stores (model_pred_dict, pred_time_dict) pair values. In this stored pair,
17
+ (model_pred_dict, pred_time_dict), both dictionaries are keyed by model names.
18
+ """
19
+
20
+ def __init__(self, root_path: str):
21
+ self.root_path = Path(root_path)
22
+
23
+ @abstractmethod
24
+ def get(
25
+ self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame]
26
+ ) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def put(
31
+ self,
32
+ data: TimeSeriesDataFrame,
33
+ known_covariates: Optional[TimeSeriesDataFrame],
34
+ model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
35
+ pred_time_dict: dict[str, float],
36
+ ) -> None:
37
+ pass
38
+
39
+ @abstractmethod
40
+ def clear(self) -> None:
41
+ pass
42
+
43
+
44
+ def get_prediction_cache(use_cache: bool, root_path: str) -> PredictionCache:
45
+ if use_cache:
46
+ return FileBasedPredictionCache(root_path=root_path)
47
+ else:
48
+ return NoOpPredictionCache(root_path=root_path)
49
+
50
+
51
+ def compute_dataset_hash(data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame] = None) -> str:
52
+ """Compute a unique string that identifies the time series dataset."""
53
+ combined_hash = hash_pandas_df(data) + hash_pandas_df(known_covariates) + hash_pandas_df(data.static_features)
54
+ return combined_hash
55
+
56
+
57
+ class NoOpPredictionCache(PredictionCache):
58
+ """A dummy (no-op) prediction cache."""
59
+
60
+ def get(
61
+ self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame]
62
+ ) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
63
+ return {}, {}
64
+
65
+ def put(
66
+ self,
67
+ data: TimeSeriesDataFrame,
68
+ known_covariates: Optional[TimeSeriesDataFrame],
69
+ model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
70
+ pred_time_dict: dict[str, float],
71
+ ) -> None:
72
+ pass
73
+
74
+ def clear(self) -> None:
75
+ pass
76
+
77
+
78
+ class FileBasedPredictionCache(PredictionCache):
79
+ """A file-backed cache of model predictions."""
80
+
81
+ _cached_predictions_filename = "cached_predictions.pkl"
82
+
83
+ @property
84
+ def path(self) -> Path:
85
+ return Path(self.root_path) / self._cached_predictions_filename
86
+
87
+ def get(
88
+ self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame]
89
+ ) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
90
+ dataset_hash = compute_dataset_hash(data, known_covariates)
91
+ return self._get_cached_pred_dicts(dataset_hash)
92
+
93
+ def put(
94
+ self,
95
+ data: TimeSeriesDataFrame,
96
+ known_covariates: Optional[TimeSeriesDataFrame],
97
+ model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
98
+ pred_time_dict: dict[str, float],
99
+ ) -> None:
100
+ dataset_hash = compute_dataset_hash(data, known_covariates)
101
+ self._save_cached_pred_dicts(dataset_hash, model_pred_dict, pred_time_dict)
102
+
103
+ def clear(self) -> None:
104
+ if self.path.exists():
105
+ logger.debug(f"Removing existing cached predictions file {self.path}")
106
+ self.path.unlink()
107
+
108
+ def _load_cached_predictions(self) -> dict[str, dict[str, dict[str, Any]]]:
109
+ if self.path.exists():
110
+ try:
111
+ cached_predictions = load_pkl.load(str(self.path))
112
+ except Exception:
113
+ cached_predictions = {}
114
+ else:
115
+ cached_predictions = {}
116
+ return cached_predictions
117
+
118
+ def _get_cached_pred_dicts(
119
+ self, dataset_hash: str
120
+ ) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
121
+ """Load cached predictions for given dataset_hash from disk, if possible.
122
+
123
+ If loading fails for any reason, empty dicts are returned.
124
+ """
125
+ cached_predictions = self._load_cached_predictions()
126
+ if dataset_hash in cached_predictions:
127
+ try:
128
+ model_pred_dict = cached_predictions[dataset_hash]["model_pred_dict"]
129
+ pred_time_dict = cached_predictions[dataset_hash]["pred_time_dict"]
130
+ assert model_pred_dict.keys() == pred_time_dict.keys()
131
+ return model_pred_dict, pred_time_dict
132
+ except Exception:
133
+ logger.warning("Cached predictions are corrupted. Predictions will be made from scratch.")
134
+ return {}, {}
135
+
136
+ def _save_cached_pred_dicts(
137
+ self,
138
+ dataset_hash: str,
139
+ model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
140
+ pred_time_dict: dict[str, float],
141
+ ) -> None:
142
+ cached_predictions = self._load_cached_predictions()
143
+ # Do not save results for models that failed
144
+ cached_predictions[dataset_hash] = {
145
+ "model_pred_dict": {k: v for k, v in model_pred_dict.items() if v is not None},
146
+ "pred_time_dict": {k: v for k, v in pred_time_dict.items() if v is not None},
147
+ }
148
+ save_pkl.save(str(self.path), object=cached_predictions)
149
+ logger.debug(f"Cached predictions saved to {self.path}")
@@ -5,14 +5,14 @@ import time
5
5
  import traceback
6
6
  from collections import defaultdict
7
7
  from pathlib import Path
8
- from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
8
+ from typing import Any, Literal, Optional, Type, Union
9
9
 
10
10
  import networkx as nx
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from tqdm import tqdm
14
14
 
15
- from autogluon.common.utils.utils import hash_pandas_df, seed_everything
15
+ from autogluon.common.utils.utils import seed_everything
16
16
  from autogluon.core.trainer.abstract_trainer import AbstractTrainer
17
17
  from autogluon.core.utils.exceptions import TimeLimitExceeded
18
18
  from autogluon.core.utils.loaders import load_pkl
@@ -22,7 +22,6 @@ from autogluon.timeseries.metrics import TimeSeriesScorer, check_get_evaluation_
22
22
  from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel, TimeSeriesModelBase
23
23
  from autogluon.timeseries.models.ensemble import AbstractTimeSeriesEnsembleModel, GreedyEnsemble
24
24
  from autogluon.timeseries.models.multi_window import MultiWindowBacktestingModel
25
- from autogluon.timeseries.models.presets import contains_searchspace, get_preset_models
26
25
  from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
27
26
  from autogluon.timeseries.utils.features import (
28
27
  ConstantReplacementFeatureImportanceTransform,
@@ -31,12 +30,13 @@ from autogluon.timeseries.utils.features import (
31
30
  )
32
31
  from autogluon.timeseries.utils.warning_filters import disable_tqdm, warning_filter
33
32
 
33
+ from .model_set_builder import TrainableModelSetBuilder, contains_searchspace
34
+ from .prediction_cache import PredictionCache, get_prediction_cache
35
+
34
36
  logger = logging.getLogger("autogluon.timeseries.trainer")
35
37
 
36
38
 
37
39
  class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
38
- _cached_predictions_filename = "cached_predictions.pkl"
39
-
40
40
  max_rel_importance_score: float = 1e5
41
41
  eps_abs_importance_score: float = 1e-5
42
42
  max_ensemble_time_limit: float = 600.0
@@ -81,7 +81,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
81
81
 
82
82
  self.verbosity = verbosity
83
83
 
84
- #: Dict of normal model -> FULL model. FULL models are produced by
84
+ #: dict of normal model -> FULL model. FULL models are produced by
85
85
  #: self.refit_single_full() and self.refit_full().
86
86
  self.model_refit_map = {}
87
87
 
@@ -91,12 +91,10 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
91
91
  assert isinstance(val_splitter, AbstractWindowSplitter), "val_splitter must be of type AbstractWindowSplitter"
92
92
  self.val_splitter = val_splitter
93
93
  self.refit_every_n_windows = refit_every_n_windows
94
- self.cache_predictions = cache_predictions
95
94
  self.hpo_results = {}
96
95
 
97
- if self._cached_predictions_path.exists():
98
- logger.debug(f"Removing existing cached predictions file {self._cached_predictions_path}")
99
- self._cached_predictions_path.unlink()
96
+ self.prediction_cache: PredictionCache = get_prediction_cache(cache_predictions, self.path)
97
+ self.prediction_cache.clear()
100
98
 
101
99
  @property
102
100
  def path_pkl(self) -> str:
@@ -121,7 +119,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
121
119
  else:
122
120
  return None
123
121
 
124
- def load_data(self) -> Tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
122
+ def load_data(self) -> tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
125
123
  train_data = self.load_train_data()
126
124
  val_data = self.load_val_data()
127
125
  return train_data, val_data
@@ -136,7 +134,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
136
134
 
137
135
  self.models = models
138
136
 
139
- def _get_model_oof_predictions(self, model_name: str) -> List[TimeSeriesDataFrame]:
137
+ def _get_model_oof_predictions(self, model_name: str) -> list[TimeSeriesDataFrame]:
140
138
  model_path = os.path.join(self.path, self.get_model_attribute(model=model_name, attribute="path"))
141
139
  model_type = self.get_model_attribute(model=model_name, attribute="type")
142
140
  return model_type.load_oof_predictions(path=model_path)
@@ -144,16 +142,16 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
144
142
  def _add_model(
145
143
  self,
146
144
  model: TimeSeriesModelBase,
147
- base_models: Optional[List[str]] = None,
145
+ base_models: Optional[list[str]] = None,
148
146
  ):
149
147
  """Add a model to the model graph of the trainer. If the model is an ensemble, also add
150
148
  information about dependencies to the model graph (list of models specified via ``base_models``).
151
149
 
152
150
  Parameters
153
151
  ----------
154
- model : TimeSeriesModelBase
152
+ model
155
153
  The model to be added to the model graph.
156
- base_models : List[str], optional, default None
154
+ base_models
157
155
  If the model is an ensemble, the list of base model names that are included in the ensemble.
158
156
  Expected only when ``model`` is a ``AbstractTimeSeriesEnsembleModel``.
159
157
 
@@ -176,7 +174,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
176
174
  for base_model in base_models:
177
175
  self.model_graph.add_edge(base_model, model.name)
178
176
 
179
- def _get_model_levels(self) -> Dict[str, int]:
177
+ def _get_model_levels(self) -> dict[str, int]:
180
178
  """Get a dictionary mapping each model to their level in the model graph"""
181
179
 
182
180
  # get nodes without a parent
@@ -197,7 +195,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
197
195
 
198
196
  return levels
199
197
 
200
- def get_models_attribute_dict(self, attribute: str, models: Optional[List[str]] = None) -> Dict[str, Any]:
198
+ def get_models_attribute_dict(self, attribute: str, models: Optional[list[str]] = None) -> dict[str, Any]:
201
199
  """Get an attribute from the `model_graph` for each of the model names
202
200
  specified. If `models` is none, the attribute will be returned for all models"""
203
201
  results = {}
@@ -230,13 +228,13 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
230
228
  key=lambda mns: (mns[1], -mns[2]), # (score, -level)
231
229
  )[0]
232
230
 
233
- def get_model_names(self, level: Optional[int] = None) -> List[str]:
231
+ def get_model_names(self, level: Optional[int] = None) -> list[str]:
234
232
  """Get model names that are registered in the model graph"""
235
233
  if level is not None:
236
234
  return list(node for node, l in self._get_model_levels().items() if l == level) # noqa: E741
237
235
  return list(self.model_graph.nodes)
238
236
 
239
- def get_info(self, include_model_info: bool = False) -> Dict[str, Any]:
237
+ def get_info(self, include_model_info: bool = False) -> dict[str, Any]:
240
238
  num_models_trained = len(self.get_model_names())
241
239
  if self.model_best is not None:
242
240
  best_model = self.model_best
@@ -339,12 +337,13 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
339
337
  model: AbstractTimeSeriesModel,
340
338
  val_data: Optional[TimeSeriesDataFrame] = None,
341
339
  time_limit: Optional[float] = None,
342
- ) -> List[str]:
340
+ ) -> list[str]:
343
341
  """Fit and save the given model on given training and validation data and save the trained model.
344
342
 
345
343
  Returns
346
344
  -------
347
- model_names_trained: the list of model names that were successfully trained
345
+ model_names_trained
346
+ the list of model names that were successfully trained
348
347
  """
349
348
  fit_start_time = time.time()
350
349
  model_names_trained = []
@@ -397,13 +396,13 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
397
396
  def _train_multi(
398
397
  self,
399
398
  train_data: TimeSeriesDataFrame,
400
- hyperparameters: Union[str, Dict],
399
+ hyperparameters: Union[str, dict],
401
400
  val_data: Optional[TimeSeriesDataFrame] = None,
402
401
  hyperparameter_tune_kwargs: Optional[Union[str, dict]] = None,
403
- excluded_model_types: Optional[List[str]] = None,
402
+ excluded_model_types: Optional[list[str]] = None,
404
403
  time_limit: Optional[float] = None,
405
404
  random_seed: Optional[int] = None,
406
- ) -> List[str]:
405
+ ) -> list[str]:
407
406
  logger.info(f"\nStarting training. Start time is {time.strftime('%Y-%m-%d %H:%M:%S')}")
408
407
 
409
408
  time_start = time.time()
@@ -415,7 +414,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
415
414
  self.save_val_data(val_data)
416
415
  self.is_data_saved = True
417
416
 
418
- models = self.construct_model_templates(
417
+ models = self.get_trainable_base_models(
419
418
  hyperparameters=hyperparameters,
420
419
  hyperparameter_tune=hyperparameter_tune_kwargs is not None, # TODO: remove hyperparameter_tune
421
420
  freq=train_data.freq,
@@ -439,8 +438,6 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
439
438
  num_base_models = len(models)
440
439
  model_names_trained = []
441
440
  for i, model in enumerate(models):
442
- assert isinstance(model, AbstractTimeSeriesModel)
443
-
444
441
  if time_limit is None:
445
442
  time_left = None
446
443
  time_left_for_model = None
@@ -541,7 +538,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
541
538
 
542
539
  def _get_ensemble_oof_data(
543
540
  self, train_data: TimeSeriesDataFrame, val_data: Optional[TimeSeriesDataFrame]
544
- ) -> List[TimeSeriesDataFrame]:
541
+ ) -> list[TimeSeriesDataFrame]:
545
542
  if val_data is None:
546
543
  return [val_fold for _, val_fold in self.val_splitter.split(train_data)]
547
544
  else:
@@ -558,13 +555,13 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
558
555
 
559
556
  def fit_ensemble(
560
557
  self,
561
- data_per_window: List[TimeSeriesDataFrame],
562
- model_names: List[str],
558
+ data_per_window: list[TimeSeriesDataFrame],
559
+ model_names: list[str],
563
560
  time_limit: Optional[float] = None,
564
561
  ) -> str:
565
562
  logger.info("Fitting simple weighted ensemble.")
566
563
 
567
- predictions_per_window: Dict[str, List[TimeSeriesDataFrame]] = {}
564
+ predictions_per_window: dict[str, list[TimeSeriesDataFrame]] = {}
568
565
  base_model_scores = self.get_models_attribute_dict(attribute="val_score", models=self.get_model_names(0))
569
566
 
570
567
  for model_name in model_names:
@@ -614,7 +611,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
614
611
  self,
615
612
  data: Optional[TimeSeriesDataFrame] = None,
616
613
  extra_info: bool = False,
617
- extra_metrics: Optional[List[Union[str, TimeSeriesScorer]]] = None,
614
+ extra_metrics: Optional[list[Union[str, TimeSeriesScorer]]] = None,
618
615
  use_cache: bool = True,
619
616
  ) -> pd.DataFrame:
620
617
  logger.debug("Generating leaderboard for all models trained")
@@ -704,8 +701,8 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
704
701
  return df[explicit_column_order]
705
702
 
706
703
  def persist(
707
- self, model_names: Union[Literal["all", "best"], List[str]] = "all", with_ancestors: bool = False
708
- ) -> List[str]:
704
+ self, model_names: Union[Literal["all", "best"], list[str]] = "all", with_ancestors: bool = False
705
+ ) -> list[str]:
709
706
  if model_names == "all":
710
707
  model_names = self.get_model_names()
711
708
  elif model_names == "best":
@@ -729,7 +726,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
729
726
 
730
727
  return model_names
731
728
 
732
- def unpersist(self, model_names: Union[Literal["all"], List[str]] = "all") -> List[str]:
729
+ def unpersist(self, model_names: Union[Literal["all"], list[str]] = "all") -> list[str]:
733
730
  if model_names == "all":
734
731
  model_names = list(self.models.keys())
735
732
  if not isinstance(model_names, list):
@@ -826,9 +823,9 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
826
823
  self,
827
824
  data: TimeSeriesDataFrame,
828
825
  model: Optional[Union[str, TimeSeriesModelBase]] = None,
829
- metrics: Optional[Union[str, TimeSeriesScorer, List[Union[str, TimeSeriesScorer]]]] = None,
826
+ metrics: Optional[Union[str, TimeSeriesScorer, list[Union[str, TimeSeriesScorer]]]] = None,
830
827
  use_cache: bool = True,
831
- ) -> Dict[str, float]:
828
+ ) -> dict[str, float]:
832
829
  past_data, known_covariates = data.get_model_inputs_for_scoring(
833
830
  prediction_length=self.prediction_length, known_covariates_names=self.covariate_metadata.known_covariates
834
831
  )
@@ -846,7 +843,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
846
843
  def get_feature_importance(
847
844
  self,
848
845
  data: TimeSeriesDataFrame,
849
- features: List[str],
846
+ features: list[str],
850
847
  model: Optional[Union[str, TimeSeriesModelBase]] = None,
851
848
  metric: Optional[Union[str, TimeSeriesScorer]] = None,
852
849
  time_limit: Optional[float] = None,
@@ -996,7 +993,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
996
993
  self,
997
994
  model: Union[str, TimeSeriesModelBase],
998
995
  data: TimeSeriesDataFrame,
999
- model_pred_dict: Dict[str, Optional[TimeSeriesDataFrame]],
996
+ model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
1000
997
  known_covariates: Optional[TimeSeriesDataFrame] = None,
1001
998
  ) -> TimeSeriesDataFrame:
1002
999
  """Generate predictions using the given model.
@@ -1012,8 +1009,8 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1012
1009
  self,
1013
1010
  model: Union[str, TimeSeriesModelBase],
1014
1011
  data: TimeSeriesDataFrame,
1015
- model_pred_dict: Dict[str, Optional[TimeSeriesDataFrame]],
1016
- ) -> Union[TimeSeriesDataFrame, Dict[str, Optional[TimeSeriesDataFrame]]]:
1012
+ model_pred_dict: dict[str, Optional[TimeSeriesDataFrame]],
1013
+ ) -> Union[TimeSeriesDataFrame, dict[str, Optional[TimeSeriesDataFrame]]]:
1017
1014
  """Get the first argument that should be passed to model.predict.
1018
1015
 
1019
1016
  This method assumes that model_pred_dict contains the predictions of all base models, if model is an ensemble.
@@ -1029,13 +1026,13 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1029
1026
 
1030
1027
  def get_model_pred_dict(
1031
1028
  self,
1032
- model_names: List[str],
1029
+ model_names: list[str],
1033
1030
  data: TimeSeriesDataFrame,
1034
1031
  known_covariates: Optional[TimeSeriesDataFrame] = None,
1035
1032
  raise_exception_if_failed: bool = True,
1036
1033
  use_cache: bool = True,
1037
1034
  random_seed: Optional[int] = None,
1038
- ) -> Tuple[Dict[str, Optional[TimeSeriesDataFrame]], Dict[str, float]]:
1035
+ ) -> tuple[dict[str, Optional[TimeSeriesDataFrame]], dict[str, float]]:
1039
1036
  """Return a dictionary with predictions of all models for the given dataset.
1040
1037
 
1041
1038
  Parameters
@@ -1055,12 +1052,13 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1055
1052
  use_cache
1056
1053
  If False, will ignore the cache even if it's available.
1057
1054
  """
1058
- if self.cache_predictions and use_cache:
1059
- dataset_hash = self._compute_dataset_hash(data=data, known_covariates=known_covariates)
1060
- model_pred_dict, pred_time_dict_marginal = self._get_cached_pred_dicts(dataset_hash)
1055
+ if use_cache:
1056
+ model_pred_dict, pred_time_dict_marginal = self.prediction_cache.get(
1057
+ data=data, known_covariates=known_covariates
1058
+ )
1061
1059
  else:
1062
1060
  model_pred_dict = {}
1063
- pred_time_dict_marginal: Dict[str, Any] = {}
1061
+ pred_time_dict_marginal: dict[str, Any] = {}
1064
1062
 
1065
1063
  model_set = set()
1066
1064
  for model_name in model_names:
@@ -1093,9 +1091,11 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1093
1091
 
1094
1092
  if len(failed_models) > 0 and raise_exception_if_failed:
1095
1093
  raise RuntimeError(f"Following models failed to predict: {failed_models}")
1096
- if self.cache_predictions and use_cache:
1097
- self._save_cached_pred_dicts(
1098
- dataset_hash, # type: ignore
1094
+
1095
+ if use_cache:
1096
+ self.prediction_cache.put(
1097
+ data=data,
1098
+ known_covariates=known_covariates,
1099
1099
  model_pred_dict=model_pred_dict,
1100
1100
  pred_time_dict=pred_time_dict_marginal,
1101
1101
  )
@@ -1106,7 +1106,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1106
1106
 
1107
1107
  return final_model_pred_dict, final_pred_time_dict_total
1108
1108
 
1109
- def _get_total_pred_time_from_marginal(self, pred_time_dict_marginal: Dict[str, float]) -> Dict[str, float]:
1109
+ def _get_total_pred_time_from_marginal(self, pred_time_dict_marginal: dict[str, float]) -> dict[str, float]:
1110
1110
  pred_time_dict_total = defaultdict(float)
1111
1111
  for model_name in pred_time_dict_marginal.keys():
1112
1112
  for base_model in self.get_minimum_model_set(model_name):
@@ -1114,62 +1114,6 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1114
1114
  pred_time_dict_total[model_name] += pred_time_dict_marginal[base_model]
1115
1115
  return dict(pred_time_dict_total)
1116
1116
 
1117
- @property
1118
- def _cached_predictions_path(self) -> Path:
1119
- return Path(self.path) / self._cached_predictions_filename
1120
-
1121
- @staticmethod
1122
- def _compute_dataset_hash(
1123
- data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame] = None
1124
- ) -> str:
1125
- """Compute a unique string that identifies the time series dataset."""
1126
- combined_hash = hash_pandas_df(data) + hash_pandas_df(known_covariates) + hash_pandas_df(data.static_features)
1127
- return combined_hash
1128
-
1129
- def _load_cached_predictions(self) -> dict[str, dict[str, dict[str, Any]]]:
1130
- """Load cached predictions from disk. If loading fails, an empty dictionary is returned."""
1131
- if self._cached_predictions_path.exists():
1132
- try:
1133
- cached_predictions = load_pkl.load(str(self._cached_predictions_path))
1134
- except Exception:
1135
- cached_predictions = {}
1136
- else:
1137
- cached_predictions = {}
1138
- return cached_predictions
1139
-
1140
- def _get_cached_pred_dicts(
1141
- self, dataset_hash: str
1142
- ) -> Tuple[Dict[str, Optional[TimeSeriesDataFrame]], Dict[str, float]]:
1143
- """Load cached predictions for given dataset_hash from disk, if possible.
1144
-
1145
- If loading fails for any reason, empty dicts are returned.
1146
- """
1147
- cached_predictions = self._load_cached_predictions()
1148
- if dataset_hash in cached_predictions:
1149
- try:
1150
- model_pred_dict = cached_predictions[dataset_hash]["model_pred_dict"]
1151
- pred_time_dict = cached_predictions[dataset_hash]["pred_time_dict"]
1152
- assert model_pred_dict.keys() == pred_time_dict.keys()
1153
- return model_pred_dict, pred_time_dict
1154
- except Exception:
1155
- logger.warning("Cached predictions are corrupted. Predictions will be made from scratch.")
1156
- return {}, {}
1157
-
1158
- def _save_cached_pred_dicts(
1159
- self,
1160
- dataset_hash: str,
1161
- model_pred_dict: Dict[str, Optional[TimeSeriesDataFrame]],
1162
- pred_time_dict: Dict[str, float],
1163
- ) -> None:
1164
- cached_predictions = self._load_cached_predictions()
1165
- # Do not save results for models that failed
1166
- cached_predictions[dataset_hash] = {
1167
- "model_pred_dict": {k: v for k, v in model_pred_dict.items() if v is not None},
1168
- "pred_time_dict": {k: v for k, v in pred_time_dict.items() if v is not None},
1169
- }
1170
- save_pkl.save(str(self._cached_predictions_path), object=cached_predictions)
1171
- logger.debug(f"Cached predictions saved to {self._cached_predictions_path}")
1172
-
1173
1117
  def _merge_refit_full_data(
1174
1118
  self, train_data: TimeSeriesDataFrame, val_data: Optional[TimeSeriesDataFrame]
1175
1119
  ) -> TimeSeriesDataFrame:
@@ -1183,8 +1127,8 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1183
1127
  self,
1184
1128
  train_data: Optional[TimeSeriesDataFrame] = None,
1185
1129
  val_data: Optional[TimeSeriesDataFrame] = None,
1186
- models: Optional[List[str]] = None,
1187
- ) -> List[str]:
1130
+ models: Optional[list[str]] = None,
1131
+ ) -> list[str]:
1188
1132
  train_data = train_data or self.load_train_data()
1189
1133
  val_data = val_data or self.load_val_data()
1190
1134
  refit_full_data = self._merge_refit_full_data(train_data, val_data)
@@ -1228,7 +1172,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1228
1172
  self.save()
1229
1173
  return models_trained_full
1230
1174
 
1231
- def refit_full(self, model: str = "all") -> Dict[str, str]:
1175
+ def refit_full(self, model: str = "all") -> dict[str, str]:
1232
1176
  time_start = time.time()
1233
1177
  existing_models = self.get_model_names()
1234
1178
  if model == "all":
@@ -1260,38 +1204,38 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1260
1204
  logger.info(f"Total runtime: {time.time() - time_start:.2f} s")
1261
1205
  return copy.deepcopy(self.model_refit_map)
1262
1206
 
1263
- def construct_model_templates(
1207
+ def get_trainable_base_models(
1264
1208
  self,
1265
- hyperparameters: Union[str, Dict[str, Any]],
1209
+ hyperparameters: Union[str, dict[str, Any]],
1266
1210
  *,
1267
1211
  multi_window: bool = False,
1268
1212
  freq: Optional[str] = None,
1269
- excluded_model_types: Optional[List[str]] = None,
1213
+ excluded_model_types: Optional[list[str]] = None,
1270
1214
  hyperparameter_tune: bool = False,
1271
- ) -> List[TimeSeriesModelBase]:
1272
- return get_preset_models(
1215
+ ) -> list[AbstractTimeSeriesModel]:
1216
+ return TrainableModelSetBuilder(
1217
+ freq=freq,
1218
+ prediction_length=self.prediction_length,
1273
1219
  path=self.path,
1274
1220
  eval_metric=self.eval_metric,
1275
- prediction_length=self.prediction_length,
1276
- freq=freq,
1277
- hyperparameters=hyperparameters,
1278
- hyperparameter_tune=hyperparameter_tune,
1279
1221
  quantile_levels=self.quantile_levels,
1280
- all_assigned_names=self._get_banned_model_names(),
1281
1222
  target=self.target,
1282
1223
  covariate_metadata=self.covariate_metadata,
1283
- excluded_model_types=excluded_model_types,
1284
- # if skip_model_selection = True, we skip backtesting
1285
1224
  multi_window=multi_window and not self.skip_model_selection,
1225
+ ).get_model_set(
1226
+ hyperparameters=hyperparameters,
1227
+ hyperparameter_tune=hyperparameter_tune,
1228
+ excluded_model_types=excluded_model_types,
1229
+ banned_model_names=self._get_banned_model_names(),
1286
1230
  )
1287
1231
 
1288
1232
  def fit(
1289
1233
  self,
1290
1234
  train_data: TimeSeriesDataFrame,
1291
- hyperparameters: Union[str, Dict[Any, Dict]],
1235
+ hyperparameters: Union[str, dict[Any, dict]],
1292
1236
  val_data: Optional[TimeSeriesDataFrame] = None,
1293
- hyperparameter_tune_kwargs: Optional[Union[str, Dict]] = None,
1294
- excluded_model_types: Optional[List[str]] = None,
1237
+ hyperparameter_tune_kwargs: Optional[Union[str, dict]] = None,
1238
+ excluded_model_types: Optional[list[str]] = None,
1295
1239
  time_limit: Optional[float] = None,
1296
1240
  random_seed: Optional[int] = None,
1297
1241
  ):
@@ -1301,13 +1245,13 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1301
1245
 
1302
1246
  Parameters
1303
1247
  ----------
1304
- train_data: TimeSeriesDataFrame
1248
+ train_data
1305
1249
  Training data for fitting time series timeseries models.
1306
- hyperparameters: str or Dict
1250
+ hyperparameters
1307
1251
  A dictionary mapping selected model names, model classes or model factory to hyperparameter
1308
1252
  settings. Model names should be present in `trainer.presets.DEFAULT_MODEL_NAMES`. Optionally,
1309
1253
  the user may provide one of "default", "light" and "very_light" to specify presets.
1310
- val_data: TimeSeriesDataFrame
1254
+ val_data
1311
1255
  Optional validation data set to report validation scores on.
1312
1256
  hyperparameter_tune_kwargs
1313
1257
  Args for hyperparameter tuning
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Dict, List, Literal, Optional, Protocol, overload, runtime_checkable
2
+ from typing import Literal, Optional, Protocol, overload, runtime_checkable
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
@@ -53,7 +53,7 @@ class GlobalCovariateScaler(CovariateScaler):
53
53
  self.use_past_covariates = use_past_covariates
54
54
  self.use_static_features = use_static_features
55
55
  self.skew_threshold = skew_threshold
56
- self._column_transformers: Optional[Dict[Literal["known", "past", "static"], ColumnTransformer]] = None
56
+ self._column_transformers: Optional[dict[Literal["known", "past", "static"], ColumnTransformer]] = None
57
57
 
58
58
  def is_fit(self) -> bool:
59
59
  return self._column_transformers is not None
@@ -117,7 +117,7 @@ class GlobalCovariateScaler(CovariateScaler):
117
117
  known_covariates[columns] = self._column_transformers["known"].transform(known_covariates[columns])
118
118
  return known_covariates
119
119
 
120
- def _get_transformer_for_columns(self, df: pd.DataFrame, columns: List[str]) -> ColumnTransformer:
120
+ def _get_transformer_for_columns(self, df: pd.DataFrame, columns: list[str]) -> ColumnTransformer:
121
121
  """Passthrough bool features, use QuantileTransform for skewed features, and use StandardScaler for the rest.
122
122
 
123
123
  The preprocessing logic is similar to the TORCH_NN model from Tabular.