autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +84 -0
- autogluon/timeseries/dataset/ts_dataframe.py +339 -186
- autogluon/timeseries/learner.py +192 -60
- autogluon/timeseries/metrics/__init__.py +55 -11
- autogluon/timeseries/metrics/abstract.py +96 -25
- autogluon/timeseries/metrics/point.py +186 -39
- autogluon/timeseries/metrics/quantile.py +47 -20
- autogluon/timeseries/metrics/utils.py +6 -6
- autogluon/timeseries/models/__init__.py +13 -7
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
- autogluon/timeseries/models/abstract/model_trial.py +10 -10
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
- autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
- autogluon/timeseries/models/chronos/__init__.py +4 -0
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +738 -0
- autogluon/timeseries/models/chronos/utils.py +369 -0
- autogluon/timeseries/models/ensemble/__init__.py +35 -2
- autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
- autogluon/timeseries/models/gluonts/__init__.py +3 -1
- autogluon/timeseries/models/gluonts/abstract.py +583 -0
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
- autogluon/timeseries/models/local/__init__.py +1 -10
- autogluon/timeseries/models/local/abstract_local_model.py +150 -97
- autogluon/timeseries/models/local/naive.py +31 -23
- autogluon/timeseries/models/local/npts.py +6 -2
- autogluon/timeseries/models/local/statsforecast.py +99 -112
- autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +826 -305
- autogluon/timeseries/regressor.py +253 -0
- autogluon/timeseries/splitter.py +10 -31
- autogluon/timeseries/trainer/__init__.py +2 -3
- autogluon/timeseries/trainer/ensemble_composer.py +439 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/trainer/trainer.py +1298 -0
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -0
- autogluon/timeseries/transforms/covariate_scaler.py +164 -0
- autogluon/timeseries/transforms/target_scaler.py +149 -0
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/base.py +38 -20
- autogluon/timeseries/utils/datetime/lags.py +18 -16
- autogluon/timeseries/utils/datetime/seasonality.py +14 -14
- autogluon/timeseries/utils/datetime/time_features.py +17 -14
- autogluon/timeseries/utils/features.py +317 -53
- autogluon/timeseries/utils/forecast.py +31 -17
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +44 -6
- autogluon/timeseries/version.py +2 -1
- autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
- autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -11
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -325
- autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
- autogluon/timeseries/trainer/auto_trainer.py +0 -74
- autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
- autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
|
@@ -1,550 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
|
-
import shutil
|
|
4
|
-
from datetime import timedelta
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union
|
|
7
|
-
|
|
8
|
-
import gluonts
|
|
9
|
-
import gluonts.core.settings
|
|
10
|
-
import numpy as np
|
|
11
|
-
import pandas as pd
|
|
12
|
-
from gluonts.core.component import from_hyperparameters
|
|
13
|
-
from gluonts.dataset.common import Dataset as GluonTSDataset
|
|
14
|
-
from gluonts.dataset.field_names import FieldName
|
|
15
|
-
from gluonts.model.estimator import Estimator as GluonTSEstimator
|
|
16
|
-
from gluonts.model.forecast import Forecast, QuantileForecast, SampleForecast
|
|
17
|
-
from gluonts.model.predictor import Predictor as GluonTSPredictor
|
|
18
|
-
from pandas.tseries.frequencies import to_offset
|
|
19
|
-
|
|
20
|
-
from autogluon.common.loaders import load_pkl
|
|
21
|
-
from autogluon.core.hpo.constants import RAY_BACKEND
|
|
22
|
-
from autogluon.timeseries.dataset.ts_dataframe import ITEMID, TIMESTAMP, TimeSeriesDataFrame
|
|
23
|
-
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
24
|
-
from autogluon.timeseries.utils.datetime import norm_freq_str
|
|
25
|
-
from autogluon.timeseries.utils.forecast import get_forecast_horizon_index_ts_dataframe
|
|
26
|
-
from autogluon.timeseries.utils.warning_filters import disable_root_logger, warning_filter
|
|
27
|
-
|
|
28
|
-
# NOTE: We avoid imports for torch and lightning.pytorch at the top level and hide them inside class methods.
|
|
29
|
-
# This is done to skip these imports during multiprocessing (which may cause bugs)
|
|
30
|
-
|
|
31
|
-
logger = logging.getLogger(__name__)
|
|
32
|
-
gts_logger = logging.getLogger(gluonts.__name__)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
GLUONTS_SUPPORTED_OFFSETS = ["Y", "Q", "M", "W", "D", "B", "H", "T", "min", "S"]
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class SimpleGluonTSDataset(GluonTSDataset):
|
|
39
|
-
"""Wrapper for TimeSeriesDataFrame that is compatible with the GluonTS Dataset API."""
|
|
40
|
-
|
|
41
|
-
def __init__(
|
|
42
|
-
self,
|
|
43
|
-
target_df: TimeSeriesDataFrame,
|
|
44
|
-
target_column: str = "target",
|
|
45
|
-
feat_static_cat: Optional[pd.DataFrame] = None,
|
|
46
|
-
feat_static_real: Optional[pd.DataFrame] = None,
|
|
47
|
-
feat_dynamic_real: Optional[pd.DataFrame] = None,
|
|
48
|
-
past_feat_dynamic_real: Optional[pd.DataFrame] = None,
|
|
49
|
-
includes_future: bool = False,
|
|
50
|
-
prediction_length: int = None,
|
|
51
|
-
):
|
|
52
|
-
assert target_df is not None
|
|
53
|
-
assert target_df.freq, "Initializing GluonTS data sets without freq is not allowed"
|
|
54
|
-
# Convert TimeSeriesDataFrame to pd.Series for faster processing
|
|
55
|
-
self.target_array = self._to_array(target_df[target_column], dtype=np.float32)
|
|
56
|
-
self.feat_static_cat = self._to_array(feat_static_cat, dtype=np.int64)
|
|
57
|
-
self.feat_static_real = self._to_array(feat_static_real, dtype=np.float32)
|
|
58
|
-
self.feat_dynamic_real = self._to_array(feat_dynamic_real, dtype=np.float32)
|
|
59
|
-
self.past_feat_dynamic_real = self._to_array(past_feat_dynamic_real, dtype=np.float32)
|
|
60
|
-
self.freq = self._to_gluonts_freq(target_df.freq)
|
|
61
|
-
|
|
62
|
-
# Necessary to compute indptr for known_covariates at prediction time
|
|
63
|
-
self.includes_future = includes_future
|
|
64
|
-
self.prediction_length = prediction_length
|
|
65
|
-
|
|
66
|
-
# Replace inefficient groupby ITEMID with indptr that stores start:end of each time series
|
|
67
|
-
item_id_index = target_df.index.get_level_values(ITEMID)
|
|
68
|
-
indices_sizes = item_id_index.value_counts(sort=False)
|
|
69
|
-
self.item_ids = indices_sizes.index # shape [num_items]
|
|
70
|
-
cum_sizes = indices_sizes.values.cumsum()
|
|
71
|
-
self.indptr = np.append(0, cum_sizes).astype(np.int32)
|
|
72
|
-
self.start_timestamps = target_df.reset_index(TIMESTAMP).groupby(level=ITEMID, sort=False).first()[TIMESTAMP]
|
|
73
|
-
assert len(self.item_ids) == len(self.start_timestamps)
|
|
74
|
-
|
|
75
|
-
@staticmethod
|
|
76
|
-
def _to_array(df: Optional[pd.DataFrame], dtype: np.dtype) -> Optional[np.ndarray]:
|
|
77
|
-
if df is None:
|
|
78
|
-
return None
|
|
79
|
-
else:
|
|
80
|
-
return df.to_numpy(dtype=dtype)
|
|
81
|
-
|
|
82
|
-
@staticmethod
|
|
83
|
-
def _to_gluonts_freq(freq: str) -> str:
|
|
84
|
-
# FIXME: GluonTS expects a frequency string, but only supports a limited number of such strings
|
|
85
|
-
# for feature generation. If the frequency string doesn't match or is not provided, it raises an exception.
|
|
86
|
-
# Here we bypass this by issuing a default "yearly" frequency, tricking it into not producing
|
|
87
|
-
# any lags or features.
|
|
88
|
-
pd_offset = to_offset(freq)
|
|
89
|
-
|
|
90
|
-
# normalize freq str to handle peculiarities such as W-SUN
|
|
91
|
-
offset_base_alias = norm_freq_str(pd_offset)
|
|
92
|
-
if offset_base_alias not in GLUONTS_SUPPORTED_OFFSETS:
|
|
93
|
-
return "A"
|
|
94
|
-
else:
|
|
95
|
-
return f"{pd_offset.n}{offset_base_alias}"
|
|
96
|
-
|
|
97
|
-
def __len__(self):
|
|
98
|
-
return len(self.indptr) - 1 # noqa
|
|
99
|
-
|
|
100
|
-
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
101
|
-
for j in range(len(self.indptr) - 1):
|
|
102
|
-
start_idx = self.indptr[j]
|
|
103
|
-
end_idx = self.indptr[j + 1]
|
|
104
|
-
# GluonTS expects item_id to be a string
|
|
105
|
-
ts = {
|
|
106
|
-
FieldName.ITEM_ID: str(self.item_ids[j]),
|
|
107
|
-
FieldName.START: pd.Period(self.start_timestamps.iloc[j], freq=self.freq),
|
|
108
|
-
FieldName.TARGET: self.target_array[start_idx:end_idx],
|
|
109
|
-
}
|
|
110
|
-
if self.feat_static_cat is not None:
|
|
111
|
-
ts[FieldName.FEAT_STATIC_CAT] = self.feat_static_cat[j]
|
|
112
|
-
if self.feat_static_real is not None:
|
|
113
|
-
ts[FieldName.FEAT_STATIC_REAL] = self.feat_static_real[j]
|
|
114
|
-
if self.past_feat_dynamic_real is not None:
|
|
115
|
-
ts[FieldName.PAST_FEAT_DYNAMIC_REAL] = self.past_feat_dynamic_real[start_idx:end_idx].T
|
|
116
|
-
if self.feat_dynamic_real is not None:
|
|
117
|
-
if self.includes_future:
|
|
118
|
-
start_idx = start_idx + j * self.prediction_length
|
|
119
|
-
end_idx = end_idx + (j + 1) * self.prediction_length
|
|
120
|
-
ts[FieldName.FEAT_DYNAMIC_REAL] = self.feat_dynamic_real[start_idx:end_idx].T
|
|
121
|
-
yield ts
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
125
|
-
"""Abstract class wrapping GluonTS estimators for use in autogluon.timeseries.
|
|
126
|
-
|
|
127
|
-
Parameters
|
|
128
|
-
----------
|
|
129
|
-
path: str
|
|
130
|
-
directory to store model artifacts.
|
|
131
|
-
freq: str
|
|
132
|
-
string representation (compatible with GluonTS frequency strings) for the data provided.
|
|
133
|
-
For example, "1D" for daily data, "1H" for hourly data, etc.
|
|
134
|
-
prediction_length: int
|
|
135
|
-
Number of time steps ahead (length of the forecast horizon) the model will be optimized
|
|
136
|
-
to predict. At inference time, this will be the number of time steps the model will
|
|
137
|
-
predict.
|
|
138
|
-
name: str
|
|
139
|
-
Name of the model. Also, name of subdirectory inside path where model will be saved.
|
|
140
|
-
eval_metric: str
|
|
141
|
-
objective function the model intends to optimize, will use WQL by default.
|
|
142
|
-
hyperparameters:
|
|
143
|
-
various hyperparameters that will be used by model (can be search spaces instead of
|
|
144
|
-
fixed values). See *Other Parameters* in each inheriting model's documentation for
|
|
145
|
-
possible values.
|
|
146
|
-
"""
|
|
147
|
-
|
|
148
|
-
gluonts_model_path = "gluon_ts"
|
|
149
|
-
# default number of samples for prediction
|
|
150
|
-
default_num_samples: int = 250
|
|
151
|
-
supports_known_covariates: bool = False
|
|
152
|
-
supports_past_covariates: bool = False
|
|
153
|
-
|
|
154
|
-
def __init__(
|
|
155
|
-
self,
|
|
156
|
-
freq: Optional[str] = None,
|
|
157
|
-
prediction_length: int = 1,
|
|
158
|
-
path: Optional[str] = None,
|
|
159
|
-
name: Optional[str] = None,
|
|
160
|
-
eval_metric: str = None,
|
|
161
|
-
hyperparameters: Dict[str, Any] = None,
|
|
162
|
-
**kwargs, # noqa
|
|
163
|
-
):
|
|
164
|
-
super().__init__(
|
|
165
|
-
path=path,
|
|
166
|
-
freq=freq,
|
|
167
|
-
prediction_length=prediction_length,
|
|
168
|
-
name=name,
|
|
169
|
-
eval_metric=eval_metric,
|
|
170
|
-
hyperparameters=hyperparameters,
|
|
171
|
-
**kwargs,
|
|
172
|
-
)
|
|
173
|
-
self.gts_predictor: Optional[GluonTSPredictor] = None
|
|
174
|
-
self.callbacks = []
|
|
175
|
-
self.num_feat_static_cat = 0
|
|
176
|
-
self.num_feat_static_real = 0
|
|
177
|
-
self.num_feat_dynamic_real = 0
|
|
178
|
-
self.num_past_feat_dynamic_real = 0
|
|
179
|
-
self.feat_static_cat_cardinality: List[int] = []
|
|
180
|
-
self.negative_data = True
|
|
181
|
-
|
|
182
|
-
def save(self, path: str = None, verbose: bool = True) -> str:
|
|
183
|
-
# we flush callbacks instance variable if it has been set. it can keep weak references which breaks training
|
|
184
|
-
self.callbacks = []
|
|
185
|
-
# The GluonTS predictor is serialized using custom logic
|
|
186
|
-
predictor = self.gts_predictor
|
|
187
|
-
self.gts_predictor = None
|
|
188
|
-
path = Path(super().save(path=path, verbose=verbose))
|
|
189
|
-
|
|
190
|
-
with disable_root_logger():
|
|
191
|
-
if predictor:
|
|
192
|
-
Path.mkdir(path / self.gluonts_model_path, exist_ok=True)
|
|
193
|
-
predictor.serialize(path / self.gluonts_model_path)
|
|
194
|
-
|
|
195
|
-
self.gts_predictor = predictor
|
|
196
|
-
|
|
197
|
-
return str(path)
|
|
198
|
-
|
|
199
|
-
@classmethod
|
|
200
|
-
def load(cls, path: str, reset_paths: bool = True, verbose: bool = True) -> "AbstractGluonTSModel":
|
|
201
|
-
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
202
|
-
|
|
203
|
-
with warning_filter():
|
|
204
|
-
model = load_pkl.load(path=os.path.join(path, cls.model_file_name), verbose=verbose)
|
|
205
|
-
if reset_paths:
|
|
206
|
-
model.set_contexts(path)
|
|
207
|
-
model.gts_predictor = PyTorchPredictor.deserialize(Path(path) / cls.gluonts_model_path)
|
|
208
|
-
return model
|
|
209
|
-
|
|
210
|
-
def _get_hpo_backend(self):
|
|
211
|
-
return RAY_BACKEND
|
|
212
|
-
|
|
213
|
-
def _deferred_init_params_aux(self, **kwargs) -> None:
|
|
214
|
-
"""Update GluonTS specific parameters with information available
|
|
215
|
-
only at training time.
|
|
216
|
-
"""
|
|
217
|
-
if "dataset" in kwargs:
|
|
218
|
-
ds = kwargs.get("dataset")
|
|
219
|
-
self.freq = ds.freq or self.freq
|
|
220
|
-
if not self.freq:
|
|
221
|
-
raise ValueError(
|
|
222
|
-
"Dataset frequency not provided in the dataset, fit arguments or "
|
|
223
|
-
"during initialization. Please provide a `freq` string to `fit`."
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
model_params = self._get_model_params()
|
|
227
|
-
disable_static_features = model_params.get("disable_static_features", False)
|
|
228
|
-
if not disable_static_features:
|
|
229
|
-
self.num_feat_static_cat = len(self.metadata.static_features_cat)
|
|
230
|
-
self.num_feat_static_real = len(self.metadata.static_features_real)
|
|
231
|
-
if self.num_feat_static_cat > 0:
|
|
232
|
-
feat_static_cat = ds.static_features[self.metadata.static_features_cat]
|
|
233
|
-
self.feat_static_cat_cardinality = feat_static_cat.nunique().tolist()
|
|
234
|
-
disable_known_covariates = model_params.get("disable_known_covariates", False)
|
|
235
|
-
if not disable_known_covariates and self.supports_known_covariates:
|
|
236
|
-
self.num_feat_dynamic_real = len(self.metadata.known_covariates_real)
|
|
237
|
-
disable_past_covariates = model_params.get("disable_past_covariates", False)
|
|
238
|
-
if not disable_past_covariates and self.supports_past_covariates:
|
|
239
|
-
self.num_past_feat_dynamic_real = len(self.metadata.past_covariates_real)
|
|
240
|
-
self.negative_data = (ds[self.target] < 0).any()
|
|
241
|
-
|
|
242
|
-
if "callbacks" in kwargs:
|
|
243
|
-
self.callbacks += kwargs["callbacks"]
|
|
244
|
-
|
|
245
|
-
@property
|
|
246
|
-
def default_context_length(self) -> int:
|
|
247
|
-
return min(512, max(10, 2 * self.prediction_length))
|
|
248
|
-
|
|
249
|
-
def _get_model_params(self) -> dict:
|
|
250
|
-
"""Gets params that are passed to the inner model."""
|
|
251
|
-
init_args = super()._get_model_params().copy()
|
|
252
|
-
init_args.setdefault("batch_size", 64)
|
|
253
|
-
init_args.setdefault("context_length", self.default_context_length)
|
|
254
|
-
init_args.setdefault("predict_batch_size", 500)
|
|
255
|
-
init_args.setdefault("early_stopping_patience", 20)
|
|
256
|
-
init_args.update(
|
|
257
|
-
dict(
|
|
258
|
-
freq=self.freq,
|
|
259
|
-
prediction_length=self.prediction_length,
|
|
260
|
-
quantiles=self.quantile_levels,
|
|
261
|
-
callbacks=self.callbacks,
|
|
262
|
-
)
|
|
263
|
-
)
|
|
264
|
-
# Support MXNet kwarg names for backwards compatibility
|
|
265
|
-
init_args.setdefault("lr", init_args.get("learning_rate", 1e-3))
|
|
266
|
-
init_args.setdefault("max_epochs", init_args.get("epochs", 100))
|
|
267
|
-
return init_args
|
|
268
|
-
|
|
269
|
-
def _get_estimator_init_args(self) -> Dict[str, Any]:
|
|
270
|
-
"""Get GluonTS specific constructor arguments for estimator objects, an alias to `self._get_model_params`
|
|
271
|
-
for better readability."""
|
|
272
|
-
return self._get_model_params()
|
|
273
|
-
|
|
274
|
-
def _get_estimator_class(self) -> Type[GluonTSEstimator]:
|
|
275
|
-
raise NotImplementedError
|
|
276
|
-
|
|
277
|
-
def _get_estimator(self) -> GluonTSEstimator:
|
|
278
|
-
"""Return the GluonTS Estimator object for the model"""
|
|
279
|
-
# As GluonTSPyTorchLightningEstimator objects do not implement `from_hyperparameters` convenience
|
|
280
|
-
# constructors, we re-implement the logic here.
|
|
281
|
-
# we translate the "epochs" parameter to "max_epochs" for consistency in the AbstractGluonTSModel interface
|
|
282
|
-
init_args = self._get_estimator_init_args()
|
|
283
|
-
|
|
284
|
-
default_trainer_kwargs = {
|
|
285
|
-
"limit_val_batches": 3,
|
|
286
|
-
"max_epochs": init_args["max_epochs"],
|
|
287
|
-
"callbacks": init_args["callbacks"],
|
|
288
|
-
"enable_progress_bar": False,
|
|
289
|
-
"default_root_dir": self.path,
|
|
290
|
-
}
|
|
291
|
-
|
|
292
|
-
if self._is_gpu_available():
|
|
293
|
-
default_trainer_kwargs["accelerator"] = "gpu"
|
|
294
|
-
default_trainer_kwargs["devices"] = 1
|
|
295
|
-
else:
|
|
296
|
-
default_trainer_kwargs["accelerator"] = "cpu"
|
|
297
|
-
|
|
298
|
-
default_trainer_kwargs.update(init_args.pop("trainer_kwargs", {}))
|
|
299
|
-
logger.debug(f"\tTraining on device '{default_trainer_kwargs['accelerator']}'")
|
|
300
|
-
|
|
301
|
-
return from_hyperparameters(
|
|
302
|
-
self._get_estimator_class(),
|
|
303
|
-
trainer_kwargs=default_trainer_kwargs,
|
|
304
|
-
**init_args,
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
def _is_gpu_available(self) -> bool:
|
|
308
|
-
import torch.cuda
|
|
309
|
-
|
|
310
|
-
return torch.cuda.is_available()
|
|
311
|
-
|
|
312
|
-
def get_minimum_resources(self, is_gpu_available: bool = False) -> Dict[str, Union[int, float]]:
|
|
313
|
-
minimum_resources = {"num_cpus": 1}
|
|
314
|
-
# if GPU is available, we train with 1 GPU per trial
|
|
315
|
-
if is_gpu_available:
|
|
316
|
-
minimum_resources["num_gpus"] = 1
|
|
317
|
-
return minimum_resources
|
|
318
|
-
|
|
319
|
-
def _to_gluonts_dataset(
|
|
320
|
-
self, time_series_df: Optional[TimeSeriesDataFrame], known_covariates: Optional[TimeSeriesDataFrame] = None
|
|
321
|
-
) -> Optional[GluonTSDataset]:
|
|
322
|
-
if time_series_df is not None:
|
|
323
|
-
# TODO: Preprocess real-valued features with StdScaler?
|
|
324
|
-
if self.num_feat_static_cat > 0:
|
|
325
|
-
feat_static_cat = time_series_df.static_features[self.metadata.static_features_cat]
|
|
326
|
-
else:
|
|
327
|
-
feat_static_cat = None
|
|
328
|
-
|
|
329
|
-
if self.num_feat_static_real > 0:
|
|
330
|
-
feat_static_real = time_series_df.static_features[self.metadata.static_features_real]
|
|
331
|
-
if feat_static_real.isna().values.any():
|
|
332
|
-
feat_static_real = feat_static_real.fillna(feat_static_real.mean())
|
|
333
|
-
else:
|
|
334
|
-
feat_static_real = None
|
|
335
|
-
|
|
336
|
-
if self.num_feat_dynamic_real > 0:
|
|
337
|
-
# Convert TSDF -> DF to avoid overhead / input validation
|
|
338
|
-
feat_dynamic_real = pd.DataFrame(time_series_df[self.metadata.known_covariates_real])
|
|
339
|
-
# Append future values of known covariates
|
|
340
|
-
if known_covariates is not None:
|
|
341
|
-
feat_dynamic_real = pd.concat([feat_dynamic_real, known_covariates], axis=0)
|
|
342
|
-
expected_length = len(time_series_df) + self.prediction_length * time_series_df.num_items
|
|
343
|
-
if len(feat_dynamic_real) != expected_length:
|
|
344
|
-
raise ValueError(
|
|
345
|
-
f"known_covariates must contain values for the next prediction_length = "
|
|
346
|
-
f"{self.prediction_length} time steps in each time series."
|
|
347
|
-
)
|
|
348
|
-
else:
|
|
349
|
-
feat_dynamic_real = None
|
|
350
|
-
|
|
351
|
-
if self.num_past_feat_dynamic_real > 0:
|
|
352
|
-
# Convert TSDF -> DF to avoid overhead / input validation
|
|
353
|
-
past_feat_dynamic_real = pd.DataFrame(time_series_df[self.metadata.past_covariates_real])
|
|
354
|
-
else:
|
|
355
|
-
past_feat_dynamic_real = None
|
|
356
|
-
|
|
357
|
-
return SimpleGluonTSDataset(
|
|
358
|
-
target_df=time_series_df,
|
|
359
|
-
target_column=self.target,
|
|
360
|
-
feat_static_cat=feat_static_cat,
|
|
361
|
-
feat_static_real=feat_static_real,
|
|
362
|
-
feat_dynamic_real=feat_dynamic_real,
|
|
363
|
-
past_feat_dynamic_real=past_feat_dynamic_real,
|
|
364
|
-
includes_future=known_covariates is not None,
|
|
365
|
-
prediction_length=self.prediction_length,
|
|
366
|
-
)
|
|
367
|
-
else:
|
|
368
|
-
return None
|
|
369
|
-
|
|
370
|
-
def _fit(
|
|
371
|
-
self,
|
|
372
|
-
train_data: TimeSeriesDataFrame,
|
|
373
|
-
val_data: Optional[TimeSeriesDataFrame] = None,
|
|
374
|
-
time_limit: int = None,
|
|
375
|
-
**kwargs,
|
|
376
|
-
) -> None:
|
|
377
|
-
# necessary to initialize the loggers
|
|
378
|
-
import lightning.pytorch # noqa
|
|
379
|
-
|
|
380
|
-
verbosity = kwargs.get("verbosity", 2)
|
|
381
|
-
for logger_name in logging.root.manager.loggerDict:
|
|
382
|
-
if "lightning" in logger_name:
|
|
383
|
-
pl_logger = logging.getLogger(logger_name)
|
|
384
|
-
pl_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
|
|
385
|
-
gts_logger.setLevel(logging.ERROR if verbosity <= 3 else logging.INFO)
|
|
386
|
-
|
|
387
|
-
if verbosity > 3:
|
|
388
|
-
logger.warning(
|
|
389
|
-
"GluonTS logging is turned on during training. Note that losses reported by GluonTS "
|
|
390
|
-
"may not correspond to those specified via `eval_metric`."
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
self._check_fit_params()
|
|
394
|
-
# update auxiliary parameters
|
|
395
|
-
init_args = self._get_estimator_init_args()
|
|
396
|
-
keep_lightning_logs = init_args.pop("keep_lightning_logs", False)
|
|
397
|
-
callbacks = self._get_callbacks(
|
|
398
|
-
time_limit=time_limit,
|
|
399
|
-
early_stopping_patience=None if val_data is None else init_args["early_stopping_patience"],
|
|
400
|
-
)
|
|
401
|
-
self._deferred_init_params_aux(dataset=train_data, callbacks=callbacks)
|
|
402
|
-
|
|
403
|
-
estimator = self._get_estimator()
|
|
404
|
-
with warning_filter(), disable_root_logger(), gluonts.core.settings.let(gluonts.env.env, use_tqdm=False):
|
|
405
|
-
self.gts_predictor = estimator.train(
|
|
406
|
-
self._to_gluonts_dataset(train_data),
|
|
407
|
-
validation_data=self._to_gluonts_dataset(val_data),
|
|
408
|
-
cache_data=True,
|
|
409
|
-
)
|
|
410
|
-
# Increase batch size during prediction to speed up inference
|
|
411
|
-
if init_args["predict_batch_size"] is not None:
|
|
412
|
-
self.gts_predictor.batch_size = init_args["predict_batch_size"]
|
|
413
|
-
|
|
414
|
-
lightning_logs_dir = Path(self.path) / "lightning_logs"
|
|
415
|
-
if not keep_lightning_logs and lightning_logs_dir.exists() and lightning_logs_dir.is_dir():
|
|
416
|
-
logger.debug(f"Removing lightning_logs directory {lightning_logs_dir}")
|
|
417
|
-
shutil.rmtree(lightning_logs_dir)
|
|
418
|
-
|
|
419
|
-
def _get_callbacks(
|
|
420
|
-
self,
|
|
421
|
-
time_limit: int,
|
|
422
|
-
early_stopping_patience: Optional[int] = None,
|
|
423
|
-
) -> List[Callable]:
|
|
424
|
-
"""Retrieve a list of callback objects for the GluonTS trainer"""
|
|
425
|
-
from lightning.pytorch.callbacks import EarlyStopping, Timer
|
|
426
|
-
|
|
427
|
-
callbacks = []
|
|
428
|
-
if time_limit is not None:
|
|
429
|
-
callbacks.append(Timer(timedelta(seconds=time_limit)))
|
|
430
|
-
if early_stopping_patience is not None:
|
|
431
|
-
callbacks.append(EarlyStopping(monitor="val_loss", patience=early_stopping_patience))
|
|
432
|
-
return callbacks
|
|
433
|
-
|
|
434
|
-
def _predict(
|
|
435
|
-
self,
|
|
436
|
-
data: TimeSeriesDataFrame,
|
|
437
|
-
known_covariates: Optional[TimeSeriesDataFrame] = None,
|
|
438
|
-
**kwargs,
|
|
439
|
-
) -> TimeSeriesDataFrame:
|
|
440
|
-
if self.gts_predictor is None:
|
|
441
|
-
raise ValueError("Please fit the model before predicting.")
|
|
442
|
-
|
|
443
|
-
with warning_filter(), gluonts.core.settings.let(gluonts.env.env, use_tqdm=False):
|
|
444
|
-
predicted_targets = self._predict_gluonts_forecasts(data, known_covariates=known_covariates, **kwargs)
|
|
445
|
-
df = self._gluonts_forecasts_to_data_frame(
|
|
446
|
-
predicted_targets,
|
|
447
|
-
forecast_index=get_forecast_horizon_index_ts_dataframe(data, self.prediction_length),
|
|
448
|
-
)
|
|
449
|
-
return df
|
|
450
|
-
|
|
451
|
-
def _predict_gluonts_forecasts(
|
|
452
|
-
self, data: TimeSeriesDataFrame, known_covariates: Optional[TimeSeriesDataFrame] = None, **kwargs
|
|
453
|
-
) -> List[Forecast]:
|
|
454
|
-
gts_data = self._to_gluonts_dataset(data, known_covariates=known_covariates)
|
|
455
|
-
|
|
456
|
-
predictor_kwargs = dict(dataset=gts_data)
|
|
457
|
-
predictor_kwargs["num_samples"] = kwargs.get("num_samples", self.default_num_samples)
|
|
458
|
-
|
|
459
|
-
return list(self.gts_predictor.predict(**predictor_kwargs))
|
|
460
|
-
|
|
461
|
-
def _stack_quantile_forecasts(self, forecasts: List[QuantileForecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
462
|
-
# GluonTS always saves item_id as a string
|
|
463
|
-
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
464
|
-
result_dfs = []
|
|
465
|
-
for item_id in item_ids:
|
|
466
|
-
forecast = item_id_to_forecast[str(item_id)]
|
|
467
|
-
result_dfs.append(pd.DataFrame(forecast.forecast_array.T, columns=forecast.forecast_keys))
|
|
468
|
-
forecast_df = pd.concat(result_dfs)
|
|
469
|
-
if "mean" not in forecast_df.columns:
|
|
470
|
-
forecast_df["mean"] = forecast_df["0.5"]
|
|
471
|
-
columns_order = ["mean"] + [str(q) for q in self.quantile_levels]
|
|
472
|
-
return forecast_df[columns_order]
|
|
473
|
-
|
|
474
|
-
def _stack_sample_forecasts(self, forecasts: List[SampleForecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
475
|
-
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
476
|
-
samples_per_item = []
|
|
477
|
-
for item_id in item_ids:
|
|
478
|
-
forecast = item_id_to_forecast[str(item_id)]
|
|
479
|
-
samples_per_item.append(forecast.samples.T)
|
|
480
|
-
samples = np.concatenate(samples_per_item, axis=0)
|
|
481
|
-
quantiles = np.quantile(samples, self.quantile_levels, axis=1).T
|
|
482
|
-
mean = samples.mean(axis=1, keepdims=True)
|
|
483
|
-
forecast_array = np.concatenate([mean, quantiles], axis=1)
|
|
484
|
-
return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
|
|
485
|
-
|
|
486
|
-
def _stack_distribution_forecasts(self, forecasts: List[Forecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
487
|
-
import torch
|
|
488
|
-
from gluonts.torch.distributions import AffineTransformed
|
|
489
|
-
from torch.distributions import Distribution
|
|
490
|
-
|
|
491
|
-
# Sort forecasts in the same order as in the dataset
|
|
492
|
-
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
493
|
-
forecasts = [item_id_to_forecast[str(item_id)] for item_id in item_ids]
|
|
494
|
-
|
|
495
|
-
def stack_distributions(distributions: List[Distribution]) -> Distribution:
|
|
496
|
-
"""Stack multiple torch.Distribution objects into a single distribution"""
|
|
497
|
-
params_per_dist = []
|
|
498
|
-
for dist in distributions:
|
|
499
|
-
params = {name: getattr(dist, name) for name in dist.arg_constraints.keys()}
|
|
500
|
-
params_per_dist.append(params)
|
|
501
|
-
# Make sure that all distributions have same keys
|
|
502
|
-
assert len(set(tuple(p.keys()) for p in params_per_dist)) == 1
|
|
503
|
-
|
|
504
|
-
stacked_params = {}
|
|
505
|
-
for key in dist.arg_constraints.keys():
|
|
506
|
-
stacked_params[key] = torch.cat([p[key] for p in params_per_dist])
|
|
507
|
-
return dist.__class__(**stacked_params)
|
|
508
|
-
|
|
509
|
-
if not isinstance(forecasts[0].distribution, AffineTransformed):
|
|
510
|
-
raise AssertionError("Expected forecast.distribution to be an instance of AffineTransformed")
|
|
511
|
-
|
|
512
|
-
# We stack all forecast distribution into a single Distribution object.
|
|
513
|
-
# This dramatically speeds up the quantiles calculation.
|
|
514
|
-
stacked_base_dist = stack_distributions([f.distribution.base_dist for f in forecasts])
|
|
515
|
-
|
|
516
|
-
stacked_loc = torch.cat([f.distribution.loc for f in forecasts])
|
|
517
|
-
if stacked_loc.shape != stacked_base_dist.batch_shape:
|
|
518
|
-
stacked_loc = stacked_loc.repeat_interleave(self.prediction_length)
|
|
519
|
-
|
|
520
|
-
stacked_scale = torch.cat([f.distribution.scale for f in forecasts])
|
|
521
|
-
if stacked_scale.shape != stacked_base_dist.batch_shape:
|
|
522
|
-
stacked_scale = stacked_scale.repeat_interleave(self.prediction_length)
|
|
523
|
-
|
|
524
|
-
stacked_dist = AffineTransformed(stacked_base_dist, loc=stacked_loc, scale=stacked_scale)
|
|
525
|
-
|
|
526
|
-
mean_prediction = stacked_dist.mean.cpu().detach().numpy()
|
|
527
|
-
quantiles = torch.tensor(self.quantile_levels, device=stacked_dist.mean.device).reshape(-1, 1)
|
|
528
|
-
quantile_predictions = stacked_dist.icdf(quantiles).cpu().detach().numpy()
|
|
529
|
-
forecast_array = np.vstack([mean_prediction, quantile_predictions]).T
|
|
530
|
-
return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
|
|
531
|
-
|
|
532
|
-
def _gluonts_forecasts_to_data_frame(
|
|
533
|
-
self,
|
|
534
|
-
forecasts: List[Forecast],
|
|
535
|
-
forecast_index: pd.MultiIndex,
|
|
536
|
-
) -> TimeSeriesDataFrame:
|
|
537
|
-
from gluonts.torch.model.forecast import DistributionForecast
|
|
538
|
-
|
|
539
|
-
item_ids = forecast_index.unique(level=ITEMID)
|
|
540
|
-
if isinstance(forecasts[0], SampleForecast):
|
|
541
|
-
forecast_df = self._stack_sample_forecasts(forecasts, item_ids)
|
|
542
|
-
elif isinstance(forecasts[0], QuantileForecast):
|
|
543
|
-
forecast_df = self._stack_quantile_forecasts(forecasts, item_ids)
|
|
544
|
-
elif isinstance(forecasts[0], DistributionForecast):
|
|
545
|
-
forecast_df = self._stack_distribution_forecasts(forecasts, item_ids)
|
|
546
|
-
else:
|
|
547
|
-
raise ValueError(f"Unrecognized forecast type {type(forecasts[0])}")
|
|
548
|
-
|
|
549
|
-
forecast_df.index = forecast_index
|
|
550
|
-
return TimeSeriesDataFrame(forecast_df)
|
|
File without changes
|