autogluon.timeseries 1.2.1b20250224__py3-none-any.whl → 1.4.1b20251215__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +106 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +256 -141
  5. autogluon/timeseries/learner.py +86 -52
  6. autogluon/timeseries/metrics/__init__.py +42 -8
  7. autogluon/timeseries/metrics/abstract.py +89 -19
  8. autogluon/timeseries/metrics/point.py +142 -53
  9. autogluon/timeseries/metrics/quantile.py +46 -21
  10. autogluon/timeseries/metrics/utils.py +4 -4
  11. autogluon/timeseries/models/__init__.py +8 -2
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +361 -592
  14. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +282 -194
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +25 -18
  20. autogluon/timeseries/models/chronos/__init__.py +2 -1
  21. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  22. autogluon/timeseries/models/chronos/model.py +219 -138
  23. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +81 -50
  24. autogluon/timeseries/models/ensemble/__init__.py +37 -2
  25. autogluon/timeseries/models/ensemble/abstract.py +107 -0
  26. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  27. autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
  28. autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  34. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  35. autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
  36. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  37. autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
  38. autogluon/timeseries/models/ensemble/weighted/basic.py +91 -0
  39. autogluon/timeseries/models/ensemble/weighted/greedy.py +62 -0
  40. autogluon/timeseries/models/gluonts/__init__.py +1 -1
  41. autogluon/timeseries/models/gluonts/{abstract_gluonts.py → abstract.py} +148 -208
  42. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  43. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +38 -22
  44. autogluon/timeseries/models/local/__init__.py +0 -7
  45. autogluon/timeseries/models/local/abstract_local_model.py +71 -74
  46. autogluon/timeseries/models/local/naive.py +13 -9
  47. autogluon/timeseries/models/local/npts.py +9 -2
  48. autogluon/timeseries/models/local/statsforecast.py +52 -36
  49. autogluon/timeseries/models/multi_window/multi_window_model.py +65 -45
  50. autogluon/timeseries/models/registry.py +64 -0
  51. autogluon/timeseries/models/toto/__init__.py +3 -0
  52. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  62. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  63. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  64. autogluon/timeseries/models/toto/dataloader.py +108 -0
  65. autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
  66. autogluon/timeseries/models/toto/model.py +249 -0
  67. autogluon/timeseries/predictor.py +685 -297
  68. autogluon/timeseries/regressor.py +94 -44
  69. autogluon/timeseries/splitter.py +8 -32
  70. autogluon/timeseries/trainer/__init__.py +3 -0
  71. autogluon/timeseries/trainer/ensemble_composer.py +444 -0
  72. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  73. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  74. autogluon/timeseries/{trainer.py → trainer/trainer.py} +387 -390
  75. autogluon/timeseries/trainer/utils.py +17 -0
  76. autogluon/timeseries/transforms/__init__.py +2 -13
  77. autogluon/timeseries/transforms/covariate_scaler.py +34 -40
  78. autogluon/timeseries/transforms/target_scaler.py +37 -20
  79. autogluon/timeseries/utils/constants.py +10 -0
  80. autogluon/timeseries/utils/datetime/lags.py +3 -5
  81. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  82. autogluon/timeseries/utils/datetime/time_features.py +2 -2
  83. autogluon/timeseries/utils/features.py +70 -47
  84. autogluon/timeseries/utils/forecast.py +19 -14
  85. autogluon/timeseries/utils/timer.py +173 -0
  86. autogluon/timeseries/utils/warning_filters.py +4 -2
  87. autogluon/timeseries/version.py +1 -1
  88. autogluon.timeseries-1.4.1b20251215-py3.11-nspkg.pth +1 -0
  89. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/METADATA +49 -36
  90. autogluon_timeseries-1.4.1b20251215.dist-info/RECORD +103 -0
  91. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/WHEEL +1 -1
  92. autogluon/timeseries/configs/presets_configs.py +0 -79
  93. autogluon/timeseries/evaluator.py +0 -6
  94. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -11
  95. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  96. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -585
  97. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -518
  98. autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -78
  99. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  100. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  101. autogluon/timeseries/models/presets.py +0 -360
  102. autogluon.timeseries-1.2.1b20250224-py3.9-nspkg.pth +0 -1
  103. autogluon.timeseries-1.2.1b20250224.dist-info/RECORD +0 -68
  104. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/zip-safe +0 -0
@@ -3,7 +3,7 @@ import os
3
3
  import shutil
4
4
  from datetime import timedelta
5
5
  from pathlib import Path
6
- from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union
6
+ from typing import TYPE_CHECKING, Any, Callable, Type, cast, overload
7
7
 
8
8
  import gluonts
9
9
  import gluonts.core.settings
@@ -11,7 +11,7 @@ import numpy as np
11
11
  import pandas as pd
12
12
  from gluonts.core.component import from_hyperparameters
13
13
  from gluonts.dataset.common import Dataset as GluonTSDataset
14
- from gluonts.dataset.field_names import FieldName
14
+ from gluonts.env import env as gluonts_env
15
15
  from gluonts.model.estimator import Estimator as GluonTSEstimator
16
16
  from gluonts.model.forecast import Forecast, QuantileForecast, SampleForecast
