autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +84 -0
- autogluon/timeseries/dataset/ts_dataframe.py +339 -186
- autogluon/timeseries/learner.py +192 -60
- autogluon/timeseries/metrics/__init__.py +55 -11
- autogluon/timeseries/metrics/abstract.py +96 -25
- autogluon/timeseries/metrics/point.py +186 -39
- autogluon/timeseries/metrics/quantile.py +47 -20
- autogluon/timeseries/metrics/utils.py +6 -6
- autogluon/timeseries/models/__init__.py +13 -7
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
- autogluon/timeseries/models/abstract/model_trial.py +10 -10
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
- autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
- autogluon/timeseries/models/chronos/__init__.py +4 -0
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +738 -0
- autogluon/timeseries/models/chronos/utils.py +369 -0
- autogluon/timeseries/models/ensemble/__init__.py +35 -2
- autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
- autogluon/timeseries/models/gluonts/__init__.py +3 -1
- autogluon/timeseries/models/gluonts/abstract.py +583 -0
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
- autogluon/timeseries/models/local/__init__.py +1 -10
- autogluon/timeseries/models/local/abstract_local_model.py +150 -97
- autogluon/timeseries/models/local/naive.py +31 -23
- autogluon/timeseries/models/local/npts.py +6 -2
- autogluon/timeseries/models/local/statsforecast.py +99 -112
- autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +826 -305
- autogluon/timeseries/regressor.py +253 -0
- autogluon/timeseries/splitter.py +10 -31
- autogluon/timeseries/trainer/__init__.py +2 -3
- autogluon/timeseries/trainer/ensemble_composer.py +439 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/trainer/trainer.py +1298 -0
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -0
- autogluon/timeseries/transforms/covariate_scaler.py +164 -0
- autogluon/timeseries/transforms/target_scaler.py +149 -0
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/base.py +38 -20
- autogluon/timeseries/utils/datetime/lags.py +18 -16
- autogluon/timeseries/utils/datetime/seasonality.py +14 -14
- autogluon/timeseries/utils/datetime/time_features.py +17 -14
- autogluon/timeseries/utils/features.py +317 -53
- autogluon/timeseries/utils/forecast.py +31 -17
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +44 -6
- autogluon/timeseries/version.py +2 -1
- autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
- autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -11
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -325
- autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
- autogluon/timeseries/trainer/auto_trainer.py +0 -74
- autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
- autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,583 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
from datetime import timedelta
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Type, cast, overload
|
|
7
|
+
|
|
8
|
+
import gluonts
|
|
9
|
+
import gluonts.core.settings
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from gluonts.core.component import from_hyperparameters
|
|
13
|
+
from gluonts.dataset.common import Dataset as GluonTSDataset
|
|
14
|
+
from gluonts.env import env as gluonts_env
|
|
15
|
+
from gluonts.model.estimator import Estimator as GluonTSEstimator
|
|
16
|
+
from gluonts.model.forecast import Forecast, QuantileForecast, SampleForecast
|
|
17
|
+
from gluonts.model.predictor import Predictor as GluonTSPredictor
|
|
18
|
+
|
|
19
|
+
from autogluon.common.loaders import load_pkl
|
|
20
|
+
from autogluon.core.hpo.constants import RAY_BACKEND
|
|
21
|
+
from autogluon.tabular.models.tabular_nn.utils.categorical_encoders import (
|
|
22
|
+
OneHotMergeRaresHandleUnknownEncoder as OneHotEncoder,
|
|
23
|
+
)
|
|
24
|
+
from autogluon.timeseries.dataset import TimeSeriesDataFrame
|
|
25
|
+
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
26
|
+
from autogluon.timeseries.utils.warning_filters import disable_root_logger, warning_filter
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from gluonts.torch.model.forecast import DistributionForecast
|
|
30
|
+
|
|
31
|
+
from .dataset import SimpleGluonTSDataset
|
|
32
|
+
|
|
33
|
+
# NOTE: We avoid imports for torch and lightning.pytorch at the top level and hide them inside class methods.
|
|
34
|
+
# This is done to skip these imports during multiprocessing (which may cause bugs)
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
gts_logger = logging.getLogger(gluonts.__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
41
|
+
"""Abstract class wrapping GluonTS estimators for use in autogluon.timeseries.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
path
|
|
46
|
+
directory to store model artifacts.
|
|
47
|
+
freq
|
|
48
|
+
string representation (compatible with GluonTS frequency strings) for the data provided.
|
|
49
|
+
For example, "1D" for daily data, "1H" for hourly data, etc.
|
|
50
|
+
prediction_length
|
|
51
|
+
Number of time steps ahead (length of the forecast horizon) the model will be optimized
|
|
52
|
+
to predict. At inference time, this will be the number of time steps the model will
|
|
53
|
+
predict.
|
|
54
|
+
name
|
|
55
|
+
Name of the model. Also, name of subdirectory inside path where model will be saved.
|
|
56
|
+
eval_metric
|
|
57
|
+
objective function the model intends to optimize, will use WQL by default.
|
|
58
|
+
hyperparameters
|
|
59
|
+
various hyperparameters that will be used by model (can be search spaces instead of
|
|
60
|
+
fixed values). See *Other Parameters* in each inheriting model's documentation for
|
|
61
|
+
possible values.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
gluonts_model_path = "gluon_ts"
|
|
65
|
+
# we pass dummy freq compatible with pandas 2.1 & 2.2 to GluonTS models
|
|
66
|
+
_dummy_gluonts_freq = "D"
|
|
67
|
+
# default number of samples for prediction
|
|
68
|
+
default_num_samples: int = 250
|
|
69
|
+
|
|
70
|
+
#: whether the GluonTS model supports categorical variables as covariates
|
|
71
|
+
_supports_cat_covariates: bool = False
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
freq: str | None = None,
|
|
76
|
+
prediction_length: int = 1,
|
|
77
|
+
path: str | None = None,
|
|
78
|
+
name: str | None = None,
|
|
79
|
+
eval_metric: str | None = None,
|
|
80
|
+
hyperparameters: dict[str, Any] | None = None,
|
|
81
|
+
**kwargs, # noqa
|
|
82
|
+
):
|
|
83
|
+
super().__init__(
|
|
84
|
+
path=path,
|
|
85
|
+
freq=freq,
|
|
86
|
+
prediction_length=prediction_length,
|
|
87
|
+
name=name,
|
|
88
|
+
eval_metric=eval_metric,
|
|
89
|
+
hyperparameters=hyperparameters,
|
|
90
|
+
**kwargs,
|
|
91
|
+
)
|
|
92
|
+
self.gts_predictor: GluonTSPredictor | None = None
|
|
93
|
+
self._ohe_generator_known: OneHotEncoder | None = None
|
|
94
|
+
self._ohe_generator_past: OneHotEncoder | None = None
|
|
95
|
+
self.callbacks = []
|
|
96
|
+
# Following attributes may be overridden during fit() based on train_data & model parameters
|
|
97
|
+
self.num_feat_static_cat = 0
|
|
98
|
+
self.num_feat_static_real = 0
|
|
99
|
+
self.num_feat_dynamic_cat = 0
|
|
100
|
+
self.num_feat_dynamic_real = 0
|
|
101
|
+
self.num_past_feat_dynamic_cat = 0
|
|
102
|
+
self.num_past_feat_dynamic_real = 0
|
|
103
|
+
self.feat_static_cat_cardinality: list[int] = []
|
|
104
|
+
self.feat_dynamic_cat_cardinality: list[int] = []
|
|
105
|
+
self.past_feat_dynamic_cat_cardinality: list[int] = []
|
|
106
|
+
self.negative_data = True
|
|
107
|
+
|
|
108
|
+
def save(self, path: str | None = None, verbose: bool = True) -> str:
|
|
109
|
+
# we flush callbacks instance variable if it has been set. it can keep weak references which breaks training
|
|
110
|
+
self.callbacks = []
|
|
111
|
+
# The GluonTS predictor is serialized using custom logic
|
|
112
|
+
predictor = self.gts_predictor
|
|
113
|
+
self.gts_predictor = None
|
|
114
|
+
saved_path = Path(super().save(path=path, verbose=verbose))
|
|
115
|
+
|
|
116
|
+
with disable_root_logger():
|
|
117
|
+
if predictor:
|
|
118
|
+
Path.mkdir(saved_path / self.gluonts_model_path, exist_ok=True)
|
|
119
|
+
predictor.serialize(saved_path / self.gluonts_model_path)
|
|
120
|
+
|
|
121
|
+
self.gts_predictor = predictor
|
|
122
|
+
|
|
123
|
+
return str(saved_path)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def load(
|
|
127
|
+
cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
|
|
128
|
+
) -> "AbstractGluonTSModel":
|
|
129
|
+
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
130
|
+
|
|
131
|
+
with warning_filter():
|
|
132
|
+
model = load_pkl.load(path=os.path.join(path, cls.model_file_name), verbose=verbose)
|
|
133
|
+
if reset_paths:
|
|
134
|
+
model.set_contexts(path)
|
|
135
|
+
model.gts_predictor = PyTorchPredictor.deserialize(Path(path) / cls.gluonts_model_path, device="auto")
|
|
136
|
+
return model
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def supports_cat_covariates(self) -> bool:
|
|
140
|
+
return self.__class__._supports_cat_covariates
|
|
141
|
+
|
|
142
|
+
def _get_hpo_backend(self):
|
|
143
|
+
return RAY_BACKEND
|
|
144
|
+
|
|
145
|
+
def _deferred_init_hyperparameters(self, dataset: TimeSeriesDataFrame) -> None:
|
|
146
|
+
"""Update GluonTS specific hyperparameters with information available only at training time."""
|
|
147
|
+
model_params = self.get_hyperparameters()
|
|
148
|
+
disable_static_features = model_params.get("disable_static_features", False)
|
|
149
|
+
if not disable_static_features:
|
|
150
|
+
self.num_feat_static_cat = len(self.covariate_metadata.static_features_cat)
|
|
151
|
+
self.num_feat_static_real = len(self.covariate_metadata.static_features_real)
|
|
152
|
+
if self.num_feat_static_cat > 0:
|
|
153
|
+
assert dataset.static_features is not None, (
|
|
154
|
+
"Static features must be provided if num_feat_static_cat > 0"
|
|
155
|
+
)
|
|
156
|
+
feat_static_cat = dataset.static_features[self.covariate_metadata.static_features_cat]
|
|
157
|
+
self.feat_static_cat_cardinality = feat_static_cat.nunique().tolist()
|
|
158
|
+
|
|
159
|
+
disable_known_covariates = model_params.get("disable_known_covariates", False)
|
|
160
|
+
if not disable_known_covariates and self.supports_known_covariates:
|
|
161
|
+
self.num_feat_dynamic_cat = len(self.covariate_metadata.known_covariates_cat)
|
|
162
|
+
self.num_feat_dynamic_real = len(self.covariate_metadata.known_covariates_real)
|
|
163
|
+
if self.num_feat_dynamic_cat > 0:
|
|
164
|
+
feat_dynamic_cat = dataset[self.covariate_metadata.known_covariates_cat]
|
|
165
|
+
if self.supports_cat_covariates:
|
|
166
|
+
self.feat_dynamic_cat_cardinality = feat_dynamic_cat.nunique().tolist()
|
|
167
|
+
else:
|
|
168
|
+
# If model doesn't support categorical covariates, convert them to real via one hot encoding
|
|
169
|
+
self._ohe_generator_known = OneHotEncoder(
|
|
170
|
+
max_levels=model_params.get("max_cat_cardinality", 100),
|
|
171
|
+
sparse=False,
|
|
172
|
+
dtype="float32", # type: ignore
|
|
173
|
+
)
|
|
174
|
+
feat_dynamic_cat_ohe = self._ohe_generator_known.fit_transform(pd.DataFrame(feat_dynamic_cat))
|
|
175
|
+
self.num_feat_dynamic_cat = 0
|
|
176
|
+
self.num_feat_dynamic_real += feat_dynamic_cat_ohe.shape[1]
|
|
177
|
+
|
|
178
|
+
disable_past_covariates = model_params.get("disable_past_covariates", False)
|
|
179
|
+
if not disable_past_covariates and self.supports_past_covariates:
|
|
180
|
+
self.num_past_feat_dynamic_cat = len(self.covariate_metadata.past_covariates_cat)
|
|
181
|
+
self.num_past_feat_dynamic_real = len(self.covariate_metadata.past_covariates_real)
|
|
182
|
+
if self.num_past_feat_dynamic_cat > 0:
|
|
183
|
+
past_feat_dynamic_cat = dataset[self.covariate_metadata.past_covariates_cat]
|
|
184
|
+
if self.supports_cat_covariates:
|
|
185
|
+
self.past_feat_dynamic_cat_cardinality = past_feat_dynamic_cat.nunique().tolist()
|
|
186
|
+
else:
|
|
187
|
+
# If model doesn't support categorical covariates, convert them to real via one hot encoding
|
|
188
|
+
self._ohe_generator_past = OneHotEncoder(
|
|
189
|
+
max_levels=model_params.get("max_cat_cardinality", 100),
|
|
190
|
+
sparse=False,
|
|
191
|
+
dtype="float32", # type: ignore
|
|
192
|
+
)
|
|
193
|
+
past_feat_dynamic_cat_ohe = self._ohe_generator_past.fit_transform(
|
|
194
|
+
pd.DataFrame(past_feat_dynamic_cat)
|
|
195
|
+
)
|
|
196
|
+
self.num_past_feat_dynamic_cat = 0
|
|
197
|
+
self.num_past_feat_dynamic_real += past_feat_dynamic_cat_ohe.shape[1]
|
|
198
|
+
|
|
199
|
+
self.negative_data = (dataset[self.target] < 0).any()
|
|
200
|
+
|
|
201
|
+
def _get_default_hyperparameters(self):
|
|
202
|
+
"""Gets default parameters for GluonTS estimator initialization that are available after
|
|
203
|
+
AbstractTimeSeriesModel initialization (i.e., before deferred initialization). Models may
|
|
204
|
+
override this method to update default parameters.
|
|
205
|
+
"""
|
|
206
|
+
return {
|
|
207
|
+
"batch_size": 64,
|
|
208
|
+
"context_length": min(512, max(10, 2 * self.prediction_length)),
|
|
209
|
+
"predict_batch_size": 500,
|
|
210
|
+
"early_stopping_patience": 20,
|
|
211
|
+
"max_epochs": 100,
|
|
212
|
+
"lr": 1e-3,
|
|
213
|
+
"freq": self._dummy_gluonts_freq,
|
|
214
|
+
"prediction_length": self.prediction_length,
|
|
215
|
+
"quantiles": self.quantile_levels,
|
|
216
|
+
"covariate_scaler": "global",
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
def get_hyperparameters(self) -> dict:
|
|
220
|
+
"""Gets params that are passed to the inner model."""
|
|
221
|
+
# for backward compatibility with the old GluonTS MXNet API
|
|
222
|
+
parameter_name_aliases = {
|
|
223
|
+
"epochs": "max_epochs",
|
|
224
|
+
"learning_rate": "lr",
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
init_args = super().get_hyperparameters()
|
|
228
|
+
for alias, actual in parameter_name_aliases.items():
|
|
229
|
+
if alias in init_args:
|
|
230
|
+
if actual in init_args:
|
|
231
|
+
raise ValueError(f"Parameter '{alias}' cannot be specified when '{actual}' is also specified.")
|
|
232
|
+
else:
|
|
233
|
+
init_args[actual] = init_args.pop(alias)
|
|
234
|
+
|
|
235
|
+
return self._get_default_hyperparameters() | init_args
|
|
236
|
+
|
|
237
|
+
def _get_estimator_init_args(self) -> dict[str, Any]:
|
|
238
|
+
"""Get GluonTS specific constructor arguments for estimator objects, an alias to `self.get_hyperparameters`
|
|
239
|
+
for better readability."""
|
|
240
|
+
return self.get_hyperparameters()
|
|
241
|
+
|
|
242
|
+
def _get_estimator_class(self) -> Type[GluonTSEstimator]:
|
|
243
|
+
raise NotImplementedError
|
|
244
|
+
|
|
245
|
+
def _get_estimator(self) -> GluonTSEstimator:
|
|
246
|
+
"""Return the GluonTS Estimator object for the model"""
|
|
247
|
+
# As GluonTSPyTorchLightningEstimator objects do not implement `from_hyperparameters` convenience
|
|
248
|
+
# constructors, we re-implement the logic here.
|
|
249
|
+
# we translate the "epochs" parameter to "max_epochs" for consistency in the AbstractGluonTSModel interface
|
|
250
|
+
init_args = self._get_estimator_init_args()
|
|
251
|
+
|
|
252
|
+
default_trainer_kwargs = {
|
|
253
|
+
"limit_val_batches": 3,
|
|
254
|
+
"max_epochs": init_args["max_epochs"],
|
|
255
|
+
"callbacks": self.callbacks,
|
|
256
|
+
"enable_progress_bar": False,
|
|
257
|
+
"default_root_dir": self.path,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
if self._is_gpu_available():
|
|
261
|
+
default_trainer_kwargs["accelerator"] = "gpu"
|
|
262
|
+
default_trainer_kwargs["devices"] = 1
|
|
263
|
+
else:
|
|
264
|
+
default_trainer_kwargs["accelerator"] = "cpu"
|
|
265
|
+
|
|
266
|
+
default_trainer_kwargs.update(init_args.pop("trainer_kwargs", {}))
|
|
267
|
+
logger.debug(f"\tTraining on device '{default_trainer_kwargs['accelerator']}'")
|
|
268
|
+
|
|
269
|
+
return from_hyperparameters(
|
|
270
|
+
self._get_estimator_class(),
|
|
271
|
+
trainer_kwargs=default_trainer_kwargs,
|
|
272
|
+
**init_args,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def _is_gpu_available(self) -> bool:
|
|
276
|
+
import torch.cuda
|
|
277
|
+
|
|
278
|
+
return torch.cuda.is_available()
|
|
279
|
+
|
|
280
|
+
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
|
|
281
|
+
minimum_resources: dict[str, int | float] = {"num_cpus": 1}
|
|
282
|
+
# if GPU is available, we train with 1 GPU per trial
|
|
283
|
+
if is_gpu_available:
|
|
284
|
+
minimum_resources["num_gpus"] = 1
|
|
285
|
+
return minimum_resources
|
|
286
|
+
|
|
287
|
+
@overload
|
|
288
|
+
def _to_gluonts_dataset(self, time_series_df: None, known_covariates=None) -> None: ...
|
|
289
|
+
@overload
|
|
290
|
+
def _to_gluonts_dataset(self, time_series_df: TimeSeriesDataFrame, known_covariates=None) -> GluonTSDataset: ...
|
|
291
|
+
def _to_gluonts_dataset(
|
|
292
|
+
self, time_series_df: TimeSeriesDataFrame | None, known_covariates: TimeSeriesDataFrame | None = None
|
|
293
|
+
) -> GluonTSDataset | None:
|
|
294
|
+
if time_series_df is not None:
|
|
295
|
+
# TODO: Preprocess real-valued features with StdScaler?
|
|
296
|
+
if self.num_feat_static_cat > 0:
|
|
297
|
+
assert time_series_df.static_features is not None, (
|
|
298
|
+
"Static features must be provided if num_feat_static_cat > 0"
|
|
299
|
+
)
|
|
300
|
+
feat_static_cat = time_series_df.static_features[
|
|
301
|
+
self.covariate_metadata.static_features_cat
|
|
302
|
+
].to_numpy()
|
|
303
|
+
else:
|
|
304
|
+
feat_static_cat = None
|
|
305
|
+
|
|
306
|
+
if self.num_feat_static_real > 0:
|
|
307
|
+
assert time_series_df.static_features is not None, (
|
|
308
|
+
"Static features must be provided if num_feat_static_real > 0"
|
|
309
|
+
)
|
|
310
|
+
feat_static_real = time_series_df.static_features[
|
|
311
|
+
self.covariate_metadata.static_features_real
|
|
312
|
+
].to_numpy()
|
|
313
|
+
else:
|
|
314
|
+
feat_static_real = None
|
|
315
|
+
|
|
316
|
+
expected_known_covariates_len = len(time_series_df) + self.prediction_length * time_series_df.num_items
|
|
317
|
+
# Convert TSDF -> DF to avoid overhead / input validation
|
|
318
|
+
df = pd.DataFrame(time_series_df)
|
|
319
|
+
if known_covariates is not None:
|
|
320
|
+
known_covariates = pd.DataFrame(known_covariates) # type: ignore
|
|
321
|
+
if self.num_feat_dynamic_cat > 0:
|
|
322
|
+
feat_dynamic_cat = df[self.covariate_metadata.known_covariates_cat].to_numpy()
|
|
323
|
+
if known_covariates is not None:
|
|
324
|
+
feat_dynamic_cat = np.concatenate(
|
|
325
|
+
[feat_dynamic_cat, known_covariates[self.covariate_metadata.known_covariates_cat].to_numpy()]
|
|
326
|
+
)
|
|
327
|
+
assert len(feat_dynamic_cat) == expected_known_covariates_len
|
|
328
|
+
else:
|
|
329
|
+
feat_dynamic_cat = None
|
|
330
|
+
|
|
331
|
+
if self.num_feat_dynamic_real > 0:
|
|
332
|
+
feat_dynamic_real = df[self.covariate_metadata.known_covariates_real].to_numpy()
|
|
333
|
+
# Append future values of known covariates
|
|
334
|
+
if known_covariates is not None:
|
|
335
|
+
feat_dynamic_real = np.concatenate(
|
|
336
|
+
[feat_dynamic_real, known_covariates[self.covariate_metadata.known_covariates_real].to_numpy()]
|
|
337
|
+
)
|
|
338
|
+
assert len(feat_dynamic_real) == expected_known_covariates_len
|
|
339
|
+
# Categorical covariates are one-hot-encoded as real
|
|
340
|
+
if self._ohe_generator_known is not None:
|
|
341
|
+
feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform(
|
|
342
|
+
df[self.covariate_metadata.known_covariates_cat]
|
|
343
|
+
) # type: ignore
|
|
344
|
+
if known_covariates is not None:
|
|
345
|
+
future_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform( # type: ignore
|
|
346
|
+
known_covariates[self.covariate_metadata.known_covariates_cat]
|
|
347
|
+
)
|
|
348
|
+
feat_dynamic_cat_ohe = np.concatenate([feat_dynamic_cat_ohe, future_dynamic_cat_ohe])
|
|
349
|
+
assert len(feat_dynamic_cat_ohe) == expected_known_covariates_len
|
|
350
|
+
feat_dynamic_real = np.concatenate([feat_dynamic_real, feat_dynamic_cat_ohe], axis=1)
|
|
351
|
+
else:
|
|
352
|
+
feat_dynamic_real = None
|
|
353
|
+
|
|
354
|
+
if self.num_past_feat_dynamic_cat > 0:
|
|
355
|
+
past_feat_dynamic_cat = df[self.covariate_metadata.past_covariates_cat].to_numpy()
|
|
356
|
+
else:
|
|
357
|
+
past_feat_dynamic_cat = None
|
|
358
|
+
|
|
359
|
+
if self.num_past_feat_dynamic_real > 0:
|
|
360
|
+
past_feat_dynamic_real = df[self.covariate_metadata.past_covariates_real].to_numpy()
|
|
361
|
+
if self._ohe_generator_past is not None:
|
|
362
|
+
past_feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_past.transform( # type: ignore
|
|
363
|
+
df[self.covariate_metadata.past_covariates_cat]
|
|
364
|
+
)
|
|
365
|
+
past_feat_dynamic_real = np.concatenate(
|
|
366
|
+
[past_feat_dynamic_real, past_feat_dynamic_cat_ohe], axis=1
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
past_feat_dynamic_real = None
|
|
370
|
+
|
|
371
|
+
assert self.freq is not None
|
|
372
|
+
return SimpleGluonTSDataset(
|
|
373
|
+
target_df=time_series_df[[self.target]], # type: ignore
|
|
374
|
+
freq=self.freq,
|
|
375
|
+
target_column=self.target,
|
|
376
|
+
feat_static_cat=feat_static_cat,
|
|
377
|
+
feat_static_real=feat_static_real,
|
|
378
|
+
feat_dynamic_cat=feat_dynamic_cat,
|
|
379
|
+
feat_dynamic_real=feat_dynamic_real,
|
|
380
|
+
past_feat_dynamic_cat=past_feat_dynamic_cat,
|
|
381
|
+
past_feat_dynamic_real=past_feat_dynamic_real,
|
|
382
|
+
includes_future=known_covariates is not None,
|
|
383
|
+
prediction_length=self.prediction_length,
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
def _fit(
|
|
389
|
+
self,
|
|
390
|
+
train_data: TimeSeriesDataFrame,
|
|
391
|
+
val_data: TimeSeriesDataFrame | None = None,
|
|
392
|
+
time_limit: float | None = None,
|
|
393
|
+
num_cpus: int | None = None,
|
|
394
|
+
num_gpus: int | None = None,
|
|
395
|
+
verbosity: int = 2,
|
|
396
|
+
**kwargs,
|
|
397
|
+
) -> None:
|
|
398
|
+
# necessary to initialize the loggers
|
|
399
|
+
import lightning.pytorch # noqa
|
|
400
|
+
|
|
401
|
+
for logger_name in logging.root.manager.loggerDict:
|
|
402
|
+
if "lightning" in logger_name:
|
|
403
|
+
pl_logger = logging.getLogger(logger_name)
|
|
404
|
+
pl_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
|
|
405
|
+
gts_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
|
|
406
|
+
|
|
407
|
+
if verbosity > 3:
|
|
408
|
+
logger.warning(
|
|
409
|
+
"GluonTS logging is turned on during training. Note that losses reported by GluonTS "
|
|
410
|
+
"may not correspond to those specified via `eval_metric`."
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
self._check_fit_params()
|
|
414
|
+
# update auxiliary parameters
|
|
415
|
+
init_args = self._get_estimator_init_args()
|
|
416
|
+
keep_lightning_logs = init_args.pop("keep_lightning_logs", False)
|
|
417
|
+
self.callbacks = self._get_callbacks(
|
|
418
|
+
time_limit=time_limit,
|
|
419
|
+
early_stopping_patience=None if val_data is None else init_args["early_stopping_patience"],
|
|
420
|
+
)
|
|
421
|
+
self._deferred_init_hyperparameters(train_data)
|
|
422
|
+
|
|
423
|
+
estimator = self._get_estimator()
|
|
424
|
+
with warning_filter(), disable_root_logger(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
|
|
425
|
+
self.gts_predictor = estimator.train(
|
|
426
|
+
self._to_gluonts_dataset(train_data),
|
|
427
|
+
validation_data=self._to_gluonts_dataset(val_data),
|
|
428
|
+
cache_data=True, # type: ignore
|
|
429
|
+
)
|
|
430
|
+
# Increase batch size during prediction to speed up inference
|
|
431
|
+
if init_args["predict_batch_size"] is not None:
|
|
432
|
+
self.gts_predictor.batch_size = init_args["predict_batch_size"] # type: ignore
|
|
433
|
+
|
|
434
|
+
lightning_logs_dir = Path(self.path) / "lightning_logs"
|
|
435
|
+
if not keep_lightning_logs and lightning_logs_dir.exists() and lightning_logs_dir.is_dir():
|
|
436
|
+
logger.debug(f"Removing lightning_logs directory {lightning_logs_dir}")
|
|
437
|
+
shutil.rmtree(lightning_logs_dir)
|
|
438
|
+
|
|
439
|
+
def _get_callbacks(
|
|
440
|
+
self,
|
|
441
|
+
time_limit: float | None,
|
|
442
|
+
early_stopping_patience: int | None = None,
|
|
443
|
+
) -> list[Callable]:
|
|
444
|
+
"""Retrieve a list of callback objects for the GluonTS trainer"""
|
|
445
|
+
from lightning.pytorch.callbacks import EarlyStopping, Timer
|
|
446
|
+
|
|
447
|
+
callbacks = []
|
|
448
|
+
if time_limit is not None:
|
|
449
|
+
callbacks.append(Timer(timedelta(seconds=time_limit)))
|
|
450
|
+
if early_stopping_patience is not None:
|
|
451
|
+
callbacks.append(EarlyStopping(monitor="val_loss", patience=early_stopping_patience))
|
|
452
|
+
return callbacks
|
|
453
|
+
|
|
454
|
+
def _predict(
|
|
455
|
+
self,
|
|
456
|
+
data: TimeSeriesDataFrame,
|
|
457
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
458
|
+
**kwargs,
|
|
459
|
+
) -> TimeSeriesDataFrame:
|
|
460
|
+
if self.gts_predictor is None:
|
|
461
|
+
raise ValueError("Please fit the model before predicting.")
|
|
462
|
+
|
|
463
|
+
with warning_filter(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
|
|
464
|
+
predicted_targets = self._predict_gluonts_forecasts(data, known_covariates=known_covariates)
|
|
465
|
+
df = self._gluonts_forecasts_to_data_frame(
|
|
466
|
+
predicted_targets,
|
|
467
|
+
forecast_index=self.get_forecast_horizon_index(data),
|
|
468
|
+
)
|
|
469
|
+
return df
|
|
470
|
+
|
|
471
|
+
def _predict_gluonts_forecasts(
|
|
472
|
+
self,
|
|
473
|
+
data: TimeSeriesDataFrame,
|
|
474
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
475
|
+
num_samples: int | None = None,
|
|
476
|
+
) -> list[Forecast]:
|
|
477
|
+
assert self.gts_predictor is not None, "GluonTS models must be fit before predicting."
|
|
478
|
+
gts_data = self._to_gluonts_dataset(data, known_covariates=known_covariates)
|
|
479
|
+
return list(
|
|
480
|
+
self.gts_predictor.predict(
|
|
481
|
+
dataset=gts_data,
|
|
482
|
+
num_samples=num_samples or self.default_num_samples,
|
|
483
|
+
)
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
def _stack_quantile_forecasts(self, forecasts: list[QuantileForecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
487
|
+
# GluonTS always saves item_id as a string
|
|
488
|
+
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
489
|
+
result_dfs = []
|
|
490
|
+
for item_id in item_ids:
|
|
491
|
+
forecast = item_id_to_forecast[str(item_id)]
|
|
492
|
+
result_dfs.append(pd.DataFrame(forecast.forecast_array.T, columns=forecast.forecast_keys))
|
|
493
|
+
forecast_df = pd.concat(result_dfs)
|
|
494
|
+
if "mean" not in forecast_df.columns:
|
|
495
|
+
forecast_df["mean"] = forecast_df["0.5"]
|
|
496
|
+
columns_order = ["mean"] + [str(q) for q in self.quantile_levels]
|
|
497
|
+
return forecast_df[columns_order]
|
|
498
|
+
|
|
499
|
+
def _stack_sample_forecasts(self, forecasts: list[SampleForecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
500
|
+
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
501
|
+
samples_per_item = []
|
|
502
|
+
for item_id in item_ids:
|
|
503
|
+
forecast = item_id_to_forecast[str(item_id)]
|
|
504
|
+
samples_per_item.append(forecast.samples.T)
|
|
505
|
+
samples = np.concatenate(samples_per_item, axis=0)
|
|
506
|
+
quantiles = np.quantile(samples, self.quantile_levels, axis=1).T
|
|
507
|
+
mean = samples.mean(axis=1, keepdims=True)
|
|
508
|
+
forecast_array = np.concatenate([mean, quantiles], axis=1)
|
|
509
|
+
return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
|
|
510
|
+
|
|
511
|
+
def _stack_distribution_forecasts(
|
|
512
|
+
self, forecasts: list["DistributionForecast"], item_ids: pd.Index
|
|
513
|
+
) -> pd.DataFrame:
|
|
514
|
+
import torch
|
|
515
|
+
from gluonts.torch.distributions import AffineTransformed
|
|
516
|
+
from torch.distributions import Distribution
|
|
517
|
+
|
|
518
|
+
# Sort forecasts in the same order as in the dataset
|
|
519
|
+
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
520
|
+
dist_forecasts = [item_id_to_forecast[str(item_id)] for item_id in item_ids]
|
|
521
|
+
|
|
522
|
+
assert all(isinstance(f.distribution, AffineTransformed) for f in dist_forecasts), (
|
|
523
|
+
"Expected forecast.distribution to be an instance of AffineTransformed"
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
def stack_distributions(distributions: list[Distribution]) -> Distribution:
|
|
527
|
+
"""Stack multiple torch.Distribution objects into a single distribution"""
|
|
528
|
+
last_dist: Distribution = distributions[-1]
|
|
529
|
+
|
|
530
|
+
params_per_dist = []
|
|
531
|
+
for dist in distributions:
|
|
532
|
+
params = {name: getattr(dist, name) for name in dist.arg_constraints.keys()}
|
|
533
|
+
params_per_dist.append(params)
|
|
534
|
+
# Make sure that all distributions have same keys
|
|
535
|
+
assert len(set(tuple(p.keys()) for p in params_per_dist)) == 1
|
|
536
|
+
|
|
537
|
+
stacked_params = {}
|
|
538
|
+
for key in last_dist.arg_constraints.keys():
|
|
539
|
+
stacked_params[key] = torch.cat([p[key] for p in params_per_dist])
|
|
540
|
+
return last_dist.__class__(**stacked_params)
|
|
541
|
+
|
|
542
|
+
# We stack all forecast distribution into a single Distribution object.
|
|
543
|
+
# This dramatically speeds up the quantiles calculation.
|
|
544
|
+
stacked_base_dist = stack_distributions([f.distribution.base_dist for f in dist_forecasts]) # type: ignore
|
|
545
|
+
|
|
546
|
+
stacked_loc = torch.cat([f.distribution.loc for f in dist_forecasts]) # type: ignore
|
|
547
|
+
if stacked_loc.shape != stacked_base_dist.batch_shape:
|
|
548
|
+
stacked_loc = stacked_loc.repeat_interleave(self.prediction_length)
|
|
549
|
+
|
|
550
|
+
stacked_scale = torch.cat([f.distribution.scale for f in dist_forecasts]) # type: ignore
|
|
551
|
+
if stacked_scale.shape != stacked_base_dist.batch_shape:
|
|
552
|
+
stacked_scale = stacked_scale.repeat_interleave(self.prediction_length)
|
|
553
|
+
|
|
554
|
+
stacked_dist = AffineTransformed(stacked_base_dist, loc=stacked_loc, scale=stacked_scale)
|
|
555
|
+
|
|
556
|
+
mean_prediction = stacked_dist.mean.cpu().detach().numpy()
|
|
557
|
+
quantiles = torch.tensor(self.quantile_levels, device=stacked_dist.mean.device).reshape(-1, 1)
|
|
558
|
+
quantile_predictions = stacked_dist.icdf(quantiles).cpu().detach().numpy() # type: ignore
|
|
559
|
+
forecast_array = np.vstack([mean_prediction, quantile_predictions]).T
|
|
560
|
+
return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
|
|
561
|
+
|
|
562
|
+
def _gluonts_forecasts_to_data_frame(
|
|
563
|
+
self,
|
|
564
|
+
forecasts: list[Forecast],
|
|
565
|
+
forecast_index: pd.MultiIndex,
|
|
566
|
+
) -> TimeSeriesDataFrame:
|
|
567
|
+
from gluonts.torch.model.forecast import DistributionForecast
|
|
568
|
+
|
|
569
|
+
item_ids = forecast_index.unique(level=TimeSeriesDataFrame.ITEMID)
|
|
570
|
+
if isinstance(forecasts[0], SampleForecast):
|
|
571
|
+
forecast_df = self._stack_sample_forecasts(cast(list[SampleForecast], forecasts), item_ids)
|
|
572
|
+
elif isinstance(forecasts[0], QuantileForecast):
|
|
573
|
+
forecast_df = self._stack_quantile_forecasts(cast(list[QuantileForecast], forecasts), item_ids)
|
|
574
|
+
elif isinstance(forecasts[0], DistributionForecast):
|
|
575
|
+
forecast_df = self._stack_distribution_forecasts(cast(list[DistributionForecast], forecasts), item_ids)
|
|
576
|
+
else:
|
|
577
|
+
raise ValueError(f"Unrecognized forecast type {type(forecasts[0])}")
|
|
578
|
+
|
|
579
|
+
forecast_df.index = forecast_index
|
|
580
|
+
return TimeSeriesDataFrame(forecast_df)
|
|
581
|
+
|
|
582
|
+
def _more_tags(self) -> dict:
|
|
583
|
+
return {"allow_nan": True, "can_use_val_data": True}
|