autogluon.timeseries 1.0.1b20240405__tar.gz → 1.0.1b20240407__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/PKG-INFO +1 -1
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/learner.py +70 -1
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/abstract/abstract_timeseries_model.py +14 -4
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/autogluon_tabular/mlforecast.py +7 -1
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/chronos/model.py +2 -1
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/gluonts/abstract_gluonts.py +213 -63
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/gluonts/torch/models.py +13 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/multi_window/multi_window_model.py +12 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/predictor.py +146 -12
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/trainer/abstract_trainer.py +161 -8
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/features.py +118 -2
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/version.py +1 -1
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon.timeseries.egg-info/PKG-INFO +1 -1
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon.timeseries.egg-info/requires.txt +4 -4
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/setup.cfg +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/setup.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/configs/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/configs/presets_configs.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/dataset/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/dataset/ts_dataframe.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/evaluator.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/metrics/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/metrics/abstract.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/metrics/point.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/metrics/quantile.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/metrics/utils.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/abstract/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/abstract/model_trial.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/autogluon_tabular/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/autogluon_tabular/utils.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/chronos/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/chronos/pipeline.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/ensemble/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/gluonts/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/local/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/local/abstract_local_model.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/local/naive.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/local/npts.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/local/statsforecast.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/multi_window/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/models/presets.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/splitter.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/trainer/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/trainer/auto_trainer.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/datetime/__init__.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/datetime/base.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/datetime/lags.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/datetime/seasonality.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/datetime/time_features.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/forecast.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon/timeseries/utils/warning_filters.py +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon.timeseries.egg-info/SOURCES.txt +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon.timeseries.egg-info/dependency_links.txt +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon.timeseries.egg-info/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon.timeseries.egg-info/top_level.txt +0 -0
- {autogluon.timeseries-1.0.1b20240405 → autogluon.timeseries-1.0.1b20240407}/src/autogluon.timeseries.egg-info/zip-safe +0 -0
|
@@ -198,7 +198,7 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
198
198
|
|
|
199
199
|
def evaluate(
|
|
200
200
|
self,
|
|
201
|
-
data:
|
|
201
|
+
data: TimeSeriesDataFrame,
|
|
202
202
|
model: Optional[str] = None,
|
|
203
203
|
metrics: Optional[Union[str, TimeSeriesScorer, List[Union[str, TimeSeriesScorer]]]] = None,
|
|
204
204
|
use_cache: bool = True,
|
|
@@ -206,6 +206,75 @@ class TimeSeriesLearner(AbstractLearner):
|
|
|
206
206
|
data = self.feature_generator.transform(data)
|
|
207
207
|
return self.load_trainer().evaluate(data=data, model=model, metrics=metrics, use_cache=use_cache)
|
|
208
208
|
|
|
209
|
+
def get_feature_importance(
|
|
210
|
+
self,
|
|
211
|
+
data: Optional[TimeSeriesDataFrame] = None,
|
|
212
|
+
model: Optional[str] = None,
|
|
213
|
+
metric: Optional[Union[str, TimeSeriesScorer]] = None,
|
|
214
|
+
features: Optional[List[str]] = None,
|
|
215
|
+
time_limit: Optional[float] = None,
|
|
216
|
+
method: Literal["naive", "permutation"] = "permutation",
|
|
217
|
+
subsample_size: int = 50,
|
|
218
|
+
num_iterations: int = 1,
|
|
219
|
+
random_seed: Optional[int] = None,
|
|
220
|
+
relative_scores: bool = False,
|
|
221
|
+
include_confidence_band: bool = True,
|
|
222
|
+
confidence_level: float = 0.99,
|
|
223
|
+
) -> pd.DataFrame:
|
|
224
|
+
trainer = self.load_trainer()
|
|
225
|
+
if data is None:
|
|
226
|
+
data = trainer.load_val_data() or trainer.load_train_data()
|
|
227
|
+
|
|
228
|
+
# if features are provided in the dataframe, check that they are valid features in the covariate metadata
|
|
229
|
+
provided_static_columns = [] if data.static_features is None else data.static_features.columns
|
|
230
|
+
unused_features = [
|
|
231
|
+
f
|
|
232
|
+
for f in set(provided_static_columns).union(set(data.columns) - {self.target})
|
|
233
|
+
if f not in self.feature_generator.covariate_metadata.all_features
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
if features is None:
|
|
237
|
+
features = self.feature_generator.covariate_metadata.all_features
|
|
238
|
+
else:
|
|
239
|
+
if len(features) == 0:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
"No features provided to compute feature importance. At least some valid features should be provided."
|
|
242
|
+
)
|
|
243
|
+
for fn in features:
|
|
244
|
+
if fn not in self.feature_generator.covariate_metadata.all_features and fn not in unused_features:
|
|
245
|
+
raise ValueError(f"Feature {fn} not found in covariate metadata or the dataset.")
|
|
246
|
+
|
|
247
|
+
if len(set(features)) < len(features):
|
|
248
|
+
logger.warning(
|
|
249
|
+
"Duplicate feature names provided to compute feature importance. This will lead to unexpected behavior. "
|
|
250
|
+
"Please provide unique feature names across both static features and covariates."
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
data = self.feature_generator.transform(data)
|
|
254
|
+
|
|
255
|
+
importance_df = trainer.get_feature_importance(
|
|
256
|
+
data=data,
|
|
257
|
+
features=features,
|
|
258
|
+
model=model,
|
|
259
|
+
metric=metric,
|
|
260
|
+
time_limit=time_limit,
|
|
261
|
+
method=method,
|
|
262
|
+
subsample_size=subsample_size,
|
|
263
|
+
num_iterations=num_iterations,
|
|
264
|
+
random_seed=random_seed,
|
|
265
|
+
relative_scores=relative_scores,
|
|
266
|
+
include_confidence_band=include_confidence_band,
|
|
267
|
+
confidence_level=confidence_level,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
for feature in set(features).union(unused_features):
|
|
271
|
+
if feature not in importance_df.index:
|
|
272
|
+
importance_df.loc[feature] = (
|
|
273
|
+
[0, 0, 0] if not include_confidence_band else [0, 0, 0, float("nan"), float("nan")]
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
return importance_df
|
|
277
|
+
|
|
209
278
|
def leaderboard(self, data: Optional[TimeSeriesDataFrame] = None, use_cache: bool = True) -> pd.DataFrame:
|
|
210
279
|
if data is not None:
|
|
211
280
|
data = self.feature_generator.transform(data)
|
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
import re
|
|
4
4
|
import time
|
|
5
5
|
from contextlib import nullcontext
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Dict, List, Optional, Union
|
|
7
7
|
|
|
8
8
|
from autogluon.common import space
|
|
9
9
|
from autogluon.common.loaders import load_pkl
|
|
@@ -74,6 +74,10 @@ class AbstractTimeSeriesModel(AbstractModel):
|
|
|
74
74
|
_preprocess_nonadaptive = None
|
|
75
75
|
_preprocess_set_features = None
|
|
76
76
|
|
|
77
|
+
supports_known_covariates: bool = False
|
|
78
|
+
supports_past_covariates: bool = False
|
|
79
|
+
supports_static_features: bool = False
|
|
80
|
+
|
|
77
81
|
def __init__(
|
|
78
82
|
self,
|
|
79
83
|
freq: Optional[str] = None,
|
|
@@ -296,6 +300,7 @@ class AbstractTimeSeriesModel(AbstractModel):
|
|
|
296
300
|
of input items.
|
|
297
301
|
"""
|
|
298
302
|
data = self.preprocess(data, is_train=False)
|
|
303
|
+
known_covariates = self.preprocess_known_covariates(known_covariates)
|
|
299
304
|
predictions = self._predict(data=data, known_covariates=known_covariates, **kwargs)
|
|
300
305
|
logger.debug(f"Predicting with model {self.name}")
|
|
301
306
|
# "0.5" might be missing from the quantiles if self is a wrapper (MultiWindowBacktestingModel or ensemble)
|
|
@@ -358,7 +363,7 @@ class AbstractTimeSeriesModel(AbstractModel):
|
|
|
358
363
|
time steps of each time series.
|
|
359
364
|
"""
|
|
360
365
|
past_data, known_covariates = data.get_model_inputs_for_scoring(
|
|
361
|
-
prediction_length=self.prediction_length, known_covariates_names=self.metadata.
|
|
366
|
+
prediction_length=self.prediction_length, known_covariates_names=self.metadata.known_covariates
|
|
362
367
|
)
|
|
363
368
|
predictions = self.predict(past_data, known_covariates=known_covariates)
|
|
364
369
|
return self._score_with_predictions(data=data, predictions=predictions, metric=metric)
|
|
@@ -371,7 +376,7 @@ class AbstractTimeSeriesModel(AbstractModel):
|
|
|
371
376
|
) -> None:
|
|
372
377
|
"""Compute val_score, predict_time and cache out-of-fold (OOF) predictions."""
|
|
373
378
|
past_data, known_covariates = val_data.get_model_inputs_for_scoring(
|
|
374
|
-
prediction_length=self.prediction_length, known_covariates_names=self.metadata.
|
|
379
|
+
prediction_length=self.prediction_length, known_covariates_names=self.metadata.known_covariates
|
|
375
380
|
)
|
|
376
381
|
predict_start_time = time.time()
|
|
377
382
|
oof_predictions = self.predict(past_data, known_covariates=known_covariates)
|
|
@@ -494,9 +499,14 @@ class AbstractTimeSeriesModel(AbstractModel):
|
|
|
494
499
|
|
|
495
500
|
return hpo_models, analysis
|
|
496
501
|
|
|
497
|
-
def preprocess(self, data: TimeSeriesDataFrame, is_train: bool = False, **kwargs) ->
|
|
502
|
+
def preprocess(self, data: TimeSeriesDataFrame, is_train: bool = False, **kwargs) -> TimeSeriesDataFrame:
|
|
498
503
|
return data
|
|
499
504
|
|
|
505
|
+
def preprocess_known_covariates(
|
|
506
|
+
self, known_covariates: Optional[TimeSeriesDataFrame]
|
|
507
|
+
) -> Optional[TimeSeriesDataFrame]:
|
|
508
|
+
return known_covariates
|
|
509
|
+
|
|
500
510
|
def get_memory_size(self, **kwargs) -> Optional[int]:
|
|
501
511
|
return None
|
|
502
512
|
|
|
@@ -242,7 +242,7 @@ class AbstractMLForecastModel(AbstractTimeSeriesModel):
|
|
|
242
242
|
Each row contains unique_id, ds, y, and (optionally) known covariates & static features.
|
|
243
243
|
"""
|
|
244
244
|
# TODO: Add support for past_covariates
|
|
245
|
-
selected_columns = self.metadata.
|
|
245
|
+
selected_columns = self.metadata.known_covariates.copy()
|
|
246
246
|
column_name_mapping = {ITEMID: MLF_ITEMID, TIMESTAMP: MLF_TIMESTAMP}
|
|
247
247
|
if include_target:
|
|
248
248
|
selected_columns += [self.target]
|
|
@@ -425,6 +425,9 @@ class DirectTabularModel(AbstractMLForecastModel):
|
|
|
425
425
|
end of each time series).
|
|
426
426
|
"""
|
|
427
427
|
|
|
428
|
+
supports_known_covariates = True
|
|
429
|
+
supports_static_features = True
|
|
430
|
+
|
|
428
431
|
@property
|
|
429
432
|
def is_quantile_model(self) -> bool:
|
|
430
433
|
return self.eval_metric.needs_quantile
|
|
@@ -576,6 +579,9 @@ class RecursiveTabularModel(AbstractMLForecastModel):
|
|
|
576
579
|
end of each time series).
|
|
577
580
|
"""
|
|
578
581
|
|
|
582
|
+
supports_known_covariates = True
|
|
583
|
+
supports_static_features = True
|
|
584
|
+
|
|
579
585
|
def _get_model_params(self) -> dict:
|
|
580
586
|
model_params = super()._get_model_params()
|
|
581
587
|
model_params.setdefault("scaler", "standard")
|
|
@@ -181,7 +181,8 @@ class ChronosModel(AbstractTimeSeriesModel):
|
|
|
181
181
|
)
|
|
182
182
|
self.context_length = self.maximum_context_length
|
|
183
183
|
|
|
184
|
-
|
|
184
|
+
# we truncate the name to avoid long path errors on Windows
|
|
185
|
+
model_path_safe = str(model_path_input).replace("/", "__").replace(os.path.sep, "__")[-50:]
|
|
185
186
|
name = (name if name is not None else "Chronos") + f"[{model_path_safe}]"
|
|
186
187
|
|
|
187
188
|
super().__init__(
|
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
import shutil
|
|
4
4
|
from datetime import timedelta
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, Union
|
|
6
|
+
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Type, Union
|
|
7
7
|
|
|
8
8
|
import gluonts
|
|
9
9
|
import gluonts.core.settings
|
|
@@ -16,9 +16,14 @@ from gluonts.model.estimator import Estimator as GluonTSEstimator
|
|
|
16
16
|
from gluonts.model.forecast import Forecast, QuantileForecast, SampleForecast
|
|
17
17
|
from gluonts.model.predictor import Predictor as GluonTSPredictor
|
|
18
18
|
from pandas.tseries.frequencies import to_offset
|
|
19
|
+
from sklearn.compose import ColumnTransformer
|
|
20
|
+
from sklearn.preprocessing import QuantileTransformer, StandardScaler
|
|
19
21
|
|
|
20
22
|
from autogluon.common.loaders import load_pkl
|
|
21
23
|
from autogluon.core.hpo.constants import RAY_BACKEND
|
|
24
|
+
from autogluon.tabular.models.tabular_nn.utils.categorical_encoders import (
|
|
25
|
+
OneHotMergeRaresHandleUnknownEncoder as OneHotEncoder,
|
|
26
|
+
)
|
|
22
27
|
from autogluon.timeseries.dataset.ts_dataframe import ITEMID, TIMESTAMP, TimeSeriesDataFrame
|
|
23
28
|
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
24
29
|
from autogluon.timeseries.utils.datetime import norm_freq_str
|
|
@@ -42,21 +47,25 @@ class SimpleGluonTSDataset(GluonTSDataset):
|
|
|
42
47
|
self,
|
|
43
48
|
target_df: TimeSeriesDataFrame,
|
|
44
49
|
target_column: str = "target",
|
|
45
|
-
feat_static_cat: Optional[
|
|
46
|
-
feat_static_real: Optional[
|
|
47
|
-
|
|
48
|
-
|
|
50
|
+
feat_static_cat: Optional[np.ndarray] = None,
|
|
51
|
+
feat_static_real: Optional[np.ndarray] = None,
|
|
52
|
+
feat_dynamic_cat: Optional[np.ndarray] = None,
|
|
53
|
+
feat_dynamic_real: Optional[np.ndarray] = None,
|
|
54
|
+
past_feat_dynamic_cat: Optional[np.ndarray] = None,
|
|
55
|
+
past_feat_dynamic_real: Optional[np.ndarray] = None,
|
|
49
56
|
includes_future: bool = False,
|
|
50
57
|
prediction_length: int = None,
|
|
51
58
|
):
|
|
52
59
|
assert target_df is not None
|
|
53
60
|
assert target_df.freq, "Initializing GluonTS data sets without freq is not allowed"
|
|
54
61
|
# Convert TimeSeriesDataFrame to pd.Series for faster processing
|
|
55
|
-
self.target_array =
|
|
56
|
-
self.feat_static_cat = self.
|
|
57
|
-
self.feat_static_real = self.
|
|
58
|
-
self.
|
|
59
|
-
self.
|
|
62
|
+
self.target_array = target_df[target_column].to_numpy(np.float32)
|
|
63
|
+
self.feat_static_cat = self._astype(feat_static_cat, dtype=np.int64)
|
|
64
|
+
self.feat_static_real = self._astype(feat_static_real, dtype=np.float32)
|
|
65
|
+
self.feat_dynamic_cat = self._astype(feat_dynamic_cat, dtype=np.int64)
|
|
66
|
+
self.feat_dynamic_real = self._astype(feat_dynamic_real, dtype=np.float32)
|
|
67
|
+
self.past_feat_dynamic_cat = self._astype(past_feat_dynamic_cat, dtype=np.int64)
|
|
68
|
+
self.past_feat_dynamic_real = self._astype(past_feat_dynamic_real, dtype=np.float32)
|
|
60
69
|
self.freq = self._to_gluonts_freq(target_df.freq)
|
|
61
70
|
|
|
62
71
|
# Necessary to compute indptr for known_covariates at prediction time
|
|
@@ -73,11 +82,11 @@ class SimpleGluonTSDataset(GluonTSDataset):
|
|
|
73
82
|
assert len(self.item_ids) == len(self.start_timestamps)
|
|
74
83
|
|
|
75
84
|
@staticmethod
|
|
76
|
-
def
|
|
77
|
-
if
|
|
85
|
+
def _astype(array: Optional[np.ndarray], dtype: np.dtype) -> Optional[np.ndarray]:
|
|
86
|
+
if array is None:
|
|
78
87
|
return None
|
|
79
88
|
else:
|
|
80
|
-
return
|
|
89
|
+
return array.astype(dtype)
|
|
81
90
|
|
|
82
91
|
@staticmethod
|
|
83
92
|
def _to_gluonts_freq(freq: str) -> str:
|
|
@@ -111,12 +120,18 @@ class SimpleGluonTSDataset(GluonTSDataset):
|
|
|
111
120
|
ts[FieldName.FEAT_STATIC_CAT] = self.feat_static_cat[j]
|
|
112
121
|
if self.feat_static_real is not None:
|
|
113
122
|
ts[FieldName.FEAT_STATIC_REAL] = self.feat_static_real[j]
|
|
123
|
+
if self.past_feat_dynamic_cat is not None:
|
|
124
|
+
ts[FieldName.PAST_FEAT_DYNAMIC_CAT] = self.past_feat_dynamic_cat[start_idx:end_idx].T
|
|
114
125
|
if self.past_feat_dynamic_real is not None:
|
|
115
126
|
ts[FieldName.PAST_FEAT_DYNAMIC_REAL] = self.past_feat_dynamic_real[start_idx:end_idx].T
|
|
127
|
+
|
|
128
|
+
# Dynamic features that may extend into the future
|
|
129
|
+
if self.includes_future:
|
|
130
|
+
start_idx = start_idx + j * self.prediction_length
|
|
131
|
+
end_idx = end_idx + (j + 1) * self.prediction_length
|
|
132
|
+
if self.feat_dynamic_cat is not None:
|
|
133
|
+
ts[FieldName.FEAT_DYNAMIC_CAT] = self.feat_dynamic_cat[start_idx:end_idx].T
|
|
116
134
|
if self.feat_dynamic_real is not None:
|
|
117
|
-
if self.includes_future:
|
|
118
|
-
start_idx = start_idx + j * self.prediction_length
|
|
119
|
-
end_idx = end_idx + (j + 1) * self.prediction_length
|
|
120
135
|
ts[FieldName.FEAT_DYNAMIC_REAL] = self.feat_dynamic_real[start_idx:end_idx].T
|
|
121
136
|
yield ts
|
|
122
137
|
|
|
@@ -148,8 +163,7 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
148
163
|
gluonts_model_path = "gluon_ts"
|
|
149
164
|
# default number of samples for prediction
|
|
150
165
|
default_num_samples: int = 250
|
|
151
|
-
|
|
152
|
-
supports_past_covariates: bool = False
|
|
166
|
+
supports_cat_covariates: bool = False
|
|
153
167
|
|
|
154
168
|
def __init__(
|
|
155
169
|
self,
|
|
@@ -171,12 +185,20 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
171
185
|
**kwargs,
|
|
172
186
|
)
|
|
173
187
|
self.gts_predictor: Optional[GluonTSPredictor] = None
|
|
188
|
+
self._real_column_transformers: Dict[Literal["known", "past", "static"], ColumnTransformer] = {}
|
|
189
|
+
self._ohe_generator_known: Optional[OneHotEncoder] = None
|
|
190
|
+
self._ohe_generator_past: Optional[OneHotEncoder] = None
|
|
174
191
|
self.callbacks = []
|
|
192
|
+
# Following attributes may be overridden during fit() based on train_data & model parameters
|
|
175
193
|
self.num_feat_static_cat = 0
|
|
176
194
|
self.num_feat_static_real = 0
|
|
195
|
+
self.num_feat_dynamic_cat = 0
|
|
177
196
|
self.num_feat_dynamic_real = 0
|
|
197
|
+
self.num_past_feat_dynamic_cat = 0
|
|
178
198
|
self.num_past_feat_dynamic_real = 0
|
|
179
199
|
self.feat_static_cat_cardinality: List[int] = []
|
|
200
|
+
self.feat_dynamic_cat_cardinality: List[int] = []
|
|
201
|
+
self.past_feat_dynamic_cat_cardinality: List[int] = []
|
|
180
202
|
self.negative_data = True
|
|
181
203
|
|
|
182
204
|
def save(self, path: str = None, verbose: bool = True) -> str:
|
|
@@ -210,42 +232,136 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
210
232
|
def _get_hpo_backend(self):
|
|
211
233
|
return RAY_BACKEND
|
|
212
234
|
|
|
213
|
-
def _deferred_init_params_aux(self,
|
|
214
|
-
"""Update GluonTS specific parameters with information available
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
raise ValueError(
|
|
222
|
-
"Dataset frequency not provided in the dataset, fit arguments or "
|
|
223
|
-
"during initialization. Please provide a `freq` string to `fit`."
|
|
224
|
-
)
|
|
235
|
+
def _deferred_init_params_aux(self, dataset: TimeSeriesDataFrame) -> None:
|
|
236
|
+
"""Update GluonTS specific parameters with information available only at training time."""
|
|
237
|
+
self.freq = dataset.freq or self.freq
|
|
238
|
+
if not self.freq:
|
|
239
|
+
raise ValueError(
|
|
240
|
+
"Dataset frequency not provided in the dataset, fit arguments or "
|
|
241
|
+
"during initialization. Please provide a `freq` string to `fit`."
|
|
242
|
+
)
|
|
225
243
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
+
model_params = self._get_model_params()
|
|
245
|
+
disable_static_features = model_params.get("disable_static_features", False)
|
|
246
|
+
if not disable_static_features:
|
|
247
|
+
self.num_feat_static_cat = len(self.metadata.static_features_cat)
|
|
248
|
+
self.num_feat_static_real = len(self.metadata.static_features_real)
|
|
249
|
+
if self.num_feat_static_cat > 0:
|
|
250
|
+
feat_static_cat = dataset.static_features[self.metadata.static_features_cat]
|
|
251
|
+
self.feat_static_cat_cardinality = feat_static_cat.nunique().tolist()
|
|
252
|
+
|
|
253
|
+
disable_known_covariates = model_params.get("disable_known_covariates", False)
|
|
254
|
+
if not disable_known_covariates and self.supports_known_covariates:
|
|
255
|
+
self.num_feat_dynamic_cat = len(self.metadata.known_covariates_cat)
|
|
256
|
+
self.num_feat_dynamic_real = len(self.metadata.known_covariates_real)
|
|
257
|
+
if self.num_feat_dynamic_cat > 0:
|
|
258
|
+
feat_dynamic_cat = dataset[self.metadata.known_covariates_cat]
|
|
259
|
+
if self.supports_cat_covariates:
|
|
260
|
+
self.feat_dynamic_cat_cardinality = feat_dynamic_cat.nunique().tolist()
|
|
261
|
+
else:
|
|
262
|
+
# If model doesn't support categorical covariates, convert them to real via one hot encoding
|
|
263
|
+
self._ohe_generator_known = OneHotEncoder(
|
|
264
|
+
max_levels=model_params.get("max_cat_cardinality", 100),
|
|
265
|
+
sparse=False,
|
|
266
|
+
dtype="float32",
|
|
267
|
+
)
|
|
268
|
+
feat_dynamic_cat_ohe = self._ohe_generator_known.fit_transform(pd.DataFrame(feat_dynamic_cat))
|
|
269
|
+
self.num_feat_dynamic_cat = 0
|
|
270
|
+
self.num_feat_dynamic_real += feat_dynamic_cat_ohe.shape[1]
|
|
271
|
+
|
|
272
|
+
disable_past_covariates = model_params.get("disable_past_covariates", False)
|
|
273
|
+
if not disable_past_covariates and self.supports_past_covariates:
|
|
274
|
+
self.num_past_feat_dynamic_cat = len(self.metadata.past_covariates_cat)
|
|
275
|
+
self.num_past_feat_dynamic_real = len(self.metadata.past_covariates_real)
|
|
276
|
+
if self.num_past_feat_dynamic_cat > 0:
|
|
277
|
+
past_feat_dynamic_cat = dataset[self.metadata.past_covariates_cat]
|
|
278
|
+
if self.supports_cat_covariates:
|
|
279
|
+
self.past_feat_dynamic_cat_cardinality = past_feat_dynamic_cat.nunique().tolist()
|
|
280
|
+
else:
|
|
281
|
+
# If model doesn't support categorical covariates, convert them to real via one hot encoding
|
|
282
|
+
self._ohe_generator_past = OneHotEncoder(
|
|
283
|
+
max_levels=model_params.get("max_cat_cardinality", 100),
|
|
284
|
+
sparse=False,
|
|
285
|
+
dtype="float32",
|
|
286
|
+
)
|
|
287
|
+
past_feat_dynamic_cat_ohe = self._ohe_generator_past.fit_transform(
|
|
288
|
+
pd.DataFrame(past_feat_dynamic_cat)
|
|
289
|
+
)
|
|
290
|
+
self.num_past_feat_dynamic_cat = 0
|
|
291
|
+
self.num_past_feat_dynamic_real += past_feat_dynamic_cat_ohe.shape[1]
|
|
292
|
+
|
|
293
|
+
self.negative_data = (dataset[self.target] < 0).any()
|
|
244
294
|
|
|
245
295
|
@property
|
|
246
296
|
def default_context_length(self) -> int:
|
|
247
297
|
return min(512, max(10, 2 * self.prediction_length))
|
|
248
298
|
|
|
299
|
+
def preprocess(self, data: TimeSeriesDataFrame, is_train: bool = False, **kwargs) -> TimeSeriesDataFrame:
|
|
300
|
+
# Copy data to avoid SettingWithCopyWarning from pandas
|
|
301
|
+
data = data.copy()
|
|
302
|
+
if self.supports_known_covariates and len(self.metadata.known_covariates_real) > 0:
|
|
303
|
+
columns = self.metadata.known_covariates_real
|
|
304
|
+
if is_train:
|
|
305
|
+
self._real_column_transformers["known"] = self._get_transformer_for_columns(data, columns=columns)
|
|
306
|
+
assert "known" in self._real_column_transformers, "Preprocessing pipeline must be fit first"
|
|
307
|
+
data[columns] = self._real_column_transformers["known"].transform(data[columns])
|
|
308
|
+
|
|
309
|
+
if self.supports_past_covariates and len(self.metadata.past_covariates_real) > 0:
|
|
310
|
+
columns = self.metadata.past_covariates_real
|
|
311
|
+
if is_train:
|
|
312
|
+
self._real_column_transformers["past"] = self._get_transformer_for_columns(data, columns=columns)
|
|
313
|
+
assert "past" in self._real_column_transformers, "Preprocessing pipeline must be fit first"
|
|
314
|
+
data[columns] = self._real_column_transformers["past"].transform(data[columns])
|
|
315
|
+
|
|
316
|
+
if self.supports_static_features and len(self.metadata.static_features_real) > 0:
|
|
317
|
+
columns = self.metadata.static_features_real
|
|
318
|
+
if is_train:
|
|
319
|
+
self._real_column_transformers["static"] = self._get_transformer_for_columns(
|
|
320
|
+
data.static_features, columns=columns
|
|
321
|
+
)
|
|
322
|
+
assert "static" in self._real_column_transformers, "Preprocessing pipeline must be fit first"
|
|
323
|
+
data.static_features[columns] = self._real_column_transformers["static"].transform(
|
|
324
|
+
data.static_features[columns]
|
|
325
|
+
)
|
|
326
|
+
return data
|
|
327
|
+
|
|
328
|
+
def _get_transformer_for_columns(self, df: pd.DataFrame, columns: List[str]) -> Dict[str, str]:
|
|
329
|
+
"""Passthrough bool features, use QuantileTransform for skewed features, and use StandardScaler for the rest.
|
|
330
|
+
|
|
331
|
+
The preprocessing logic is similar to the TORCH_NN model from Tabular.
|
|
332
|
+
"""
|
|
333
|
+
skew_threshold = self._get_model_params().get("proc.skew_threshold", 0.99)
|
|
334
|
+
bool_features = []
|
|
335
|
+
skewed_features = []
|
|
336
|
+
continuous_features = []
|
|
337
|
+
for col in columns:
|
|
338
|
+
if df[col].isin([0, 1]).all():
|
|
339
|
+
bool_features.append(col)
|
|
340
|
+
elif np.abs(df[col].skew()) > skew_threshold:
|
|
341
|
+
skewed_features.append(col)
|
|
342
|
+
else:
|
|
343
|
+
continuous_features.append(col)
|
|
344
|
+
transformers = []
|
|
345
|
+
logger.debug(
|
|
346
|
+
f"\tbool_features: {bool_features}, continuous_features: {continuous_features}, skewed_features: {skewed_features}"
|
|
347
|
+
)
|
|
348
|
+
if continuous_features:
|
|
349
|
+
transformers.append(("scaler", StandardScaler(), continuous_features))
|
|
350
|
+
if skewed_features:
|
|
351
|
+
transformers.append(("skew", QuantileTransformer(output_distribution="normal"), skewed_features))
|
|
352
|
+
with warning_filter():
|
|
353
|
+
column_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough").fit(df[columns])
|
|
354
|
+
return column_transformer
|
|
355
|
+
|
|
356
|
+
def preprocess_known_covariates(
|
|
357
|
+
self, known_covariates: Optional[TimeSeriesDataFrame]
|
|
358
|
+
) -> Optional[TimeSeriesDataFrame]:
|
|
359
|
+
columns = self.metadata.known_covariates_real
|
|
360
|
+
if self.supports_known_covariates and len(columns) > 0:
|
|
361
|
+
assert "known" in self._real_column_transformers, "Preprocessing pipeline must be fit first"
|
|
362
|
+
known_covariates[columns] = self._real_column_transformers["known"].transform(known_covariates[columns])
|
|
363
|
+
return known_covariates
|
|
364
|
+
|
|
249
365
|
def _get_model_params(self) -> dict:
|
|
250
366
|
"""Gets params that are passed to the inner model."""
|
|
251
367
|
init_args = super()._get_model_params().copy()
|
|
@@ -322,42 +438,76 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
322
438
|
if time_series_df is not None:
|
|
323
439
|
# TODO: Preprocess real-valued features with StdScaler?
|
|
324
440
|
if self.num_feat_static_cat > 0:
|
|
325
|
-
feat_static_cat = time_series_df.static_features[self.metadata.static_features_cat]
|
|
441
|
+
feat_static_cat = time_series_df.static_features[self.metadata.static_features_cat].to_numpy()
|
|
326
442
|
else:
|
|
327
443
|
feat_static_cat = None
|
|
328
444
|
|
|
329
445
|
if self.num_feat_static_real > 0:
|
|
330
|
-
feat_static_real = time_series_df.static_features[self.metadata.static_features_real]
|
|
446
|
+
feat_static_real = time_series_df.static_features[self.metadata.static_features_real].to_numpy()
|
|
331
447
|
else:
|
|
332
448
|
feat_static_real = None
|
|
333
449
|
|
|
450
|
+
expected_known_covariates_len = len(time_series_df) + self.prediction_length * time_series_df.num_items
|
|
451
|
+
# Convert TSDF -> DF to avoid overhead / input validation
|
|
452
|
+
df = pd.DataFrame(time_series_df)
|
|
453
|
+
if known_covariates is not None:
|
|
454
|
+
known_covariates = pd.DataFrame(known_covariates)
|
|
455
|
+
if self.num_feat_dynamic_cat > 0:
|
|
456
|
+
feat_dynamic_cat = df[self.metadata.known_covariates_cat].to_numpy()
|
|
457
|
+
if known_covariates is not None:
|
|
458
|
+
feat_dynamic_cat = np.concatenate(
|
|
459
|
+
[feat_dynamic_cat, known_covariates[self.metadata.known_covariates_cat].to_numpy()]
|
|
460
|
+
)
|
|
461
|
+
assert len(feat_dynamic_cat) == expected_known_covariates_len
|
|
462
|
+
else:
|
|
463
|
+
feat_dynamic_cat = None
|
|
464
|
+
|
|
334
465
|
if self.num_feat_dynamic_real > 0:
|
|
335
|
-
|
|
336
|
-
feat_dynamic_real = pd.DataFrame(time_series_df[self.metadata.known_covariates_real])
|
|
466
|
+
feat_dynamic_real = df[self.metadata.known_covariates_real].to_numpy()
|
|
337
467
|
# Append future values of known covariates
|
|
338
468
|
if known_covariates is not None:
|
|
339
|
-
feat_dynamic_real =
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
469
|
+
feat_dynamic_real = np.concatenate(
|
|
470
|
+
[feat_dynamic_real, known_covariates[self.metadata.known_covariates_real].to_numpy()]
|
|
471
|
+
)
|
|
472
|
+
assert len(feat_dynamic_real) == expected_known_covariates_len
|
|
473
|
+
# Categorical covariates are one-hot-encoded as real
|
|
474
|
+
if self._ohe_generator_known is not None:
|
|
475
|
+
feat_dynamic_cat_ohe = self._ohe_generator_known.transform(df[self.metadata.known_covariates_cat])
|
|
476
|
+
if known_covariates is not None:
|
|
477
|
+
future_dynamic_cat_ohe = self._ohe_generator_known.transform(
|
|
478
|
+
known_covariates[self.metadata.known_covariates_cat]
|
|
345
479
|
)
|
|
480
|
+
feat_dynamic_cat_ohe = np.concatenate([feat_dynamic_cat_ohe, future_dynamic_cat_ohe])
|
|
481
|
+
assert len(feat_dynamic_cat_ohe) == expected_known_covariates_len
|
|
482
|
+
feat_dynamic_real = np.concatenate([feat_dynamic_real, feat_dynamic_cat_ohe], axis=1)
|
|
346
483
|
else:
|
|
347
484
|
feat_dynamic_real = None
|
|
348
485
|
|
|
486
|
+
if self.num_past_feat_dynamic_cat > 0:
|
|
487
|
+
past_feat_dynamic_cat = df[self.metadata.past_covariates_cat].to_numpy()
|
|
488
|
+
else:
|
|
489
|
+
past_feat_dynamic_cat = None
|
|
490
|
+
|
|
349
491
|
if self.num_past_feat_dynamic_real > 0:
|
|
350
|
-
|
|
351
|
-
|
|
492
|
+
past_feat_dynamic_real = df[self.metadata.past_covariates_real].to_numpy()
|
|
493
|
+
if self._ohe_generator_past is not None:
|
|
494
|
+
past_feat_dynamic_cat_ohe = self._ohe_generator_past.transform(
|
|
495
|
+
df[self.metadata.past_covariates_cat]
|
|
496
|
+
)
|
|
497
|
+
past_feat_dynamic_real = np.concatenate(
|
|
498
|
+
[past_feat_dynamic_real, past_feat_dynamic_cat_ohe], axis=1
|
|
499
|
+
)
|
|
352
500
|
else:
|
|
353
501
|
past_feat_dynamic_real = None
|
|
354
502
|
|
|
355
503
|
return SimpleGluonTSDataset(
|
|
356
|
-
target_df=time_series_df,
|
|
504
|
+
target_df=time_series_df[[self.target]],
|
|
357
505
|
target_column=self.target,
|
|
358
506
|
feat_static_cat=feat_static_cat,
|
|
359
507
|
feat_static_real=feat_static_real,
|
|
508
|
+
feat_dynamic_cat=feat_dynamic_cat,
|
|
360
509
|
feat_dynamic_real=feat_dynamic_real,
|
|
510
|
+
past_feat_dynamic_cat=past_feat_dynamic_cat,
|
|
361
511
|
past_feat_dynamic_real=past_feat_dynamic_real,
|
|
362
512
|
includes_future=known_covariates is not None,
|
|
363
513
|
prediction_length=self.prediction_length,
|
|
@@ -392,11 +542,11 @@ class AbstractGluonTSModel(AbstractTimeSeriesModel):
|
|
|
392
542
|
# update auxiliary parameters
|
|
393
543
|
init_args = self._get_estimator_init_args()
|
|
394
544
|
keep_lightning_logs = init_args.pop("keep_lightning_logs", False)
|
|
395
|
-
callbacks = self._get_callbacks(
|
|
545
|
+
self.callbacks = self._get_callbacks(
|
|
396
546
|
time_limit=time_limit,
|
|
397
547
|
early_stopping_patience=None if val_data is None else init_args["early_stopping_patience"],
|
|
398
548
|
)
|
|
399
|
-
self._deferred_init_params_aux(
|
|
549
|
+
self._deferred_init_params_aux(train_data)
|
|
400
550
|
|
|
401
551
|
estimator = self._get_estimator()
|
|
402
552
|
with warning_filter(), disable_root_logger(), gluonts.core.settings.let(gluonts.env.env, use_tqdm=False):
|
|
@@ -61,6 +61,8 @@ class DeepARModel(AbstractGluonTSModel):
|
|
|
61
61
|
embedding_dimension : int, optional
|
|
62
62
|
Dimension of the embeddings for categorical features
|
|
63
63
|
(if None, defaults to [min(50, (cat+1)//2) for cat in cardinality])
|
|
64
|
+
max_cat_cardinality : int, default = 100
|
|
65
|
+
Maximum number of dimensions to use when one-hot-encoding categorical known_covariates.
|
|
64
66
|
distr_output : gluonts.torch.distributions.DistributionOutput, default = StudentTOutput()
|
|
65
67
|
Distribution to use to evaluate observations and sample predictions
|
|
66
68
|
scaling: bool, default = True
|
|
@@ -84,6 +86,7 @@ class DeepARModel(AbstractGluonTSModel):
|
|
|
84
86
|
"""
|
|
85
87
|
|
|
86
88
|
supports_known_covariates = True
|
|
89
|
+
supports_static_features = True
|
|
87
90
|
|
|
88
91
|
def _get_estimator_class(self) -> Type[GluonTSEstimator]:
|
|
89
92
|
from gluonts.torch.model.deepar import DeepAREstimator
|
|
@@ -199,6 +202,8 @@ class TemporalFusionTransformerModel(AbstractGluonTSModel):
|
|
|
199
202
|
|
|
200
203
|
supports_known_covariates = True
|
|
201
204
|
supports_past_covariates = True
|
|
205
|
+
supports_cat_covariates = True
|
|
206
|
+
supports_static_features = True
|
|
202
207
|
|
|
203
208
|
@property
|
|
204
209
|
def default_context_length(self) -> int:
|
|
@@ -219,6 +224,11 @@ class TemporalFusionTransformerModel(AbstractGluonTSModel):
|
|
|
219
224
|
init_kwargs["static_dims"] = [self.num_feat_static_real]
|
|
220
225
|
if len(self.feat_static_cat_cardinality):
|
|
221
226
|
init_kwargs["static_cardinalities"] = self.feat_static_cat_cardinality
|
|
227
|
+
if len(self.feat_dynamic_cat_cardinality):
|
|
228
|
+
init_kwargs["dynamic_cardinalities"] = self.feat_dynamic_cat_cardinality
|
|
229
|
+
if len(self.past_feat_dynamic_cat_cardinality):
|
|
230
|
+
init_kwargs["past_dynamic_cardinalities"] = self.past_feat_dynamic_cat_cardinality
|
|
231
|
+
|
|
222
232
|
init_kwargs.setdefault("time_features", get_time_features_for_frequency(self.freq))
|
|
223
233
|
return init_kwargs
|
|
224
234
|
|
|
@@ -372,6 +382,8 @@ class WaveNetModel(AbstractGluonTSModel):
|
|
|
372
382
|
If True, logarithm of the scale of the past data will be used as an additional static feature.
|
|
373
383
|
negative_data : bool, default = True
|
|
374
384
|
Flag indicating whether the time series take negative values.
|
|
385
|
+
max_cat_cardinality : int, default = 100
|
|
386
|
+
Maximum number of dimensions to use when one-hot-encoding categorical known_covariates.
|
|
375
387
|
max_epochs : int, default = 100
|
|
376
388
|
Number of epochs the model will be trained for
|
|
377
389
|
batch_size : int, default = 64
|
|
@@ -393,6 +405,7 @@ class WaveNetModel(AbstractGluonTSModel):
|
|
|
393
405
|
"""
|
|
394
406
|
|
|
395
407
|
supports_known_covariates = True
|
|
408
|
+
supports_static_features = True
|
|
396
409
|
default_num_samples: int = 100
|
|
397
410
|
|
|
398
411
|
def _get_estimator_class(self) -> Type[GluonTSEstimator]:
|