17
17
  from gluonts.model.predictor import Predictor as GluonTSPredictor
@@ -21,11 +21,15 @@ from autogluon.core.hpo.constants import RAY_BACKEND
21
21
  from autogluon.tabular.models.tabular_nn.utils.categorical_encoders import (
22
22
  OneHotMergeRaresHandleUnknownEncoder as OneHotEncoder,
23
23
  )
24
- from autogluon.timeseries.dataset.ts_dataframe import ITEMID, TIMESTAMP, TimeSeriesDataFrame
24
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
25
25
  from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
26
- from autogluon.timeseries.utils.datetime import norm_freq_str
27
26
  from autogluon.timeseries.utils.warning_filters import disable_root_logger, warning_filter
28
27
 
28
+ if TYPE_CHECKING:
29
+ from gluonts.torch.model.forecast import DistributionForecast
30
+
31
+ from .dataset import SimpleGluonTSDataset
32
+
29
33
  # NOTE: We avoid imports for torch and lightning.pytorch at the top level and hide them inside class methods.
30
34
  # This is done to skip these imports during multiprocessing (which may cause bugs)
31
35
 
@@ -33,124 +37,25 @@ logger = logging.getLogger(__name__)
33
37
  gts_logger = logging.getLogger(gluonts.__name__)
34
38
 
35
39
 
36
- class SimpleGluonTSDataset(GluonTSDataset):
37
- """Wrapper for TimeSeriesDataFrame that is compatible with the GluonTS Dataset API."""
38
-
39
- def __init__(
40
- self,
41
- target_df: TimeSeriesDataFrame,
42
- freq: str,
43
- target_column: str = "target",
44
- feat_static_cat: Optional[np.ndarray] = None,
45
- feat_static_real: Optional[np.ndarray] = None,
46
- feat_dynamic_cat: Optional[np.ndarray] = None,
47
- feat_dynamic_real: Optional[np.ndarray] = None,
48
- past_feat_dynamic_cat: Optional[np.ndarray] = None,
49
- past_feat_dynamic_real: Optional[np.ndarray] = None,
50
- includes_future: bool = False,
51
- prediction_length: int = None,
52
- ):
53
- assert target_df is not None
54
- # Convert TimeSeriesDataFrame to pd.Series for faster processing
55
- self.target_array = target_df[target_column].to_numpy(np.float32)
56
- self.feat_static_cat = self._astype(feat_static_cat, dtype=np.int64)
57
- self.feat_static_real = self._astype(feat_static_real, dtype=np.float32)
58
- self.feat_dynamic_cat = self._astype(feat_dynamic_cat, dtype=np.int64)
59
- self.feat_dynamic_real = self._astype(feat_dynamic_real, dtype=np.float32)
60
- self.past_feat_dynamic_cat = self._astype(past_feat_dynamic_cat, dtype=np.int64)
61
- self.past_feat_dynamic_real = self._astype(past_feat_dynamic_real, dtype=np.float32)
62
- self.freq = self._get_freq_for_period(freq)
63
-
64
- # Necessary to compute indptr for known_covariates at prediction time
65
- self.includes_future = includes_future
66
- self.prediction_length = prediction_length
67
-
68
- # Replace inefficient groupby ITEMID with indptr that stores start:end of each time series
69
- item_id_index = target_df.index.get_level_values(ITEMID)
70
- indices_sizes = item_id_index.value_counts(sort=False)
71
- self.item_ids = indices_sizes.index # shape [num_items]
72
- cum_sizes = indices_sizes.to_numpy().cumsum()
73
- self.indptr = np.append(0, cum_sizes).astype(np.int32)
74
- self.start_timestamps = target_df.reset_index(TIMESTAMP).groupby(level=ITEMID, sort=False).first()[TIMESTAMP]
75
- assert len(self.item_ids) == len(self.start_timestamps)
76
-
77
- @staticmethod
78
- def _astype(array: Optional[np.ndarray], dtype: np.dtype) -> Optional[np.ndarray]:
79
- if array is None:
80
- return None
81
- else:
82
- return array.astype(dtype)
83
-
84
- @staticmethod
85
- def _get_freq_for_period(freq: str) -> str:
86
- """Convert freq to format compatible with pd.Period.
87
-
88
- For example, ME freq must be converted to M when creating a pd.Period.
89
- """
90
- offset = pd.tseries.frequencies.to_offset(freq)
91
- freq_name = norm_freq_str(offset)
92
- if freq_name == "SME":
93
- # Replace unsupported frequency "SME" with "2W"
94
- return "2W"
95
- elif freq_name == "bh":
96
- # Replace unsupported frequency "bh" with dummy value "Y"
97
- return "Y"
98
- else:
99
- freq_name_for_period = {"YE": "Y", "QE": "Q", "ME": "M"}.get(freq_name, freq_name)
100
- return f"{offset.n}{freq_name_for_period}"
101
-
102
- def __len__(self):
103
- return len(self.indptr) - 1 # noqa
104
-
105
- def __iter__(self) -> Iterator[Dict[str, Any]]:
106
- for j in range(len(self.indptr) - 1):
107
- start_idx = self.indptr[j]
108
- end_idx = self.indptr[j + 1]
109
- # GluonTS expects item_id to be a string
110
- ts = {
111
- FieldName.ITEM_ID: str(self.item_ids[j]),
112
- FieldName.START: pd.Period(self.start_timestamps.iloc[j], freq=self.freq),
113
- FieldName.TARGET: self.target_array[start_idx:end_idx],
114
- }
115
- if self.feat_static_cat is not None:
116
- ts[FieldName.FEAT_STATIC_CAT] = self.feat_static_cat[j]
117
- if self.feat_static_real is not None:
118
- ts[FieldName.FEAT_STATIC_REAL] = self.feat_static_real[j]
119
- if self.past_feat_dynamic_cat is not None:
120
- ts[FieldName.PAST_FEAT_DYNAMIC_CAT] = self.past_feat_dynamic_cat[start_idx:end_idx].T
121
- if self.past_feat_dynamic_real is not None:
122
- ts[FieldName.PAST_FEAT_DYNAMIC_REAL] = self.past_feat_dynamic_real[start_idx:end_idx].T
123
-
124
- # Dynamic features that may extend into the future
125
- if self.includes_future:
126
- start_idx = start_idx + j * self.prediction_length
127
- end_idx = end_idx + (j + 1) * self.prediction_length
128
- if self.feat_dynamic_cat is not None:
129
- ts[FieldName.FEAT_DYNAMIC_CAT] = self.feat_dynamic_cat[start_idx:end_idx].T
130
- if self.feat_dynamic_real is not None:
131
- ts[FieldName.FEAT_DYNAMIC_REAL] = self.feat_dynamic_real[start_idx:end_idx].T
132
- yield ts
133
-
134
-
135
40
  class AbstractGluonTSModel(AbstractTimeSeriesModel):
