autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +339 -186
  5. autogluon/timeseries/learner.py +192 -60
  6. autogluon/timeseries/metrics/__init__.py +55 -11
  7. autogluon/timeseries/metrics/abstract.py +96 -25
  8. autogluon/timeseries/metrics/point.py +186 -39
  9. autogluon/timeseries/metrics/quantile.py +47 -20
  10. autogluon/timeseries/metrics/utils.py +6 -6
  11. autogluon/timeseries/models/__init__.py +13 -7
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
  14. autogluon/timeseries/models/abstract/model_trial.py +10 -10
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
  20. autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
  21. autogluon/timeseries/models/chronos/__init__.py +4 -0
  22. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  23. autogluon/timeseries/models/chronos/model.py +738 -0
  24. autogluon/timeseries/models/chronos/utils.py +369 -0
  25. autogluon/timeseries/models/ensemble/__init__.py +35 -2
  26. autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
  27. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  28. autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
  29. autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  34. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  35. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  36. autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
  37. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  38. autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
  39. autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
  40. autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
  41. autogluon/timeseries/models/gluonts/__init__.py +3 -1
  42. autogluon/timeseries/models/gluonts/abstract.py +583 -0
  43. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  44. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
  45. autogluon/timeseries/models/local/__init__.py +1 -10
  46. autogluon/timeseries/models/local/abstract_local_model.py +150 -97
  47. autogluon/timeseries/models/local/naive.py +31 -23
  48. autogluon/timeseries/models/local/npts.py +6 -2
  49. autogluon/timeseries/models/local/statsforecast.py +99 -112
  50. autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
  51. autogluon/timeseries/models/registry.py +64 -0
  52. autogluon/timeseries/models/toto/__init__.py +3 -0
  53. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  62. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  63. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  64. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  65. autogluon/timeseries/models/toto/dataloader.py +108 -0
  66. autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
  67. autogluon/timeseries/models/toto/model.py +236 -0
  68. autogluon/timeseries/predictor.py +826 -305
  69. autogluon/timeseries/regressor.py +253 -0
  70. autogluon/timeseries/splitter.py +10 -31
  71. autogluon/timeseries/trainer/__init__.py +2 -3
  72. autogluon/timeseries/trainer/ensemble_composer.py +439 -0
  73. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  74. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  75. autogluon/timeseries/trainer/trainer.py +1298 -0
  76. autogluon/timeseries/trainer/utils.py +17 -0
  77. autogluon/timeseries/transforms/__init__.py +2 -0
  78. autogluon/timeseries/transforms/covariate_scaler.py +164 -0
  79. autogluon/timeseries/transforms/target_scaler.py +149 -0
  80. autogluon/timeseries/utils/constants.py +10 -0
  81. autogluon/timeseries/utils/datetime/base.py +38 -20
  82. autogluon/timeseries/utils/datetime/lags.py +18 -16
  83. autogluon/timeseries/utils/datetime/seasonality.py +14 -14
  84. autogluon/timeseries/utils/datetime/time_features.py +17 -14
  85. autogluon/timeseries/utils/features.py +317 -53
  86. autogluon/timeseries/utils/forecast.py +31 -17
  87. autogluon/timeseries/utils/timer.py +173 -0
  88. autogluon/timeseries/utils/warning_filters.py +44 -6
  89. autogluon/timeseries/version.py +2 -1
  90. autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
  91. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
  92. autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
  93. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
  94. autogluon/timeseries/configs/presets_configs.py +0 -11
  95. autogluon/timeseries/evaluator.py +0 -6
  96. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  97. autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
  98. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  99. autogluon/timeseries/models/presets.py +0 -325
  100. autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
  101. autogluon/timeseries/trainer/auto_trainer.py +0 -74
  102. autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
  103. autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
  104. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
