autogluon.timeseries 1.2.1b20250224__py3-none-any.whl → 1.4.1b20251215__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +106 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +256 -141
  5. autogluon/timeseries/learner.py +86 -52
  6. autogluon/timeseries/metrics/__init__.py +42 -8
  7. autogluon/timeseries/metrics/abstract.py +89 -19
  8. autogluon/timeseries/metrics/point.py +142 -53
  9. autogluon/timeseries/metrics/quantile.py +46 -21
  10. autogluon/timeseries/metrics/utils.py +4 -4
  11. autogluon/timeseries/models/__init__.py +8 -2
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +361 -592
  14. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +282 -194
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +25 -18
  20. autogluon/timeseries/models/chronos/__init__.py +2 -1
  21. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  22. autogluon/timeseries/models/chronos/model.py +219 -138
  23. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +81 -50
  24. autogluon/timeseries/models/ensemble/__init__.py +37 -2
  25. autogluon/timeseries/models/ensemble/abstract.py +107 -0
  26. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  27. autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
  28. autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  34. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  35. autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
  36. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  37. autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
  38. autogluon/timeseries/models/ensemble/weighted/basic.py +91 -0
  39. autogluon/timeseries/models/ensemble/weighted/greedy.py +62 -0
  40. autogluon/timeseries/models/gluonts/__init__.py +1 -1
  41. autogluon/timeseries/models/gluonts/{abstract_gluonts.py → abstract.py} +148 -208
  42. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  43. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +38 -22
  44. autogluon/timeseries/models/local/__init__.py +0 -7
  45. autogluon/timeseries/models/local/abstract_local_model.py +71 -74
  46. autogluon/timeseries/models/local/naive.py +13 -9
  47. autogluon/timeseries/models/local/npts.py +9 -2
  48. autogluon/timeseries/models/local/statsforecast.py +52 -36
  49. autogluon/timeseries/models/multi_window/multi_window_model.py +65 -45
  50. autogluon/timeseries/models/registry.py +64 -0
  51. autogluon/timeseries/models/toto/__init__.py +3 -0
  52. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  62. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  63. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  64. autogluon/timeseries/models/toto/dataloader.py +108 -0
  65. autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
  66. autogluon/timeseries/models/toto/model.py +249 -0
  67. autogluon/timeseries/predictor.py +685 -297
  68. autogluon/timeseries/regressor.py +94 -44
  69. autogluon/timeseries/splitter.py +8 -32
  70. autogluon/timeseries/trainer/__init__.py +3 -0
  71. autogluon/timeseries/trainer/ensemble_composer.py +444 -0
  72. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  73. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  74. autogluon/timeseries/{trainer.py → trainer/trainer.py} +387 -390
  75. autogluon/timeseries/trainer/utils.py +17 -0
  76. autogluon/timeseries/transforms/__init__.py +2 -13
  77. autogluon/timeseries/transforms/covariate_scaler.py +34 -40
  78. autogluon/timeseries/transforms/target_scaler.py +37 -20
  79. autogluon/timeseries/utils/constants.py +10 -0
  80. autogluon/timeseries/utils/datetime/lags.py +3 -5
  81. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  82. autogluon/timeseries/utils/datetime/time_features.py +2 -2
  83. autogluon/timeseries/utils/features.py +70 -47
  84. autogluon/timeseries/utils/forecast.py +19 -14
  85. autogluon/timeseries/utils/timer.py +173 -0
  86. autogluon/timeseries/utils/warning_filters.py +4 -2
  87. autogluon/timeseries/version.py +1 -1
  88. autogluon.timeseries-1.4.1b20251215-py3.11-nspkg.pth +1 -0
  89. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/METADATA +49 -36
  90. autogluon_timeseries-1.4.1b20251215.dist-info/RECORD +103 -0
  91. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/WHEEL +1 -1
  92. autogluon/timeseries/configs/presets_configs.py +0 -79
  93. autogluon/timeseries/evaluator.py +0 -6
  94. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -11
  95. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  96. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -585
  97. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -518
  98. autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -78
  99. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  100. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  101. autogluon/timeseries/models/presets.py +0 -360
  102. autogluon.timeseries-1.2.1b20250224-py3.9-nspkg.pth +0 -1
  103. autogluon.timeseries-1.2.1b20250224.dist-info/RECORD +0 -68
  104. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/zip-safe +0 -0
@@ -5,38 +5,40 @@ import time
5
5
  import traceback
6
6
  from collections import defaultdict
7
7
  from pathlib import Path
8
- from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
8
+ from typing import Any, Literal
9
9
 
10
10
  import networkx as nx
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from tqdm import tqdm
14
14
 
15
- from autogluon.common.utils.utils import hash_pandas_df, seed_everything
15
+ from autogluon.common.utils.utils import seed_everything
16
16
  from autogluon.core.trainer.abstract_trainer import AbstractTrainer
17
17
  from autogluon.core.utils.exceptions import TimeLimitExceeded
18
18
  from autogluon.core.utils.loaders import load_pkl
19
19
  from autogluon.core.utils.savers import save_pkl
20
20
  from autogluon.timeseries import TimeSeriesDataFrame
21
21
  from autogluon.timeseries.metrics import TimeSeriesScorer, check_get_evaluation_metric
22
- from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
23
- from autogluon.timeseries.models.ensemble import AbstractTimeSeriesEnsembleModel, TimeSeriesGreedyEnsemble
22
+ from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel, TimeSeriesModelBase
23
+ from autogluon.timeseries.models.ensemble import AbstractTimeSeriesEnsembleModel
24
24
  from autogluon.timeseries.models.multi_window import MultiWindowBacktestingModel
25
- from autogluon.timeseries.models.presets import contains_searchspace, get_preset_models
26
25
  from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
26
+ from autogluon.timeseries.trainer.ensemble_composer import EnsembleComposer, validate_ensemble_hyperparameters
27
27
  from autogluon.timeseries.utils.features import (
28
28
  ConstantReplacementFeatureImportanceTransform,
29
29
  CovariateMetadata,
30
30
  PermutationFeatureImportanceTransform,
31
31
  )
32
- from autogluon.timeseries.utils.warning_filters import disable_tqdm, warning_filter
32
+ from autogluon.timeseries.utils.warning_filters import disable_tqdm
33
33
 
34
- logger = logging.getLogger("autogluon.timeseries.trainer")
34
+ from .model_set_builder import TrainableModelSetBuilder, contains_searchspace
35
+ from .prediction_cache import PredictionCache, get_prediction_cache
36
+ from .utils import log_scores_and_times
35
37
 
38
+ logger = logging.getLogger("autogluon.timeseries.trainer")
36
39
 
37
- class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
38
- _cached_predictions_filename = "cached_predictions.pkl"
39
40
 
41
+ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
40
42
  max_rel_importance_score: float = 1e5
41
43
  eps_abs_importance_score: float = 1e-5
42
44
  max_ensemble_time_limit: float = 600.0
@@ -45,16 +47,16 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
45
47
  self,
46
48
  path: str,
47
49
  prediction_length: int = 1,
48
- eval_metric: Union[str, TimeSeriesScorer, None] = None,
49
- eval_metric_seasonal_period: Optional[int] = None,
50
+ eval_metric: str | TimeSeriesScorer | None = None,
50
51
  save_data: bool = True,
51
52
  skip_model_selection: bool = False,
52
53
  enable_ensemble: bool = True,
53
54
  verbosity: int = 2,
54
- val_splitter: Optional[AbstractWindowSplitter] = None,
55
- refit_every_n_windows: Optional[int] = 1,
55
+ num_val_windows: tuple[int, ...] = (1,),
56
+ val_step_size: int | None = None,
57
+ refit_every_n_windows: int | None = 1,
58
+ # TODO: Set cache_predictions=False by default once all models in default presets have a reasonable inference speed
56
59
  cache_predictions: bool = True,
57
- ensemble_model_type: Optional[Type] = None,
58
60
  **kwargs,
59
61
  ):