136
41
  """Abstract class wrapping GluonTS estimators for use in autogluon.timeseries.
137
42
 
138
43
  Parameters
139
44
  ----------
140
- path: str
45
+ path
141
46
  directory to store model artifacts.
142
- freq: str
47
+ freq
143
48
  string representation (compatible with GluonTS frequency strings) for the data provided.
144
49
  For example, "1D" for daily data, "1H" for hourly data, etc.
145
- prediction_length: int
50
+ prediction_length
146
51
  Number of time steps ahead (length of the forecast horizon) the model will be optimized
147
52
  to predict. At inference time, this will be the number of time steps the model will
148
53
  predict.
149
- name: str
54
+ name
150
55
  Name of the model. Also, name of subdirectory inside path where model will be saved.
151
- eval_metric: str
56
+ eval_metric
152
57
  objective function the model intends to optimize, will use WQL by default.
153
- hyperparameters:
58
+ hyperparameters
154
59
  various hyperparameters that will be used by model (can be search spaces instead of
155
60
  fixed values). See *Other Parameters* in each inheriting model's documentation for
156
61
  possible values.
@@ -167,12 +72,12 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
167
72
 
168
73
  def __init__(
169
74
  self,
170
- freq: Optional[str] = None,
75
+ freq: str | None = None,
171
76
  prediction_length: int = 1,
172
- path: Optional[str] = None,
173
- name: Optional[str] = None,
174
- eval_metric: str = None,
175
- hyperparameters: Dict[str, Any] = None,
77
+ path: str | None = None,
78
+ name: str | None = None,
79
+ eval_metric: str | None = None,
80
+ hyperparameters: dict[str, Any] | None = None,
176
81
  **kwargs, # noqa
177
82
  ):
178
83
  super().__init__(
@@ -184,9 +89,9 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
184
89
  hyperparameters=hyperparameters,
185
90
  **kwargs,
186
91
  )
187
- self.gts_predictor: Optional[GluonTSPredictor] = None
188
- self._ohe_generator_known: Optional[OneHotEncoder] = None
189
- self._ohe_generator_past: Optional[OneHotEncoder] = None
92
+ self.gts_predictor: GluonTSPredictor | None = None
93
+ self._ohe_generator_known: OneHotEncoder | None = None
94
+ self._ohe_generator_past: OneHotEncoder | None = None
190
95
  self.callbacks = []
191
96
  # Following attributes may be overridden during fit() based on train_data & model parameters
192
97
  self.num_feat_static_cat = 0
@@ -195,30 +100,32 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
195
100
  self.num_feat_dynamic_real = 0
196
101
  self.num_past_feat_dynamic_cat = 0
197
102
  self.num_past_feat_dynamic_real = 0
198
- self.feat_static_cat_cardinality: List[int] = []
199
- self.feat_dynamic_cat_cardinality: List[int] = []
200
- self.past_feat_dynamic_cat_cardinality: List[int] = []
103
+ self.feat_static_cat_cardinality: list[int] = []
104
+ self.feat_dynamic_cat_cardinality: list[int] = []
105
+ self.past_feat_dynamic_cat_cardinality: list[int] = []
201
106
  self.negative_data = True
202
107
 
203
- def save(self, path: str = None, verbose: bool = True) -> str:
108
+ def save(self, path: str | None = None, verbose: bool = True) -> str:
204
109
  # we flush callbacks instance variable if it has been set. it can keep weak references which breaks training
205
110
  self.callbacks = []
206
111
  # The GluonTS predictor is serialized using custom logic
207
112
  predictor = self.gts_predictor
208
113
  self.gts_predictor = None
209
- path = Path(super().save(path=path, verbose=verbose))
114
+ saved_path = Path(super().save(path=path, verbose=verbose))
210
115
 
211
116
  with disable_root_logger():
212
117
  if predictor:
213
- Path.mkdir(path / self.gluonts_model_path, exist_ok=True)
214
- predictor.serialize(path / self.gluonts_model_path)
118
+ Path.mkdir(saved_path / self.gluonts_model_path, exist_ok=True)
119
+ predictor.serialize(saved_path / self.gluonts_model_path)
215
120
 
216
121
  self.gts_predictor = predictor
217
122
 
218
- return str(path)
123
+ return str(saved_path)
219
124
 
220
125
  @classmethod
221
- def load(cls, path: str, reset_paths: bool = True, verbose: bool = True) -> "AbstractGluonTSModel":
126
+ def load(
127
+ cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
128
+ ) -> "AbstractGluonTSModel":
222
129
  from gluonts.torch.model.predictor import PyTorchPredictor
223
130
 
224
131
  with warning_filter():
@@ -235,31 +142,33 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
235
142
  def _get_hpo_backend(self):
236
143
  return RAY_BACKEND
237
144
 
238
- def _deferred_init_params_aux(self, dataset: TimeSeriesDataFrame) -> None:
239
- """Update GluonTS specific parameters with information available only at training time."""
240
- model_params = self._get_model_params()
145
+ def _deferred_init_hyperparameters(self, dataset: TimeSeriesDataFrame) -> None:
146
+ """Update GluonTS specific hyperparameters with information available only at training time."""
147
+ model_params = self.get_hyperparameters()
241
148
  disable_static_features = model_params.get("disable_static_features", False)
242
149
  if not disable_static_features:
243
- self.num_feat_static_cat = len(self.metadata.static_features_cat)
244
- self.num_feat_static_real = len(self.metadata.static_features_real)
150
+ self.num_feat_static_cat = len(self.covariate_metadata.static_features_cat)
151
+ self.num_feat_static_real = len(self.covariate_metadata.static_features_real)
245
152
  if self.num_feat_static_cat > 0:
246
- feat_static_cat = dataset.static_features[self.metadata.static_features_cat]
247
- self.feat_static_cat_cardinality = feat_static_cat.nunique().tolist()
153
+ assert dataset.static_features is not None, (
154
+ "Static features must be provided if num_feat_static_cat > 0"
155
+ )
156
+ self.feat_static_cat_cardinality = list(self.covariate_metadata.static_cat_cardinality.values())
248
157
 
249
158
  disable_known_covariates = model_params.get("disable_known_covariates", False)
250
159
  if not disable_known_covariates and self.supports_known_covariates:
251
- self.num_feat_dynamic_cat = len(self.metadata.known_covariates_cat)
252
- self.num_feat_dynamic_real = len(self.metadata.known_covariates_real)
160
+ self.num_feat_dynamic_cat = len(self.covariate_metadata.known_covariates_cat)
161
+ self.num_feat_dynamic_real = len(self.covariate_metadata.known_covariates_real)
253
162
  if self.num_feat_dynamic_cat > 0:
254
- feat_dynamic_cat = dataset[self.metadata.known_covariates_cat]
255
163
  if self.supports_cat_covariates:
256
- self.feat_dynamic_cat_cardinality = feat_dynamic_cat.nunique().tolist()
164
+ self.feat_dynamic_cat_cardinality = list(self.covariate_metadata.known_cat_cardinality.values())
257
165
  else:
166
+ feat_dynamic_cat = dataset[self.covariate_metadata.known_covariates_cat]
258
167
  # If model doesn't support categorical covariates, convert them to real via one hot encoding
259
168
  self._ohe_generator_known = OneHotEncoder(
260
169
  max_levels=model_params.get("max_cat_cardinality", 100),
261
170
  sparse=False,
262
- dtype="float32",
171
+ dtype="float32", # type: ignore
263
172
  )
264
173
  feat_dynamic_cat_ohe = self._ohe_generator_known.fit_transform(pd.DataFrame(feat_dynamic_cat))
265
174
  self.num_feat_dynamic_cat = 0
@@ -267,18 +176,20 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
267
176
 
268
177
  disable_past_covariates = model_params.get("disable_past_covariates", False)
269
178
  if not disable_past_covariates and self.supports_past_covariates:
270
- self.num_past_feat_dynamic_cat = len(self.metadata.past_covariates_cat)
271
- self.num_past_feat_dynamic_real = len(self.metadata.past_covariates_real)
179
+ self.num_past_feat_dynamic_cat = len(self.covariate_metadata.past_covariates_cat)
180
+ self.num_past_feat_dynamic_real = len(self.covariate_metadata.past_covariates_real)
272
181
  if self.num_past_feat_dynamic_cat > 0:
273
- past_feat_dynamic_cat = dataset[self.metadata.past_covariates_cat]
274
182
  if self.supports_cat_covariates:
275
- self.past_feat_dynamic_cat_cardinality = past_feat_dynamic_cat.nunique().tolist()
183
+ self.past_feat_dynamic_cat_cardinality = list(
184
+ self.covariate_metadata.past_cat_cardinality.values()
185
+ )
276
186
  else:
187
+ past_feat_dynamic_cat = dataset[self.covariate_metadata.past_covariates_cat]
277
188
  # If model doesn't support categorical covariates, convert them to real via one hot encoding
278
189
  self._ohe_generator_past = OneHotEncoder(
279
190
  max_levels=model_params.get("max_cat_cardinality", 100),
280
191
  sparse=False,
281
- dtype="float32",
192
+ dtype="float32", # type: ignore
282
193
  )
283
194
  past_feat_dynamic_cat_ohe = self._ohe_generator_past.fit_transform(
284
195
  pd.DataFrame(past_feat_dynamic_cat)
@@ -288,7 +199,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
288
199
 
289
200
  self.negative_data = (dataset[self.target] < 0).any()
290
201
 
291
- def _get_default_params(self):
202
+ def _get_default_hyperparameters(self):
292
203
  """Gets default parameters for GluonTS estimator initialization that are available after
