autogluon.timeseries 1.3.2b20250712__py3-none-any.whl → 1.4.1b20251116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +98 -72
  5. autogluon/timeseries/learner.py +19 -18
  6. autogluon/timeseries/metrics/__init__.py +5 -5
  7. autogluon/timeseries/metrics/abstract.py +17 -17
  8. autogluon/timeseries/metrics/point.py +1 -1
  9. autogluon/timeseries/metrics/quantile.py +2 -2
  10. autogluon/timeseries/metrics/utils.py +4 -4
  11. autogluon/timeseries/models/__init__.py +4 -0
  12. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -75
  13. autogluon/timeseries/models/abstract/tunable.py +6 -6
  14. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +72 -76
  15. autogluon/timeseries/models/autogluon_tabular/per_step.py +104 -46
  16. autogluon/timeseries/models/autogluon_tabular/transforms.py +9 -7
  17. autogluon/timeseries/models/chronos/model.py +115 -78
  18. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +76 -44
  19. autogluon/timeseries/models/ensemble/__init__.py +29 -2
  20. autogluon/timeseries/models/ensemble/abstract.py +16 -52
  21. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  22. autogluon/timeseries/models/ensemble/array_based/abstract.py +247 -0
  23. autogluon/timeseries/models/ensemble/array_based/models.py +50 -0
  24. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +10 -0
  25. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +87 -0
  26. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +133 -0
  27. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +141 -0
  28. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  29. autogluon/timeseries/models/ensemble/weighted/abstract.py +41 -0
  30. autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +8 -18
  31. autogluon/timeseries/models/ensemble/{greedy.py → weighted/greedy.py} +13 -13
  32. autogluon/timeseries/models/gluonts/abstract.py +26 -26
  33. autogluon/timeseries/models/gluonts/dataset.py +4 -4
  34. autogluon/timeseries/models/gluonts/models.py +27 -12
  35. autogluon/timeseries/models/local/abstract_local_model.py +14 -14
  36. autogluon/timeseries/models/local/naive.py +4 -0
  37. autogluon/timeseries/models/local/npts.py +1 -0
  38. autogluon/timeseries/models/local/statsforecast.py +30 -14
  39. autogluon/timeseries/models/multi_window/multi_window_model.py +34 -23
  40. autogluon/timeseries/models/registry.py +65 -0
  41. autogluon/timeseries/models/toto/__init__.py +3 -0
  42. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  43. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  44. autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
  45. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  46. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  47. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  48. autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
  49. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
  50. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  51. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  52. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  53. autogluon/timeseries/models/toto/dataloader.py +108 -0
  54. autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
  55. autogluon/timeseries/models/toto/model.py +236 -0
  56. autogluon/timeseries/predictor.py +94 -107
  57. autogluon/timeseries/regressor.py +31 -27
  58. autogluon/timeseries/splitter.py +7 -31
  59. autogluon/timeseries/trainer/__init__.py +3 -0
  60. autogluon/timeseries/trainer/ensemble_composer.py +250 -0
  61. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  62. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  63. autogluon/timeseries/{trainer.py → trainer/trainer.py} +182 -307
  64. autogluon/timeseries/trainer/utils.py +18 -0
  65. autogluon/timeseries/transforms/covariate_scaler.py +4 -4
  66. autogluon/timeseries/transforms/target_scaler.py +14 -14
  67. autogluon/timeseries/utils/datetime/lags.py +2 -2
  68. autogluon/timeseries/utils/datetime/time_features.py +2 -2
  69. autogluon/timeseries/utils/features.py +41 -37
  70. autogluon/timeseries/utils/forecast.py +5 -5
  71. autogluon/timeseries/utils/warning_filters.py +3 -1
  72. autogluon/timeseries/version.py +1 -1
  73. autogluon.timeseries-1.4.1b20251116-py3.9-nspkg.pth +1 -0
  74. {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/METADATA +32 -17
  75. autogluon_timeseries-1.4.1b20251116.dist-info/RECORD +96 -0
  76. {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/WHEEL +1 -1
  77. autogluon/timeseries/configs/presets_configs.py +0 -79
  78. autogluon/timeseries/evaluator.py +0 -6
  79. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
  80. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  81. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
  82. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -530
  83. autogluon/timeseries/models/presets.py +0 -358
  84. autogluon.timeseries-1.3.2b20250712-py3.9-nspkg.pth +0 -1
  85. autogluon.timeseries-1.3.2b20250712.dist-info/RECORD +0 -71
  86. {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info/licenses}/LICENSE +0 -0
  87. {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info/licenses}/NOTICE +0 -0
  88. {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/namespace_packages.txt +0 -0
  89. {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/top_level.txt +0 -0
  90. {autogluon.timeseries-1.3.2b20250712.dist-info → autogluon_timeseries-1.4.1b20251116.dist-info}/zip-safe +0 -0
@@ -1,13 +1,13 @@
1
1
  import logging
2
2
  import time
3
- from typing import Any, Dict, Optional, Protocol, Union, overload, runtime_checkable
3
+ from typing import Any, Optional, Protocol, Union, overload, runtime_checkable
4
4
 
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
 
8
8
  from autogluon.core.models import AbstractModel
9
9
  from autogluon.tabular.registry import ag_model_registry as tabular_ag_model_registry
10
- from autogluon.timeseries.dataset.ts_dataframe import ITEMID, TimeSeriesDataFrame
10
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
11
11
  from autogluon.timeseries.utils.features import CovariateMetadata
12
12
 
13
13
  logger = logging.getLogger(__name__)
@@ -40,42 +40,42 @@ class GlobalCovariateRegressor(CovariateRegressor):
40
40
 
41
41
  Parameters
42
42
  ----------
43
- model_name : str
44
- Name of the tabular regression model. See `autogluon.tabular.registry.ag_model_registry` or
43
+ model_name
44
+ Name of the tabular regression model. See ``autogluon.tabular.registry.ag_model_registry`` or
45
45
  `the documentation <https://auto.gluon.ai/stable/api/autogluon.tabular.models.html>`_ for the list of available
46
46
  tabular models.
47
- model_hyperparameters : dict or None
47
+ model_hyperparameters
48
48
  Hyperparameters passed to the tabular regression model.
49
- eval_metric : str
50
- Metric provided as `eval_metric` to the tabular regression model. Must be compatible with `problem_type="regression"`.
51
- refit_during_predict : bool
52
- If True, the model will be re-trained every time `fit_transform` is called. If False, the model will only be
53
- trained the first time that `fit_transform` is called, and future calls to `fit_transform` will only perform a
54
- `transform`.
55
- max_num_samples : int or None
49
+ eval_metric
50
+ Metric provided as ``eval_metric`` to the tabular regression model. Must be compatible with `problem_type="regression"`.
51
+ refit_during_predict
52
+ If True, the model will be re-trained every time ``fit_transform`` is called. If False, the model will only be
53
+ trained the first time that ``fit_transform`` is called, and future calls to ``fit_transform`` will only perform a
54
+ ``transform``.
55
+ max_num_samples
56
56
  If not None, training dataset passed to regression model will contain at most this many rows.
57
- covariate_metadata : CovariateMetadata
57
+ covariate_metadata
58
58
  Metadata object describing the covariates available in the dataset.
59
- target : str
59
+ target
60
60
  Name of the target column.
61
- validation_fraction : float, optional
61
+ validation_fraction
62
62
  Fraction of observations that are reserved as the validation set during training (starting from the end of each
63
63
  time series).
64
- fit_time_fraction: float
64
+ fit_time_fraction
65
65
  The fraction of the time_limit that will be reserved for model training. The remainder (1 - fit_time_fraction)
66
66
  will be reserved for prediction.
67
67
 
68
- If the estimated prediction time exceeds `(1 - fit_time_fraction) * time_limit`, the regressor will be disabled.
69
- include_static_features: bool
68
+ If the estimated prediction time exceeds ``(1 - fit_time_fraction) * time_limit``, the regressor will be disabled.
69
+ include_static_features
70
70
  If True, static features will be included as features for the regressor.
71
- include_item_id: bool
71
+ include_item_id
72
72
  If True, item_id will be included as a categorical feature for the regressor.
73
73
  """
74
74
 
75
75
  def __init__(
76
76
  self,
77
77
  model_name: str = "CAT",
78
- model_hyperparameters: Optional[Dict[str, Any]] = None,
78
+ model_hyperparameters: Optional[dict[str, Any]] = None,
79
79
  eval_metric: str = "mean_absolute_error",
80
80
  refit_during_predict: bool = False,
81
81
  max_num_samples: Optional[int] = 500_000,
@@ -119,9 +119,9 @@ class GlobalCovariateRegressor(CovariateRegressor):
119
119
  median_ts_length = data.num_timesteps_per_item().median()
120
120
  features_to_drop = [self.target]
121
121
  if not self.include_item_id:
122
- features_to_drop += [ITEMID]
122
+ features_to_drop += [TimeSeriesDataFrame.ITEMID]
123
123
  if self.validation_fraction is not None:
124
- grouped_df = tabular_df.groupby(ITEMID, observed=False, sort=False)
124
+ grouped_df = tabular_df.groupby(TimeSeriesDataFrame.ITEMID, observed=False, sort=False)
125
125
  val_size = max(int(self.validation_fraction * median_ts_length), 1)
126
126
  train_df = self._subsample_df(grouped_df.head(-val_size))
127
127
  val_df = self._subsample_df(grouped_df.tail(val_size))
@@ -201,7 +201,7 @@ class GlobalCovariateRegressor(CovariateRegressor):
201
201
  assert self.model is not None, "CovariateRegressor must be fit before calling predict."
202
202
  tabular_df = self._get_tabular_df(data, static_features=static_features)
203
203
  if not self.include_item_id:
204
- tabular_df = tabular_df.drop(columns=[ITEMID])
204
+ tabular_df = tabular_df.drop(columns=[TimeSeriesDataFrame.ITEMID])
205
205
  return self.model.predict(X=tabular_df)
206
206
 
207
207
  def _get_tabular_df(
@@ -211,12 +211,14 @@ class GlobalCovariateRegressor(CovariateRegressor):
211
211
  include_target: bool = False,
212
212
  ) -> pd.DataFrame:
213
213
  """Construct a tabular dataframe from known covariates and static features."""
214
- available_columns = [ITEMID] + self.covariate_metadata.known_covariates
214
+ available_columns = [TimeSeriesDataFrame.ITEMID] + self.covariate_metadata.known_covariates
215
215
  if include_target:
216
216
  available_columns += [self.target]
217
- tabular_df = pd.DataFrame(data).reset_index()[available_columns].astype({ITEMID: "category"})
217
+ tabular_df = (
218
+ pd.DataFrame(data).reset_index()[available_columns].astype({TimeSeriesDataFrame.ITEMID: "category"})
219
+ )
218
220
  if static_features is not None and self.include_static_features:
219
- tabular_df = pd.merge(tabular_df, static_features, on=ITEMID)
221
+ tabular_df = pd.merge(tabular_df, static_features, on=TimeSeriesDataFrame.ITEMID)
220
222
  return tabular_df
221
223
 
222
224
  def _subsample_df(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -239,7 +241,9 @@ def get_covariate_regressor(
239
241
  if covariate_regressor is None:
240
242
  return None
241
243
  elif len(covariate_metadata.known_covariates + covariate_metadata.static_features) == 0:
242
- logger.info("\tSkipping covariate_regressor since the dataset contains no covariates or static features.")
244
+ logger.info(
245
+ "\tSkipping covariate_regressor since the dataset contains no known_covariates or static_features."
246
+ )
243
247
  return None
244
248
  else:
245
249
  if isinstance(covariate_regressor, str):
@@ -1,6 +1,6 @@
1
- from typing import Iterator, Optional, Tuple
1
+ from typing import Iterator, Optional
2
2
 
3
- from .dataset.ts_dataframe import TimeSeriesDataFrame
3
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
4
4
 
5
5
  __all__ = [
6
6
  "AbstractWindowSplitter",
@@ -13,7 +13,7 @@ class AbstractWindowSplitter:
13
13
  self.prediction_length = prediction_length
14
14
  self.num_val_windows = num_val_windows
15
15
 
16
- def split(self, data: TimeSeriesDataFrame) -> Iterator[Tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]]:
16
+ def split(self, data: TimeSeriesDataFrame) -> Iterator[tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]]:
17
17
  raise NotImplementedError
18
18
 
19
19
 
@@ -33,11 +33,11 @@ class ExpandingWindowSplitter(AbstractWindowSplitter):
33
33
 
34
34
  Parameters
35
35
  ----------
36
- prediction_length : int
36
+ prediction_length
37
37
  Length of the forecast horizon.
38
- num_val_windows: int, default = 1
38
+ num_val_windows
39
39
  Number of windows to generate from each time series in the dataset.
40
- val_step_size : int, optional
40
+ val_step_size
41
41
  The end of each subsequent window is moved this many time steps forward.
42
42
  """
43
43
 
@@ -47,7 +47,7 @@ class ExpandingWindowSplitter(AbstractWindowSplitter):
47
47
  val_step_size = prediction_length
48
48
  self.val_step_size = val_step_size
49
49
 
50
- def split(self, data: TimeSeriesDataFrame) -> Iterator[Tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]]:
50
+ def split(self, data: TimeSeriesDataFrame) -> Iterator[tuple[TimeSeriesDataFrame, TimeSeriesDataFrame]]:
51
51
  """Generate train and validation folds for a time series dataset."""
52
52
  for window_idx in range(1, self.num_val_windows + 1):
53
53
  val_end = -(self.num_val_windows - window_idx) * self.val_step_size
@@ -57,27 +57,3 @@ class ExpandingWindowSplitter(AbstractWindowSplitter):
57
57
  train_data = data.slice_by_timestep(None, train_end)
58
58
  val_data = data.slice_by_timestep(None, val_end)
59
59
  yield train_data, val_data
60
-
61
-
62
- class AbstractTimeSeriesSplitter:
63
- def __init__(self, *args, **kwargs):
64
- raise ValueError(
65
- "`AbstractTimeSeriesSplitter` has been deprecated. "
66
- "Please use `autogluon.timeseries.splitter.ExpandingWindowSplitter` instead."
67
- )
68
-
69
-
70
- class MultiWindowSplitter(AbstractTimeSeriesSplitter):
71
- def __init__(self, *args, **kwargs):
72
- raise ValueError(
73
- "`MultiWindowSplitter` has been deprecated. "
74
- "Please use `autogluon.timeseries.splitter.ExpandingWindowSplitter` instead."
75
- )
76
-
77
-
78
- class LastWindowSplitter(MultiWindowSplitter):
79
- def __init__(self, *args, **kwargs):
80
- raise ValueError(
81
- "`LastWindowSplitter` has been deprecated. "
82
- "Please use `autogluon.timeseries.splitter.ExpandingWindowSplitter` instead."
83
- )
@@ -0,0 +1,3 @@
1
+ from .trainer import TimeSeriesTrainer
2
+
3
+ __all__ = ["TimeSeriesTrainer"]
@@ -0,0 +1,250 @@
1
+ import logging
2
+ import os
3
+ import time
4
+ import traceback
5
+ from typing import Iterator, Optional
6
+
7
+ import networkx as nx
8
+ import numpy as np
9
+ from typing_extensions import Self
10
+
11
+ from autogluon.timeseries import TimeSeriesDataFrame
12
+ from autogluon.timeseries.metrics import TimeSeriesScorer
13
+ from autogluon.timeseries.models.ensemble import AbstractTimeSeriesEnsembleModel, get_ensemble_class
14
+ from autogluon.timeseries.splitter import AbstractWindowSplitter
15
+ from autogluon.timeseries.utils.warning_filters import warning_filter
16
+
17
+ from .utils import log_scores_and_times
18
+
19
+ logger = logging.getLogger("autogluon.timeseries.trainer")
20
+
21
+
22
+ class EnsembleComposer:
23
+ """Helper class for TimeSeriesTrainer to build multi-layer stack ensembles."""
24
+
25
+ def __init__(
26
+ self,
27
+ path,
28
+ prediction_length: int,
29
+ eval_metric: TimeSeriesScorer,
30
+ target: str,
31
+ quantile_levels: list[float],
32
+ model_graph: nx.DiGraph,
33
+ ensemble_hyperparameters: dict,
34
+ window_splitter: AbstractWindowSplitter,
35
+ ):
36
+ self.eval_metric = eval_metric
37
+ self.path = path
38
+ self.prediction_length = prediction_length
39
+ self.target = target
40
+ self.quantile_levels = quantile_levels
41
+
42
+ self.ensemble_hyperparameters = ensemble_hyperparameters
43
+
44
+ self.window_splitter = window_splitter
45
+
46
+ self.banned_model_names = list(model_graph.nodes)
47
+ self.model_graph = self._get_base_model_graph(source_graph=model_graph)
48
+
49
+ @staticmethod
50
+ def _get_base_model_graph(source_graph: nx.DiGraph) -> nx.DiGraph:
51
+ """Return a model graph by copying only base models (nodes without predecessors)
52
+ This ensures we start fresh for ensemble building.
53
+ """
54
+ rootset = EnsembleComposer._get_rootset(source_graph)
55
+
56
+ dst_graph = nx.DiGraph()
57
+ for node in rootset:
58
+ dst_graph.add_node(node, **source_graph.nodes[node])
59
+
60
+ return dst_graph
61
+
62
+ @staticmethod
63
+ def _get_rootset(graph: nx.DiGraph) -> list[str]:
64
+ return [n for n in graph.nodes if not list(graph.predecessors(n))]
65
+
66
+ def iter_ensembles(self) -> Iterator[tuple[int, AbstractTimeSeriesEnsembleModel, list[str]]]:
67
+ """Iterate over trained ensemble models, layer by layer.
68
+
69
+ Yields
70
+ ------
71
+ layer_ix
72
+ The layer index of the ensemble.
73
+ model
74
+ The ensemble model object
75
+ base_model_names
76
+ The names of the base models that are part of the ensemble.
77
+ """
78
+ rootset = self._get_rootset(self.model_graph)
79
+
80
+ for layer_ix, layer in enumerate(nx.traversal.bfs_layers(self.model_graph, rootset)):
81
+ if layer_ix == 0: # we don't need base models
82
+ continue
83
+
84
+ for model_name in layer:
85
+ attrs = self.model_graph.nodes[model_name]
86
+ model_path = os.path.join(self.path, *attrs["path"])
87
+ model = attrs["type"].load(path=model_path)
88
+
89
+ yield (
90
+ layer_ix,
91
+ model,
92
+ list(self.model_graph.predecessors(model_name)),
93
+ )
94
+
95
+ def fit(
96
+ self,
97
+ train_data: TimeSeriesDataFrame,
98
+ val_data: Optional[TimeSeriesDataFrame] = None,
99
+ time_limit: Optional[float] = None,
100
+ ) -> Self:
101
+ base_model_scores = {k: self.model_graph.nodes[k]["val_score"] for k in self.model_graph.nodes}
102
+ model_names = list(base_model_scores.keys())
103
+
104
+ if not self._can_fit_ensemble(time_limit, len(model_names)):
105
+ return self
106
+
107
+ logger.info(f"Fitting {len(self.ensemble_hyperparameters)} ensemble(s).")
108
+
109
+ # get target and base model prediction data for ensemble training
110
+ data_per_window = self._get_validation_windows(train_data=train_data, val_data=val_data)
111
+ predictions_per_window = self._get_base_model_predictions(model_names)
112
+
113
+ for ensemble_name, ensemble_hp_dict in self.ensemble_hyperparameters.items():
114
+ try:
115
+ time_start = time.monotonic()
116
+ ensemble_class = get_ensemble_class(ensemble_name)
117
+ ensemble = ensemble_class(
118
+ eval_metric=self.eval_metric,
119
+ target=self.target,
120
+ prediction_length=self.prediction_length,
121
+ path=self.path,
122
+ freq=data_per_window[0].freq,
123
+ quantile_levels=self.quantile_levels,
124
+ hyperparameters=ensemble_hp_dict,
125
+ )
126
+ # update name to prevent name collisions
127
+ ensemble.name = self._get_ensemble_model_name(ensemble.name)
128
+
129
+ with warning_filter():
130
+ ensemble.fit(
131
+ predictions_per_window=predictions_per_window,
132
+ data_per_window=data_per_window,
133
+ model_scores=base_model_scores,
134
+ time_limit=time_limit,
135
+ )
136
+ ensemble.fit_time = time.monotonic() - time_start
137
+
138
+ score_per_fold = []
139
+ for window_idx, data in enumerate(data_per_window):
140
+ predictions = ensemble.predict(
141
+ {n: predictions_per_window[n][window_idx] for n in ensemble.model_names}
142
+ )
143
+ score_per_fold.append(self.eval_metric.score(data, predictions, self.target))
144
+ ensemble.val_score = float(np.mean(score_per_fold, dtype=np.float64))
145
+
146
+ # TODO: add ensemble's own time to predict_time
147
+ ensemble.predict_time = self._calculate_base_models_predict_time(ensemble.model_names)
148
+
149
+ log_scores_and_times(
150
+ ensemble.val_score,
151
+ ensemble.fit_time,
152
+ ensemble.predict_time,
153
+ eval_metric_name=self.eval_metric.name_with_sign,
154
+ )
155
+
156
+ self._add_model(ensemble, base_models=ensemble.model_names)
157
+
158
+ # Save the ensemble model to disk
159
+ ensemble.save()
160
+ except Exception as err: # noqa
161
+ logger.error(
162
+ f"\tWarning: Exception caused {ensemble_name} to fail during training... Skipping this model."
163
+ )
164
+ logger.error(f"\t{err}")
165
+ logger.debug(traceback.format_exc())
166
+
167
+ return self
168
+
169
+ def _add_model(self, model, base_models: list[str]):
170
+ self.model_graph.add_node(
171
+ model.name,
172
+ path=os.path.relpath(model.path, self.path).split(os.sep),
173
+ type=type(model),
174
+ fit_time=model.fit_time,
175
+ predict_time=model.predict_time,
176
+ val_score=model.val_score,
177
+ )
178
+ for base_model in base_models:
179
+ self.model_graph.add_edge(base_model, model.name)
180
+
181
+ def _can_fit_ensemble(
182
+ self,
183
+ time_limit: Optional[float],
184
+ num_models_available_for_ensemble: int,
185
+ ) -> bool:
186
+ if time_limit is not None and time_limit <= 0:
187
+ logger.info(f"Not fitting ensemble due to lack of time remaining. Time left: {time_limit:.1f} seconds")
188
+ return False
189
+
190
+ if num_models_available_for_ensemble <= 1:
191
+ logger.info(
192
+ "Not fitting ensemble as "
193
+ + (
194
+ "no models were successfully trained."
195
+ if not num_models_available_for_ensemble
196
+ else "only 1 model was trained."
197
+ )
198
+ )
199
+ return False
200
+
201
+ return True
202
+
203
+ def _get_validation_windows(
204
+ self, train_data: TimeSeriesDataFrame, val_data: Optional[TimeSeriesDataFrame]
205
+ ) -> list[TimeSeriesDataFrame]:
206
+ # TODO: update for window/stack-layer logic and refit logic
207
+ if val_data is None:
208
+ return [val_fold for _, val_fold in self.window_splitter.split(train_data)]
209
+ else:
210
+ return [val_data]
211
+
212
+ def _get_ensemble_model_name(self, name: str) -> str:
213
+ """Revise name for an ensemble model, ensuring we don't have name collisions"""
214
+ base_name = name
215
+ increment = 1
216
+ while name in self.banned_model_names:
217
+ increment += 1
218
+ name = f"{base_name}_{increment}"
219
+ return name
220
+
221
+ def _get_base_model_predictions(self, model_names: list[str]) -> dict[str, list[TimeSeriesDataFrame]]:
222
+ """Get base model predictions for ensemble training / inference."""
223
+ # TODO: update for window/stack-layer logic and refit logic
224
+ predictions_per_window = {}
225
+
226
+ for model_name in model_names:
227
+ model_attrs = self.model_graph.nodes[model_name]
228
+
229
+ model_path = os.path.join(self.path, *model_attrs["path"])
230
+ model_type = model_attrs["type"]
231
+
232
+ predictions_per_window[model_name] = model_type.load_oof_predictions(path=model_path)
233
+
234
+ return predictions_per_window
235
+
236
+ def _calculate_base_models_predict_time(self, model_names: list[str]) -> float:
237
+ """Calculate ensemble predict time as sum of base model predict times."""
238
+ return sum(self.model_graph.nodes[name]["predict_time"] for name in model_names)
239
+
240
+
241
+ def validate_ensemble_hyperparameters(hyperparameters) -> dict:
242
+ """Validate ensemble hyperparameters dict."""
243
+ if not isinstance(hyperparameters, dict):
244
+ raise ValueError(f"ensemble_hyperparameters must be dict, got {type(hyperparameters)}")
245
+
246
+ # Validate all ensemble names are known
247
+ for ensemble_name, ensemble_hyperparameters in hyperparameters.items():
248
+ get_ensemble_class(ensemble_name) # Will raise if unknown
249
+ assert isinstance(ensemble_hyperparameters, dict)
250
+ return hyperparameters