autogluon.timeseries 1.4.1b20250825__py3-none-any.whl → 1.4.1b20250826__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (20) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/models/__init__.py +2 -0
  5. autogluon/timeseries/predictor.py +2 -2
  6. autogluon/timeseries/trainer/__init__.py +3 -0
  7. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  8. autogluon/timeseries/{trainer.py → trainer/trainer.py} +13 -14
  9. autogluon/timeseries/version.py +1 -1
  10. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/METADATA +5 -5
  11. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/RECORD +18 -16
  12. autogluon/timeseries/configs/presets_configs.py +0 -79
  13. autogluon/timeseries/models/presets.py +0 -280
  14. /autogluon.timeseries-1.4.1b20250825-py3.9-nspkg.pth → /autogluon.timeseries-1.4.1b20250826-py3.9-nspkg.pth +0 -0
  15. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/LICENSE +0 -0
  16. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/NOTICE +0 -0
  17. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/WHEEL +0 -0
  18. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/namespace_packages.txt +0 -0
  19. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/top_level.txt +0 -0
  20. {autogluon.timeseries-1.4.1b20250825.dist-info → autogluon.timeseries-1.4.1b20250826.dist-info}/zip-safe +0 -0
@@ -1,3 +1,4 @@
1
- from .presets_configs import TIMESERIES_PRESETS_CONFIGS
1
+ from .hyperparameter_presets import get_hyperparameter_presets
2
+ from .predictor_presets import get_predictor_presets
2
3
 
3
- __all__ = ["TIMESERIES_PRESETS_CONFIGS"]
4
+ __all__ = ["get_hyperparameter_presets", "get_predictor_presets"]
@@ -0,0 +1,62 @@
1
+ from typing import Any, Union
2
+
3
+
4
+ def get_hyperparameter_presets() -> dict[str, dict[str, Union[dict[str, Any], list[dict[str, Any]]]]]:
5
+ return {
6
+ "very_light": {
7
+ "Naive": {},
8
+ "SeasonalNaive": {},
9
+ "ETS": {},
10
+ "Theta": {},
11
+ "RecursiveTabular": {"max_num_samples": 100_000},
12
+ "DirectTabular": {"max_num_samples": 100_000},
13
+ },
14
+ "light": {
15
+ "Naive": {},
16
+ "SeasonalNaive": {},
17
+ "ETS": {},
18
+ "Theta": {},
19
+ "RecursiveTabular": {},
20
+ "DirectTabular": {},
21
+ "TemporalFusionTransformer": {},
22
+ "Chronos": {"model_path": "bolt_small"},
23
+ },
24
+ "light_inference": {
25
+ "SeasonalNaive": {},
26
+ "DirectTabular": {},
27
+ "RecursiveTabular": {},
28
+ "TemporalFusionTransformer": {},
29
+ "PatchTST": {},
30
+ },
31
+ "default": {
32
+ "SeasonalNaive": {},
33
+ "AutoETS": {},
34
+ "NPTS": {},
35
+ "DynamicOptimizedTheta": {},
36
+ "RecursiveTabular": {},
37
+ "DirectTabular": {},
38
+ "TemporalFusionTransformer": {},
39
+ "PatchTST": {},
40
+ "DeepAR": {},
41
+ "Chronos": [
42
+ {
43
+ "ag_args": {"name_suffix": "ZeroShot"},
44
+ "model_path": "bolt_base",
45
+ },
46
+ {
47
+ "ag_args": {"name_suffix": "FineTuned"},
48
+ "model_path": "bolt_small",
49
+ "fine_tune": True,
50
+ "target_scaler": "standard",
51
+ "covariate_regressor": {"model_name": "CAT", "model_hyperparameters": {"iterations": 1_000}},
52
+ },
53
+ ],
54
+ "TiDE": {
55
+ "encoder_hidden_dim": 256,
56
+ "decoder_hidden_dim": 256,
57
+ "temporal_hidden_dim": 64,
58
+ "num_batches_per_epoch": 100,
59
+ "lr": 1e-4,
60
+ },
61
+ },
62
+ }
@@ -0,0 +1,84 @@
1
+ """Preset configurations for autogluon.timeseries Predictors"""
2
+
3
+ from typing import Any
4
+
5
+ from . import get_hyperparameter_presets
6
+
7
+ TIMESERIES_PRESETS_ALIASES = dict(
8
+ chronos="chronos_small",
9
+ best="best_quality",
10
+ high="high_quality",
11
+ medium="medium_quality",
12
+ bq="best_quality",
13
+ hq="high_quality",
14
+ mq="medium_quality",
15
+ )
16
+
17
+
18
+ def get_predictor_presets() -> dict[str, Any]:
19
+ hp_presets = get_hyperparameter_presets()
20
+
21
+ predictor_presets = dict(
22
+ best_quality={"hyperparameters": "default", "num_val_windows": 2},
23
+ high_quality={"hyperparameters": "default"},
24
+ medium_quality={"hyperparameters": "light"},
25
+ fast_training={"hyperparameters": "very_light"},
26
+ # Chronos-Bolt models
27
+ bolt_tiny={
28
+ "hyperparameters": {"Chronos": {"model_path": "bolt_tiny"}},
29
+ "skip_model_selection": True,
30
+ },
31
+ bolt_mini={
32
+ "hyperparameters": {"Chronos": {"model_path": "bolt_mini"}},
33
+ "skip_model_selection": True,
34
+ },
35
+ bolt_small={
36
+ "hyperparameters": {"Chronos": {"model_path": "bolt_small"}},
37
+ "skip_model_selection": True,
38
+ },
39
+ bolt_base={
40
+ "hyperparameters": {"Chronos": {"model_path": "bolt_base"}},
41
+ "skip_model_selection": True,
42
+ },
43
+ # Original Chronos models
44
+ chronos_tiny={
45
+ "hyperparameters": {"Chronos": {"model_path": "tiny"}},
46
+ "skip_model_selection": True,
47
+ },
48
+ chronos_mini={
49
+ "hyperparameters": {"Chronos": {"model_path": "mini"}},
50
+ "skip_model_selection": True,
51
+ },
52
+ chronos_small={
53
+ "hyperparameters": {"Chronos": {"model_path": "small"}},
54
+ "skip_model_selection": True,
55
+ },
56
+ chronos_base={
57
+ "hyperparameters": {"Chronos": {"model_path": "base"}},
58
+ "skip_model_selection": True,
59
+ },
60
+ chronos_large={
61
+ "hyperparameters": {"Chronos": {"model_path": "large", "batch_size": 8}},
62
+ "skip_model_selection": True,
63
+ },
64
+ chronos_ensemble={
65
+ "hyperparameters": {
66
+ "Chronos": {"model_path": "small"},
67
+ **hp_presets["light_inference"],
68
+ }
69
+ },
70
+ chronos_large_ensemble={
71
+ "hyperparameters": {
72
+ "Chronos": {"model_path": "large", "batch_size": 8},
73
+ **hp_presets["light_inference"],
74
+ }
75
+ },
76
+ )
77
+
78
+ # update with aliases
79
+ predictor_presets = {
80
+ **predictor_presets,
81
+ **{k: predictor_presets[v].copy() for k, v in TIMESERIES_PRESETS_ALIASES.items()},
82
+ }
83
+
84
+ return predictor_presets
@@ -27,6 +27,7 @@ from .local import (
27
27
  ThetaModel,
28
28
  ZeroModel,
29
29
  )