@@ -0,0 +1,583 @@
1
+ import logging
2
+ import os
3
+ import shutil
4
+ from datetime import timedelta
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Callable, Type, cast, overload
7
+
8
+ import gluonts
9
+ import gluonts.core.settings
10
+ import numpy as np
11
+ import pandas as pd
12
+ from gluonts.core.component import from_hyperparameters
13
+ from gluonts.dataset.common import Dataset as GluonTSDataset
14
+ from gluonts.env import env as gluonts_env
15
+ from gluonts.model.estimator import Estimator as GluonTSEstimator
16
+ from gluonts.model.forecast import Forecast, QuantileForecast, SampleForecast
17
+ from gluonts.model.predictor import Predictor as GluonTSPredictor
18
+
19
+ from autogluon.common.loaders import load_pkl
20
+ from autogluon.core.hpo.constants import RAY_BACKEND
21
+ from autogluon.tabular.models.tabular_nn.utils.categorical_encoders import (
22
+ OneHotMergeRaresHandleUnknownEncoder as OneHotEncoder,
23
+ )
24
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
25
+ from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
26
+ from autogluon.timeseries.utils.warning_filters import disable_root_logger, warning_filter
27
+
28
+ if TYPE_CHECKING:
29
+ from gluonts.torch.model.forecast import DistributionForecast
30
+
31
+ from .dataset import SimpleGluonTSDataset
32
+
33
+ # NOTE: We avoid imports for torch and lightning.pytorch at the top level and hide them inside class methods.
34
+ # This is done to skip these imports during multiprocessing (which may cause bugs)
35
+
36
+ logger = logging.getLogger(__name__)
37
+ gts_logger = logging.getLogger(gluonts.__name__)
38
+
39
+
40
+ class AbstractGluonTSModel(AbstractTimeSeriesModel):
41
+ """Abstract class wrapping GluonTS estimators for use in autogluon.timeseries.
42
+
43
+ Parameters
44
+ ----------
45
+ path
46
+ directory to store model artifacts.
47
+ freq
48
+ string representation (compatible with GluonTS frequency strings) for the data provided.
49
+ For example, "1D" for daily data, "1H" for hourly data, etc.
50
+ prediction_length
51
+ Number of time steps ahead (length of the forecast horizon) the model will be optimized
52
+ to predict. At inference time, this will be the number of time steps the model will
53
+ predict.
54
+ name
55
+ Name of the model. Also, name of subdirectory inside path where model will be saved.
56
+ eval_metric
57
+ objective function the model intends to optimize, will use WQL by default.
58
+ hyperparameters
59
+ various hyperparameters that will be used by model (can be search spaces instead of
60
+ fixed values). See *Other Parameters* in each inheriting model's documentation for
61
+ possible values.
62
+ """
63
+
64
+ gluonts_model_path = "gluon_ts"
65
+ # we pass dummy freq compatible with pandas 2.1 & 2.2 to GluonTS models
66
+ _dummy_gluonts_freq = "D"
67
+ # default number of samples for prediction
68
+ default_num_samples: int = 250
69
+
70
+ #: whether the GluonTS model supports categorical variables as covariates
71
+ _supports_cat_covariates: bool = False
72
+
73
+ def __init__(
74
+ self,
75
+ freq: str | None = None,
76
+ prediction_length: int = 1,
77
+ path: str | None = None,
78
+ name: str | None = None,
79
+ eval_metric: str | None = None,
80
+ hyperparameters: dict[str, Any] | None = None,
81
+ **kwargs, # noqa
82
+ ):
83
+ super().__init__(
84
+ path=path,
85
+ freq=freq,
86
+ prediction_length=prediction_length,
87
+ name=name,
88
+ eval_metric=eval_metric,
89
+ hyperparameters=hyperparameters,
90
+ **kwargs,
91
+ )
92
+ self.gts_predictor: GluonTSPredictor | None = None
93
+ self._ohe_generator_known: OneHotEncoder | None = None
94
+ self._ohe_generator_past: OneHotEncoder | None = None
95
+ self.callbacks = []
96
+ # Following attributes may be overridden during fit() based on train_data & model parameters
97
+ self.num_feat_static_cat = 0
98
+ self.num_feat_static_real = 0
99
+ self.num_feat_dynamic_cat = 0
100
+ self.num_feat_dynamic_real = 0
101
+ self.num_past_feat_dynamic_cat = 0
102
+ self.num_past_feat_dynamic_real = 0
103
+ self.feat_static_cat_cardinality: list[int] = []
104
+ self.feat_dynamic_cat_cardinality: list[int] = []
105
+ self.past_feat_dynamic_cat_cardinality: list[int] = []
106
+ self.negative_data = True
107
+
108
+ def save(self, path: str | None = None, verbose: bool = True) -> str:
109
+ # we flush callbacks instance variable if it has been set. it can keep weak references which breaks training
110
+ self.callbacks = []
111
+ # The GluonTS predictor is serialized using custom logic
112
+ predictor = self.gts_predictor
113
+ self.gts_predictor = None
114
+ saved_path = Path(super().save(path=path, verbose=verbose))
115
+
116
+ with disable_root_logger():
117
+ if predictor:
118
+ Path.mkdir(saved_path / self.gluonts_model_path, exist_ok=True)
119
+ predictor.serialize(saved_path / self.gluonts_model_path)
120
+
121
+ self.gts_predictor = predictor
122
+
123
+ return str(saved_path)
124
+
125
+ @classmethod
126
+ def load(
127
+ cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
128
+ ) -> "AbstractGluonTSModel":
129
+ from gluonts.torch.model.predictor import PyTorchPredictor
130
+
131
+ with warning_filter():
132
+ model = load_pkl.load(path=os.path.join(path, cls.model_file_name), verbose=verbose)
133
+ if reset_paths:
134
+ model.set_contexts(path)
135
+ model.gts_predictor = PyTorchPredictor.deserialize(Path(path) / cls.gluonts_model_path, device="auto")
136
+ return model
137
+
138
+ @property
139
+ def supports_cat_covariates(self) -> bool:
140
+ return self.__class__._supports_cat_covariates
141
+
142
+ def _get_hpo_backend(self):
143
+ return RAY_BACKEND
144
+
145
+ def _deferred_init_hyperparameters(self, dataset: TimeSeriesDataFrame) -> None:
146
+ """Update GluonTS specific hyperparameters with information available only at training time."""
147
+ model_params = self.get_hyperparameters()
148
+ disable_static_features = model_params.get("disable_static_features", False)
149
+ if not disable_static_features:
150
+ self.num_feat_static_cat = len(self.covariate_metadata.static_features_cat)
151
+ self.num_feat_static_real = len(self.covariate_metadata.static_features_real)
152
+ if self.num_feat_static_cat > 0:
153
+ assert dataset.static_features is not None, (
154
+ "Static features must be provided if num_feat_static_cat > 0"
155
+ )
156
+ feat_static_cat = dataset.static_features[self.covariate_metadata.static_features_cat]
157
+ self.feat_static_cat_cardinality = feat_static_cat.nunique().tolist()
158
+
159
+ disable_known_covariates = model_params.get("disable_known_covariates", False)
160
+ if not disable_known_covariates and self.supports_known_covariates:
161
+ self.num_feat_dynamic_cat = len(self.covariate_metadata.known_covariates_cat)
162
+ self.num_feat_dynamic_real = len(self.covariate_metadata.known_covariates_real)
163
+ if self.num_feat_dynamic_cat > 0:
164
+ feat_dynamic_cat = dataset[self.covariate_metadata.known_covariates_cat]
165
+ if self.supports_cat_covariates:
166
+ self.feat_dynamic_cat_cardinality = feat_dynamic_cat.nunique().tolist()
167
+ else:
168
+ # If model doesn't support categorical covariates, convert them to real via one hot encoding
169
+ self._ohe_generator_known = OneHotEncoder(
170
+ max_levels=model_params.get("max_cat_cardinality", 100),
171
+ sparse=False,
172
+ dtype="float32", # type: ignore
173
+ )
174
+ feat_dynamic_cat_ohe = self._ohe_generator_known.fit_transform(pd.DataFrame(feat_dynamic_cat))
175
+ self.num_feat_dynamic_cat = 0
176
+ self.num_feat_dynamic_real += feat_dynamic_cat_ohe.shape[1]
177
+
178
+ disable_past_covariates = model_params.get("disable_past_covariates", False)
179
+ if not disable_past_covariates and self.supports_past_covariates:
180
+ self.num_past_feat_dynamic_cat = len(self.covariate_metadata.past_covariates_cat)
181
+ self.num_past_feat_dynamic_real = len(self.covariate_metadata.past_covariates_real)
182
+ if self.num_past_feat_dynamic_cat > 0:
183
+ past_feat_dynamic_cat = dataset[self.covariate_metadata.past_covariates_cat]
184
+ if self.supports_cat_covariates:
185
+ self.past_feat_dynamic_cat_cardinality = past_feat_dynamic_cat.nunique().tolist()
186
+ else:
187
+ # If model doesn't support categorical covariates, convert them to real via one hot encoding
188
+ self._ohe_generator_past = OneHotEncoder(
189
+ max_levels=model_params.get("max_cat_cardinality", 100),
190
+ sparse=False,
191
+ dtype="float32", # type: ignore
192
+ )
193
+ past_feat_dynamic_cat_ohe = self._ohe_generator_past.fit_transform(
194
+ pd.DataFrame(past_feat_dynamic_cat)
195
+ )
196
+ self.num_past_feat_dynamic_cat = 0
197
+ self.num_past_feat_dynamic_real += past_feat_dynamic_cat_ohe.shape[1]
198
+
199
+ self.negative_data = (dataset[self.target] < 0).any()
200
+
201
+ def _get_default_hyperparameters(self):
202
+ """Gets default parameters for GluonTS estimator initialization that are available after
203
+ AbstractTimeSeriesModel initialization (i.e., before deferred initialization). Models may
204
+ override this method to update default parameters.
205
+ """
206
+ return {
207
+ "batch_size": 64,
208
+ "context_length": min(512, max(10, 2 * self.prediction_length)),
209
+ "predict_batch_size": 500,
210
+ "early_stopping_patience": 20,
211
+ "max_epochs": 100,
212
+ "lr": 1e-3,
213
+ "freq": self._dummy_gluonts_freq,
214
+ "prediction_length": self.prediction_length,
215
+ "quantiles": self.quantile_levels,
216
+ "covariate_scaler": "global",
217
+ }
218
+
219
+ def get_hyperparameters(self) -> dict:
220
+ """Gets params that are passed to the inner model."""
221
+ # for backward compatibility with the old GluonTS MXNet API
222
+ parameter_name_aliases = {
223
+ "epochs": "max_epochs",
224
+ "learning_rate": "lr",
225
+ }
226
+
227
+ init_args = super().get_hyperparameters()
228
+ for alias, actual in parameter_name_aliases.items():
229
+ if alias in init_args:
230
+ if actual in init_args:
231
+ raise ValueError(f"Parameter '{alias}' cannot be specified when '{actual}' is also specified.")
232
+ else:
233
+ init_args[actual] = init_args.pop(alias)
234
+
235
+ return self._get_default_hyperparameters() | init_args
236
+
237
+ def _get_estimator_init_args(self) -> dict[str, Any]:
238
+ """Get GluonTS specific constructor arguments for estimator objects, an alias to `self.get_hyperparameters`
239
+ for better readability."""
240
+ return self.get_hyperparameters()
241
+
242
+ def _get_estimator_class(self) -> Type[GluonTSEstimator]:
243
+ raise NotImplementedError
244
+
245
+ def _get_estimator(self) -> GluonTSEstimator:
246
+ """Return the GluonTS Estimator object for the model"""
247
+ # As GluonTSPyTorchLightningEstimator objects do not implement `from_hyperparameters` convenience
248
+ # constructors, we re-implement the logic here.
249
+ # we translate the "epochs" parameter to "max_epochs" for consistency in the AbstractGluonTSModel interface
250
+ init_args = self._get_estimator_init_args()
251
+
252
+ default_trainer_kwargs = {
253
+ "limit_val_batches": 3,
254
+ "max_epochs": init_args["max_epochs"],
255
+ "callbacks": self.callbacks,
256
+ "enable_progress_bar": False,
257
+ "default_root_dir": self.path,
258
+ }
259
+
260
+ if self._is_gpu_available():
261
+ default_trainer_kwargs["accelerator"] = "gpu"
262
+ default_trainer_kwargs["devices"] = 1
263
+ else:
264
+ default_trainer_kwargs["accelerator"] = "cpu"
265
+
266
+ default_trainer_kwargs.update(init_args.pop("trainer_kwargs", {}))
267
+ logger.debug(f"\tTraining on device '{default_trainer_kwargs['accelerator']}'")
268
+
269
+ return from_hyperparameters(
270
+ self._get_estimator_class(),
271
+ trainer_kwargs=default_trainer_kwargs,
272
+ **init_args,
273
+ )
274
+
275
+ def _is_gpu_available(self) -> bool:
276
+ import torch.cuda
277
+
278
+ return torch.cuda.is_available()
279
+
280
+ def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
281
+ minimum_resources: dict[str, int | float] = {"num_cpus": 1}
282
+ # if GPU is available, we train with 1 GPU per trial
283
+ if is_gpu_available:
284
+ minimum_resources["num_gpus"] = 1
285
+ return minimum_resources
286
+
287
+ @overload
288
+ def _to_gluonts_dataset(self, time_series_df: None, known_covariates=None) -> None: ...
289
+ @overload
290
+ def _to_gluonts_dataset(self, time_series_df: TimeSeriesDataFrame, known_covariates=None) -> GluonTSDataset: ...
291
+ def _to_gluonts_dataset(
292
+ self, time_series_df: TimeSeriesDataFrame | None, known_covariates: TimeSeriesDataFrame | None = None
293
+ ) -> GluonTSDataset | None:
294
+ if time_series_df is not None:
295
+ # TODO: Preprocess real-valued features with StdScaler?
296
+ if self.num_feat_static_cat > 0:
297
+ assert time_series_df.static_features is not None, (
298
+ "Static features must be provided if num_feat_static_cat > 0"
299
+ )
300
+ feat_static_cat = time_series_df.static_features[
301
+ self.covariate_metadata.static_features_cat
302
+ ].to_numpy()
303
+ else:
304
+ feat_static_cat = None
305
+
306
+ if self.num_feat_static_real > 0:
307
+ assert time_series_df.static_features is not None, (
308
+ "Static features must be provided if num_feat_static_real > 0"
309
+ )
310
+ feat_static_real = time_series_df.static_features[
311
+ self.covariate_metadata.static_features_real
312
+ ].to_numpy()
313
+ else:
314
+ feat_static_real = None
315
+
316
+ expected_known_covariates_len = len(time_series_df) + self.prediction_length * time_series_df.num_items
317
+ # Convert TSDF -> DF to avoid overhead / input validation
318
+ df = pd.DataFrame(time_series_df)
319
+ if known_covariates is not None:
320
+ known_covariates = pd.DataFrame(known_covariates) # type: ignore
321
+ if self.num_feat_dynamic_cat > 0:
322
+ feat_dynamic_cat = df[self.covariate_metadata.known_covariates_cat].to_numpy()
323
+ if known_covariates is not None:
324
+ feat_dynamic_cat = np.concatenate(
325
+ [feat_dynamic_cat, known_covariates[self.covariate_metadata.known_covariates_cat].to_numpy()]
326
+ )
327
+ assert len(feat_dynamic_cat) == expected_known_covariates_len
328
+ else:
329
+ feat_dynamic_cat = None
330
+
331
+ if self.num_feat_dynamic_real > 0:
332
+ feat_dynamic_real = df[self.covariate_metadata.known_covariates_real].to_numpy()
333
+ # Append future values of known covariates
334
+ if known_covariates is not None:
335
+ feat_dynamic_real = np.concatenate(
336
+ [feat_dynamic_real, known_covariates[self.covariate_metadata.known_covariates_real].to_numpy()]
337
+ )
338
+ assert len(feat_dynamic_real) == expected_known_covariates_len
339
+ # Categorical covariates are one-hot-encoded as real
340
+ if self._ohe_generator_known is not None:
341
+ feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform(
342
+ df[self.covariate_metadata.known_covariates_cat]
343
+ ) # type: ignore
344
+ if known_covariates is not None:
345
+ future_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform( # type: ignore
346
+ known_covariates[self.covariate_metadata.known_covariates_cat]
347
+ )
348
+ feat_dynamic_cat_ohe = np.concatenate([feat_dynamic_cat_ohe, future_dynamic_cat_ohe])
349
+ assert len(feat_dynamic_cat_ohe) == expected_known_covariates_len
350
+ feat_dynamic_real = np.concatenate([feat_dynamic_real, feat_dynamic_cat_ohe], axis=1)
351
+ else:
352
+ feat_dynamic_real = None
353
+
354
+ if self.num_past_feat_dynamic_cat > 0:
355
+ past_feat_dynamic_cat = df[self.covariate_metadata.past_covariates_cat].to_numpy()
356
+ else:
357
+ past_feat_dynamic_cat = None
358
+
359
+ if self.num_past_feat_dynamic_real > 0:
360
+ past_feat_dynamic_real = df[self.covariate_metadata.past_covariates_real].to_numpy()
361
+ if self._ohe_generator_past is not None:
362
+ past_feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_past.transform( # type: ignore
363
+ df[self.covariate_metadata.past_covariates_cat]
364
+ )
365
+ past_feat_dynamic_real = np.concatenate(
366
+ [past_feat_dynamic_real, past_feat_dynamic_cat_ohe], axis=1
367
+ )
368
+ else:
369
+ past_feat_dynamic_real = None
370
+
371
+ assert self.freq is not None
372
+ return SimpleGluonTSDataset(
373
+ target_df=time_series_df[[self.target]], # type: ignore
374
+ freq=self.freq,
375
+ target_column=self.target,
376
+ feat_static_cat=feat_static_cat,
377
+ feat_static_real=feat_static_real,
378
+ feat_dynamic_cat=feat_dynamic_cat,
379
+ feat_dynamic_real=feat_dynamic_real,
380
+ past_feat_dynamic_cat=past_feat_dynamic_cat,
381
+ past_feat_dynamic_real=past_feat_dynamic_real,
382
+ includes_future=known_covariates is not None,
383
+ prediction_length=self.prediction_length,
384
+ )
385
+ else:
386
+ return None
387
+
388
+ def _fit(
389
+ self,
390
+ train_data: TimeSeriesDataFrame,
391
+ val_data: TimeSeriesDataFrame | None = None,
392
+ time_limit: float | None = None,
393
+ num_cpus: int | None = None,
394
+ num_gpus: int | None = None,
395
+ verbosity: int = 2,
396
+ **kwargs,
397
+ ) -> None:
398
+ # necessary to initialize the loggers
399
+ import lightning.pytorch # noqa
400
+
401
+ for logger_name in logging.root.manager.loggerDict:
402
+ if "lightning" in logger_name:
403
+ pl_logger = logging.getLogger(logger_name)
404
+ pl_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
405
+ gts_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
406
+
407
+ if verbosity > 3:
408
+ logger.warning(
409
+ "GluonTS logging is turned on during training. Note that losses reported by GluonTS "
410
+ "may not correspond to those specified via `eval_metric`."
411
+ )
412
+
413
+ self._check_fit_params()
414
+ # update auxiliary parameters
415
+ init_args = self._get_estimator_init_args()
416
+ keep_lightning_logs = init_args.pop("keep_lightning_logs", False)
417
+ self.callbacks = self._get_callbacks(
418
+ time_limit=time_limit,
419
+ early_stopping_patience=None if val_data is None else init_args["early_stopping_patience"],
420
+ )
421
+ self._deferred_init_hyperparameters(train_data)
422
+
423
+ estimator = self._get_estimator()
424
+ with warning_filter(), disable_root_logger(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
425
+ self.gts_predictor = estimator.train(
426
+ self._to_gluonts_dataset(train_data),
427
+ validation_data=self._to_gluonts_dataset(val_data),
428
+ cache_data=True, # type: ignore
429
+ )
430
+ # Increase batch size during prediction to speed up inference
431
+ if init_args["predict_batch_size"] is not None:
432
+ self.gts_predictor.batch_size = init_args["predict_batch_size"] # type: ignore
433
+
434
+ lightning_logs_dir = Path(self.path) / "lightning_logs"
435
+ if not keep_lightning_logs and lightning_logs_dir.exists() and lightning_logs_dir.is_dir():
436
+ logger.debug(f"Removing lightning_logs directory {lightning_logs_dir}")
437
+ shutil.rmtree(lightning_logs_dir)
438
+
439
+ def _get_callbacks(
440
+ self,
441
+ time_limit: float | None,
442
+ early_stopping_patience: int | None = None,
443
+ ) -> list[Callable]:
444
+ """Retrieve a list of callback objects for the GluonTS trainer"""
445
+ from lightning.pytorch.callbacks import EarlyStopping, Timer
446
+
447
+ callbacks = []
448
+ if time_limit is not None:
449
+ callbacks.append(Timer(timedelta(seconds=time_limit)))
450
+ if early_stopping_patience is not None:
451
+ callbacks.append(EarlyStopping(monitor="val_loss", patience=early_stopping_patience))
452
+ return callbacks
453
+
454
+ def _predict(
455
+ self,
456
+ data: TimeSeriesDataFrame,
457
+ known_covariates: TimeSeriesDataFrame | None = None,
458
+ **kwargs,
459
+ ) -> TimeSeriesDataFrame:
460
+ if self.gts_predictor is None:
461
+ raise ValueError("Please fit the model before predicting.")
462
+
463
+ with warning_filter(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
464
+ predicted_targets = self._predict_gluonts_forecasts(data, known_covariates=known_covariates)
465
+ df = self._gluonts_forecasts_to_data_frame(
466
+ predicted_targets,
467
+ forecast_index=self.get_forecast_horizon_index(data),
468
+ )
469
+ return df
470
+
471
+ def _predict_gluonts_forecasts(
472
+ self,
473
+ data: TimeSeriesDataFrame,
474
+ known_covariates: TimeSeriesDataFrame | None = None,
475
+ num_samples: int | None = None,
476
+ ) -> list[Forecast]:
477
+ assert self.gts_predictor is not None, "GluonTS models must be fit before predicting."
478
+ gts_data = self._to_gluonts_dataset(data, known_covariates=known_covariates)
479
+ return list(
480
+ self.gts_predictor.predict(
481
+ dataset=gts_data,
482
+ num_samples=num_samples or self.default_num_samples,
483
+ )
484
+ )
485
+
486
+ def _stack_quantile_forecasts(self, forecasts: list[QuantileForecast], item_ids: pd.Index) -> pd.DataFrame:
487
+ # GluonTS always saves item_id as a string
488
+ item_id_to_forecast = {str(f.item_id): f for f in forecasts}
489
+ result_dfs = []
490
+ for item_id in item_ids:
491
+ forecast = item_id_to_forecast[str(item_id)]
492
+ result_dfs.append(pd.DataFrame(forecast.forecast_array.T, columns=forecast.forecast_keys))
493
+ forecast_df = pd.concat(result_dfs)
494
+ if "mean" not in forecast_df.columns:
495
+ forecast_df["mean"] = forecast_df["0.5"]
496
+ columns_order = ["mean"] + [str(q) for q in self.quantile_levels]
497
+ return forecast_df[columns_order]
498
+
499
+ def _stack_sample_forecasts(self, forecasts: list[SampleForecast], item_ids: pd.Index) -> pd.DataFrame:
500
+ item_id_to_forecast = {str(f.item_id): f for f in forecasts}
501
+ samples_per_item = []
502
+ for item_id in item_ids:
503
+ forecast = item_id_to_forecast[str(item_id)]
504
+ samples_per_item.append(forecast.samples.T)
505
+ samples = np.concatenate(samples_per_item, axis=0)
506
+ quantiles = np.quantile(samples, self.quantile_levels, axis=1).T
507
+ mean = samples.mean(axis=1, keepdims=True)
508
+ forecast_array = np.concatenate([mean, quantiles], axis=1)
509
+ return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
510
+
511
+ def _stack_distribution_forecasts(
512
+ self, forecasts: list["DistributionForecast"], item_ids: pd.Index
513
+ ) -> pd.DataFrame:
514
+ import torch
515
+ from gluonts.torch.distributions import AffineTransformed
516
+ from torch.distributions import Distribution
517
+
518
+ # Sort forecasts in the same order as in the dataset
519
+ item_id_to_forecast = {str(f.item_id): f for f in forecasts}
520
+ dist_forecasts = [item_id_to_forecast[str(item_id)] for item_id in item_ids]
521
+
522
+ assert all(isinstance(f.distribution, AffineTransformed) for f in dist_forecasts), (
523
+ "Expected forecast.distribution to be an instance of AffineTransformed"
524
+ )
525
+
526
+ def stack_distributions(distributions: list[Distribution]) -> Distribution:
527
+ """Stack multiple torch.Distribution objects into a single distribution"""
528
+ last_dist: Distribution = distributions[-1]
529
+
530
+ params_per_dist = []
531
+ for dist in distributions:
532
+ params = {name: getattr(dist, name) for name in dist.arg_constraints.keys()}
533
+ params_per_dist.append(params)
534
+ # Make sure that all distributions have same keys
535
+ assert len(set(tuple(p.keys()) for p in params_per_dist)) == 1
536
+
537
+ stacked_params = {}
538
+ for key in last_dist.arg_constraints.keys():
539
+ stacked_params[key] = torch.cat([p[key] for p in params_per_dist])
540
+ return last_dist.__class__(**stacked_params)
541
+
542
+ # We stack all forecast distribution into a single Distribution object.
543
+ # This dramatically speeds up the quantiles calculation.
544
+ stacked_base_dist = stack_distributions([f.distribution.base_dist for f in dist_forecasts]) # type: ignore
545
+
546
+ stacked_loc = torch.cat([f.distribution.loc for f in dist_forecasts]) # type: ignore
547
+ if stacked_loc.shape != stacked_base_dist.batch_shape:
548
+ stacked_loc = stacked_loc.repeat_interleave(self.prediction_length)
549
+
550
+ stacked_scale = torch.cat([f.distribution.scale for f in dist_forecasts]) # type: ignore
551
+ if stacked_scale.shape != stacked_base_dist.batch_shape:
552
+ stacked_scale = stacked_scale.repeat_interleave(self.prediction_length)
553
+
554
+ stacked_dist = AffineTransformed(stacked_base_dist, loc=stacked_loc, scale=stacked_scale)
555
+
556
+ mean_prediction = stacked_dist.mean.cpu().detach().numpy()
557
+ quantiles = torch.tensor(self.quantile_levels, device=stacked_dist.mean.device).reshape(-1, 1)
558
+ quantile_predictions = stacked_dist.icdf(quantiles).cpu().detach().numpy() # type: ignore
559
+ forecast_array = np.vstack([mean_prediction, quantile_predictions]).T
560
+ return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
561
+
562
+ def _gluonts_forecasts_to_data_frame(
563
+ self,
564
+ forecasts: list[Forecast],
565
+ forecast_index: pd.MultiIndex,
566
+ ) -> TimeSeriesDataFrame:
567
+ from gluonts.torch.model.forecast import DistributionForecast
568
+
569
+ item_ids = forecast_index.unique(level=TimeSeriesDataFrame.ITEMID)
570
+ if isinstance(forecasts[0], SampleForecast):
571
+ forecast_df = self._stack_sample_forecasts(cast(list[SampleForecast], forecasts), item_ids)
572
+ elif isinstance(forecasts[0], QuantileForecast):
573
+ forecast_df = self._stack_quantile_forecasts(cast(list[QuantileForecast], forecasts), item_ids)
574
+ elif isinstance(forecasts[0], DistributionForecast):
575
+ forecast_df = self._stack_distribution_forecasts(cast(list[DistributionForecast], forecasts), item_ids)
576
+ else:
577
+ raise ValueError(f"Unrecognized forecast type {type(forecasts[0])}")
578
+
579
+ forecast_df.index = forecast_index
580
+ return TimeSeriesDataFrame(forecast_df)
581
+
582
+ def _more_tags(self) -> dict:
583
+ return {"allow_nan": True, "can_use_val_data": True}