autogluon.timeseries 1.2.1b20250224__py3-none-any.whl → 1.4.1b20251215__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +106 -0
- autogluon/timeseries/dataset/ts_dataframe.py +256 -141
- autogluon/timeseries/learner.py +86 -52
- autogluon/timeseries/metrics/__init__.py +42 -8
- autogluon/timeseries/metrics/abstract.py +89 -19
- autogluon/timeseries/metrics/point.py +142 -53
- autogluon/timeseries/metrics/quantile.py +46 -21
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +8 -2
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +361 -592
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +282 -194
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +25 -18
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +219 -138
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +81 -50
- autogluon/timeseries/models/ensemble/__init__.py +37 -2
- autogluon/timeseries/models/ensemble/abstract.py +107 -0
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +91 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +62 -0
- autogluon/timeseries/models/gluonts/__init__.py +1 -1
- autogluon/timeseries/models/gluonts/{abstract_gluonts.py → abstract.py} +148 -208
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +38 -22
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +71 -74
- autogluon/timeseries/models/local/naive.py +13 -9
- autogluon/timeseries/models/local/npts.py +9 -2
- autogluon/timeseries/models/local/statsforecast.py +52 -36
- autogluon/timeseries/models/multi_window/multi_window_model.py +65 -45
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
- autogluon/timeseries/models/toto/model.py +249 -0
- autogluon/timeseries/predictor.py +685 -297
- autogluon/timeseries/regressor.py +94 -44
- autogluon/timeseries/splitter.py +8 -32
- autogluon/timeseries/trainer/__init__.py +3 -0
- autogluon/timeseries/trainer/ensemble_composer.py +444 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/{trainer.py → trainer/trainer.py} +387 -390
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -13
- autogluon/timeseries/transforms/covariate_scaler.py +34 -40
- autogluon/timeseries/transforms/target_scaler.py +37 -20
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +3 -5
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/datetime/time_features.py +2 -2
- autogluon/timeseries/utils/features.py +70 -47
- autogluon/timeseries/utils/forecast.py +19 -14
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +4 -2
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.4.1b20251215-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/METADATA +49 -36
- autogluon_timeseries-1.4.1b20251215.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -79
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -11
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -585
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -518
- autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -78
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -360
- autogluon.timeseries-1.2.1b20250224-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.2.1b20250224.dist-info/RECORD +0 -68
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/zip-safe +0 -0
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
import shutil
|
|
4
4
|
from datetime import timedelta
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable,
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Type, cast, overload
|
|
7
7
|
|
|
8
8
|
import gluonts
|
|
9
9
|
import gluonts.core.settings
|
|
@@ -11,7 +11,7 @@ import numpy as np
|
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from gluonts.core.component import from_hyperparameters
|
|
13
13
|
from gluonts.dataset.common import Dataset as GluonTSDataset
|
|
14
|
-
from gluonts.
|
|
14
|
+
from gluonts.env import env as gluonts_env
|
|
15
15
|
from gluonts.model.estimator import Estimator as GluonTSEstimator
|
|
16
16
|
from gluonts.model.forecast import Forecast, QuantileForecast, SampleForecast
|
|
17
17
|
from gluonts.model.predictor import Predictor as GluonTSPredictor
|
|
@@ -21,11 +21,15 @@ from autogluon.core.hpo.constants import RAY_BACKEND
|
|
|
21
21
|
from autogluon.tabular.models.tabular_nn.utils.categorical_encoders import (
|
|
22
22
|
OneHotMergeRaresHandleUnknownEncoder as OneHotEncoder,
|
|
23
23
|
)
|
|
24
|
-
from autogluon.timeseries.dataset
|
|
24
|
+
from autogluon.timeseries.dataset import TimeSeriesDataFrame
|
|
25
25
|
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
26
|
-
from autogluon.timeseries.utils.datetime import norm_freq_str
|
|
27
26
|
from autogluon.timeseries.utils.warning_filters import disable_root_logger, warning_filter
|
|
28
27
|
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from gluonts.torch.model.forecast import DistributionForecast
|
|
30
|
+
|
|
31
|
+
from .dataset import SimpleGluonTSDataset
|
|
32
|
+
|
|
29
33
|
# NOTE: We avoid imports for torch and lightning.pytorch at the top level and hide them inside class methods.
|
|
30
34
|
# This is done to skip these imports during multiprocessing (which may cause bugs)
|
|
31
35
|
|
|
@@ -33,124 +37,25 @@ logger = logging.getLogger(__name__)
|
|
|
33
37
|
gts_logger = logging.getLogger(gluonts.__name__)
|
|
34
38
|
|
|
35
39
|
|
|
36
|
-
class SimpleGluonTSDataset(GluonTSDataset):
|
|
37
|
-
"""Wrapper for TimeSeriesDataFrame that is compatible with the GluonTS Dataset API."""
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
target_df: TimeSeriesDataFrame,
|
|
42
|
-
freq: str,
|
|
43
|
-
target_column: str = "target",
|
|
44
|
-
feat_static_cat: Optional[np.ndarray] = None,
|
|
45
|
-
feat_static_real: Optional[np.ndarray] = None,
|
|
46
|
-
feat_dynamic_cat: Optional[np.ndarray] = None,
|
|
47
|
-
feat_dynamic_real: Optional[np.ndarray] = None,
|
|
48
|
-
past_feat_dynamic_cat: Optional[np.ndarray] = None,
|
|
49
|
-
past_feat_dynamic_real: Optional[np.ndarray] = None,
|
|
50
|
-
includes_future: bool = False,
|
|
51
|
-
prediction_length: int = None,
|
|
52
|
-
):
|
|
53
|
-
assert target_df is not None
|
|
54
|
-
# Convert TimeSeriesDataFrame to pd.Series for faster processing
|
|
55
|
-
self.target_array = target_df[target_column].to_numpy(np.float32)
|
|
56
|
-
self.feat_static_cat = self._astype(feat_static_cat, dtype=np.int64)
|
|
57
|
-
self.feat_static_real = self._astype(feat_static_real, dtype=np.float32)
|
|
58
|
-
self.feat_dynamic_cat = self._astype(feat_dynamic_cat, dtype=np.int64)
|
|
59
|
-
self.feat_dynamic_real = self._astype(feat_dynamic_real, dtype=np.float32)
|
|
60
|
-
self.past_feat_dynamic_cat = self._astype(past_feat_dynamic_cat, dtype=np.int64)
|
|
61
|
-
self.past_feat_dynamic_real = self._astype(past_feat_dynamic_real, dtype=np.float32)
|
|
62
|
-
self.freq = self._get_freq_for_period(freq)
|
|
63
|
-
|
|
64
|
-
# Necessary to compute indptr for known_covariates at prediction time
|
|
65
|
-
self.includes_future = includes_future
|
|
66
|
-
self.prediction_length = prediction_length
|
|
67
|
-
|
|
68
|
-
# Replace inefficient groupby ITEMID with indptr that stores start:end of each time series
|
|
69
|
-
item_id_index = target_df.index.get_level_values(ITEMID)
|
|
70
|
-
indices_sizes = item_id_index.value_counts(sort=False)
|
|
71
|
-
self.item_ids = indices_sizes.index # shape [num_items]
|
|
72
|
-
cum_sizes = indices_sizes.to_numpy().cumsum()
|
|
73
|
-
self.indptr = np.append(0, cum_sizes).astype(np.int32)
|
|
74
|
-
self.start_timestamps = target_df.reset_index(TIMESTAMP).groupby(level=ITEMID, sort=False).first()[TIMESTAMP]
|
|
75
|
-
assert len(self.item_ids) == len(self.start_timestamps)
|
|
76
|
-
|
|
77
|
-
@staticmethod
|
|
78
|
-
def _astype(array: Optional[np.ndarray], dtype: np.dtype) -> Optional[np.ndarray]:
|
|
79
|
-
if array is None:
|
|
80
|
-
return None
|
|
81
|
-
else:
|
|
82
|
-
return array.astype(dtype)
|
|
83
|
-
|
|
84
|
-
@staticmethod
|
|
85
|
-
def _get_freq_for_period(freq: str) -> str:
|
|
86
|
-
"""Convert freq to format compatible with pd.Period.
|
|
87
|
-
|
|
88
|
-
For example, ME freq must be converted to M when creating a pd.Period.
|
|
89
|
-
"""
|
|
90
|
-
offset = pd.tseries.frequencies.to_offset(freq)
|
|
91
|
-
freq_name = norm_freq_str(offset)
|
|
92
|
-
if freq_name == "SME":
|
|
93
|
-
# Replace unsupported frequency "SME" with "2W"
|
|
94
|
-
return "2W"
|
|
95
|
-
elif freq_name == "bh":
|
|
96
|
-
# Replace unsupported frequency "bh" with dummy value "Y"
|
|
97
|
-
return "Y"
|
|
98
|
-
else:
|
|
99
|
-
freq_name_for_period = {"YE": "Y", "QE": "Q", "ME": "M"}.get(freq_name, freq_name)
|
|
100
|
-
return f"{offset.n}{freq_name_for_period}"
|
|
101
|
-
|
|
102
|
-
def __len__(self):
|
|
103
|
-
return len(self.indptr) - 1 # noqa
|
|
104
|
-
|
|
105
|
-
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
106
|
-
for j in range(len(self.indptr) - 1):
|
|
107
|
-
start_idx = self.indptr[j]
|
|
108
|
-
end_idx = self.indptr[j + 1]
|
|
109
|
-
# GluonTS expects item_id to be a string
|
|
110
|
-
ts = {
|
|
111
|
-
FieldName.ITEM_ID: str(self.item_ids[j]),
|
|
112
|
-
FieldName.START: pd.Period(self.start_timestamps.iloc[j], freq=self.freq),
|
|
113
|
-
FieldName.TARGET: self.target_array[start_idx:end_idx],
|
|
114
|
-
}
|
|
115
|
-
if self.feat_static_cat is not None:
|
|
116
|
-
ts[FieldName.FEAT_STATIC_CAT] = self.feat_static_cat[j]
|
|
117
|
-
if self.feat_static_real is not None:
|
|
118
|
-
ts[FieldName.FEAT_STATIC_REAL] = self.feat_static_real[j]
|
|
119
|
-
if self.past_feat_dynamic_cat is not None:
|
|
120
|
-
ts[FieldName.PAST_FEAT_DYNAMIC_CAT] = self.past_feat_dynamic_cat[start_idx:end_idx].T
|
|
121
|
-
if self.past_feat_dynamic_real is not None:
|
|
122
|
-
ts[FieldName.PAST_FEAT_DYNAMIC_REAL] = self.past_feat_dynamic_real[start_idx:end_idx].T
|
|
123
|
-
|
|
124
|
-
# Dynamic features that may extend into the future
|
|
125
|
-
if self.includes_future:
|
|
126
|
-
start_idx = start_idx + j * self.prediction_length
|
|
127
|
-
end_idx = end_idx + (j + 1) * self.prediction_length
|
|
128
|
-
if self.feat_dynamic_cat is not None:
|
|
129
|
-
ts[FieldName.FEAT_DYNAMIC_CAT] = self.feat_dynamic_cat[start_idx:end_idx].T
|
|
130
|
-
if self.feat_dynamic_real is not None:
|
|
131
|
-
ts[FieldName.FEAT_DYNAMIC_REAL] = self.feat_dynamic_real[start_idx:end_idx].T
|
|
132
|
-
yield ts
|
|
133
|
-
|
|
134
|
-
|
|
135
40
|
class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
136
41
|
"""Abstract class wrapping GluonTS estimators for use in autogluon.timeseries.
|
|
137
42
|
|
|
138
43
|
Parameters
|
|
139
44
|
----------
|
|
140
|
-
path
|
|
45
|
+
path
|
|
141
46
|
directory to store model artifacts.
|
|
142
|
-
freq
|
|
47
|
+
freq
|
|
143
48
|
string representation (compatible with GluonTS frequency strings) for the data provided.
|
|
144
49
|
For example, "1D" for daily data, "1H" for hourly data, etc.
|
|
145
|
-
prediction_length
|
|
50
|
+
prediction_length
|
|
146
51
|
Number of time steps ahead (length of the forecast horizon) the model will be optimized
|
|
147
52
|
to predict. At inference time, this will be the number of time steps the model will
|
|
148
53
|
predict.
|
|
149
|
-
name
|
|
54
|
+
name
|
|
150
55
|
Name of the model. Also, name of subdirectory inside path where model will be saved.
|
|
151
|
-
eval_metric
|
|
56
|
+
eval_metric
|
|
152
57
|
objective function the model intends to optimize, will use WQL by default.
|
|
153
|
-
hyperparameters
|
|
58
|
+
hyperparameters
|
|
154
59
|
various hyperparameters that will be used by model (can be search spaces instead of
|
|
155
60
|
fixed values). See *Other Parameters* in each inheriting model's documentation for
|
|
156
61
|
possible values.
|
|
@@ -167,12 +72,12 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
167
72
|
|
|
168
73
|
def __init__(
|
|
169
74
|
self,
|
|
170
|
-
freq:
|
|
75
|
+
freq: str | None = None,
|
|
171
76
|
prediction_length: int = 1,
|
|
172
|
-
path:
|
|
173
|
-
name:
|
|
174
|
-
eval_metric: str = None,
|
|
175
|
-
hyperparameters:
|
|
77
|
+
path: str | None = None,
|
|
78
|
+
name: str | None = None,
|
|
79
|
+
eval_metric: str | None = None,
|
|
80
|
+
hyperparameters: dict[str, Any] | None = None,
|
|
176
81
|
**kwargs, # noqa
|
|
177
82
|
):
|
|
178
83
|
super().__init__(
|
|
@@ -184,9 +89,9 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
184
89
|
hyperparameters=hyperparameters,
|
|
185
90
|
**kwargs,
|
|
186
91
|
)
|
|
187
|
-
self.gts_predictor:
|
|
188
|
-
self._ohe_generator_known:
|
|
189
|
-
self._ohe_generator_past:
|
|
92
|
+
self.gts_predictor: GluonTSPredictor | None = None
|
|
93
|
+
self._ohe_generator_known: OneHotEncoder | None = None
|
|
94
|
+
self._ohe_generator_past: OneHotEncoder | None = None
|
|
190
95
|
self.callbacks = []
|
|
191
96
|
# Following attributes may be overridden during fit() based on train_data & model parameters
|
|
192
97
|
self.num_feat_static_cat = 0
|
|
@@ -195,30 +100,32 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
195
100
|
self.num_feat_dynamic_real = 0
|
|
196
101
|
self.num_past_feat_dynamic_cat = 0
|
|
197
102
|
self.num_past_feat_dynamic_real = 0
|
|
198
|
-
self.feat_static_cat_cardinality:
|
|
199
|
-
self.feat_dynamic_cat_cardinality:
|
|
200
|
-
self.past_feat_dynamic_cat_cardinality:
|
|
103
|
+
self.feat_static_cat_cardinality: list[int] = []
|
|
104
|
+
self.feat_dynamic_cat_cardinality: list[int] = []
|
|
105
|
+
self.past_feat_dynamic_cat_cardinality: list[int] = []
|
|
201
106
|
self.negative_data = True
|
|
202
107
|
|
|
203
|
-
def save(self, path: str = None, verbose: bool = True) -> str:
|
|
108
|
+
def save(self, path: str | None = None, verbose: bool = True) -> str:
|
|
204
109
|
# we flush callbacks instance variable if it has been set. it can keep weak references which breaks training
|
|
205
110
|
self.callbacks = []
|
|
206
111
|
# The GluonTS predictor is serialized using custom logic
|
|
207
112
|
predictor = self.gts_predictor
|
|
208
113
|
self.gts_predictor = None
|
|
209
|
-
|
|
114
|
+
saved_path = Path(super().save(path=path, verbose=verbose))
|
|
210
115
|
|
|
211
116
|
with disable_root_logger():
|
|
212
117
|
if predictor:
|
|
213
|
-
Path.mkdir(
|
|
214
|
-
predictor.serialize(
|
|
118
|
+
Path.mkdir(saved_path / self.gluonts_model_path, exist_ok=True)
|
|
119
|
+
predictor.serialize(saved_path / self.gluonts_model_path)
|
|
215
120
|
|
|
216
121
|
self.gts_predictor = predictor
|
|
217
122
|
|
|
218
|
-
return str(
|
|
123
|
+
return str(saved_path)
|
|
219
124
|
|
|
220
125
|
@classmethod
|
|
221
|
-
def load(
|
|
126
|
+
def load(
|
|
127
|
+
cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
|
|
128
|
+
) -> "AbstractGluonTSModel":
|
|
222
129
|
from gluonts.torch.model.predictor import PyTorchPredictor
|
|
223
130
|
|
|
224
131
|
with warning_filter():
|
|
@@ -235,31 +142,33 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
235
142
|
def _get_hpo_backend(self):
|
|
236
143
|
return RAY_BACKEND
|
|
237
144
|
|
|
238
|
-
def
|
|
239
|
-
"""Update GluonTS specific
|
|
240
|
-
model_params = self.
|
|
145
|
+
def _deferred_init_hyperparameters(self, dataset: TimeSeriesDataFrame) -> None:
|
|
146
|
+
"""Update GluonTS specific hyperparameters with information available only at training time."""
|
|
147
|
+
model_params = self.get_hyperparameters()
|
|
241
148
|
disable_static_features = model_params.get("disable_static_features", False)
|
|
242
149
|
if not disable_static_features:
|
|
243
|
-
self.num_feat_static_cat = len(self.
|
|
244
|
-
self.num_feat_static_real = len(self.
|
|
150
|
+
self.num_feat_static_cat = len(self.covariate_metadata.static_features_cat)
|
|
151
|
+
self.num_feat_static_real = len(self.covariate_metadata.static_features_real)
|
|
245
152
|
if self.num_feat_static_cat > 0:
|
|
246
|
-
|
|
247
|
-
|
|
153
|
+
assert dataset.static_features is not None, (
|
|
154
|
+
"Static features must be provided if num_feat_static_cat > 0"
|
|
155
|
+
)
|
|
156
|
+
self.feat_static_cat_cardinality = list(self.covariate_metadata.static_cat_cardinality.values())
|
|
248
157
|
|
|
249
158
|
disable_known_covariates = model_params.get("disable_known_covariates", False)
|
|
250
159
|
if not disable_known_covariates and self.supports_known_covariates:
|
|
251
|
-
self.num_feat_dynamic_cat = len(self.
|
|
252
|
-
self.num_feat_dynamic_real = len(self.
|
|
160
|
+
self.num_feat_dynamic_cat = len(self.covariate_metadata.known_covariates_cat)
|
|
161
|
+
self.num_feat_dynamic_real = len(self.covariate_metadata.known_covariates_real)
|
|
253
162
|
if self.num_feat_dynamic_cat > 0:
|
|
254
|
-
feat_dynamic_cat = dataset[self.metadata.known_covariates_cat]
|
|
255
163
|
if self.supports_cat_covariates:
|
|
256
|
-
self.feat_dynamic_cat_cardinality =
|
|
164
|
+
self.feat_dynamic_cat_cardinality = list(self.covariate_metadata.known_cat_cardinality.values())
|
|
257
165
|
else:
|
|
166
|
+
feat_dynamic_cat = dataset[self.covariate_metadata.known_covariates_cat]
|
|
258
167
|
# If model doesn't support categorical covariates, convert them to real via one hot encoding
|
|
259
168
|
self._ohe_generator_known = OneHotEncoder(
|
|
260
169
|
max_levels=model_params.get("max_cat_cardinality", 100),
|
|
261
170
|
sparse=False,
|
|
262
|
-
dtype="float32",
|
|
171
|
+
dtype="float32", # type: ignore
|
|
263
172
|
)
|
|
264
173
|
feat_dynamic_cat_ohe = self._ohe_generator_known.fit_transform(pd.DataFrame(feat_dynamic_cat))
|
|
265
174
|
self.num_feat_dynamic_cat = 0
|
|
@@ -267,18 +176,20 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
267
176
|
|
|
268
177
|
disable_past_covariates = model_params.get("disable_past_covariates", False)
|
|
269
178
|
if not disable_past_covariates and self.supports_past_covariates:
|
|
270
|
-
self.num_past_feat_dynamic_cat = len(self.
|
|
271
|
-
self.num_past_feat_dynamic_real = len(self.
|
|
179
|
+
self.num_past_feat_dynamic_cat = len(self.covariate_metadata.past_covariates_cat)
|
|
180
|
+
self.num_past_feat_dynamic_real = len(self.covariate_metadata.past_covariates_real)
|
|
272
181
|
if self.num_past_feat_dynamic_cat > 0:
|
|
273
|
-
past_feat_dynamic_cat = dataset[self.metadata.past_covariates_cat]
|
|
274
182
|
if self.supports_cat_covariates:
|
|
275
|
-
self.past_feat_dynamic_cat_cardinality =
|
|
183
|
+
self.past_feat_dynamic_cat_cardinality = list(
|
|
184
|
+
self.covariate_metadata.past_cat_cardinality.values()
|
|
185
|
+
)
|
|
276
186
|
else:
|
|
187
|
+
past_feat_dynamic_cat = dataset[self.covariate_metadata.past_covariates_cat]
|
|
277
188
|
# If model doesn't support categorical covariates, convert them to real via one hot encoding
|
|
278
189
|
self._ohe_generator_past = OneHotEncoder(
|
|
279
190
|
max_levels=model_params.get("max_cat_cardinality", 100),
|
|
280
191
|
sparse=False,
|
|
281
|
-
dtype="float32",
|
|
192
|
+
dtype="float32", # type: ignore
|
|
282
193
|
)
|
|
283
194
|
past_feat_dynamic_cat_ohe = self._ohe_generator_past.fit_transform(
|
|
284
195
|
pd.DataFrame(past_feat_dynamic_cat)
|
|
@@ -288,7 +199,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
288
199
|
|
|
289
200
|
self.negative_data = (dataset[self.target] < 0).any()
|
|
290
201
|
|
|
291
|
-
def
|
|
202
|
+
def _get_default_hyperparameters(self):
|
|
292
203
|
"""Gets default parameters for GluonTS estimator initialization that are available after
|
|
293
204
|
AbstractTimeSeriesModel initialization (i.e., before deferred initialization). Models may
|
|
294
205
|
override this method to update default parameters.
|
|
@@ -306,7 +217,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
306
217
|
"covariate_scaler": "global",
|
|
307
218
|
}
|
|
308
219
|
|
|
309
|
-
def
|
|
220
|
+
def get_hyperparameters(self) -> dict:
|
|
310
221
|
"""Gets params that are passed to the inner model."""
|
|
311
222
|
# for backward compatibility with the old GluonTS MXNet API
|
|
312
223
|
parameter_name_aliases = {
|
|
@@ -314,7 +225,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
314
225
|
"learning_rate": "lr",
|
|
315
226
|
}
|
|
316
227
|
|
|
317
|
-
init_args = super().
|
|
228
|
+
init_args = super().get_hyperparameters()
|
|
318
229
|
for alias, actual in parameter_name_aliases.items():
|
|
319
230
|
if alias in init_args:
|
|
320
231
|
if actual in init_args:
|
|
@@ -322,12 +233,12 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
322
233
|
else:
|
|
323
234
|
init_args[actual] = init_args.pop(alias)
|
|
324
235
|
|
|
325
|
-
return self.
|
|
236
|
+
return self._get_default_hyperparameters() | init_args
|
|
326
237
|
|
|
327
|
-
def _get_estimator_init_args(self) ->
|
|
328
|
-
"""Get GluonTS specific constructor arguments for estimator objects, an alias to `self.
|
|
238
|
+
def _get_estimator_init_args(self) -> dict[str, Any]:
|
|
239
|
+
"""Get GluonTS specific constructor arguments for estimator objects, an alias to `self.get_hyperparameters`
|
|
329
240
|
for better readability."""
|
|
330
|
-
return self.
|
|
241
|
+
return self.get_hyperparameters()
|
|
331
242
|
|
|
332
243
|
def _get_estimator_class(self) -> Type[GluonTSEstimator]:
|
|
333
244
|
raise NotImplementedError
|
|
@@ -367,25 +278,39 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
367
278
|
|
|
368
279
|
return torch.cuda.is_available()
|
|
369
280
|
|
|
370
|
-
def get_minimum_resources(self, is_gpu_available: bool = False) ->
|
|
371
|
-
minimum_resources = {"num_cpus": 1}
|
|
281
|
+
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
|
|
282
|
+
minimum_resources: dict[str, int | float] = {"num_cpus": 1}
|
|
372
283
|
# if GPU is available, we train with 1 GPU per trial
|
|
373
284
|
if is_gpu_available:
|
|
374
285
|
minimum_resources["num_gpus"] = 1
|
|
375
286
|
return minimum_resources
|
|
376
287
|
|
|
288
|
+
@overload
|
|
289
|
+
def _to_gluonts_dataset(self, time_series_df: None, known_covariates=None) -> None: ...
|
|
290
|
+
@overload
|
|
291
|
+
def _to_gluonts_dataset(self, time_series_df: TimeSeriesDataFrame, known_covariates=None) -> GluonTSDataset: ...
|
|
377
292
|
def _to_gluonts_dataset(
|
|
378
|
-
self, time_series_df:
|
|
379
|
-
) ->
|
|
293
|
+
self, time_series_df: TimeSeriesDataFrame | None, known_covariates: TimeSeriesDataFrame | None = None
|
|
294
|
+
) -> GluonTSDataset | None:
|
|
380
295
|
if time_series_df is not None:
|
|
381
296
|
# TODO: Preprocess real-valued features with StdScaler?
|
|
382
297
|
if self.num_feat_static_cat > 0:
|
|
383
|
-
|
|
298
|
+
assert time_series_df.static_features is not None, (
|
|
299
|
+
"Static features must be provided if num_feat_static_cat > 0"
|
|
300
|
+
)
|
|
301
|
+
feat_static_cat = time_series_df.static_features[
|
|
302
|
+
self.covariate_metadata.static_features_cat
|
|
303
|
+
].to_numpy()
|
|
384
304
|
else:
|
|
385
305
|
feat_static_cat = None
|
|
386
306
|
|
|
387
307
|
if self.num_feat_static_real > 0:
|
|
388
|
-
|
|
308
|
+
assert time_series_df.static_features is not None, (
|
|
309
|
+
"Static features must be provided if num_feat_static_real > 0"
|
|
310
|
+
)
|
|
311
|
+
feat_static_real = time_series_df.static_features[
|
|
312
|
+
self.covariate_metadata.static_features_real
|
|
313
|
+
].to_numpy()
|
|
389
314
|
else:
|
|
390
315
|
feat_static_real = None
|
|
391
316
|
|
|
@@ -393,31 +318,33 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
393
318
|
# Convert TSDF -> DF to avoid overhead / input validation
|
|
394
319
|
df = pd.DataFrame(time_series_df)
|
|
395
320
|
if known_covariates is not None:
|
|
396
|
-
known_covariates = pd.DataFrame(known_covariates)
|
|
321
|
+
known_covariates = pd.DataFrame(known_covariates) # type: ignore
|
|
397
322
|
if self.num_feat_dynamic_cat > 0:
|
|
398
|
-
feat_dynamic_cat = df[self.
|
|
323
|
+
feat_dynamic_cat = df[self.covariate_metadata.known_covariates_cat].to_numpy()
|
|
399
324
|
if known_covariates is not None:
|
|
400
325
|
feat_dynamic_cat = np.concatenate(
|
|
401
|
-
[feat_dynamic_cat, known_covariates[self.
|
|
326
|
+
[feat_dynamic_cat, known_covariates[self.covariate_metadata.known_covariates_cat].to_numpy()]
|
|
402
327
|
)
|
|
403
328
|
assert len(feat_dynamic_cat) == expected_known_covariates_len
|
|
404
329
|
else:
|
|
405
330
|
feat_dynamic_cat = None
|
|
406
331
|
|
|
407
332
|
if self.num_feat_dynamic_real > 0:
|
|
408
|
-
feat_dynamic_real = df[self.
|
|
333
|
+
feat_dynamic_real = df[self.covariate_metadata.known_covariates_real].to_numpy()
|
|
409
334
|
# Append future values of known covariates
|
|
410
335
|
if known_covariates is not None:
|
|
411
336
|
feat_dynamic_real = np.concatenate(
|
|
412
|
-
[feat_dynamic_real, known_covariates[self.
|
|
337
|
+
[feat_dynamic_real, known_covariates[self.covariate_metadata.known_covariates_real].to_numpy()]
|
|
413
338
|
)
|
|
414
339
|
assert len(feat_dynamic_real) == expected_known_covariates_len
|
|
415
340
|
# Categorical covariates are one-hot-encoded as real
|
|
416
341
|
if self._ohe_generator_known is not None:
|
|
417
|
-
feat_dynamic_cat_ohe = self._ohe_generator_known.transform(
|
|
342
|
+
feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform(
|
|
343
|
+
df[self.covariate_metadata.known_covariates_cat]
|
|
344
|
+
) # type: ignore
|
|
418
345
|
if known_covariates is not None:
|
|
419
|
-
future_dynamic_cat_ohe = self._ohe_generator_known.transform(
|
|
420
|
-
known_covariates[self.
|
|
346
|
+
future_dynamic_cat_ohe: np.ndarray = self._ohe_generator_known.transform( # type: ignore
|
|
347
|
+
known_covariates[self.covariate_metadata.known_covariates_cat]
|
|
421
348
|
)
|
|
422
349
|
feat_dynamic_cat_ohe = np.concatenate([feat_dynamic_cat_ohe, future_dynamic_cat_ohe])
|
|
423
350
|
assert len(feat_dynamic_cat_ohe) == expected_known_covariates_len
|
|
@@ -426,15 +353,15 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
426
353
|
feat_dynamic_real = None
|
|
427
354
|
|
|
428
355
|
if self.num_past_feat_dynamic_cat > 0:
|
|
429
|
-
past_feat_dynamic_cat = df[self.
|
|
356
|
+
past_feat_dynamic_cat = df[self.covariate_metadata.past_covariates_cat].to_numpy()
|
|
430
357
|
else:
|
|
431
358
|
past_feat_dynamic_cat = None
|
|
432
359
|
|
|
433
360
|
if self.num_past_feat_dynamic_real > 0:
|
|
434
|
-
past_feat_dynamic_real = df[self.
|
|
361
|
+
past_feat_dynamic_real = df[self.covariate_metadata.past_covariates_real].to_numpy()
|
|
435
362
|
if self._ohe_generator_past is not None:
|
|
436
|
-
past_feat_dynamic_cat_ohe = self._ohe_generator_past.transform(
|
|
437
|
-
df[self.
|
|
363
|
+
past_feat_dynamic_cat_ohe: np.ndarray = self._ohe_generator_past.transform( # type: ignore
|
|
364
|
+
df[self.covariate_metadata.past_covariates_cat]
|
|
438
365
|
)
|
|
439
366
|
past_feat_dynamic_real = np.concatenate(
|
|
440
367
|
[past_feat_dynamic_real, past_feat_dynamic_cat_ohe], axis=1
|
|
@@ -442,8 +369,9 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
442
369
|
else:
|
|
443
370
|
past_feat_dynamic_real = None
|
|
444
371
|
|
|
372
|
+
assert self.freq is not None
|
|
445
373
|
return SimpleGluonTSDataset(
|
|
446
|
-
target_df=time_series_df[[self.target]],
|
|
374
|
+
target_df=time_series_df[[self.target]], # type: ignore
|
|
447
375
|
freq=self.freq,
|
|
448
376
|
target_column=self.target,
|
|
449
377
|
feat_static_cat=feat_static_cat,
|
|
@@ -461,14 +389,16 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
461
389
|
def _fit(
|
|
462
390
|
self,
|
|
463
391
|
train_data: TimeSeriesDataFrame,
|
|
464
|
-
val_data:
|
|
465
|
-
time_limit:
|
|
392
|
+
val_data: TimeSeriesDataFrame | None = None,
|
|
393
|
+
time_limit: float | None = None,
|
|
394
|
+
num_cpus: int | None = None,
|
|
395
|
+
num_gpus: int | None = None,
|
|
396
|
+
verbosity: int = 2,
|
|
466
397
|
**kwargs,
|
|
467
398
|
) -> None:
|
|
468
399
|
# necessary to initialize the loggers
|
|
469
400
|
import lightning.pytorch # noqa
|
|
470
401
|
|
|
471
|
-
verbosity = kwargs.get("verbosity", 2)
|
|
472
402
|
for logger_name in logging.root.manager.loggerDict:
|
|
473
403
|
if "lightning" in logger_name:
|
|
474
404
|
pl_logger = logging.getLogger(logger_name)
|
|
@@ -489,18 +419,18 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
489
419
|
time_limit=time_limit,
|
|
490
420
|
early_stopping_patience=None if val_data is None else init_args["early_stopping_patience"],
|
|
491
421
|
)
|
|
492
|
-
self.
|
|
422
|
+
self._deferred_init_hyperparameters(train_data)
|
|
493
423
|
|
|
494
424
|
estimator = self._get_estimator()
|
|
495
|
-
with warning_filter(), disable_root_logger(), gluonts.core.settings.let(
|
|
425
|
+
with warning_filter(), disable_root_logger(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
|
|
496
426
|
self.gts_predictor = estimator.train(
|
|
497
427
|
self._to_gluonts_dataset(train_data),
|
|
498
428
|
validation_data=self._to_gluonts_dataset(val_data),
|
|
499
|
-
cache_data=True,
|
|
429
|
+
cache_data=True, # type: ignore
|
|
500
430
|
)
|
|
501
431
|
# Increase batch size during prediction to speed up inference
|
|
502
432
|
if init_args["predict_batch_size"] is not None:
|
|
503
|
-
self.gts_predictor.batch_size = init_args["predict_batch_size"]
|
|
433
|
+
self.gts_predictor.batch_size = init_args["predict_batch_size"] # type: ignore
|
|
504
434
|
|
|
505
435
|
lightning_logs_dir = Path(self.path) / "lightning_logs"
|
|
506
436
|
if not keep_lightning_logs and lightning_logs_dir.exists() and lightning_logs_dir.is_dir():
|
|
@@ -509,9 +439,9 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
509
439
|
|
|
510
440
|
def _get_callbacks(
|
|
511
441
|
self,
|
|
512
|
-
time_limit:
|
|
513
|
-
early_stopping_patience:
|
|
514
|
-
) ->
|
|
442
|
+
time_limit: float | None,
|
|
443
|
+
early_stopping_patience: int | None = None,
|
|
444
|
+
) -> list[Callable]:
|
|
515
445
|
"""Retrieve a list of callback objects for the GluonTS trainer"""
|
|
516
446
|
from lightning.pytorch.callbacks import EarlyStopping, Timer
|
|
517
447
|
|
|
@@ -525,14 +455,14 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
525
455
|
def _predict(
|
|
526
456
|
self,
|
|
527
457
|
data: TimeSeriesDataFrame,
|
|
528
|
-
known_covariates:
|
|
458
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
529
459
|
**kwargs,
|
|
530
460
|
) -> TimeSeriesDataFrame:
|
|
531
461
|
if self.gts_predictor is None:
|
|
532
462
|
raise ValueError("Please fit the model before predicting.")
|
|
533
463
|
|
|
534
|
-
with warning_filter(), gluonts.core.settings.let(
|
|
535
|
-
predicted_targets = self._predict_gluonts_forecasts(data, known_covariates=known_covariates
|
|
464
|
+
with warning_filter(), gluonts.core.settings.let(gluonts_env, use_tqdm=False):
|
|
465
|
+
predicted_targets = self._predict_gluonts_forecasts(data, known_covariates=known_covariates)
|
|
536
466
|
df = self._gluonts_forecasts_to_data_frame(
|
|
537
467
|
predicted_targets,
|
|
538
468
|
forecast_index=self.get_forecast_horizon_index(data),
|
|
@@ -540,16 +470,21 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
540
470
|
return df
|
|
541
471
|
|
|
542
472
|
def _predict_gluonts_forecasts(
|
|
543
|
-
self,
|
|
544
|
-
|
|
473
|
+
self,
|
|
474
|
+
data: TimeSeriesDataFrame,
|
|
475
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
476
|
+
num_samples: int | None = None,
|
|
477
|
+
) -> list[Forecast]:
|
|
478
|
+
assert self.gts_predictor is not None, "GluonTS models must be fit before predicting."
|
|
545
479
|
gts_data = self._to_gluonts_dataset(data, known_covariates=known_covariates)
|
|
480
|
+
return list(
|
|
481
|
+
self.gts_predictor.predict(
|
|
482
|
+
dataset=gts_data,
|
|
483
|
+
num_samples=num_samples or self.default_num_samples,
|
|
484
|
+
)
|
|
485
|
+
)
|
|
546
486
|
|
|
547
|
-
|
|
548
|
-
predictor_kwargs["num_samples"] = kwargs.get("num_samples", self.default_num_samples)
|
|
549
|
-
|
|
550
|
-
return list(self.gts_predictor.predict(**predictor_kwargs))
|
|
551
|
-
|
|
552
|
-
def _stack_quantile_forecasts(self, forecasts: List[QuantileForecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
487
|
+
def _stack_quantile_forecasts(self, forecasts: list[QuantileForecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
553
488
|
# GluonTS always saves item_id as a string
|
|
554
489
|
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
555
490
|
result_dfs = []
|
|
@@ -562,7 +497,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
562
497
|
columns_order = ["mean"] + [str(q) for q in self.quantile_levels]
|
|
563
498
|
return forecast_df[columns_order]
|
|
564
499
|
|
|
565
|
-
def _stack_sample_forecasts(self, forecasts:
|
|
500
|
+
def _stack_sample_forecasts(self, forecasts: list[SampleForecast], item_ids: pd.Index) -> pd.DataFrame:
|
|
566
501
|
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
567
502
|
samples_per_item = []
|
|
568
503
|
for item_id in item_ids:
|
|
@@ -574,17 +509,25 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
574
509
|
forecast_array = np.concatenate([mean, quantiles], axis=1)
|
|
575
510
|
return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
|
|
576
511
|
|
|
577
|
-
def _stack_distribution_forecasts(
|
|
512
|
+
def _stack_distribution_forecasts(
|
|
513
|
+
self, forecasts: list["DistributionForecast"], item_ids: pd.Index
|
|
514
|
+
) -> pd.DataFrame:
|
|
578
515
|
import torch
|
|
579
516
|
from gluonts.torch.distributions import AffineTransformed
|
|
580
517
|
from torch.distributions import Distribution
|
|
581
518
|
|
|
582
519
|
# Sort forecasts in the same order as in the dataset
|
|
583
520
|
item_id_to_forecast = {str(f.item_id): f for f in forecasts}
|
|
584
|
-
|
|
521
|
+
dist_forecasts = [item_id_to_forecast[str(item_id)] for item_id in item_ids]
|
|
522
|
+
|
|
523
|
+
assert all(isinstance(f.distribution, AffineTransformed) for f in dist_forecasts), (
|
|
524
|
+
"Expected forecast.distribution to be an instance of AffineTransformed"
|
|
525
|
+
)
|
|
585
526
|
|
|
586
|
-
def stack_distributions(distributions:
|
|
527
|
+
def stack_distributions(distributions: list[Distribution]) -> Distribution:
|
|
587
528
|
"""Stack multiple torch.Distribution objects into a single distribution"""
|
|
529
|
+
last_dist: Distribution = distributions[-1]
|
|
530
|
+
|
|
588
531
|
params_per_dist = []
|
|
589
532
|
for dist in distributions:
|
|
590
533
|
params = {name: getattr(dist, name) for name in dist.arg_constraints.keys()}
|
|
@@ -593,22 +536,19 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
593
536
|
assert len(set(tuple(p.keys()) for p in params_per_dist)) == 1
|
|
594
537
|
|
|
595
538
|
stacked_params = {}
|
|
596
|
-
for key in
|
|
539
|
+
for key in last_dist.arg_constraints.keys():
|
|
597
540
|
stacked_params[key] = torch.cat([p[key] for p in params_per_dist])
|
|
598
|
-
return
|
|
599
|
-
|
|
600
|
-
if not isinstance(forecasts[0].distribution, AffineTransformed):
|
|
601
|
-
raise AssertionError("Expected forecast.distribution to be an instance of AffineTransformed")
|
|
541
|
+
return last_dist.__class__(**stacked_params)
|
|
602
542
|
|
|
603
543
|
# We stack all forecast distribution into a single Distribution object.
|
|
604
544
|
# This dramatically speeds up the quantiles calculation.
|
|
605
|
-
stacked_base_dist = stack_distributions([f.distribution.base_dist for f in
|
|
545
|
+
stacked_base_dist = stack_distributions([f.distribution.base_dist for f in dist_forecasts]) # type: ignore
|
|
606
546
|
|
|
607
|
-
stacked_loc = torch.cat([f.distribution.loc for f in
|
|
547
|
+
stacked_loc = torch.cat([f.distribution.loc for f in dist_forecasts]) # type: ignore
|
|
608
548
|
if stacked_loc.shape != stacked_base_dist.batch_shape:
|
|
609
549
|
stacked_loc = stacked_loc.repeat_interleave(self.prediction_length)
|
|
610
550
|
|
|
611
|
-
stacked_scale = torch.cat([f.distribution.scale for f in
|
|
551
|
+
stacked_scale = torch.cat([f.distribution.scale for f in dist_forecasts]) # type: ignore
|
|
612
552
|
if stacked_scale.shape != stacked_base_dist.batch_shape:
|
|
613
553
|
stacked_scale = stacked_scale.repeat_interleave(self.prediction_length)
|
|
614
554
|
|
|
@@ -616,24 +556,24 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
616
556
|
|
|
617
557
|
mean_prediction = stacked_dist.mean.cpu().detach().numpy()
|
|
618
558
|
quantiles = torch.tensor(self.quantile_levels, device=stacked_dist.mean.device).reshape(-1, 1)
|
|
619
|
-
quantile_predictions = stacked_dist.icdf(quantiles).cpu().detach().numpy()
|
|
559
|
+
quantile_predictions = stacked_dist.icdf(quantiles).cpu().detach().numpy() # type: ignore
|
|
620
560
|
forecast_array = np.vstack([mean_prediction, quantile_predictions]).T
|
|
621
561
|
return pd.DataFrame(forecast_array, columns=["mean"] + [str(q) for q in self.quantile_levels])
|
|
622
562
|
|
|
623
563
|
def _gluonts_forecasts_to_data_frame(
|
|
624
564
|
self,
|
|
625
|
-
forecasts:
|
|
565
|
+
forecasts: list[Forecast],
|
|
626
566
|
forecast_index: pd.MultiIndex,
|
|
627
567
|
) -> TimeSeriesDataFrame:
|
|
628
568
|
from gluonts.torch.model.forecast import DistributionForecast
|
|
629
569
|
|
|
630
|
-
item_ids = forecast_index.unique(level=ITEMID)
|
|
570
|
+
item_ids = forecast_index.unique(level=TimeSeriesDataFrame.ITEMID)
|
|
631
571
|
if isinstance(forecasts[0], SampleForecast):
|
|
632
|
-
forecast_df = self._stack_sample_forecasts(forecasts, item_ids)
|
|
572
|
+
forecast_df = self._stack_sample_forecasts(cast(list[SampleForecast], forecasts), item_ids)
|
|
633
573
|
elif isinstance(forecasts[0], QuantileForecast):
|
|
634
|
-
forecast_df = self._stack_quantile_forecasts(forecasts, item_ids)
|
|
574
|
+
forecast_df = self._stack_quantile_forecasts(cast(list[QuantileForecast], forecasts), item_ids)
|
|
635
575
|
elif isinstance(forecasts[0], DistributionForecast):
|
|
636
|
-
forecast_df = self._stack_distribution_forecasts(forecasts, item_ids)
|
|
576
|
+
forecast_df = self._stack_distribution_forecasts(cast(list[DistributionForecast], forecasts), item_ids)
|
|
637
577
|
else:
|
|
638
578
|
raise ValueError(f"Unrecognized forecast type {type(forecasts[0])}")
|
|
639
579
|
|