60
62
  super().__init__(
@@ -66,38 +68,39 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
66
68
  self.prediction_length = prediction_length
67
69
  self.quantile_levels = kwargs.get("quantile_levels", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
68
70
  self.target = kwargs.get("target", "target")
69
- self.metadata = kwargs.get("metadata", CovariateMetadata())
71
+ self.covariate_metadata = kwargs.get("covariate_metadata", CovariateMetadata())
70
72
  self.is_data_saved = False
71
73
  self.skip_model_selection = skip_model_selection
72
74
  # Ensemble cannot be fit if val_scores are not computed
73
75
  self.enable_ensemble = enable_ensemble and not skip_model_selection
74
- if ensemble_model_type is None:
75
- ensemble_model_type = TimeSeriesGreedyEnsemble
76
- else:
76
+ if kwargs.get("ensemble_model_type") is not None:
77
77
  logger.warning(
78
- "Using a custom `ensemble_model_type` is experimental functionality that may break in future versions."
78
+ "Using a custom `ensemble_model_type` is no longer supported. Use the `ensemble_hyperparameters` "
79
+ "argument to `fit` instead."
79
80
  )
80
- self.ensemble_model_type = ensemble_model_type
81
81
 
82
82
  self.verbosity = verbosity
83
83
 
84
- #: Dict of normal model -> FULL model. FULL models are produced by
84
+ #: dict of normal model -> FULL model. FULL models are produced by
85
85
  #: self.refit_single_full() and self.refit_full().
86
86
  self.model_refit_map = {}
87
87
 
88
- self.eval_metric: TimeSeriesScorer = check_get_evaluation_metric(eval_metric)
89
- self.eval_metric_seasonal_period = eval_metric_seasonal_period
90
- if val_splitter is None:
91
- val_splitter = ExpandingWindowSplitter(prediction_length=self.prediction_length)
92
- assert isinstance(val_splitter, AbstractWindowSplitter), "val_splitter must be of type AbstractWindowSplitter"
93
- self.val_splitter = val_splitter
88
+ self.eval_metric = check_get_evaluation_metric(eval_metric, prediction_length=prediction_length)
89
+
90
+ self.num_val_windows = num_val_windows
91
+
92
+ # Validate num_val_windows
93
+ if len(self.num_val_windows) == 0:
94
+ raise ValueError("num_val_windows cannot be empty")
95
+ if not all(isinstance(w, int) and w > 0 for w in self.num_val_windows):
96
+ raise ValueError(f"num_val_windows must contain only positive integers, got {self.num_val_windows}")
97
+
98
+ self.val_step_size = val_step_size
94
99
  self.refit_every_n_windows = refit_every_n_windows
95
- self.cache_predictions = cache_predictions
96
100
  self.hpo_results = {}
97
101
 
98
- if self._cached_predictions_path.exists():
99
- logger.debug(f"Removing existing cached predictions file {self._cached_predictions_path}")
100
- self._cached_predictions_path.unlink()
102
+ self.prediction_cache: PredictionCache = get_prediction_cache(cache_predictions, self.path)
103
+ self.prediction_cache.clear()
101
104
 
102
105
  @property
103
106
  def path_pkl(self) -> str:
@@ -115,14 +118,14 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
115
118
  path = os.path.join(self.path_data, "train.pkl")
116
119
  return load_pkl.load(path=path)
117
120
 
118
- def load_val_data(self) -> Optional[TimeSeriesDataFrame]:
121
+ def load_val_data(self) -> TimeSeriesDataFrame | None:
119
122
  path = os.path.join(self.path_data, "val.pkl")
120
123
  if os.path.exists(path):
121
124
  return load_pkl.load(path=path)
122
125
  else:
123
126
  return None
124
127
 
125
- def load_data(self) -> Tuple[TimeSeriesDataFrame, Optional[TimeSeriesDataFrame]]:
128
+ def load_data(self) -> tuple[TimeSeriesDataFrame, TimeSeriesDataFrame | None]:
126
129
  train_data = self.load_train_data()
127
130
  val_data = self.load_val_data()
128
131
  return train_data, val_data
@@ -137,24 +140,24 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
137
140
 
138
141
  self.models = models
139
142
 
140
- def _get_model_oof_predictions(self, model_name: str) -> List[TimeSeriesDataFrame]:
143
+ def _get_model_oof_predictions(self, model_name: str) -> list[TimeSeriesDataFrame]:
141
144
  model_path = os.path.join(self.path, self.get_model_attribute(model=model_name, attribute="path"))
142
145
  model_type = self.get_model_attribute(model=model_name, attribute="type")
143
146
  return model_type.load_oof_predictions(path=model_path)
144
147
 
145
148
  def _add_model(
146
149
  self,
147
- model: AbstractTimeSeriesModel,
148
- base_models: Optional[List[str]] = None,
150
+ model: TimeSeriesModelBase,
151
+ base_models: list[str] | None = None,
149
152
  ):
150
153
  """Add a model to the model graph of the trainer. If the model is an ensemble, also add
151
154
  information about dependencies to the model graph (list of models specified via ``base_models``).
152
155
 
153
156
  Parameters
154
157
  ----------
155
- model : AbstractTimeSeriesModel
158
+ model
156
159
  The model to be added to the model graph.
157
- base_models : List[str], optional, default None
160
+ base_models
158
161
  If the model is an ensemble, the list of base model names that are included in the ensemble.
159
162
  Expected only when ``model`` is a ``AbstractTimeSeriesEnsembleModel``.
160
163
 
@@ -177,8 +180,8 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
177
180
  for base_model in base_models:
178
181
  self.model_graph.add_edge(base_model, model.name)
179
182
 
180
- def _get_model_levels(self) -> Dict[str, int]:
181
- """Get a dictionary mapping each model to their level in the model graph"""
183
+ def _get_model_layers(self) -> dict[str, int]:
184
+ """Get a dictionary mapping each model to their layer in the model graph"""
182
185
 
183
186
  # get nodes without a parent
184
187
  rootset = set(self.model_graph.nodes)
@@ -191,14 +194,14 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
191
194
  for dest_node in paths_to:
192
195
  paths_from[dest_node][source_node] = paths_to[dest_node]
193
196
 
194
- # determine levels
195
- levels = {}
197
+ # determine layers
198
+ layers = {}
196
199
  for n in paths_from:
197
- levels[n] = max(paths_from[n].get(src, 0) for src in rootset)
200
+ layers[n] = max(paths_from[n].get(src, 0) for src in rootset)
198
201
 
199
- return levels
202
+ return layers
200
203
 
201
- def get_models_attribute_dict(self, attribute: str, models: Optional[List[str]] = None) -> Dict[str, Any]:
204
+ def get_models_attribute_dict(self, attribute: str, models: list[str] | None = None) -> dict[str, Any]:
202
205
  """Get an attribute from the `model_graph` for each of the model names
203
206
  specified. If `models` is none, the attribute will be returned for all models"""
204
207
  results = {}
@@ -216,28 +219,28 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
216
219
  if len(models) == 1:
217
220
  return models[0]
218
221
  model_performances = self.get_models_attribute_dict(attribute="val_score")
219
- model_levels = self._get_model_levels()
220
- model_name_score_level_list = [
221
- (m, model_performances[m], model_levels.get(m, 0)) for m in models if model_performances[m] is not None
222
+ model_layers = self._get_model_layers()
223
+ model_name_score_layer_list = [
224
+ (m, model_performances[m], model_layers.get(m, 0)) for m in models if model_performances[m] is not None
222
225
  ]
223
226
 
224
- if not model_name_score_level_list:
227
+ if not model_name_score_layer_list:
225
228
  raise ValueError("No fitted models have validation scores computed.")
226
229
 
227
230
  # rank models in terms of validation score. if two models have the same validation score,
228
- # rank them by their level in the model graph (lower level models are preferred).
231
+ # rank them by their layer in the model graph (lower layer models are preferred).
229
232
  return max(
230
- model_name_score_level_list,
231
- key=lambda mns: (mns[1], -mns[2]), # (score, -level)
233
+ model_name_score_layer_list,
234
+ key=lambda mns: (mns[1], -mns[2]), # (score, -layer)
232
235
  )[0]
233
236
 
234
- def get_model_names(self, level: Optional[int] = None) -> List[str]:
237
+ def get_model_names(self, layer: int | None = None) -> list[str]:
235
238
  """Get model names that are registered in the model graph"""
236
- if level is not None:
237
- return list(node for node, l in self._get_model_levels().items() if l == level) # noqa: E741
239
+ if layer is not None:
240
+ return list(node for node, l in self._get_model_layers().items() if l == layer) # noqa: E741
238
241
  return list(self.model_graph.nodes)
239
242
 
240
- def get_info(self, include_model_info: bool = False) -> Dict[str, Any]:
243
+ def get_info(self, include_model_info: bool = False) -> dict[str, Any]:
241
244
  num_models_trained = len(self.get_model_names())
242
245
  if self.model_best is not None:
243
246
  best_model = self.model_best
@@ -262,32 +265,13 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
262
265
 
263
266
  return info
264
267
 
265
- def _train_single(
266
- self,
267
- train_data: TimeSeriesDataFrame,
268
- model: AbstractTimeSeriesModel,
269
- val_data: Optional[TimeSeriesDataFrame] = None,
270
- time_limit: Optional[float] = None,
271
- ) -> AbstractTimeSeriesModel:
272
- """Train the single model and return the model object that was fitted. This method
273
- does not save the resulting model."""
274
- model.fit(
275
- train_data=train_data,
276
- val_data=val_data,
277
- time_limit=time_limit,
278
- verbosity=self.verbosity,
279
- val_splitter=self.val_splitter,
280
- refit_every_n_windows=self.refit_every_n_windows,
281
- )
282
- return model
283
-
284
268
  def tune_model_hyperparameters(
285
269
  self,
286
270
  model: AbstractTimeSeriesModel,
287
271
  train_data: TimeSeriesDataFrame,
288
- time_limit: Optional[float] = None,
289
- val_data: Optional[TimeSeriesDataFrame] = None,
290
- hyperparameter_tune_kwargs: Union[str, dict] = "auto",
272
+ time_limit: float | None = None,
273
+ val_data: TimeSeriesDataFrame | None = None,
274
+ hyperparameter_tune_kwargs: str | dict = "auto",
291
275
  ):
292
276
  default_num_trials = None
293
277
  if time_limit is None and (
@@ -303,7 +287,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
303
287
  hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
304
288
  time_limit=time_limit,
305
289
  default_num_trials=default_num_trials,
306
- val_splitter=self.val_splitter,
290
+ val_splitter=self._get_val_splitter(use_val_data=val_data is not None),
307
291
  refit_every_n_windows=self.refit_every_n_windows,
308
292
  )
309
293
  total_tuning_time = time.time() - tuning_start_time
@@ -313,11 +297,21 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
313
297
  # add each of the trained HPO configurations to the trained models
314
298
  for model_hpo_name, model_info in hpo_models.items():
315
299
  model_path = os.path.join(self.path, model_info["path"])
300
+
316
301
  # Only load model configurations that didn't fail
317
- if Path(model_path).exists():
318
- model_hpo = self.load_model(model_hpo_name, path=model_path, model_type=type(model))
319
- self._add_model(model_hpo)
320
- model_names_trained.append(model_hpo.name)
302
+ if not Path(model_path).exists():
303
+ continue
304
+
305
+ model_hpo = self.load_model(model_hpo_name, path=model_path, model_type=type(model))
306
+
307
+ # override validation score to align evaluations on the final ensemble layer's window
308
+ if isinstance(model_hpo, MultiWindowBacktestingModel):
309
+ model_hpo.val_score = float(
310
+ np.mean([info["val_score"] for info in model_hpo.info_per_val_window[-self.num_val_windows[-1] :]])
311
+ )
312
+
313
+ self._add_model(model_hpo)
314
+ model_names_trained.append(model_hpo.name)
321
315
 
322
316
  logger.info(f"\tTrained {len(model_names_trained)} models while tuning {model.name}.")
323
317
 
@@ -338,14 +332,15 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
338
332
  self,
339
333
  train_data: TimeSeriesDataFrame,
340
334
  model: AbstractTimeSeriesModel,
341
- val_data: Optional[TimeSeriesDataFrame] = None,
342
- time_limit: Optional[float] = None,
343
- ) -> List[str]:
335
+ val_data: TimeSeriesDataFrame | None = None,
336
+ time_limit: float | None = None,
337
+ ) -> list[str]:
344
338
  """Fit and save the given model on given training and validation data and save the trained model.
345
339
 
346
340
  Returns
347
341
  -------
348
- model_names_trained: the list of model names that were successfully trained
342
+ model_names_trained
343
+ the list of model names that were successfully trained
349
344
  """
350
345
  fit_start_time = time.time()
351
346
  model_names_trained = []
@@ -355,26 +350,46 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
355
350
  logger.info(f"\tSkipping {model.name} due to lack of time remaining.")
356
351
  return model_names_trained
357
352
 
358
- model = self._train_single(train_data, model, val_data=val_data, time_limit=time_limit)
353
+ model.fit(
354
+ train_data=train_data,
355
+ val_data=None if isinstance(model, MultiWindowBacktestingModel) else val_data,
356
+ time_limit=time_limit,
357
+ verbosity=self.verbosity,
358
+ val_splitter=self._get_val_splitter(use_val_data=val_data is not None),
359
+ refit_every_n_windows=self.refit_every_n_windows,
360
+ )
361
+
359
362
  fit_end_time = time.time()
360
363
  model.fit_time = model.fit_time or (fit_end_time - fit_start_time)
361
364
 
362
365
  if time_limit is not None:
363
- time_limit = fit_end_time - fit_start_time
364
- if val_data is not None and not self.skip_model_selection:
366
+ time_limit = time_limit - (fit_end_time - fit_start_time)
367
+ if val_data is not None:
365
368
  model.score_and_cache_oof(
366
369
  val_data, store_val_score=True, store_predict_time=True, time_limit=time_limit
367
370
  )
368
371
 
369
- self._log_scores_and_times(model.val_score, model.fit_time, model.predict_time)
372
+ # by default, MultiWindowBacktestingModel computes validation score on all windows. However,
373
+ # when doing multi-layer stacking, the trainer only scores on the windows of the last layer.
374
+ # we override the val_score to align scores.
375
+ if isinstance(model, MultiWindowBacktestingModel):
376
+ model.val_score = float(
377
+ np.mean([info["val_score"] for info in model.info_per_val_window[-self.num_val_windows[-1] :]])
378
+ )
379
+
380
+ log_scores_and_times(
381
+ val_score=model.val_score,
382
+ fit_time=model.fit_time,
383
+ predict_time=model.predict_time,
384
+ eval_metric_name=self.eval_metric.name_with_sign,
385
+ )
370
386
 
371
387
  self.save_model(model=model)
372
388
  except TimeLimitExceeded:
373
389
  logger.error(f"\tTime limit exceeded... Skipping {model.name}.")
374
- except (Exception, MemoryError) as err:
390
+ except (Exception, MemoryError):
375
391
  logger.error(f"\tWarning: Exception caused {model.name} to fail during training... Skipping this model.")
376
- logger.error(f"\t{err}")
377
- logger.debug(traceback.format_exc())
392
+ logger.error(traceback.format_exc())
378
393
  else:
379
394
  self._add_model(model=model) # noqa: F821
380
395
  model_names_trained.append(model.name) # noqa: F821
@@ -383,45 +398,75 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
383
398
 
384
399
  return model_names_trained
385
400
 
386
- def _log_scores_and_times(
387
- self,
388
- val_score: Optional[float] = None,
389
- fit_time: Optional[float] = None,
390
- predict_time: Optional[float] = None,
391
- ):
392
- if val_score is not None:
393
- logger.info(f"\t{val_score:<7.4f}".ljust(15) + f"= Validation score ({self.eval_metric.name_with_sign})")
394
- if fit_time is not None:
395
- logger.info(f"\t{fit_time:<7.2f} s".ljust(15) + "= Training runtime")
396
- if predict_time is not None:
397
- logger.info(f"\t{predict_time:<7.2f} s".ljust(15) + "= Validation (prediction) runtime")
398
-
399
- def _train_multi(
401
+ def fit(
400
402
  self,
401
403
  train_data: TimeSeriesDataFrame,
402
- hyperparameters: Union[str, Dict],
403
- val_data: Optional[TimeSeriesDataFrame] = None,
404
- hyperparameter_tune_kwargs: Optional[Union[str, dict]] = None,
405
- excluded_model_types: Optional[List[str]] = None,
406
- time_limit: Optional[float] = None,
407
- random_seed: Optional[int] = None,
408
- ) -> List[str]:
404
+ hyperparameters: str | dict[Any, dict],
405
+ val_data: TimeSeriesDataFrame | None = None,
406
+ ensemble_hyperparameters: dict | list[dict] | None = None,
407
+ hyperparameter_tune_kwargs: str | dict | None = None,
408
+ excluded_model_types: list[str] | None = None,
409
+ time_limit: float | None = None,
410
+ random_seed: int | None = None,
411
+ ):
412
+ """Fit a set of timeseries models specified by the `hyperparameters`
413
+ dictionary that maps model names to their specified hyperparameters.
414
+
415
+ Parameters
416
+ ----------
417
+ train_data
418
+ Training data for fitting time series timeseries models.
419
+ hyperparameters
420
+ A dictionary mapping selected model names, model classes or model factory to hyperparameter
421
+ settings. Model names should be present in `trainer.presets.DEFAULT_MODEL_NAMES`. Optionally,
422
+ the user may provide one of "default", "light" and "very_light" to specify presets.
423
+ val_data
424
+ Optional validation data set to report validation scores on.
425
+ ensemble_hyperparameters
426
+ A dictionary mapping ensemble names to their specified hyperparameters. Ensemble names
427
+ should be defined in the models.ensemble namespace. defaults to `{"GreedyEnsemble": {}}`
428
+ which only fits a greedy weighted ensemble with default hyperparameters. Providing an
429
+ empty dictionary disables ensemble training.
430
+ hyperparameter_tune_kwargs
431
+ Args for hyperparameter tuning
432
+ excluded_model_types
433
+ Names of models that should not be trained, even if listed in `hyperparameters`.
434
+ time_limit
435
+ Time limit for training
436
+ random_seed
437
+ Random seed that will be set to each model during training
438
+ """
409
439
  logger.info(f"\nStarting training. Start time is {time.strftime('%Y-%m-%d %H:%M:%S')}")
410
440
 
441
+ # Handle ensemble hyperparameters
442
+ if ensemble_hyperparameters is None:
443
+ ensemble_hyperparameters = [{"GreedyEnsemble": {}}]
444
+ if isinstance(ensemble_hyperparameters, dict):
445
+ ensemble_hyperparameters = [ensemble_hyperparameters]
446
+ validate_ensemble_hyperparameters(ensemble_hyperparameters)
447
+
411
448
  time_start = time.time()
412
449
  hyperparameters = copy.deepcopy(hyperparameters)
413
450
 
451
+ if val_data is not None:
452
+ if self.num_val_windows[-1] != 1:
453
+ raise ValueError(
454
+ f"When val_data is provided, the last element of num_val_windows must be 1, "
455
+ f"got {self.num_val_windows[-1]}"
456
+ )
457
+ multi_window = self._get_val_splitter(use_val_data=val_data is not None).num_val_windows > 0
458
+
414
459
  if self.save_data and not self.is_data_saved:
415
460
  self.save_train_data(train_data)
416
461
  if val_data is not None:
417
462
  self.save_val_data(val_data)
418
463
  self.is_data_saved = True
419
464
 
420
- models = self.construct_model_templates(
465
+ models = self.get_trainable_base_models(
421
466
  hyperparameters=hyperparameters,
422
467
  hyperparameter_tune=hyperparameter_tune_kwargs is not None, # TODO: remove hyperparameter_tune
423
468
  freq=train_data.freq,
424
- multi_window=self.val_splitter.num_val_windows > 0,
469
+ multi_window=multi_window,
425
470
  excluded_model_types=excluded_model_types,
426
471
  )
427
472
 
@@ -433,7 +478,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
433
478
  "When `skip_model_selection=True`, only a single model must be provided via `hyperparameters` "
434
479
  f"but {len(models)} models were given"
435
480
  )
436
- if contains_searchspace(models[0].get_user_params()):
481
+ if contains_searchspace(models[0].get_hyperparameters()):
437
482
  raise ValueError(
438
483
  "When `skip_model_selection=True`, model configuration should contain no search spaces."
439
484
  )
@@ -461,7 +506,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
461
506
  if random_seed is not None:
462
507
  seed_everything(random_seed + i)
463
508
 
464
- if contains_searchspace(model.get_user_params()):
509
+ if contains_searchspace(model.get_hyperparameters()):
465
510
  fit_log_message = f"Hyperparameter tuning model {model.name}. "
466
511
  if time_left is not None:
467
512
  fit_log_message += (
@@ -490,42 +535,16 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
490
535
  train_data, model=model, val_data=val_data, time_limit=time_left_for_model
491
536
  )
492
537
 
493
- if self.enable_ensemble:
494
- models_available_for_ensemble = self.get_model_names(level=0)
495
-
496
- time_left_for_ensemble = None
497
- if time_limit is not None:
498
- time_left_for_ensemble = time_limit - (time.time() - time_start)
499
-
500
- if time_left_for_ensemble is not None and time_left_for_ensemble <= 0:
501
- logger.info(
502
- "Not fitting ensemble due to lack of time remaining. "
503
- f"Time left: {time_left_for_ensemble:.1f} seconds"
504
- )
505
- elif len(models_available_for_ensemble) <= 1:
506
- logger.info(
507
- "Not fitting ensemble as "
508
- + (
509
- "no models were successfully trained."
510
- if not models_available_for_ensemble
511
- else "only 1 model was trained."
512
- )
513
- )
514
- else:
515
- try:
516
- model_names_trained.append(
517
- self.fit_ensemble(
518
- data_per_window=self._get_ensemble_oof_data(train_data=train_data, val_data=val_data),
519
- model_names=models_available_for_ensemble,
520
- time_limit=time_left_for_ensemble,
521
- )
522
- )
523
- except Exception as err: # noqa
524
- logger.error(
525
- "\tWarning: Exception caused ensemble to fail during training... Skipping this model."
526
- )
527
- logger.error(f"\t{err}")
528
- logger.debug(traceback.format_exc())
538
+ if self.enable_ensemble and ensemble_hyperparameters:
539
+ model_names = self.get_model_names(layer=0)
540
+ ensemble_names = self._fit_ensembles(
541
+ data_per_window=self._get_validation_windows(train_data, val_data),
542
+ predictions_per_window=self._get_base_model_predictions(model_names),
543
+ time_limit=None if time_limit is None else time_limit - (time.time() - time_start),
544
+ ensemble_hyperparameters=ensemble_hyperparameters,
545
+ num_windows_per_layer=self.num_val_windows,
546
+ )
547
+ model_names_trained.extend(ensemble_names)
529
548
 
530
549
  logger.info(f"Training complete. Models trained: {model_names_trained}")
531
550
  logger.info(f"Total runtime: {time.time() - time_start:.2f} s")
@@ -539,73 +558,64 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
539
558
 
540
559
  return model_names_trained
541
560
 
542
- def _get_ensemble_oof_data(
543
- self, train_data: TimeSeriesDataFrame, val_data: Optional[TimeSeriesDataFrame]
544
- ) -> List[TimeSeriesDataFrame]:
545
- if val_data is None:
546
- return [val_fold for _, val_fold in self.val_splitter.split(train_data)]
547
- else:
548
- return [val_data]
549
-
550
- def _get_ensemble_model_name(self) -> str:
551
- """Ensure we don't have name collisions in the ensemble model name"""
552
- ensemble_name = "WeightedEnsemble"
553
- increment = 1
554
- while ensemble_name in self._get_banned_model_names():
555
- increment += 1
556
- ensemble_name = f"WeightedEnsemble_{increment}"
557
- return ensemble_name
558
-
559
- def fit_ensemble(
560
- self, data_per_window: List[TimeSeriesDataFrame], model_names: List[str], time_limit: Optional[float] = None
561
- ) -> str:
562
- logger.info("Fitting simple weighted ensemble.")
563
-
564
- model_preds: Dict[str, List[TimeSeriesDataFrame]] = {}
565
- for model_name in model_names:
566
- model_preds[model_name] = self._get_model_oof_predictions(model_name=model_name)
567
-
568
- time_start = time.time()
569
- ensemble = self.ensemble_model_type(
570
- name=self._get_ensemble_model_name(),
561
+ def _fit_ensembles(
562
+ self,
563
+ *,
564
+ data_per_window: list[TimeSeriesDataFrame],
565
+ predictions_per_window: dict[str, list[TimeSeriesDataFrame]],
566
+ time_limit: float | None,
567
+ ensemble_hyperparameters: list[dict],
568
+ num_windows_per_layer: tuple[int, ...],
569
+ ) -> list[str]:
570
+ ensemble_composer = EnsembleComposer(
571
+ path=self.path,
572
+ prediction_length=self.prediction_length,
571
573
  eval_metric=self.eval_metric,
572
- eval_metric_seasonal_period=self.eval_metric_seasonal_period,
573
574
  target=self.target,
574
- prediction_length=self.prediction_length,
575
- path=self.path,
576
- freq=data_per_window[0].freq,
575
+ ensemble_hyperparameters=ensemble_hyperparameters,
576
+ num_windows_per_layer=num_windows_per_layer,
577
577
  quantile_levels=self.quantile_levels,
578
- metadata=self.metadata,
578
+ model_graph=self.model_graph,
579
+ ).fit(
580
+ data_per_window=data_per_window,
581
+ predictions_per_window=predictions_per_window,
582
+ time_limit=time_limit,
579
583
  )
580
- with warning_filter():
581
- ensemble.fit_ensemble(model_preds, data_per_window=data_per_window, time_limit=time_limit)
582
- ensemble.fit_time = time.time() - time_start
583
-
584
- predict_time = 0
585
- for m in ensemble.model_names:
586
- predict_time += self.get_model_attribute(model=m, attribute="predict_time")
587
- ensemble.predict_time = predict_time
588
-
589
- score_per_fold = []
590
- for window_idx, data in enumerate(data_per_window):
591
- predictions = ensemble.predict({n: model_preds[n][window_idx] for n in ensemble.model_names})
592
- score_per_fold.append(self._score_with_predictions(data, predictions))
593
- ensemble.val_score = float(np.mean(score_per_fold, dtype=np.float64))
594
-
595
- self._log_scores_and_times(
596
- val_score=ensemble.val_score,
597
- fit_time=ensemble.fit_time,
598
- predict_time=ensemble.predict_time,
584
+
585
+ ensembles_trained = []
586
+ for _, model, base_models in ensemble_composer.iter_ensembles():
587
+ self._add_model(model=model, base_models=base_models)
588
+ self.save_model(model=model)
589
+ ensembles_trained.append(model.name)
590
+
591
+ return ensembles_trained
592
+
593
+ def _get_validation_windows(self, train_data: TimeSeriesDataFrame, val_data: TimeSeriesDataFrame | None):
594
+ train_splitter = self._get_val_splitter(use_val_data=val_data is not None)
595
+ return [val_fold for _, val_fold in train_splitter.split(train_data)] + (
596
+ [] if val_data is None else [val_data]
597
+ )
598
+
599
+ def _get_val_splitter(self, use_val_data: bool = False) -> AbstractWindowSplitter:
600
+ num_windows_from_train = sum(self.num_val_windows[:-1]) if use_val_data else sum(self.num_val_windows)
601
+ return ExpandingWindowSplitter(
602
+ prediction_length=self.prediction_length,
603
+ num_val_windows=num_windows_from_train,
604
+ val_step_size=self.val_step_size,
599
605
  )
600
- self._add_model(model=ensemble, base_models=ensemble.model_names)
601
- self.save_model(model=ensemble)
602
- return ensemble.name
606
+
607
+ def _get_base_model_predictions(self, model_names: list[str]) -> dict[str, list[TimeSeriesDataFrame]]:
608
+ """Get base model predictions for ensemble training / inference."""
609
+ predictions_per_window = {}
610
+ for model_name in model_names:
611
+ predictions_per_window[model_name] = self._get_model_oof_predictions(model_name)
612
+ return predictions_per_window
603
613
 
604
614
  def leaderboard(
605
615
  self,
606
- data: Optional[TimeSeriesDataFrame] = None,
616
+ data: TimeSeriesDataFrame | None = None,
607
617
  extra_info: bool = False,
608
- extra_metrics: Optional[List[Union[str, TimeSeriesScorer]]] = None,
618
+ extra_metrics: list[str | TimeSeriesScorer] | None = None,
609
619
  use_cache: bool = True,
610
620
  ) -> pd.DataFrame:
611
621
  logger.debug("Generating leaderboard for all models trained")
@@ -628,14 +638,15 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
628
638
  if isinstance(model, MultiWindowBacktestingModel):
629
639
  model = model.most_recent_model
630
640
  assert model is not None
631
- model_info[model_name]["hyperparameters"] = model.params
641
+ model_info[model_name]["hyperparameters"] = model.get_hyperparameters()
632
642
 
633
643
  if extra_metrics is None:
634
644
  extra_metrics = []
635
645
 
636
646
  if data is not None:
637
647
  past_data, known_covariates = data.get_model_inputs_for_scoring(
638
- prediction_length=self.prediction_length, known_covariates_names=self.metadata.known_covariates
648
+ prediction_length=self.prediction_length,
649
+ known_covariates_names=self.covariate_metadata.known_covariates,
639
650
  )
640
651
  logger.info(
641
652
  "Additional data provided, testing on additional data. Resulting leaderboard "
@@ -694,8 +705,8 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
694
705
  return df[explicit_column_order]
695
706
 
696
707
  def persist(
697
- self, model_names: Union[Literal["all", "best"], List[str]] = "all", with_ancestors: bool = False
698
- ) -> List[str]:
708
+ self, model_names: Literal["all", "best"] | list[str] = "all", with_ancestors: bool = False
709
+ ) -> list[str]:
699
710
  if model_names == "all":
700
711
  model_names = self.get_model_names()
701
712
  elif model_names == "best":
@@ -719,7 +730,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
719
730
 
720
731
  return model_names
721
732
 
722
- def unpersist(self, model_names: Union[Literal["all"], List[str]] = "all") -> List[str]:
733
+ def unpersist(self, model_names: Literal["all"] | list[str] = "all") -> list[str]:
723
734
  if model_names == "all":
724
735
  model_names = list(self.models.keys())
725
736
  if not isinstance(model_names, list):
@@ -731,9 +742,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
731
742
  unpersisted_models.append(model)
732
743
  return unpersisted_models
733
744
 
734
- def _get_model_for_prediction(
735
- self, model: Optional[Union[str, AbstractTimeSeriesModel]] = None, verbose: bool = True
736
- ) -> str:
745
+ def _get_model_for_prediction(self, model: str | TimeSeriesModelBase | None = None, verbose: bool = True) -> str:
737
746
  """Given an optional identifier or model object, return the name of the model with which to predict.
738
747
 
739
748
  If the model is not provided, this method will default to the best model according to the validation score.
@@ -749,18 +758,20 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
749
758
  )
750
759
  return self.model_best
751
760
  else:
752
- if isinstance(model, AbstractTimeSeriesModel):
761
+ if isinstance(model, TimeSeriesModelBase):
753
762
  return model.name
754
763
  else:
764
+ if model not in self.get_model_names():
765
+ raise KeyError(f"Model '{model}' not found. Available models: {self.get_model_names()}")
755
766
  return model
756
767
 
757
768
  def predict(
758
769
  self,
759
770
  data: TimeSeriesDataFrame,
760
- known_covariates: Optional[TimeSeriesDataFrame] = None,
761
- model: Optional[Union[str, AbstractTimeSeriesModel]] = None,
771
+ known_covariates: TimeSeriesDataFrame | None = None,
772
+ model: str | TimeSeriesModelBase | None = None,
762
773
  use_cache: bool = True,
763
- random_seed: Optional[int] = None,
774
+ random_seed: int | None = None,
764
775
  ) -> TimeSeriesDataFrame:
765
776
  model_name = self._get_model_for_prediction(model)
766
777
  model_pred_dict, _ = self.get_model_pred_dict(
@@ -775,49 +786,57 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
775
786
  raise ValueError(f"Model {model_name} failed to predict. Please check the model's logs.")
776
787
  return predictions
777
788
 
789
+ def _get_eval_metric(self, metric: str | TimeSeriesScorer | None) -> TimeSeriesScorer:
790
+ if metric is None:
791
+ return self.eval_metric
792
+ else:
793
+ return check_get_evaluation_metric(
794
+ metric,
795
+ prediction_length=self.prediction_length,
796
+ seasonal_period=self.eval_metric.seasonal_period,
797
+ horizon_weight=self.eval_metric.horizon_weight,
798
+ )
799
+
778
800
  def _score_with_predictions(
779
801
  self,
780
802
  data: TimeSeriesDataFrame,
781
803
  predictions: TimeSeriesDataFrame,
782
- metric: Union[str, TimeSeriesScorer, None] = None,
804
+ metric: str | TimeSeriesScorer | None = None,
783
805
  ) -> float:
784
806
  """Compute the score measuring how well the predictions align with the data."""
785
- eval_metric = self.eval_metric if metric is None else check_get_evaluation_metric(metric)
786
- return eval_metric.score(
807
+ return self._get_eval_metric(metric).score(
787
808
  data=data,
788
809
  predictions=predictions,
789
- prediction_length=self.prediction_length,
790
810
  target=self.target,
791
- seasonal_period=self.eval_metric_seasonal_period,
792
811
  )
793
812
 
794
813
  def score(
795
814
  self,
796
815
  data: TimeSeriesDataFrame,
797
- model: Optional[Union[str, AbstractTimeSeriesModel]] = None,
798
- metric: Union[str, TimeSeriesScorer, None] = None,
816
+ model: str | TimeSeriesModelBase | None = None,
817
+ metric: str | TimeSeriesScorer | None = None,
799
818
  use_cache: bool = True,
800
819
  ) -> float:
801
- eval_metric = self.eval_metric if metric is None else check_get_evaluation_metric(metric)
820
+ eval_metric = self._get_eval_metric(metric)
802
821
  scores_dict = self.evaluate(data=data, model=model, metrics=[eval_metric], use_cache=use_cache)
803
822
  return scores_dict[eval_metric.name]
804
823
 
805
824
  def evaluate(
806
825
  self,
807
826
  data: TimeSeriesDataFrame,
808
- model: Optional[Union[str, AbstractTimeSeriesModel]] = None,
809
- metrics: Optional[Union[str, TimeSeriesScorer, List[Union[str, TimeSeriesScorer]]]] = None,
827
+ model: str | TimeSeriesModelBase | None = None,
828
+ metrics: str | TimeSeriesScorer | list[str | TimeSeriesScorer] | None = None,
810
829
  use_cache: bool = True,
811
- ) -> Dict[str, float]:
830
+ ) -> dict[str, float]:
812
831
  past_data, known_covariates = data.get_model_inputs_for_scoring(
813
- prediction_length=self.prediction_length, known_covariates_names=self.metadata.known_covariates
832
+ prediction_length=self.prediction_length, known_covariates_names=self.covariate_metadata.known_covariates
814
833
  )
815
834
  predictions = self.predict(data=past_data, known_covariates=known_covariates, model=model, use_cache=use_cache)
816
835
 
817
836
  metrics_ = [metrics] if not isinstance(metrics, list) else metrics
818
837
  scores_dict = {}
819
838
  for metric in metrics_:
820
- eval_metric = self.eval_metric if metric is None else check_get_evaluation_metric(metric)
839
+ eval_metric = self._get_eval_metric(metric)
821
840
  scores_dict[eval_metric.name] = self._score_with_predictions(
822
841
  data=data, predictions=predictions, metric=eval_metric
823
842
  )
@@ -826,20 +845,20 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
826
845
  def get_feature_importance(
827
846
  self,
828
847
  data: TimeSeriesDataFrame,
829
- features: List[str],
830
- model: Optional[Union[str, AbstractTimeSeriesModel]] = None,
831
- metric: Optional[Union[str, TimeSeriesScorer]] = None,
832
- time_limit: Optional[float] = None,
848
+ features: list[str],
849
+ model: str | TimeSeriesModelBase | None = None,
850
+ metric: str | TimeSeriesScorer | None = None,
851
+ time_limit: float | None = None,
833
852
  method: Literal["naive", "permutation"] = "permutation",
834
853
  subsample_size: int = 50,
835
- num_iterations: Optional[int] = None,
836
- random_seed: Optional[int] = None,
854
+ num_iterations: int | None = None,
855
+ random_seed: int | None = None,
837
856
  relative_scores: bool = False,
838
857
  include_confidence_band: bool = True,
839
858
  confidence_level: float = 0.99,
840
859
  ) -> pd.DataFrame:
841
860
  assert method in ["naive", "permutation"], f"Invalid feature importance method {method}."
842
- metric = check_get_evaluation_metric(metric) if metric is not None else self.eval_metric
861
+ eval_metric = self._get_eval_metric(metric)
843
862
 
844
863
  logger.info("Computing feature importance")
845
864
 
@@ -871,7 +890,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
871
890
  )
872
891
 
873
892
  importance_transform = importance_transform_type(
874
- covariate_metadata=self.metadata,
893
+ covariate_metadata=self.covariate_metadata,
875
894
  prediction_length=self.prediction_length,
876
895
  random_seed=random_seed,
877
896
  )
@@ -886,11 +905,13 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
886
905
  for n in range(num_iterations):
887
906
  if subsample_size < data.num_items:
888
907
  item_ids_sampled = data.item_ids.to_series().sample(subsample_size) # noqa
889
- data_sample: TimeSeriesDataFrame = data.query("item_id in @item_ids_sampled") # type: ignore
908
+ data_sample: TimeSeriesDataFrame = data.query("item_id in @item_ids_sampled")
890
909
  else:
891
910
  data_sample = data
892
911
 
893
- base_score = self.evaluate(data=data_sample, model=model, metrics=metric, use_cache=False)[metric.name]
912
+ base_score = self.evaluate(data=data_sample, model=model, metrics=eval_metric, use_cache=False)[
913
+ eval_metric.name
914
+ ]
894
915
 
895
916
  for feature in features:
896
917
  # override importance for unused features
@@ -898,9 +919,9 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
898
919
  continue
899
920
  else:
900
921
  data_sample_replaced = importance_transform.transform(data_sample, feature_name=feature)
901
- score = self.evaluate(data=data_sample_replaced, model=model, metrics=metric, use_cache=False)[
902
- metric.name
903
- ]
922
+ score = self.evaluate(
923
+ data=data_sample_replaced, model=model, metrics=eval_metric, use_cache=False
924
+ )[eval_metric.name]
904
925
 
905
926
  importance = base_score - score
906
927
  if relative_scores:
@@ -930,19 +951,85 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
930
951
 
931
952
  return importance_df
932
953
 
933
- def _model_uses_feature(self, model: Union[str, AbstractTimeSeriesModel], feature: str) -> bool:
954
+ def _model_uses_feature(self, model: str | TimeSeriesModelBase, feature: str) -> bool:
934
955
  """Check if the given model uses the given feature."""
935
956
  models_with_ancestors = set(self.get_minimum_model_set(model))
936
957
 
937
- if feature in self.metadata.static_features:
958
+ if feature in self.covariate_metadata.static_features:
938
959
  return any(self.load_model(m).supports_static_features for m in models_with_ancestors)
939
- elif feature in self.metadata.known_covariates:
960
+ elif feature in self.covariate_metadata.known_covariates:
940
961
  return any(self.load_model(m).supports_known_covariates for m in models_with_ancestors)
941
- elif feature in self.metadata.past_covariates:
962
+ elif feature in self.covariate_metadata.past_covariates:
942
963
  return any(self.load_model(m).supports_past_covariates for m in models_with_ancestors)
943
964
 
944
965
  return False
945
966
 
967
+ def backtest_predictions(
968
+ self,
969
+ data: TimeSeriesDataFrame | None,
970
+ model_names: list[str],
971
+ num_val_windows: int | None = None,
972
+ val_step_size: int | None = None,
973
+ use_cache: bool = True,
974
+ ) -> dict[str, list[TimeSeriesDataFrame]]:
975
+ if data is None:
976
+ assert num_val_windows is None, "num_val_windows must be None when data is None"
977
+ assert val_step_size is None, "val_step_size must be None when data is None"
978
+ return {model_name: self._get_model_oof_predictions(model_name) for model_name in model_names}
979
+
980
+ if val_step_size is None:
981
+ val_step_size = self.prediction_length
982
+ if num_val_windows is None:
983
+ num_val_windows = 1
984
+
985
+ splitter = ExpandingWindowSplitter(
986
+ prediction_length=self.prediction_length,
987
+ num_val_windows=num_val_windows,
988
+ val_step_size=val_step_size,
989
+ )
990
+
991
+ result: dict[str, list[TimeSeriesDataFrame]] = {model_name: [] for model_name in model_names}
992
+ for past_data, full_data in splitter.split(data):
993
+ known_covariates = full_data.slice_by_timestep(-self.prediction_length, None)[
994
+ self.covariate_metadata.known_covariates
995
+ ]
996
+ pred_dict, _ = self.get_model_pred_dict(
997
+ model_names=model_names,
998
+ data=past_data,
999
+ known_covariates=known_covariates,
1000
+ use_cache=use_cache,
1001
+ )
1002
+ for model_name in model_names:
1003
+ result[model_name].append(pred_dict[model_name]) # type: ignore
1004
+
1005
+ return result
1006
+
1007
+ def backtest_targets(
1008
+ self,
1009
+ data: TimeSeriesDataFrame | None,
1010
+ num_val_windows: int | None = None,
1011
+ val_step_size: int | None = None,
1012
+ ) -> list[TimeSeriesDataFrame]:
1013
+ if data is None:
1014
+ assert num_val_windows is None, "num_val_windows must be None when data is None"
1015
+ assert val_step_size is None, "val_step_size must be None when data is None"
1016
+ train_data = self.load_train_data()
1017
+ val_data = self.load_val_data()
1018
+ return self._get_validation_windows(train_data=train_data, val_data=val_data)
1019
+
1020
+ if val_step_size is None:
1021
+ val_step_size = self.prediction_length
1022
+ if num_val_windows is None:
1023
+ num_val_windows = 1
1024
+
1025
+ splitter = ExpandingWindowSplitter(
1026
+ prediction_length=self.prediction_length,
1027
+ num_val_windows=num_val_windows,
1028
+ val_step_size=val_step_size,
1029
+ )
1030
+
1031
+ return [val_fold for _, val_fold in splitter.split(data)]
1032
+
946
1033
  def _add_ci_to_feature_importance(
947
1034
  self, importance_df: pd.DataFrame, confidence_level: float = 0.99
948
1035
  ) -> pd.DataFrame:
@@ -972,10 +1059,10 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
972
1059
 
973
1060
  def _predict_model(
974
1061
  self,
975
- model: Union[str, AbstractTimeSeriesModel],
1062
+ model: str | TimeSeriesModelBase,
976
1063
  data: TimeSeriesDataFrame,
977
- model_pred_dict: Dict[str, Optional[TimeSeriesDataFrame]],
978
- known_covariates: Optional[TimeSeriesDataFrame] = None,
1064
+ model_pred_dict: dict[str, TimeSeriesDataFrame | None],
1065
+ known_covariates: TimeSeriesDataFrame | None = None,
979
1066
  ) -> TimeSeriesDataFrame:
980
1067
  """Generate predictions using the given model.
981
1068
 
@@ -988,10 +1075,10 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
988
1075
 
989
1076
  def _get_inputs_to_model(
990
1077
  self,
991
- model: Union[str, AbstractTimeSeriesModel],
1078
+ model: str | TimeSeriesModelBase,
992
1079
  data: TimeSeriesDataFrame,
993
- model_pred_dict: Dict[str, Optional[TimeSeriesDataFrame]],
994
- ) -> Union[TimeSeriesDataFrame, Dict[str, Optional[TimeSeriesDataFrame]]]:
1080
+ model_pred_dict: dict[str, TimeSeriesDataFrame | None],
1081
+ ) -> TimeSeriesDataFrame | dict[str, TimeSeriesDataFrame | None]:
995
1082
  """Get the first argument that should be passed to model.predict.
996
1083
 
997
1084
  This method assumes that model_pred_dict contains the predictions of all base models, if model is an ensemble.
@@ -1007,13 +1094,13 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1007
1094
 
1008
1095
  def get_model_pred_dict(
1009
1096
  self,
1010
- model_names: List[str],
1097
+ model_names: list[str],
1011
1098
  data: TimeSeriesDataFrame,
1012
- known_covariates: Optional[TimeSeriesDataFrame] = None,
1099
+ known_covariates: TimeSeriesDataFrame | None = None,
1013
1100
  raise_exception_if_failed: bool = True,
1014
1101
  use_cache: bool = True,
1015
- random_seed: Optional[int] = None,
1016
- ) -> Tuple[Dict[str, Optional[TimeSeriesDataFrame]], Dict[str, float]]:
1102
+ random_seed: int | None = None,
1103
+ ) -> tuple[dict[str, TimeSeriesDataFrame | None], dict[str, float]]:
1017
1104
  """Return a dictionary with predictions of all models for the given dataset.
1018
1105
 
1019
1106
  Parameters
@@ -1033,20 +1120,20 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1033
1120
  use_cache
1034
1121
  If False, will ignore the cache even if it's available.
1035
1122
  """
1036
- # TODO: Unify design of the method with Tabular
1037
- if self.cache_predictions and use_cache:
1038
- dataset_hash = self._compute_dataset_hash(data=data, known_covariates=known_covariates)
1039
- model_pred_dict, pred_time_dict_marginal = self._get_cached_pred_dicts(dataset_hash)
1123
+ if use_cache:
1124
+ model_pred_dict, pred_time_dict_marginal = self.prediction_cache.get(
1125
+ data=data, known_covariates=known_covariates
1126
+ )
1040
1127
  else:
1041
1128
  model_pred_dict = {}
1042
- pred_time_dict_marginal: Dict[str, Any] = {}
1129
+ pred_time_dict_marginal: dict[str, Any] = {}
1043
1130
 
1044
1131
  model_set = set()
1045
1132
  for model_name in model_names:
1046
1133
  model_set.update(self.get_minimum_model_set(model_name))
1047
1134
  if len(model_set) > 1:
1048
- model_to_level = self._get_model_levels()
1049
- model_set = sorted(model_set, key=model_to_level.get) # type: ignore
1135
+ model_to_layer = self._get_model_layers()
1136
+ model_set = sorted(model_set, key=model_to_layer.get) # type: ignore
1050
1137
  logger.debug(f"Prediction order: {model_set}")
1051
1138
 
1052
1139
  failed_models = []
@@ -1072,9 +1159,11 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1072
1159
 
1073
1160
  if len(failed_models) > 0 and raise_exception_if_failed:
1074
1161
  raise RuntimeError(f"Following models failed to predict: {failed_models}")
1075
- if self.cache_predictions and use_cache:
1076
- self._save_cached_pred_dicts(
1077
- dataset_hash, # type: ignore
1162
+
1163
+ if use_cache:
1164
+ self.prediction_cache.put(
1165
+ data=data,
1166
+ known_covariates=known_covariates,
1078
1167
  model_pred_dict=model_pred_dict,
1079
1168
  pred_time_dict=pred_time_dict_marginal,
1080
1169
  )
@@ -1085,7 +1174,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1085
1174
 
1086
1175
  return final_model_pred_dict, final_pred_time_dict_total
1087
1176
 
1088
- def _get_total_pred_time_from_marginal(self, pred_time_dict_marginal: Dict[str, float]) -> Dict[str, float]:
1177
+ def _get_total_pred_time_from_marginal(self, pred_time_dict_marginal: dict[str, float]) -> dict[str, float]:
1089
1178
  pred_time_dict_total = defaultdict(float)
1090
1179
  for model_name in pred_time_dict_marginal.keys():
1091
1180
  for base_model in self.get_minimum_model_set(model_name):
@@ -1093,57 +1182,8 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1093
1182
  pred_time_dict_total[model_name] += pred_time_dict_marginal[base_model]
1094
1183
  return dict(pred_time_dict_total)
1095
1184
 
1096
- @property
1097
- def _cached_predictions_path(self) -> Path:
1098
- return Path(self.path) / self._cached_predictions_filename
1099
-
1100
- @staticmethod
1101
- def _compute_dataset_hash(
1102
- data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame] = None
1103
- ) -> str:
1104
- """Compute a unique string that identifies the time series dataset."""
1105
- combined_hash = hash_pandas_df(data) + hash_pandas_df(known_covariates) + hash_pandas_df(data.static_features)
1106
- return combined_hash
1107
-
1108
- def _get_cached_pred_dicts(
1109
- self, dataset_hash: str
1110
- ) -> Tuple[Dict[str, Optional[TimeSeriesDataFrame]], Dict[str, float]]:
1111
- """Load cached predictions for given dataset_hash from disk, if possible. Otherwise returns empty dicts."""
1112
- if self._cached_predictions_path.exists():
1113
- cached_predictions = load_pkl.load(str(self._cached_predictions_path))
1114
- if dataset_hash in cached_predictions:
1115
- model_pred_dict = cached_predictions[dataset_hash]["model_pred_dict"]
1116
- pred_time_dict = cached_predictions[dataset_hash]["pred_time_dict"]
1117
- if model_pred_dict.keys() == pred_time_dict.keys():
1118
- logger.debug(f"Loaded cached predictions for models {list(model_pred_dict.keys())}")
1119
- return model_pred_dict, pred_time_dict
1120
- else:
1121
- logger.warning(f"Found corrupted cached predictions in {self._cached_predictions_path}")
1122
- logger.debug("Found no cached predictions")
1123
- return {}, {}
1124
-
1125
- def _save_cached_pred_dicts(
1126
- self,
1127
- dataset_hash: str,
1128
- model_pred_dict: Dict[str, Optional[TimeSeriesDataFrame]],
1129
- pred_time_dict: Dict[str, float],
1130
- ) -> None:
1131
- # TODO: Save separate file for each dataset if _cached_predictions file grows large?
1132
- if self._cached_predictions_path.exists():
1133
- logger.debug("Extending existing cached predictions")
1134
- cached_predictions = load_pkl.load(str(self._cached_predictions_path))
1135
- else:
1136
- cached_predictions = {}
1137
- # Do not save results for models that failed
1138
- cached_predictions[dataset_hash] = {
1139
- "model_pred_dict": {k: v for k, v in model_pred_dict.items() if v is not None},
1140
- "pred_time_dict": {k: v for k, v in pred_time_dict.items() if v is not None},
1141
- }
1142
- save_pkl.save(str(self._cached_predictions_path), object=cached_predictions)
1143
- logger.debug(f"Cached predictions saved to {self._cached_predictions_path}")
1144
-
1145
1185
  def _merge_refit_full_data(
1146
- self, train_data: TimeSeriesDataFrame, val_data: Optional[TimeSeriesDataFrame]
1186
+ self, train_data: TimeSeriesDataFrame, val_data: TimeSeriesDataFrame | None
1147
1187
  ) -> TimeSeriesDataFrame:
1148
1188
  if val_data is None:
1149
1189
  return train_data
@@ -1153,10 +1193,10 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1153
1193
 
1154
1194
  def refit_single_full(
1155
1195
  self,
1156
- train_data: Optional[TimeSeriesDataFrame] = None,
1157
- val_data: Optional[TimeSeriesDataFrame] = None,
1158
- models: Optional[List[str]] = None,
1159
- ) -> List[str]:
1196
+ train_data: TimeSeriesDataFrame | None = None,
1197
+ val_data: TimeSeriesDataFrame | None = None,
1198
+ models: list[str] | None = None,
1199
+ ) -> list[str]:
1160
1200
  train_data = train_data or self.load_train_data()
1161
1201
  val_data = val_data or self.load_val_data()
1162
1202
  refit_full_data = self._merge_refit_full_data(train_data, val_data)
@@ -1164,16 +1204,17 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1164
1204
  if models is None:
1165
1205
  models = self.get_model_names()
1166
1206
 
1167
- model_to_level = self._get_model_levels()
1168
- models_sorted_by_level = sorted(models, key=model_to_level.get) # type: ignore
1207
+ model_to_layer = self._get_model_layers()
1208
+ models_sorted_by_layer = sorted(models, key=model_to_layer.get) # type: ignore
1169
1209
 
1170
1210
  model_refit_map = {}
1171
1211
  models_trained_full = []
1172
- for model in models_sorted_by_level:
1212
+ for model in models_sorted_by_layer:
1173
1213
  model = self.load_model(model)
1174
1214
  model_name = model.name
1175
1215
  if model._get_tags()["can_refit_full"]:
1176
1216
  model_full = model.convert_to_refit_full_template()
1217
+ assert isinstance(model_full, AbstractTimeSeriesModel)
1177
1218
  logger.info(f"Fitting model: {model_full.name}")
1178
1219
  models_trained = self._train_and_save(
1179
1220
  train_data=refit_full_data,
@@ -1199,7 +1240,7 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1199
1240
  self.save()
1200
1241
  return models_trained_full
1201
1242
 
1202
- def refit_full(self, model: str = "all") -> Dict[str, str]:
1243
+ def refit_full(self, model: str = "all") -> dict[str, str]:
1203
1244
  time_start = time.time()
1204
1245
  existing_models = self.get_model_names()
1205
1246
  if model == "all":
@@ -1231,71 +1272,27 @@ class TimeSeriesTrainer(AbstractTrainer[AbstractTimeSeriesModel]):
1231
1272
  logger.info(f"Total runtime: {time.time() - time_start:.2f} s")
1232
1273
  return copy.deepcopy(self.model_refit_map)
1233
1274
 
1234
- def construct_model_templates(
1275
+ def get_trainable_base_models(
1235
1276
  self,
1236
- hyperparameters: Union[str, Dict[str, Any]],
1277
+ hyperparameters: str | dict[str, Any],
1237
1278
  *,
1238
1279
  multi_window: bool = False,
1239
- freq: Optional[str] = None,
1240
- excluded_model_types: Optional[List[str]] = None,
1280
+ freq: str | None = None,
1281
+ excluded_model_types: list[str] | None = None,
1241
1282
  hyperparameter_tune: bool = False,
1242
- ) -> List[AbstractTimeSeriesModel]:
1243
- return get_preset_models(
1283
+ ) -> list[AbstractTimeSeriesModel]:
1284
+ return TrainableModelSetBuilder(
1285
+ freq=freq,
1286
+ prediction_length=self.prediction_length,
1244
1287
  path=self.path,
1245
1288
  eval_metric=self.eval_metric,
1246
- eval_metric_seasonal_period=self.eval_metric_seasonal_period,
1247
- prediction_length=self.prediction_length,
1248
- freq=freq,
1249
- hyperparameters=hyperparameters,
1250
- hyperparameter_tune=hyperparameter_tune,
1251
1289
  quantile_levels=self.quantile_levels,
1252
- all_assigned_names=self._get_banned_model_names(),
1253
1290
  target=self.target,
1254
- metadata=self.metadata,
1255
- excluded_model_types=excluded_model_types,
1256
- # if skip_model_selection = True, we skip backtesting
1291
+ covariate_metadata=self.covariate_metadata,
1257
1292
  multi_window=multi_window and not self.skip_model_selection,
1258
- )
1259
-
1260
- def fit(
1261
- self,
1262
- train_data: TimeSeriesDataFrame,
1263
- hyperparameters: Union[str, Dict[Any, Dict]],
1264
- val_data: Optional[TimeSeriesDataFrame] = None,
1265
- hyperparameter_tune_kwargs: Optional[Union[str, Dict]] = None,
1266
- excluded_model_types: Optional[List[str]] = None,
1267
- time_limit: Optional[float] = None,
1268
- random_seed: Optional[int] = None,
1269
- ):
1270
- """
1271
- Fit a set of timeseries models specified by the `hyperparameters`
1272
- dictionary that maps model names to their specified hyperparameters.
1273
-
1274
- Parameters
1275
- ----------
1276
- train_data: TimeSeriesDataFrame
1277
- Training data for fitting time series timeseries models.
1278
- hyperparameters: str or Dict
1279
- A dictionary mapping selected model names, model classes or model factory to hyperparameter
1280
- settings. Model names should be present in `trainer.presets.DEFAULT_MODEL_NAMES`. Optionally,
1281
- the user may provide one of "default", "light" and "very_light" to specify presets.
1282
- val_data: TimeSeriesDataFrame
1283
- Optional validation data set to report validation scores on.
1284
- hyperparameter_tune_kwargs
1285
- Args for hyperparameter tuning
1286
- excluded_model_types
1287
- Names of models that should not be trained, even if listed in `hyperparameters`.
1288
- time_limit
1289
- Time limit for training
1290
- random_seed
1291
- Random seed that will be set to each model during training
1292
- """
1293
- self._train_multi(
1294
- train_data,
1295
- val_data=val_data,
1293
+ ).get_model_set(
1296
1294
  hyperparameters=hyperparameters,
1297
- hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
1295
+ hyperparameter_tune=hyperparameter_tune,
1298
1296
  excluded_model_types=excluded_model_types,
1299
- time_limit=time_limit,
1300
- random_seed=random_seed,
1297
+ banned_model_names=self._get_banned_model_names(),
1301
1298
  )