293
204
  AbstractTimeSeriesModel initialization (i.e., before deferred initialization). Models may
294
205
  override this method to update default parameters.
@@ -306,7 +217,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
306
217
  "covariate_scaler": "global",
307
218
  }
308
219
 
309
- def _get_model_params(self) -> dict:
220
+ def get_hyperparameters(self) -> dict:
310
221
  """Gets params that are passed to the inner model."""
311
222
  # for backward compatibility with the old GluonTS MXNet API
312
223
  parameter_name_aliases = {
@@ -314,7 +225,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
314
225
  "learning_rate": "lr",
315
226
  }
316
227
 
317
- init_args = super()._get_model_params()
228
+ init_args = super().get_hyperparameters()
318
229
  for alias, actual in parameter_name_aliases.items():
319
230
  if alias in init_args:
320
231
  if actual in init_args:
@@ -322,12 +233,12 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
322
233
  else:
323
234
  init_args[actual] = init_args.pop(alias)
324
235
 
325
- return self._get_default_params() | init_args
236
+ return self._get_default_hyperparameters() | init_args
326
237
 
327
- def _get_estimator_init_args(self) -> Dict[str, Any]:
328
- """Get GluonTS specific constructor arguments for estimator objects, an alias to `self._get_model_params`
238
+ def _get_estimator_init_args(self) -> dict[str, Any]:
239
+ """Get GluonTS specific constructor arguments for estimator objects, an alias to `self.get_hyperparameters`
329
240
  for better readability."""
330
- return self._get_model_params()
241
+ return self.get_hyperparameters()
331
242
 
332
243
  def _get_estimator_class(self) -> Type[GluonTSEstimator]:
333
244
  raise NotImplementedError
@@ -367,25 +278,39 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
367
278
 
368
279
  return torch.cuda.is_available()
369
280
 
370
- def get_minimum_resources(self, is_gpu_available: bool = False) -> Dict[str, Union[int, float]]:
371
- minimum_resources = {"num_cpus": 1}
281
+ def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
282
+ minimum_resources: dict[str, int | float] = {"num_cpus": 1}
372
283
  # if GPU is available, we train with 1 GPU per trial
373
284
  if is_gpu_available:
374
285
  minimum_resources["num_gpus"] = 1
375
286
  return minimum_resources
376
287
 
288
+ @overload
289
+ def _to_gluonts_dataset(self, time_series_df: None, known_covariates=None) -> None: ...
290
+ @overload
291
+ def _to_gluonts_dataset(self, time_series_df: TimeSeriesDataFrame, known_covariates=None) -> GluonTSDataset: ...
377
292
  def _to_gluonts_dataset(
378
- self, time_series_df: Optional[TimeSeriesDataFrame], known_covariates: Optional[TimeSeriesDataFrame] = None
379
- ) -> Optional[GluonTSDataset]:
293
+ self, time_series_df: TimeSeriesDataFrame | None, known_covariates: TimeSeriesDataFrame | None = None
294
+ ) -> GluonTSDataset | None:
380
295
  if time_series_df is not None:
381
296
  # TODO: Preprocess real-valued features with StdScaler?
382
297
  if self.num_feat_static_cat > 0:
383
- feat_static_cat = time_series_df.static_features[self.metadata.static_features_cat].to_numpy()
298
+ assert time_series_df.static_features is not None, (
299
+ "Static features must be provided if num_feat_static_cat > 0"
300
+ )
301
+ feat_static_cat = time_series_df.static_features[
302
+ self.covariate_metadata.static_features_cat
303
+ ].to_numpy()
384
304
  else:
385
305
  feat_static_cat = None
386
306
 
387
307
  if self.num_feat_static_real > 0:
388
- feat_static_real = time_series_df.static_features[self.metadata.static_features_real].to_numpy()
308
+ assert time_series_df.static_features is not None, (
309
+ "Static features must be provided if num_feat_static_real > 0"
310
+ )
311
+ feat_static_real = time_series_df.static_features[
312
+ self.covariate_metadata.static_features_real
313
+ ].to_numpy()
389
314
  else:
390
315
  feat_static_real = None
391
316
 
@@ -393,31 +318,33 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
393
318
  # Convert TSDF -> DF to avoid overhead / input validation
394
319
  df = pd.DataFrame(time_series_df)
395
320
  if known_covariates is not None:
396
- known_covariates = pd.DataFrame(known_covariates)
321
+ known_covariates = pd.DataFrame(known_covariates) # type: ignore
397
322
  if self.num_feat_dynamic_cat > 0:
398
- feat_dynamic_cat = df[self.metadata.known_covariates_cat].to_numpy()
323
+ feat_dynamic_cat = df[self.covariate_metadata.known_covariates_cat].to_numpy()
399
324
  if known_covariates is not None:
400
325
  feat_dynamic_cat = np.concatenate(
401
- [feat_dynamic_cat, known_covariates[self.metadata.known_covariates_cat].to_numpy()]
326
+ [feat_dynamic_cat, known_covariates[self.covariate_metadata.known_covariates_cat].to_numpy()]
402
327
  )
403
328
  assert len(feat_dynamic_cat) == expected_known_covariates_len
404
329
  else:
405
330
  feat_dynamic_cat = None
406
331
 
407
332
  if self.num_feat_dynamic_real > 0:
408
- feat_dynamic_real = df[self.metadata.known_covariates_real].to_numpy()
333
+ feat_dynamic_real = df[self.covariate_metadata.known_covariates_real].to_numpy()
409
334
  # Append future values of known covariates
410
335
  if known_covariates is not None:
411
336
  feat_dynamic_real = np.concatenate(
412
- [feat_dynamic_real, known_covariates[self.metadata.known_covariates_real].to_numpy()]
337
+ [feat_dynamic_real, known_covariates[self.covariate_metadata.known_covariates_real].to_numpy()]
413
338
  )
414
339
  assert len(feat_dynamic_real) == expected_known_covariates_len
415
340
  # Categorical covariates are one-hot-encoded as real
416
341
  if self._ohe_generator_known is not None:
417
- feat_dynamic_cat_ohe = self._ohe_generator_known.transform(df[self.metadata.known_covariates_cat])
342
+ feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform(
343
+ df[self.covariate_metadata.known_covariates_cat]
344
+ ) # type: ignore
418
345
  if known_covariates is not None:
419
- future_dynamic_cat_ohe = self._ohe_generator_known.transform(
420
- known_covariates[self.metadata.known_covariates_cat]
346
+ future_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform( # type: ignore
347
+ known_covariates[self.covariate_metadata.known_covariates_cat]
421
348
  )
422
349
  feat_dynamic_cat_ohe = np.concatenate([feat_dynamic_cat_ohe, future_dynamic_cat_ohe])
423
350
  assert len(feat_dynamic_cat_ohe) == expected_known_covariates_len
@@ -426,15 +353,15 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
426
353
  feat_dynamic_real = None
427
354
 
428
355
  if self.num_past_feat_dynamic_cat > 0:
429
- past_feat_dynamic_cat = df[self.metadata.past_covariates_cat].to_numpy()
356
+ past_feat_dynamic_cat = df[self.covariate_metadata.past_covariates_cat].to_numpy()
430
357
  else:
431
358
  past_feat_dynamic_cat = None
432
359
 
433
360
  if self.num_past_feat_dynamic_real > 0:
434
- past_feat_dynamic_real = df[self.metadata.past_covariates_real].to_numpy()
361
+ past_feat_dynamic_real = df[self.covariate_metadata.past_covariates_real].to_numpy()
435
362
  if self._ohe_generator_past is not None:
436
- past_feat_dynamic_cat_ohe = self._ohe_generator_past.transform(
437
- df[self.metadata.past_covariates_cat]
363
+ past_feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_past.transform( # type: ignore
364
+ df[self.covariate_metadata.past_covariates_cat]
438
365
  )
439
366
  past_feat_dynamic_real = np.concatenate(
440
367
  [past_feat_dynamic_real, past_feat_dynamic_cat_ohe], axis=1
@@ -442,8 +369,9 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
442
369
  else:
443
370
  past_feat_dynamic_real = None
444
371
 
372
+ assert self.freq is not None
445
373
  return SimpleGluonTSDataset(
446
- target_df=time_series_df[[self.target]],
374
+ target_df=time_series_df[[self.target]], # type: ignore
447
375
  freq=self.freq,
448
376
  target_column=self.target,
449
377
  feat_static_cat=feat_static_cat,
@@ -461,14 +389,16 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
461
389
  def _fit(
462
390
  self,
463
391
  train_data: TimeSeriesDataFrame,
464
- val_data: Optional[TimeSeriesDataFrame] = None,
465
- time_limit: int = None,
392
+ val_data: TimeSeriesDataFrame | None = None,
393
+ time_limit: float | None = None,
394
+ num_cpus: int | None = None,
395
+ num_gpus: int | None = None,
396
+ verbosity: int = 2,
466
397
  **kwargs,
467
398
  ) -> None:
468
399
  # necessary to initialize the loggers
469
400
  import lightning.pytorch # noqa
470
401
 
471
- verbosity = kwargs.get("verbosity", 2)
472
402
  for logger_name in logging.root.manager.loggerDict:
473
403
  if "lightning" in logger_name:
474
404
  pl_logger = logging.getLogger(logger_name)
@@ -489,18 +419,18 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
489
419
  time_limit=time_limit,
490
420
  early_stopping_patience=None if val_data is None else init_args["early_stopping_patience"],
491
421
  )
492
- self._deferred_init_params_aux(train_data)
422
+ self._deferred_init_hyperparameters(train_data)
493
423
 
494
424
  estimator = self._get_estimator()
495
- with warning_filter(), disable_root_logger(), gluonts.core.settings.let(gluonts.env.env, use_tqdm=False):
425
+ with warning_filter(), disable_root_logger(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
496
426
  self.gts_predictor = estimator.train(
497
427
  self._to_gluonts_dataset(train_data),
498
428
  validation_data=self._to_gluonts_dataset(val_data),
499
- cache_data=True,
429
+ cache_data=True, # type: ignore
500
430
  )
501
431
  # Increase batch size during prediction to speed up inference
502
432
  if init_args["predict_batch_size"] is not None:
503
- self.gts_predictor.batch_size = init_args["predict_batch_size"]
433
+ self.gts_predictor.batch_size = init_args["predict_batch_size"] # type: ignore
504
434
 
505
435
  lightning_logs_dir = Path(self.path) / "lightning_logs"
506
436
  if not keep_lightning_logs and lightning_logs_dir.exists() and lightning_logs_dir.is_dir():
@@ -509,9 +439,9 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
509
439
 
510
440
  def _get_callbacks(
511
441
  self,
512
- time_limit: int,
513
- early_stopping_patience: Optional[int] = None,
514
- ) -> List[Callable]:
442
+ time_limit: float | None,
443
+ early_stopping_patience: int | None = None,
444
+ ) -> list[Callable]:
515
445
  """Retrieve a list of callback objects for the GluonTS trainer"""
516
446
  from lightning.pytorch.callbacks import EarlyStopping, Timer
517
447
 
@@ -525,14 +455,14 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
525
455
  def _predict(
526
456
  self,
527
457
  data: TimeSeriesDataFrame,
528
- known_covariates: Optional[TimeSeriesDataFrame] = None,
458
+ known_covariates: TimeSeriesDataFrame | None = None,
529
459
  **kwargs,
530
460
  ) -> TimeSeriesDataFrame:
531
461
  if self.gts_predictor is None:
532
462
  raise ValueError("Please fit the model before predicting.")
533
463
 
534
- with warning_filter(), gluonts.core.settings.let(gluonts.env.env, use_tqdm=False):
535
- predicted_targets = self._predict_gluonts_forecasts(data, known_covariates=known_covariates, **kwargs)
464
+ with warning_filter(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
465
+ predicted_targets = self._predict_gluonts_forecasts(data, known_covariates=known_covariates)
536
466
  df = self._gluonts_forecasts_to_data_frame(
537
467
  predicted_targets,
538
468
  forecast_index=self.get_forecast_horizon_index(data),
@@ -540,16 +470,21 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
540
470
  return df
541
471
 
542
472
  def _predict_gluonts_forecasts(
543
- self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame] = None, **kwargs
544
- ) -> List[Forecast]:
473
+ self,
474
+ data: TimeSeriesDataFrame,
475
+ known_covariates: TimeSeriesDataFrame | None = None,
476
+ num_samples: int | None = None,
477
+ ) -> list[Forecast]:
478
+ assert self.gts_predictor is not None, "GluonTS models must be fit before predicting."
545
479
  gts_data = self._to_gluonts_dataset(data, known_covariates=known_covariates)
480
+ return list(
481
+ self.gts_predictor.predict(
482
+ dataset=gts_data,
483
+ num_samples=num_samples or self.default_num_samples,
484
+ )
485
+ )
546
486
 
547
- predictor_kwargs = dict(dataset=gts_data)
548
- predictor_kwargs["num_samples"] = kwargs.get("num_samples", self.default_num_samples)
549
-
550
- return list(self.gts_predictor.predict(**predictor_kwargs))
551
-
552
- def _stack_quantile_forecasts(self, forecasts: List[QuantileForecast], item_ids: pd.Index) -> pd.DataFrame:
487
+ def _stack_quantile_forecasts(self, forecasts: list[QuantileForecast], item_ids: pd.Index) -> pd.DataFrame:
553
488
  # GluonTS always saves item_id as a string
554
489
  item_id_to_forecast = {str(f.item_id): f for f in forecasts}
555
490
  result_dfs = []
@@ -562,7 +497,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
562
497
  columns_order = ["mean"] + [str(q) for q in self.quantile_levels]
563
498
  return forecast_df[columns_order]
564
499
 
565
- def _stack_sample_forecasts(self, forecasts: List[SampleForecast], item_ids: pd.Index) -> pd.DataFrame:
500
+ def _stack_sample_forecasts(self, forecasts: list[SampleForecast], item_ids: pd.Index) -> pd.DataFrame:
566
501
  item_id_to_forecast = {str(f.item_id): f for f in forecasts}
567
502
  samples_per_item = []
568
503
  for item_id in item_ids:
@@ -574,17 +509,25 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
574
509
  forecast_array = np.concatenate([mean, quantiles], axis=1)
575
510
  return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
576
511
 
577
- def _stack_distribution_forecasts(self, forecasts: List[Forecast], item_ids: pd.Index) -> pd.DataFrame:
512
+ def _stack_distribution_forecasts(
513
+ self, forecasts: list["DistributionForecast"], item_ids: pd.Index
514
+ ) -> pd.DataFrame:
578
515
  import torch
579
516
  from gluonts.torch.distributions import AffineTransformed
580
517
  from torch.distributions import Distribution
581
518
 
582
519
  # Sort forecasts in the same order as in the dataset
583
520
  item_id_to_forecast = {str(f.item_id): f for f in forecasts}
584
- forecasts = [item_id_to_forecast[str(item_id)] for item_id in item_ids]
521
+ dist_forecasts = [item_id_to_forecast[str(item_id)] for item_id in item_ids]
522
+
523
+ assert all(isinstance(f.distribution, AffineTransformed) for f in dist_forecasts), (
524
+ "Expected forecast.distribution to be an instance of AffineTransformed"
525
+ )
585
526
 
586
- def stack_distributions(distributions: List[Distribution]) -> Distribution:
527
+ def stack_distributions(distributions: list[Distribution]) -> Distribution:
587
528
  """Stack multiple torch.Distribution objects into a single distribution"""
529
+ last_dist: Distribution = distributions[-1]
530
+
588
531
  params_per_dist = []
589
532
  for dist in distributions:
590
533
  params = {name: getattr(dist, name) for name in dist.arg_constraints.keys()}
@@ -593,22 +536,19 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
593
536
  assert len(set(tuple(p.keys()) for p in params_per_dist)) == 1
594
537
 
595
538
  stacked_params = {}
596
- for key in dist.arg_constraints.keys():
539
+ for key in last_dist.arg_constraints.keys():
597
540
  stacked_params[key] = torch.cat([p[key] for p in params_per_dist])
598
- return dist.__class__(**stacked_params)
599
-
600
- if not isinstance(forecasts[0].distribution, AffineTransformed):
601
- raise AssertionError("Expected forecast.distribution to be an instance of AffineTransformed")
541
+ return last_dist.__class__(**stacked_params)
602
542
 
603
543
  # We stack all forecast distribution into a single Distribution object.
604
544
  # This dramatically speeds up the quantiles calculation.
605
- stacked_base_dist = stack_distributions([f.distribution.base_dist for f in forecasts])
545
+ stacked_base_dist = stack_distributions([f.distribution.base_dist for f in dist_forecasts]) # type: ignore
606
546
 
607
- stacked_loc = torch.cat([f.distribution.loc for f in forecasts])
547
+ stacked_loc = torch.cat([f.distribution.loc for f in dist_forecasts]) # type: ignore
608
548
  if stacked_loc.shape != stacked_base_dist.batch_shape:
609
549
  stacked_loc = stacked_loc.repeat_interleave(self.prediction_length)
610
550
 
611
- stacked_scale = torch.cat([f.distribution.scale for f in forecasts])
551
+ stacked_scale = torch.cat([f.distribution.scale for f in dist_forecasts]) # type: ignore
612
552
  if stacked_scale.shape != stacked_base_dist.batch_shape:
613
553
  stacked_scale = stacked_scale.repeat_interleave(self.prediction_length)
614
554
 
@@ -616,24 +556,24 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
616
556
 
617
557
  mean_prediction = stacked_dist.mean.cpu().detach().numpy()
618
558
  quantiles = torch.tensor(self.quantile_levels, device=stacked_dist.mean.device).reshape(-1, 1)
619
- quantile_predictions = stacked_dist.icdf(quantiles).cpu().detach().numpy()
559
+ quantile_predictions = stacked_dist.icdf(quantiles).cpu().detach().numpy() # type: ignore
620
560
  forecast_array = np.vstack([mean_prediction, quantile_predictions]).T
621
561
  return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
622
562
 
623
563
  def _gluonts_forecasts_to_data_frame(
624
564
  self,
625
- forecasts: List[Forecast],
565
+ forecasts: list[Forecast],
626
566
  forecast_index: pd.MultiIndex,
627
567
  ) -> TimeSeriesDataFrame:
628
568
  from gluonts.torch.model.forecast import DistributionForecast
629
569
 
630
- item_ids = forecast_index.unique(level=ITEMID)
570
+ item_ids = forecast_index.unique(level=TimeSeriesDataFrame.ITEMID)
631
571
  if isinstance(forecasts[0], SampleForecast):
632
- forecast_df = self._stack_sample_forecasts(forecasts, item_ids)
572
+ forecast_df = self._stack_sample_forecasts(cast(list[SampleForecast], forecasts), item_ids)
633
573
  elif isinstance(forecasts[0], QuantileForecast):
634
- forecast_df = self._stack_quantile_forecasts(forecasts, item_ids)
574
+ forecast_df = self._stack_quantile_forecasts(cast(list[QuantileForecast], forecasts), item_ids)
635
575
  elif isinstance(forecasts[0], DistributionForecast):
636
- forecast_df = self._stack_distribution_forecasts(forecasts, item_ids)
576
+ forecast_df = self._stack_distribution_forecasts(cast(list[DistributionForecast], forecasts), item_ids)
637
577
  else:
638
578
  raise ValueError(f"Unrecognized forecast type {type(forecasts[0])}")
639
579