30
+ from .registry import ModelRegistry
30
31
 
31
32
  __all__ = [
32
33
  "ADIDAModel",
@@ -43,6 +44,7 @@ __all__ = [
43
44
  "ETSModel",
44
45
  "IMAPAModel",
45
46
  "ChronosModel",
47
+ "ModelRegistry",
46
48
  "NPTSModel",
47
49
  "NaiveModel",
48
50
  "PatchTSTModel",
@@ -21,7 +21,7 @@ from autogluon.core.utils.decorators import apply_presets
21
21
  from autogluon.core.utils.loaders import load_pkl, load_str
22
22
  from autogluon.core.utils.savers import save_pkl, save_str
23
23
  from autogluon.timeseries import __version__ as current_ag_version
24
- from autogluon.timeseries.configs import TIMESERIES_PRESETS_CONFIGS
24
+ from autogluon.timeseries.configs import get_predictor_presets
25
25
  from autogluon.timeseries.dataset.ts_dataframe import ITEMID, TimeSeriesDataFrame
26
26
  from autogluon.timeseries.learner import TimeSeriesLearner
27
27
  from autogluon.timeseries.metrics import TimeSeriesScorer, check_get_evaluation_metric
@@ -432,7 +432,7 @@ class TimeSeriesPredictor:
432
432
  )
433
433
  return train_data
434
434
 
435
- @apply_presets(TIMESERIES_PRESETS_CONFIGS)
435
+ @apply_presets(get_predictor_presets())
436
436
  def fit(
437
437
  self,
438
438
  train_data: Union[TimeSeriesDataFrame, pd.DataFrame, Path, str],
@@ -0,0 +1,3 @@
1
+ from .trainer import TimeSeriesTrainer
2
+
3
+ __all__ = ["TimeSeriesTrainer"]
@@ -0,0 +1,256 @@
1
+ import copy
2
+ import logging
3
+ import re
4
+ from collections import defaultdict
5
+ from typing import Any, Optional, Type, Union
6
+
7
+ from autogluon.common import space
8
+ from autogluon.core import constants
9
+ from autogluon.timeseries.configs import get_hyperparameter_presets
10
+ from autogluon.timeseries.metrics import TimeSeriesScorer
11
+ from autogluon.timeseries.models import ModelRegistry
12
+ from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
13
+ from autogluon.timeseries.models.multi_window import MultiWindowBacktestingModel
14
+ from autogluon.timeseries.utils.features import CovariateMetadata
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ ModelKey = Union[str, Type[AbstractTimeSeriesModel]]
20
+ ModelHyperparameters = dict[str, Any]
21
+ TrainerHyperparameterSpec = dict[ModelKey, list[ModelHyperparameters]]
22
+
23
+
24
+ class TrainableModelSetBuilder:
25
+ """Responsible for building a list of model objects, in priority order, that will be trained by the
26
+ Trainer."""
27
+
28
+ VALID_AG_ARGS_KEYS = {
29
+ "name",
30
+ "name_prefix",
31
+ "name_suffix",
32
+ }
33
+
34
+ def __init__(
35
+ self,
36
+ path: str,
37
+ freq: Optional[str],
38
+ prediction_length: int,
39
+ eval_metric: TimeSeriesScorer,
40
+ target: str,
41
+ quantile_levels: list[float],
42
+ covariate_metadata: CovariateMetadata,
43
+ multi_window: bool,
44
+ ):
45
+ self.path = path
46
+ self.freq = freq
47
+ self.prediction_length = prediction_length
48
+ self.eval_metric = eval_metric
49
+ self.target = target
50
+ self.quantile_levels = quantile_levels
51
+ self.covariate_metadata = covariate_metadata
52
+ self.multi_window = multi_window
53
+
54
+ def get_model_set(
55
+ self,
56
+ hyperparameters: Union[str, dict, None],
57
+ hyperparameter_tune: bool,
58
+ excluded_model_types: Optional[list[str]],
59
+ banned_model_names: Optional[list[str]] = None,
60
+ ) -> list[AbstractTimeSeriesModel]:
61
+ """Resolve hyperparameters and create the requested list of models"""
62
+ models = []
63
+ banned_model_names = [] if banned_model_names is None else banned_model_names.copy()
64
+
65
+ # resolve and normalize hyperparameters
66
+ model_hp_map: TrainerHyperparameterSpec = HyperparameterBuilder(
67
+ hyperparameters=hyperparameters,
68
+ hyperparameter_tune=hyperparameter_tune,
69
+ excluded_model_types=excluded_model_types,
70
+ ).get_hyperparameters()
71
+
72
+ for k in model_hp_map.keys():
73
+ if isinstance(k, type) and not issubclass(k, AbstractTimeSeriesModel):
74
+ raise ValueError(f"Custom model type {k} must inherit from `AbstractTimeSeriesModel`.")
75
+
76
+ model_priority_list = sorted(
77
+ model_hp_map.keys(), key=lambda x: ModelRegistry.get_model_priority(x), reverse=True
78
+ )
79
+
80
+ for model_key in model_priority_list:
81
+ model_type = self._get_model_type(model_key)
82
+
83
+ for model_hps in model_hp_map[model_key]:
84
+ ag_args = model_hps.pop(constants.AG_ARGS, {})
85
+
86
+ for key in ag_args:
87
+ if key not in self.VALID_AG_ARGS_KEYS:
88
+ raise ValueError(
89
+ f"Model {model_type} received unknown ag_args key: {key} (valid keys {self.VALID_AG_ARGS_KEYS})"
90
+ )
91
+ model_name_base = self._get_model_name(ag_args, model_type)
92
+
93
+ model_type_kwargs: dict[str, Any] = dict(
94
+ name=model_name_base,
95
+ hyperparameters=model_hps,
96
+ **self._get_default_model_init_kwargs(),
97
+ )
98
+
99
+ # add models while preventing name collisions
100
+ model = model_type(**model_type_kwargs)
101
+ model_type_kwargs.pop("name", None)
102
+
103
+ increment = 1
104
+ while model.name in banned_model_names:
105
+ increment += 1
106
+ model = model_type(name=f"{model_name_base}_{increment}", **model_type_kwargs)
107
+
108
+ if self.multi_window:
109
+ model = MultiWindowBacktestingModel(model_base=model, name=model.name, **model_type_kwargs) # type: ignore
110
+
111
+ banned_model_names.append(model.name)
112
+ models.append(model)
113
+
114
+ return models
115
+
116
+ def _get_model_type(self, model: ModelKey) -> Type[AbstractTimeSeriesModel]:
117
+ if isinstance(model, str):
118
+ model_type: Type[AbstractTimeSeriesModel] = ModelRegistry.get_model_class(model)
119
+ elif isinstance(model, type):
120
+ model_type = model
121
+ else:
122
+ raise ValueError(
123
+ f"Keys of the `hyperparameters` dictionary must be strings or types, received {type(model)}."
124
+ )
125
+
126
+ return model_type
127
+
128
+ def _get_default_model_init_kwargs(self) -> dict[str, Any]:
129
+ return dict(
130
+ path=self.path,
131
+ freq=self.freq,
132
+ prediction_length=self.prediction_length,
133
+ eval_metric=self.eval_metric,
134
+ target=self.target,
135
+ quantile_levels=self.quantile_levels,
136
+ covariate_metadata=self.covariate_metadata,
137
+ )
138
+
139
+ def _get_model_name(self, ag_args: dict[str, Any], model_type: Type[AbstractTimeSeriesModel]) -> str:
140
+ name = ag_args.get("name")
141
+ if name is None:
142
+ name_stem = re.sub(r"Model$", "", model_type.__name__)
143
+ name_prefix = ag_args.get("name_prefix", "")
144
+ name_suffix = ag_args.get("name_suffix", "")
145
+ name = name_prefix + name_stem + name_suffix
146
+ return name
147
+
148
+
149
+ class HyperparameterBuilder:
150
+ """Given user hyperparameter specifications, this class resolves them against presets, removes
151
+ excluded model types and canonicalizes the hyperparameter specification.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ hyperparameters: Union[str, dict, None],
157
+ hyperparameter_tune: bool,
158
+ excluded_model_types: Optional[list[str]],
159
+ ):
160
+ self.hyperparameters = hyperparameters
161
+ self.hyperparameter_tune = hyperparameter_tune
162
+ self.excluded_model_types = excluded_model_types
163
+
164
+ def get_hyperparameters(self) -> TrainerHyperparameterSpec:
165
+ hyperparameter_dict = {}
166
+ hp_presets = get_hyperparameter_presets()
167
+
168
+ if self.hyperparameters is None:
169
+ hyperparameter_dict = hp_presets["default"]
170
+ elif isinstance(self.hyperparameters, str):
171
+ try:
172
+ hyperparameter_dict = hp_presets[self.hyperparameters]
173
+ except KeyError:
174
+ raise ValueError(f"{self.hyperparameters} is not a valid preset.")
175
+ elif isinstance(self.hyperparameters, dict):
176
+ hyperparameter_dict = copy.deepcopy(self.hyperparameters)
177
+ else:
178
+ raise ValueError(
179
+ f"hyperparameters must be a dict, a string or None (received {type(self.hyperparameters)}). "
180
+ f"Please see the documentation for TimeSeriesPredictor.fit"
181
+ )
182
+
183
+ return self._check_and_clean_hyperparameters(hyperparameter_dict) # type: ignore
184
+
185
+ def _check_and_clean_hyperparameters(
186
+ self,
187
+ hyperparameters: dict[ModelKey, Union[ModelHyperparameters, list[ModelHyperparameters]]],
188
+ ) -> TrainerHyperparameterSpec:
189
+ """Convert the hyperparameters dictionary to a unified format:
190
+ - Remove 'Model' suffix from model names, if present
191
+ - Make sure that each value in the hyperparameters dict is a list with model configurations
192
+ - Checks if hyperparameters contain searchspaces
193
+ """
194
+ excluded_models = self._get_excluded_models()
195
+ hyperparameters_clean = defaultdict(list)
196
+ for model_name, model_hyperparameters in hyperparameters.items():
197
+ # Handle model names ending with "Model", e.g., "DeepARModel" is mapped to "DeepAR"
198
+ if isinstance(model_name, str):
199
+ model_name = self._normalize_model_type_name(model_name)
200
+ if model_name in excluded_models:
201
+ logger.info(
202
+ f"\tFound '{model_name}' model in `hyperparameters`, but '{model_name}' "
203
+ "is present in `excluded_model_types` and will be removed."
204
+ )
205
+ continue
206
+ if not isinstance(model_hyperparameters, list):
207
+ model_hyperparameters = [model_hyperparameters]
208
+ hyperparameters_clean[model_name].extend(model_hyperparameters)
209
+
210
+ self._verify_searchspaces(hyperparameters_clean)
211
+
212
+ return dict(hyperparameters_clean)
213
+
214
+ def _get_excluded_models(self) -> set[str]:
215
+ excluded_models = set()
216
+ if self.excluded_model_types is not None and len(self.excluded_model_types) > 0:
217
+ if not isinstance(self.excluded_model_types, list):
218
+ raise ValueError(f"`excluded_model_types` must be a list, received {type(self.excluded_model_types)}")
219
+ logger.info(f"Excluded model types: {self.excluded_model_types}")
220
+ for model in self.excluded_model_types:
221
+ if not isinstance(model, str):
222
+ raise ValueError(f"Each entry in `excluded_model_types` must be a string, received {type(model)}")
223
+ excluded_models.add(self._normalize_model_type_name(model))
224
+ return excluded_models
225
+
226
+ @staticmethod
227
+ def _normalize_model_type_name(model_name: str) -> str:
228
+ return model_name.removesuffix("Model")
229
+
230
+ def _verify_searchspaces(self, hyperparameters: dict[str, list[ModelHyperparameters]]):
231
+ if self.hyperparameter_tune:
232
+ for model, model_hps_list in hyperparameters.items():
233
+ for model_hps in model_hps_list:
234
+ if contains_searchspace(model_hps):
235
+ return
236
+
237
+ raise ValueError(
238
+ "Hyperparameter tuning specified, but no model contains a hyperparameter search space. "
239
+ "Please disable hyperparameter tuning with `hyperparameter_tune_kwargs=None` or provide a search space "
240
+ "for at least one model."
241
+ )
242
+ else:
243
+ for model, model_hps_list in hyperparameters.items():
244
+ for model_hps in model_hps_list:
245
+ if contains_searchspace(model_hps):
246
+ raise ValueError(
247
+ f"Hyperparameter tuning not specified, so hyperparameters must have fixed values. "
248
+ f"However, for model {model} hyperparameters {model_hps} contain a search space."
249
+ )
250
+
251
+
252
+ def contains_searchspace(model_hyperparameters: ModelHyperparameters) -> bool:
253
+ for hp_value in model_hyperparameters.values():
254
+ if isinstance(hp_value, space.Space):
255
+ return True
256
+ return False
@@ -22,7 +22,6 @@ from autogluon.timeseries.metrics import TimeSeriesScorer, check_get_evaluation_
22
22
  from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel, TimeSeriesModelBase
23
23
  from autogluon.timeseries.models.ensemble import AbstractTimeSeriesEnsembleModel, GreedyEnsemble
24
24
  from autogluon.timeseries.models.multi_window import MultiWindowBacktestingModel
25
- from autogluon.timeseries.models.presets import contains_searchspace, get_preset_models
26
25
  from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
27
26
  from autogluon.timeseries.utils.features import (
28
27
  ConstantReplacementFeatureImportanceTransform,
@@ -31,6 +30,8 @@ from autogluon.timeseries.utils.features import (
31
30
  )
32
31
  from autogluon.timeseries.utils.warning_filters import disable_tqdm, warning_filter
33
32
 
33
+ from .model_set_builder import TrainableModelSetBuilder, contains_searchspace
34
+
34
35
  logger = logging.getLogger("autogluon.timeseries.trainer")
35
36
 
36
37
 
@@ -416,7 +417,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
416
417
  self.save_val_data(val_data)
417
418
  self.is_data_saved = True
418
419
 
419
- models = self.construct_model_templates(
420
+ models = self.get_trainable_base_models(
420
421
  hyperparameters=hyperparameters,
421
422
  hyperparameter_tune=hyperparameter_tune_kwargs is not None, # TODO: remove hyperparameter_tune
422
423
  freq=train_data.freq,
@@ -440,8 +441,6 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
440
441
  num_base_models = len(models)
441
442
  model_names_trained = []
442
443
  for i, model in enumerate(models):
443
- assert isinstance(model, AbstractTimeSeriesModel)
444
-
445
444
  if time_limit is None:
446
445
  time_left = None
447
446
  time_left_for_model = None
@@ -1261,7 +1260,7 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1261
1260
  logger.info(f"Total runtime: {time.time() - time_start:.2f} s")
1262
1261
  return copy.deepcopy(self.model_refit_map)
1263
1262
 
1264
- def construct_model_templates(
1263
+ def get_trainable_base_models(
1265
1264
  self,
1266
1265
  hyperparameters: Union[str, dict[str, Any]],
1267
1266
  *,
@@ -1269,21 +1268,21 @@ class TimeSeriesTrainer(AbstractTrainer[TimeSeriesModelBase]):
1269
1268
  freq: Optional[str] = None,
1270
1269
  excluded_model_types: Optional[list[str]] = None,
1271
1270
  hyperparameter_tune: bool = False,
1272
- ) -> list[TimeSeriesModelBase]:
1273
- return get_preset_models(
1271
+ ) -> list[AbstractTimeSeriesModel]:
1272
+ return TrainableModelSetBuilder(
1273
+ freq=freq,
1274
+ prediction_length=self.prediction_length,
1274
1275
  path=self.path,
1275
1276
  eval_metric=self.eval_metric,
1276
- prediction_length=self.prediction_length,
1277
- freq=freq,
1278
- hyperparameters=hyperparameters,
1279
- hyperparameter_tune=hyperparameter_tune,
1280
1277
  quantile_levels=self.quantile_levels,
1281
- all_assigned_names=self._get_banned_model_names(),
1282
1278
  target=self.target,
1283
1279
  covariate_metadata=self.covariate_metadata,
1284
- excluded_model_types=excluded_model_types,
1285
- # if skip_model_selection = True, we skip backtesting
1286
1280
  multi_window=multi_window and not self.skip_model_selection,
1281
+ ).get_model_set(
1282
+ hyperparameters=hyperparameters,
1283
+ hyperparameter_tune=hyperparameter_tune,
1284
+ excluded_model_types=excluded_model_types,
1285
+ banned_model_names=self._get_banned_model_names(),
1287
1286
  )
1288
1287
 
1289
1288
  def fit(
@@ -1,4 +1,4 @@
1
1
  """This is the autogluon version file."""
2
2
 
3
- __version__ = "1.4.1b20250825"
3
+ __version__ = "1.4.1b20250826"
4
4
  __lite__ = False
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.timeseries
3
- Version: 1.4.1b20250825
3
+ Version: 1.4.1b20250826
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -55,10 +55,10 @@ Requires-Dist: fugue>=0.9.0
55
55
  Requires-Dist: tqdm<5,>=4.38
56
56
  Requires-Dist: orjson~=3.9
57
57
  Requires-Dist: tensorboard<3,>=2.9
58
- Requires-Dist: autogluon.core[raytune]==1.4.1b20250825
59
- Requires-Dist: autogluon.common==1.4.1b20250825
60
- Requires-Dist: autogluon.features==1.4.1b20250825
61
- Requires-Dist: autogluon.tabular[catboost,lightgbm,xgboost]==1.4.1b20250825
58
+ Requires-Dist: autogluon.core[raytune]==1.4.1b20250826
59
+ Requires-Dist: autogluon.common==1.4.1b20250826
60
+ Requires-Dist: autogluon.features==1.4.1b20250826
61
+ Requires-Dist: autogluon.tabular[catboost,lightgbm,xgboost]==1.4.1b20250826
62
62
  Provides-Extra: all
63
63
  Provides-Extra: tests
64
64
  Requires-Dist: pytest; extra == "tests"
@@ -1,14 +1,14 @@
1
- autogluon.timeseries-1.4.1b20250825-py3.9-nspkg.pth,sha256=cQGwpuGPqg1GXscIwt-7PmME1OnSpD-7ixkikJ31WAY,554
1
+ autogluon.timeseries-1.4.1b20250826-py3.9-nspkg.pth,sha256=cQGwpuGPqg1GXscIwt-7PmME1OnSpD-7ixkikJ31WAY,554
2
2
  autogluon/timeseries/__init__.py,sha256=_CrLLc1fkjen7UzWoO0Os8WZoHOgvZbHKy46I8v_4k4,304
3
3
  autogluon/timeseries/evaluator.py,sha256=l642tYfTHsl8WVIq_vV6qhgAFVFr9UuZD7gLra3A_Kc,250
4
4
  autogluon/timeseries/learner.py,sha256=eQrqFVOmL-2JC85LgCMkbyoLpKS02Dilg1T8RUeS_LI,13887
5
- autogluon/timeseries/predictor.py,sha256=o1RNR0vzoYSy00gsmkMkbeFKucaASf-XyVcfHkvsUbQ,88435
5
+ autogluon/timeseries/predictor.py,sha256=7X4YsWYa3Xk2RI1Irf2O-c3-I82Zqhg-cgj8cj_4AoA,88427
6
6
  autogluon/timeseries/regressor.py,sha256=lc8Qr3-8v4oxajtCnV3sxpUaW6vxXXJOA6Kr-qVne4k,11926
7
7
  autogluon/timeseries/splitter.py,sha256=8ACkuCXeUhQGUx4jz_Vv17q814WrHJQeKvq2v4-oE6s,3158
8
- autogluon/timeseries/trainer.py,sha256=Xy7A6mf7tpv2HCLRiHFxIdLlwBZXw3Z8VcKkI8_HVE8,57856
9
- autogluon/timeseries/version.py,sha256=mylW8Ea3rlNFNF9qjkkMyoxirNfX7qeiYvayEOTcpl0,91
10
- autogluon/timeseries/configs/__init__.py,sha256=BTtHIPCYeGjqgOcvqb8qPD4VNX-ICKOg6wnkew1cPOE,98
11
- autogluon/timeseries/configs/presets_configs.py,sha256=cLat8ecLlWrI-SC5KLBDCX2SbVXaucemy2pjxJAtSY0,2543
8
+ autogluon/timeseries/version.py,sha256=wbFqn083c4qhnCxA7njOp5N1chGkZNX6Wv69k5TPI18,91
9
+ autogluon/timeseries/configs/__init__.py,sha256=wiLBwxZkDTQBJkSJ9-xz3p_yJxX0dbHe108dS1P5O6A,183
10
+ autogluon/timeseries/configs/hyperparameter_presets.py,sha256=GbI2sd3uakWtaeaMyF7B5z_lmyfb6ToK6PZEUZTyG9w,2031
11
+ autogluon/timeseries/configs/predictor_presets.py,sha256=B5HFHIelh91hhG0YYE5SJ7_14P7sylFAABgHX8n_53M,2712
12
12
  autogluon/timeseries/dataset/__init__.py,sha256=UvnhAN5tjgxXTHoZMQDy64YMDj4Xxa68yY7NP4vAw0o,81
13
13
  autogluon/timeseries/dataset/ts_dataframe.py,sha256=EwxKBScspwKnJTqIk2Icukk8vIrbKYObOMAkNIn4zc8,51760
14
14
  autogluon/timeseries/metrics/__init__.py,sha256=YJPXxsJ0tRDXq7p-sTZSLb0DuXMJH6sT1PgbZ3tMt30,3594
@@ -16,8 +16,7 @@ autogluon/timeseries/metrics/abstract.py,sha256=3172nIzBko6kJl7Z5SPz8btNc_mkqNqt
16
16
  autogluon/timeseries/metrics/point.py,sha256=sS__n_Em7m4CUaBu3PNWQ_dHw1YCOHbEyC15fhytFL8,18308
17
17
  autogluon/timeseries/metrics/quantile.py,sha256=x0cq44fXRoMiuI4BVQ7mpWk1YgrK4OwLTlJAhCHQ7Xg,4634
18
18
  autogluon/timeseries/metrics/utils.py,sha256=HuDe1BNe8yJU4f_DKM913nNrUueoRaw6zhxm1-S20s0,910
19
- autogluon/timeseries/models/__init__.py,sha256=nx61eXLCxWIb-eJXpYgCw3C7naNklh_FAaKImb8EdvI,1237
20
- autogluon/timeseries/models/presets.py,sha256=HpEFA35_S9fchO5OV1aIbTD9TAIZs9nJOTU4WKCAAWc,10445
19
+ autogluon/timeseries/models/__init__.py,sha256=9YnqkOILtVEkbICk7J3VlMkMNySs-f5ErIUKrE5-fys,1294
21
20
  autogluon/timeseries/models/registry.py,sha256=8n7W04ql0ckNQUzKcAW7bxreLI8wTAUTymACgLklH9M,2158
22
21
  autogluon/timeseries/models/abstract/__init__.py,sha256=Htfkjjc3vo92RvyM8rIlQ0PLWt3jcrCKZES07UvCMV0,146
23
22
  autogluon/timeseries/models/abstract/abstract_timeseries_model.py,sha256=97HOi7fRPxtx8Y9hq-xdJI-kLMp6Z-8LUSvcfBjXFsM,31978
@@ -50,6 +49,9 @@ autogluon/timeseries/models/local/npts.py,sha256=VRZk5tEJOIentt0tLM6lxyoU8US736n
50
49
  autogluon/timeseries/models/local/statsforecast.py,sha256=sZ6aEFzAyPNZX3rMULGWFht0Toapjb3EwHe5Rb76ZxA,33318
51
50
  autogluon/timeseries/models/multi_window/__init__.py,sha256=Bq7AT2Jxdd4WNqmjTdzeqgNiwn1NCyWp4tBIWaM-zfI,60
52
51
  autogluon/timeseries/models/multi_window/multi_window_model.py,sha256=IEfQaa1_qUi8WgzjMZ_u9qx8OgWMEDe_5Plui0R2q7A,11720
52
+ autogluon/timeseries/trainer/__init__.py,sha256=_tw3iioJfvtIV7wnjtEMv0yS8oabmCFxDnGRodYE7RI,72
53
+ autogluon/timeseries/trainer/model_set_builder.py,sha256=s6tozfND3lLfst6Vxa_oP_wgCmDapyCJYFmCjkEn-es,10788
54
+ autogluon/timeseries/trainer/trainer.py,sha256=4_0IOzBL64OsJhgfvJNPRXbPWO4OQ2E6DYZNxYVNZbs,57754
53
55
  autogluon/timeseries/transforms/__init__.py,sha256=fKlT4pkJ_8Gl7IUTc3uSDzt2Xow5iH5w6fPB3ePNrTg,127
54
56
  autogluon/timeseries/transforms/covariate_scaler.py,sha256=9lEfDS4wnVZohQNnm9OcAXr3voUl83RCnctKR3O66iU,7030
55
57
  autogluon/timeseries/transforms/target_scaler.py,sha256=kTQrXAsDHCnYuqfpaVuvefyTgyp_ylDpUIPz7pArjeY,6043
@@ -62,11 +64,11 @@ autogluon/timeseries/utils/datetime/base.py,sha256=3NdsH3NDq4cVAOSoy3XpaNixyNlbj
62
64
  autogluon/timeseries/utils/datetime/lags.py,sha256=rjJtdBU0M41R1jwfmvCbo045s-6XBjhGVnGBQJ9-U1E,5997
63
65
  autogluon/timeseries/utils/datetime/seasonality.py,sha256=YK_2k8hvYIMW-sJPnjGWRtCnvIOthwA2hATB3nwVoD4,834
64
66
  autogluon/timeseries/utils/datetime/time_features.py,sha256=kEOFls4Nzh8nO0Pcz1DwLsC_NA3hMI4JUlZI3kuvuts,2666
65
- autogluon.timeseries-1.4.1b20250825.dist-info/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
66
- autogluon.timeseries-1.4.1b20250825.dist-info/METADATA,sha256=_ioD6D6hICj1WsyrLKT0kL3is6Mz92jc_0ZFbRJuGys,12463
67
- autogluon.timeseries-1.4.1b20250825.dist-info/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
68
- autogluon.timeseries-1.4.1b20250825.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
69
- autogluon.timeseries-1.4.1b20250825.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
70
- autogluon.timeseries-1.4.1b20250825.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
71
- autogluon.timeseries-1.4.1b20250825.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
72
- autogluon.timeseries-1.4.1b20250825.dist-info/RECORD,,
67
+ autogluon.timeseries-1.4.1b20250826.dist-info/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
68
+ autogluon.timeseries-1.4.1b20250826.dist-info/METADATA,sha256=33gw_Z2-JOn3FMsfY7Nrageqd0fcbzXItqKErVgS5Sg,12463
69
+ autogluon.timeseries-1.4.1b20250826.dist-info/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
70
+ autogluon.timeseries-1.4.1b20250826.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
71
+ autogluon.timeseries-1.4.1b20250826.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
72
+ autogluon.timeseries-1.4.1b20250826.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
73
+ autogluon.timeseries-1.4.1b20250826.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
74
+ autogluon.timeseries-1.4.1b20250826.dist-info/RECORD,,
@@ -1,79 +0,0 @@
1
- """Preset configurations for autogluon.timeseries Predictors"""
2
-
3
- from autogluon.timeseries.models.presets import get_default_hps
4
-
5
- # TODO: change default HPO settings when other HPO strategies (e.g., Ray tune) are available
6
- # TODO: add refit_full arguments once refitting is available
7
-
8
- TIMESERIES_PRESETS_CONFIGS = dict(
9
- best_quality={"hyperparameters": "default", "num_val_windows": 2},
10
- high_quality={"hyperparameters": "default"},
11
- medium_quality={"hyperparameters": "light"},
12
- fast_training={"hyperparameters": "very_light"},
13
- # Chronos-Bolt models
14
- bolt_tiny={
15
- "hyperparameters": {"Chronos": {"model_path": "bolt_tiny"}},
16
- "skip_model_selection": True,
17
- },
18
- bolt_mini={
19
- "hyperparameters": {"Chronos": {"model_path": "bolt_mini"}},
20
- "skip_model_selection": True,
21
- },
22
- bolt_small={
23
- "hyperparameters": {"Chronos": {"model_path": "bolt_small"}},
24
- "skip_model_selection": True,
25
- },
26
- bolt_base={
27
- "hyperparameters": {"Chronos": {"model_path": "bolt_base"}},
28
- "skip_model_selection": True,
29
- },
30
- # Original Chronos models
31
- chronos_tiny={
32
- "hyperparameters": {"Chronos": {"model_path": "tiny"}},
33
- "skip_model_selection": True,
34
- },
35
- chronos_mini={
36
- "hyperparameters": {"Chronos": {"model_path": "mini"}},
37
- "skip_model_selection": True,
38
- },
39
- chronos_small={
40
- "hyperparameters": {"Chronos": {"model_path": "small"}},
41
- "skip_model_selection": True,
42
- },
43
- chronos_base={
44
- "hyperparameters": {"Chronos": {"model_path": "base"}},
45
- "skip_model_selection": True,
46
- },
47
- chronos_large={
48
- "hyperparameters": {"Chronos": {"model_path": "large", "batch_size": 8}},
49
- "skip_model_selection": True,
50
- },
51
- chronos_ensemble={
52
- "hyperparameters": {
53
- "Chronos": {"model_path": "small"},
54
- **get_default_hps("light_inference"),
55
- }
56
- },
57
- chronos_large_ensemble={
58
- "hyperparameters": {
59
- "Chronos": {"model_path": "large", "batch_size": 8},
60
- **get_default_hps("light_inference"),
61
- }
62
- },
63
- )
64
-
65
- TIMESERIES_PRESETS_ALIASES = dict(
66
- chronos="chronos_small",
67
- best="best_quality",
68
- high="high_quality",
69
- medium="medium_quality",
70
- bq="best_quality",
71
- hq="high_quality",
72
- mq="medium_quality",
73
- )
74
-
75
- # update with aliases
76
- TIMESERIES_PRESETS_CONFIGS = {
77
- **TIMESERIES_PRESETS_CONFIGS,
78
- **{k: TIMESERIES_PRESETS_CONFIGS[v].copy() for k, v in TIMESERIES_PRESETS_ALIASES.items()},
79
- }
@@ -1,280 +0,0 @@
1
- import copy
2
- import logging
3
- import re
4
- from collections import defaultdict
5
- from typing import Any, Optional, Type, Union
6
-
7
- from autogluon.common import space
8
- from autogluon.core import constants
9
- from autogluon.timeseries.metrics import TimeSeriesScorer
10
- from autogluon.timeseries.utils.features import CovariateMetadata
11
-
12
- from .abstract import AbstractTimeSeriesModel
13
- from .multi_window.multi_window_model import MultiWindowBacktestingModel
14
- from .registry import ModelRegistry
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
- ModelHyperparameters = dict[str, Any]
19
-
20
-
21
- VALID_AG_ARGS_KEYS = {
22
- "name",
23
- "name_prefix",
24
- "name_suffix",
25
- }
26
-
27
-
28
- def get_default_hps(key):
29
- default_model_hps = {
30
- "very_light": {
31
- "Naive": {},
32
- "SeasonalNaive": {},
33
- "ETS": {},
34
- "Theta": {},
35
- "RecursiveTabular": {"max_num_samples": 100_000},
36
- "DirectTabular": {"max_num_samples": 100_000},
37
- },
38
- "light": {
39
- "Naive": {},
40
- "SeasonalNaive": {},
41
- "ETS": {},
42
- "Theta": {},
43
- "RecursiveTabular": {},
44
- "DirectTabular": {},
45
- "TemporalFusionTransformer": {},
46
- "Chronos": {"model_path": "bolt_small"},
47
- },
48
- "light_inference": {
49
- "SeasonalNaive": {},
50
- "DirectTabular": {},
51
- "RecursiveTabular": {},
52
- "TemporalFusionTransformer": {},
53
- "PatchTST": {},
54
- },
55
- "default": {
56
- "SeasonalNaive": {},
57
- "AutoETS": {},
58
- "NPTS": {},
59
- "DynamicOptimizedTheta": {},
60
- "RecursiveTabular": {},
61
- "DirectTabular": {},
62
- "TemporalFusionTransformer": {},
63
- "PatchTST": {},
64
- "DeepAR": {},
65
- "Chronos": [
66
- {
67
- "ag_args": {"name_suffix": "ZeroShot"},
68
- "model_path": "bolt_base",
69
- },
70
- {
71
- "ag_args": {"name_suffix": "FineTuned"},
72
- "model_path": "bolt_small",
73
- "fine_tune": True,
74
- "target_scaler": "standard",
75
- "covariate_regressor": {"model_name": "CAT", "model_hyperparameters": {"iterations": 1_000}},
76
- },
77
- ],
78
- "TiDE": {
79
- "encoder_hidden_dim": 256,
80
- "decoder_hidden_dim": 256,
81
- "temporal_hidden_dim": 64,
82
- "num_batches_per_epoch": 100,
83
- "lr": 1e-4,
84
- },
85
- },
86
- }
87
- return default_model_hps[key]
88
-
89
-
90
- def get_preset_models(
91
- freq: Optional[str],
92
- prediction_length: int,
93
- path: str,
94
- eval_metric: Union[str, TimeSeriesScorer],
95
- hyperparameters: Union[str, dict, None],
96
- hyperparameter_tune: bool,
97
- covariate_metadata: CovariateMetadata,
98
- all_assigned_names: list[str],
99
- excluded_model_types: Optional[list[str]],
100
- multi_window: bool = False,
101
- **kwargs,
102
- ):
103
- """
104
- Create a list of models according to hyperparameters. If hyperparamaters=None,
105
- will create models according to presets.
106
- """
107
- models = []
108
- hyperparameter_dict = get_hyperparameter_dict(hyperparameters, hyperparameter_tune)
109
-
110
- model_priority_list = sorted(
111
- hyperparameter_dict.keys(), key=lambda x: ModelRegistry.get_model_priority(x), reverse=True
112
- )
113
- excluded_models = get_excluded_models(excluded_model_types)
114
- all_assigned_names = all_assigned_names.copy()
115
-
116
- for model in model_priority_list:
117
- if isinstance(model, str):
118
- if model in excluded_models:
119
- logger.info(
120
- f"\tFound '{model}' model in `hyperparameters`, but '{model}' "
121
- "is present in `excluded_model_types` and will be removed."
122
- )
123
- continue
124
- model_type: Type[AbstractTimeSeriesModel] = ModelRegistry.get_model_class(model)
125
- elif isinstance(model, type):
126
- if not issubclass(model, AbstractTimeSeriesModel):
127
- raise ValueError(f"Custom model type {model} must inherit from `AbstractTimeSeriesModel`.")
128
- model_type = model
129
- else:
130
- raise ValueError(
131
- f"Keys of the `hyperparameters` dictionary must be strings or types, received {type(model)}."
132
- )
133
-
134
- for model_hps in hyperparameter_dict[model]:
135
- ag_args = model_hps.pop(constants.AG_ARGS, {})
136
- for key in ag_args:
137
- if key not in VALID_AG_ARGS_KEYS:
138
- raise ValueError(
139
- f"Model {model_type} received unknown ag_args key: {key} (valid keys {VALID_AG_ARGS_KEYS})"
140
- )
141
- model_name_base = get_model_name(ag_args, model_type)
142
-
143
- model_type_kwargs: dict[str, Any] = dict(
144
- name=model_name_base,
145
- path=path,
146
- freq=freq,
147
- prediction_length=prediction_length,
148
- eval_metric=eval_metric,
149
- covariate_metadata=covariate_metadata,
150
- hyperparameters=model_hps,
151
- **kwargs,
152
- )
153
-
154
- # add models while preventing name collisions
155
- model = model_type(**model_type_kwargs)
156
- model_type_kwargs.pop("name", None)
157
-
158
- increment = 1
159
- while model.name in all_assigned_names:
160
- increment += 1
161
- model = model_type(name=f"{model_name_base}_{increment}", **model_type_kwargs)
162
-
163
- if multi_window:
164
- model = MultiWindowBacktestingModel(model_base=model, name=model.name, **model_type_kwargs) # type: ignore
165
-
166
- all_assigned_names.append(model.name)
167
- models.append(model)
168
-
169
- return models
170
-
171
-
172
- def get_excluded_models(excluded_model_types: Optional[list[str]]) -> set[str]:
173
- excluded_models = set()
174
- if excluded_model_types is not None and len(excluded_model_types) > 0:
175
- if not isinstance(excluded_model_types, list):
176
- raise ValueError(f"`excluded_model_types` must be a list, received {type(excluded_model_types)}")
177
- logger.info(f"Excluded model types: {excluded_model_types}")
178
- for model in excluded_model_types:
179
- if not isinstance(model, str):
180
- raise ValueError(f"Each entry in `excluded_model_types` must be a string, received {type(model)}")
181
- excluded_models.add(normalize_model_type_name(model))
182
- return excluded_models
183
-
184
-
185
- def get_hyperparameter_dict(
186
- hyperparameters: Union[str, dict[str, Union[ModelHyperparameters, list[ModelHyperparameters]]], None],
187
- hyperparameter_tune: bool,
188
- ) -> dict[str, list[ModelHyperparameters]]:
189
- hyperparameter_dict = {}
190
-
191
- if hyperparameters is None:
192
- hyperparameter_dict = copy.deepcopy(get_default_hps("default"))
193
- elif isinstance(hyperparameters, str):
194
- hyperparameter_dict = copy.deepcopy(get_default_hps(hyperparameters))
195
- elif isinstance(hyperparameters, dict):
196
- hyperparameter_dict = copy.deepcopy(hyperparameters)
197
- else:
198
- raise ValueError(
199
- f"hyperparameters must be a dict, a string or None (received {type(hyperparameters)}). "
200
- f"Please see the documentation for TimeSeriesPredictor.fit"
201
- )
202
-
203
- hyperparameter_dict = check_and_clean_hyperparameters(
204
- hyperparameter_dict, must_contain_searchspace=hyperparameter_tune
205
- )
206
-
207
- return hyperparameter_dict
208
-
209
-
210
- def normalize_model_type_name(model_name: str) -> str:
211
- """Remove 'Model' suffix from the end of the string, if it's present."""
212
- if model_name.endswith("Model"):
213
- model_name = model_name[: -len("Model")]
214
- return model_name
215
-
216
-
217
- def check_and_clean_hyperparameters(
218
- hyperparameters: dict[str, Union[ModelHyperparameters, list[ModelHyperparameters]]],
219
- must_contain_searchspace: bool,
220
- ) -> dict[str, list[ModelHyperparameters]]:
221
- """Convert the hyperparameters dictionary to a unified format:
222
- - Remove 'Model' suffix from model names, if present
223
- - Make sure that each value in the hyperparameters dict is a list with model configurations
224
- - Checks if hyperparameters contain searchspaces
225
- """
226
- hyperparameters_clean = defaultdict(list)
227
- for key, value in hyperparameters.items():
228
- # Handle model names ending with "Model", e.g., "DeepARModel" is mapped to "DeepAR"
229
- if isinstance(key, str):
230
- key = normalize_model_type_name(key)
231
- if not isinstance(value, list):
232
- value = [value]
233
- hyperparameters_clean[key].extend(value)
234
-
235
- if must_contain_searchspace:
236
- verify_contains_at_least_one_searchspace(hyperparameters_clean)
237
- else:
238
- verify_contains_no_searchspaces(hyperparameters_clean)
239
-
240
- return dict(hyperparameters_clean)
241
-
242
-
243
- def get_model_name(ag_args: dict[str, Any], model_type: Type[AbstractTimeSeriesModel]) -> str:
244
- name = ag_args.get("name")
245
- if name is None:
246
- name_stem = re.sub(r"Model$", "", model_type.__name__)
247
- name_prefix = ag_args.get("name_prefix", "")
248
- name_suffix = ag_args.get("name_suffix", "")
249
- name = name_prefix + name_stem + name_suffix
250
- return name
251
-
252
-
253
- def contains_searchspace(model_hyperparameters: ModelHyperparameters) -> bool:
254
- for hp_value in model_hyperparameters.values():
255
- if isinstance(hp_value, space.Space):
256
- return True
257
- return False
258
-
259
-
260
- def verify_contains_at_least_one_searchspace(hyperparameters: dict[str, list[ModelHyperparameters]]):
261
- for model, model_hps_list in hyperparameters.items():
262
- for model_hps in model_hps_list:
263
- if contains_searchspace(model_hps):
264
- return
265
-
266
- raise ValueError(
267
- "Hyperparameter tuning specified, but no model contains a hyperparameter search space. "
268
- "Please disable hyperparameter tuning with `hyperparameter_tune_kwargs=None` or provide a search space "
269
- "for at least one model."
270
- )
271
-
272
-
273
- def verify_contains_no_searchspaces(hyperparameters: dict[str, list[ModelHyperparameters]]):
274
- for model, model_hps_list in hyperparameters.items():
275
- for model_hps in model_hps_list:
276
- if contains_searchspace(model_hps):
277
- raise ValueError(
278
- f"Hyperparameter tuning not specified, so hyperparameters must have fixed values. "
279
- f"However, for model {model} hyperparameters {model_hps} contain a search space."
280
- )