autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +84 -0
- autogluon/timeseries/dataset/ts_dataframe.py +339 -186
- autogluon/timeseries/learner.py +192 -60
- autogluon/timeseries/metrics/__init__.py +55 -11
- autogluon/timeseries/metrics/abstract.py +96 -25
- autogluon/timeseries/metrics/point.py +186 -39
- autogluon/timeseries/metrics/quantile.py +47 -20
- autogluon/timeseries/metrics/utils.py +6 -6
- autogluon/timeseries/models/__init__.py +13 -7
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
- autogluon/timeseries/models/abstract/model_trial.py +10 -10
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
- autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
- autogluon/timeseries/models/chronos/__init__.py +4 -0
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +738 -0
- autogluon/timeseries/models/chronos/utils.py +369 -0
- autogluon/timeseries/models/ensemble/__init__.py +35 -2
- autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
- autogluon/timeseries/models/gluonts/__init__.py +3 -1
- autogluon/timeseries/models/gluonts/abstract.py +583 -0
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
- autogluon/timeseries/models/local/__init__.py +1 -10
- autogluon/timeseries/models/local/abstract_local_model.py +150 -97
- autogluon/timeseries/models/local/naive.py +31 -23
- autogluon/timeseries/models/local/npts.py +6 -2
- autogluon/timeseries/models/local/statsforecast.py +99 -112
- autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +826 -305
- autogluon/timeseries/regressor.py +253 -0
- autogluon/timeseries/splitter.py +10 -31
- autogluon/timeseries/trainer/__init__.py +2 -3
- autogluon/timeseries/trainer/ensemble_composer.py +439 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/trainer/trainer.py +1298 -0
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -0
- autogluon/timeseries/transforms/covariate_scaler.py +164 -0
- autogluon/timeseries/transforms/target_scaler.py +149 -0
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/base.py +38 -20
- autogluon/timeseries/utils/datetime/lags.py +18 -16
- autogluon/timeseries/utils/datetime/seasonality.py +14 -14
- autogluon/timeseries/utils/datetime/time_features.py +17 -14
- autogluon/timeseries/utils/features.py +317 -53
- autogluon/timeseries/utils/forecast.py +31 -17
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +44 -6
- autogluon/timeseries/version.py +2 -1
- autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
- autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -11
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -325
- autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
- autogluon/timeseries/trainer/auto_trainer.py +0 -74
- autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
- autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
|
@@ -4,12 +4,13 @@ import logging
|
|
|
4
4
|
import math
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Any, Type
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
|
+
from typing_extensions import Self
|
|
10
11
|
|
|
11
12
|
import autogluon.core as ag
|
|
12
|
-
from autogluon.timeseries.dataset
|
|
13
|
+
from autogluon.timeseries.dataset import TimeSeriesDataFrame
|
|
13
14
|
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
14
15
|
from autogluon.timeseries.models.local.abstract_local_model import AbstractLocalModel
|
|
15
16
|
from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
|
|
@@ -25,17 +26,20 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
25
26
|
|
|
26
27
|
Parameters
|
|
27
28
|
----------
|
|
28
|
-
model_base
|
|
29
|
+
model_base
|
|
29
30
|
The base model to repeatedly train. If a AbstractTimeSeriesModel class, then also provide model_base_kwargs
|
|
30
31
|
which will be used to initialize the model via model_base(**model_base_kwargs).
|
|
31
|
-
model_base_kwargs
|
|
32
|
+
model_base_kwargs
|
|
32
33
|
kwargs used to initialize model_base if model_base is a class.
|
|
33
34
|
"""
|
|
34
35
|
|
|
36
|
+
# TODO: Remove the MultiWindowBacktestingModel class, move the logic to TimeSeriesTrainer
|
|
37
|
+
default_max_time_limit_ratio = 1.0
|
|
38
|
+
|
|
35
39
|
def __init__(
|
|
36
40
|
self,
|
|
37
|
-
model_base:
|
|
38
|
-
model_base_kwargs:
|
|
41
|
+
model_base: AbstractTimeSeriesModel | Type[AbstractTimeSeriesModel],
|
|
42
|
+
model_base_kwargs: dict[str, Any] | None = None,
|
|
39
43
|
**kwargs,
|
|
40
44
|
):
|
|
41
45
|
if inspect.isclass(model_base) and issubclass(model_base, AbstractTimeSeriesModel):
|
|
@@ -54,10 +58,22 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
54
58
|
self.model_base_type = type(self.model_base)
|
|
55
59
|
self.info_per_val_window = []
|
|
56
60
|
|
|
57
|
-
self.most_recent_model: AbstractTimeSeriesModel = None
|
|
58
|
-
self.most_recent_model_folder:
|
|
61
|
+
self.most_recent_model: AbstractTimeSeriesModel | None = None
|
|
62
|
+
self.most_recent_model_folder: str | None = None
|
|
59
63
|
super().__init__(**kwargs)
|
|
60
64
|
|
|
65
|
+
@property
|
|
66
|
+
def supports_static_features(self) -> bool:
|
|
67
|
+
return self.model_base.supports_static_features
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def supports_known_covariates(self) -> bool:
|
|
71
|
+
return self.model_base.supports_known_covariates
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def supports_past_covariates(self) -> bool:
|
|
75
|
+
return self.model_base.supports_past_covariates
|
|
76
|
+
|
|
61
77
|
def _get_model_base(self):
|
|
62
78
|
return self.model_base
|
|
63
79
|
|
|
@@ -67,16 +83,19 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
67
83
|
def _is_gpu_available(self) -> bool:
|
|
68
84
|
return self._get_model_base()._is_gpu_available()
|
|
69
85
|
|
|
70
|
-
def get_minimum_resources(self, is_gpu_available: bool = False) ->
|
|
86
|
+
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
|
|
71
87
|
return self._get_model_base().get_minimum_resources(is_gpu_available)
|
|
72
88
|
|
|
73
89
|
def _fit(
|
|
74
90
|
self,
|
|
75
91
|
train_data: TimeSeriesDataFrame,
|
|
76
|
-
val_data:
|
|
77
|
-
time_limit:
|
|
78
|
-
|
|
79
|
-
|
|
92
|
+
val_data: TimeSeriesDataFrame | None = None,
|
|
93
|
+
time_limit: float | None = None,
|
|
94
|
+
num_cpus: int | None = None,
|
|
95
|
+
num_gpus: int | None = None,
|
|
96
|
+
verbosity: int = 2,
|
|
97
|
+
val_splitter: AbstractWindowSplitter | None = None,
|
|
98
|
+
refit_every_n_windows: int | None = 1,
|
|
80
99
|
**kwargs,
|
|
81
100
|
):
|
|
82
101
|
# TODO: use incremental training for GluonTS models?
|
|
@@ -90,13 +109,17 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
90
109
|
if refit_every_n_windows is None:
|
|
91
110
|
refit_every_n_windows = val_splitter.num_val_windows + 1 # only fit model for the first window
|
|
92
111
|
|
|
93
|
-
oof_predictions_per_window = []
|
|
112
|
+
oof_predictions_per_window: list[TimeSeriesDataFrame] = []
|
|
94
113
|
global_fit_start_time = time.time()
|
|
114
|
+
model: AbstractTimeSeriesModel | None = None
|
|
95
115
|
|
|
96
116
|
for window_index, (train_fold, val_fold) in enumerate(val_splitter.split(train_data)):
|
|
97
117
|
logger.debug(f"\tWindow {window_index}")
|
|
118
|
+
|
|
98
119
|
# refit_this_window is always True for the 0th window
|
|
99
120
|
refit_this_window = window_index % refit_every_n_windows == 0
|
|
121
|
+
assert window_index != 0 or refit_this_window
|
|
122
|
+
|
|
100
123
|
if time_limit is None:
|
|
101
124
|
time_left_for_window = None
|
|
102
125
|
else:
|
|
@@ -110,8 +133,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
110
133
|
num_refits_remaining = math.ceil(
|
|
111
134
|
(val_splitter.num_val_windows - window_index) / refit_every_n_windows
|
|
112
135
|
)
|
|
113
|
-
|
|
114
|
-
time_left_for_window = 0.9 * time_left / num_refits_remaining
|
|
136
|
+
time_left_for_window = time_left / num_refits_remaining
|
|
115
137
|
|
|
116
138
|
if refit_this_window:
|
|
117
139
|
model = self.get_child_model(window_index)
|
|
@@ -120,11 +142,21 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
120
142
|
train_data=train_fold,
|
|
121
143
|
val_data=val_fold,
|
|
122
144
|
time_limit=time_left_for_window,
|
|
145
|
+
verbosity=verbosity,
|
|
123
146
|
**kwargs,
|
|
124
147
|
)
|
|
125
148
|
model.fit_time = time.time() - model_fit_start_time
|
|
126
149
|
most_recent_refit_window = f"W{window_index}"
|
|
127
|
-
|
|
150
|
+
|
|
151
|
+
if time_limit is None:
|
|
152
|
+
time_left_for_prediction = None
|
|
153
|
+
else:
|
|
154
|
+
time_left_for_prediction = time_limit - (time.time() - global_fit_start_time)
|
|
155
|
+
|
|
156
|
+
assert model is not None
|
|
157
|
+
model.score_and_cache_oof(
|
|
158
|
+
val_fold, store_val_score=True, store_predict_time=True, time_limit=time_left_for_prediction
|
|
159
|
+
)
|
|
128
160
|
|
|
129
161
|
oof_predictions_per_window.append(model.get_oof_predictions()[0])
|
|
130
162
|
|
|
@@ -146,11 +178,14 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
146
178
|
|
|
147
179
|
# Only the model trained on most recent data is saved & used for prediction
|
|
148
180
|
self.most_recent_model = model
|
|
149
|
-
self.
|
|
181
|
+
assert self.most_recent_model is not None
|
|
182
|
+
|
|
183
|
+
self.most_recent_model_folder = most_recent_refit_window # type: ignore
|
|
150
184
|
self.predict_time = self.most_recent_model.predict_time
|
|
151
|
-
self.fit_time = time.time() - global_fit_start_time - self.predict_time
|
|
152
|
-
self.
|
|
153
|
-
|
|
185
|
+
self.fit_time = time.time() - global_fit_start_time - self.predict_time # type: ignore
|
|
186
|
+
self.cache_oof_predictions(oof_predictions_per_window)
|
|
187
|
+
|
|
188
|
+
self.val_score = float(np.mean([info["val_score"] for info in self.info_per_val_window]))
|
|
154
189
|
|
|
155
190
|
def get_info(self) -> dict:
|
|
156
191
|
info = super().get_info()
|
|
@@ -165,7 +200,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
165
200
|
def _predict(
|
|
166
201
|
self,
|
|
167
202
|
data: TimeSeriesDataFrame,
|
|
168
|
-
known_covariates:
|
|
203
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
169
204
|
**kwargs,
|
|
170
205
|
) -> TimeSeriesDataFrame:
|
|
171
206
|
if self.most_recent_model is None:
|
|
@@ -177,23 +212,36 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
177
212
|
val_data: TimeSeriesDataFrame,
|
|
178
213
|
store_val_score: bool = False,
|
|
179
214
|
store_predict_time: bool = False,
|
|
215
|
+
**predict_kwargs,
|
|
180
216
|
) -> None:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
217
|
+
if self._oof_predictions is None or self.most_recent_model is None:
|
|
218
|
+
raise ValueError(f"{self.name} must be fit before calling score_and_cache_oof")
|
|
219
|
+
|
|
220
|
+
# Score on val_data using the most recent model
|
|
221
|
+
past_data, known_covariates = val_data.get_model_inputs_for_scoring(
|
|
222
|
+
prediction_length=self.prediction_length, known_covariates_names=self.covariate_metadata.known_covariates
|
|
223
|
+
)
|
|
224
|
+
predict_start_time = time.time()
|
|
225
|
+
val_predictions = self.most_recent_model.predict(
|
|
226
|
+
past_data, known_covariates=known_covariates, **predict_kwargs
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
self._oof_predictions.append(val_predictions)
|
|
230
|
+
|
|
185
231
|
if store_predict_time:
|
|
186
|
-
|
|
232
|
+
self.predict_time = time.time() - predict_start_time
|
|
187
233
|
|
|
188
|
-
|
|
189
|
-
|
|
234
|
+
if store_val_score:
|
|
235
|
+
self.val_score = self._score_with_predictions(val_data, val_predictions)
|
|
190
236
|
|
|
191
237
|
def _get_search_space(self):
|
|
192
238
|
return self.model_base._get_search_space()
|
|
193
239
|
|
|
194
|
-
def
|
|
195
|
-
|
|
196
|
-
self.
|
|
240
|
+
def _initialize_transforms_and_regressor(self) -> None:
|
|
241
|
+
# Do not initialize the target_scaler and covariate_regressor in the multi window model!
|
|
242
|
+
self.target_scaler = None
|
|
243
|
+
self.covariate_scaler = None
|
|
244
|
+
self.covariate_regressor = None
|
|
197
245
|
|
|
198
246
|
def _get_hpo_train_fn_kwargs(self, **train_fn_kwargs) -> dict:
|
|
199
247
|
train_fn_kwargs["is_bagged_model"] = True
|
|
@@ -201,7 +249,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
201
249
|
train_fn_kwargs["init_params"]["model_base_kwargs"] = self.get_params()
|
|
202
250
|
return train_fn_kwargs
|
|
203
251
|
|
|
204
|
-
def save(self, path: str = None, verbose=True) -> str:
|
|
252
|
+
def save(self, path: str | None = None, verbose: bool = True) -> str:
|
|
205
253
|
most_recent_model = self.most_recent_model
|
|
206
254
|
self.most_recent_model = None
|
|
207
255
|
save_path = super().save(path, verbose)
|
|
@@ -212,30 +260,41 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
212
260
|
most_recent_model.save()
|
|
213
261
|
return save_path
|
|
214
262
|
|
|
263
|
+
def persist(self) -> Self:
|
|
264
|
+
if self.most_recent_model is None:
|
|
265
|
+
raise ValueError(f"{self.name} must be fit before persisting")
|
|
266
|
+
self.most_recent_model.persist()
|
|
267
|
+
return self
|
|
268
|
+
|
|
215
269
|
@classmethod
|
|
216
270
|
def load(
|
|
217
271
|
cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
|
|
218
272
|
) -> AbstractTimeSeriesModel:
|
|
219
273
|
model = super().load(path=path, reset_paths=reset_paths, load_oof=load_oof, verbose=verbose)
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
274
|
+
if model.most_recent_model_folder is not None:
|
|
275
|
+
most_recent_model_path = os.path.join(model.path, model.most_recent_model_folder)
|
|
276
|
+
model.most_recent_model = model.model_base_type.load(
|
|
277
|
+
most_recent_model_path,
|
|
278
|
+
reset_paths=reset_paths,
|
|
279
|
+
verbose=verbose,
|
|
280
|
+
)
|
|
226
281
|
return model
|
|
227
282
|
|
|
228
283
|
def convert_to_refit_full_template(self) -> AbstractTimeSeriesModel:
|
|
229
284
|
# refit_model is an instance of base model type, not MultiWindowBacktestingModel
|
|
285
|
+
assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
|
|
230
286
|
refit_model = self.most_recent_model.convert_to_refit_full_template()
|
|
231
287
|
refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
|
|
232
288
|
return refit_model
|
|
233
289
|
|
|
234
290
|
def convert_to_refit_full_via_copy(self) -> AbstractTimeSeriesModel:
|
|
235
291
|
# refit_model is an instance of base model type, not MultiWindowBacktestingModel
|
|
292
|
+
assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
|
|
236
293
|
refit_model = self.most_recent_model.convert_to_refit_full_via_copy()
|
|
237
294
|
refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
|
|
238
295
|
return refit_model
|
|
239
296
|
|
|
240
297
|
def _more_tags(self) -> dict:
|
|
241
|
-
|
|
298
|
+
tags = self.model_base._get_tags()
|
|
299
|
+
tags["can_use_val_data"] = False
|
|
300
|
+
return tags
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from abc import ABCMeta
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from inspect import isabstract
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ModelRecord:
|
|
8
|
+
model_class: type
|
|
9
|
+
ag_priority: int
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelRegistry(ABCMeta):
|
|
13
|
+
"""Registry metaclass for time series models. Ensures that TimeSeriesModel classes
|
|
14
|
+
which implement this metaclass are automatically registered, in order to centralize
|
|
15
|
+
access to model types.
|
|
16
|
+
|
|
17
|
+
See, https://github.com/faif/python-patterns.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
REGISTRY: dict[str, ModelRecord] = {}
|
|
21
|
+
|
|
22
|
+
def __new__(cls, name, bases, attrs):
|
|
23
|
+
new_cls = super().__new__(cls, name, bases, attrs)
|
|
24
|
+
|
|
25
|
+
if name is not None and not isabstract(new_cls):
|
|
26
|
+
record = ModelRecord(
|
|
27
|
+
model_class=new_cls,
|
|
28
|
+
ag_priority=getattr(new_cls, "ag_priority", 0),
|
|
29
|
+
)
|
|
30
|
+
cls._add(name.removesuffix("Model"), record)
|
|
31
|
+
|
|
32
|
+
# if the class provides additional aliases, register them too
|
|
33
|
+
if aliases := attrs.get("ag_model_aliases"):
|
|
34
|
+
for alias in aliases:
|
|
35
|
+
cls._add(alias, record)
|
|
36
|
+
|
|
37
|
+
return new_cls
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def _add(cls, alias: str, record: ModelRecord) -> None:
|
|
41
|
+
if alias in cls.REGISTRY:
|
|
42
|
+
raise ValueError(f"You are trying to define a new model with {alias}, but this model already exists.")
|
|
43
|
+
cls.REGISTRY[alias] = record
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def _get_model_record(cls, alias: str | type) -> ModelRecord:
|
|
47
|
+
if isinstance(alias, type):
|
|
48
|
+
alias = alias.__name__
|
|
49
|
+
alias = alias.removesuffix("Model")
|
|
50
|
+
if alias not in cls.REGISTRY:
|
|
51
|
+
raise ValueError(f"Unknown model: {alias}, available models are: {cls.available_aliases()}")
|
|
52
|
+
return cls.REGISTRY[alias]
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_model_class(cls, alias: str | type) -> type:
|
|
56
|
+
return cls._get_model_record(alias).model_class
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def get_model_priority(cls, alias: str | type) -> int:
|
|
60
|
+
return cls._get_model_record(alias).ag_priority
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def available_aliases(cls) -> list[str]:
|
|
64
|
+
return sorted(cls.REGISTRY.keys())
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
from torch.nn.functional import scaled_dot_product_attention
|
|
12
|
+
|
|
13
|
+
from .rope import TimeAwareRotaryEmbedding
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AttentionAxis(Enum):
|
|
19
|
+
TIME = 1
|
|
20
|
+
SPACE = 2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseMultiheadAttention(torch.nn.Module):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
embed_dim: int,
|
|
27
|
+
num_heads: int,
|
|
28
|
+
dropout: float,
|
|
29
|
+
rotary_emb: TimeAwareRotaryEmbedding | None,
|
|
30
|
+
use_memory_efficient_attention: bool,
|
|
31
|
+
):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.embed_dim = embed_dim
|
|
34
|
+
self.num_heads = num_heads
|
|
35
|
+
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
|
|
36
|
+
self.head_dim = embed_dim // num_heads
|
|
37
|
+
self.rotary_emb = rotary_emb
|
|
38
|
+
|
|
39
|
+
# We allocate a single tensor for the q, k, and v projection matrices,
|
|
40
|
+
# multiply them with the inputs, and then split the projected tensors into q, k, and v using unbind.
|
|
41
|
+
# This reduces overhead a bit vs. having multiple separate Linear layers,
|
|
42
|
+
# which need to be initialized, tracked by the optimizer, etc.
|
|
43
|
+
self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
|
|
44
|
+
self.dropout = dropout
|
|
45
|
+
self.use_memory_efficient_attention = use_memory_efficient_attention
|
|
46
|
+
self.wO = torch.nn.Linear(embed_dim, embed_dim)
|
|
47
|
+
|
|
48
|
+
assert not self.use_memory_efficient_attention, (
|
|
49
|
+
"xformers is not available, so use_memory_efficient_attention must be False"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE):
|
|
53
|
+
raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.")
|
|
54
|
+
|
|
55
|
+
def rearrange_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
pattern = (
|
|
57
|
+
"batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim"
|
|
58
|
+
if self.attention_axis == AttentionAxis.TIME
|
|
59
|
+
else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return rearrange(inputs, pattern)
|
|
63
|
+
|
|
64
|
+
def get_qkv(
|
|
65
|
+
self,
|
|
66
|
+
inputs: torch.Tensor,
|
|
67
|
+
) -> tuple[torch.Tensor, ...]:
|
|
68
|
+
pattern: str = ""
|
|
69
|
+
if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
|
|
70
|
+
pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim"
|
|
71
|
+
elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
|
|
72
|
+
pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim"
|
|
73
|
+
elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
|
|
74
|
+
pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim"
|
|
75
|
+
elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
|
|
76
|
+
pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim"
|
|
77
|
+
|
|
78
|
+
assert pattern
|
|
79
|
+
qkv = self.wQKV(inputs.contiguous())
|
|
80
|
+
return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0)
|
|
81
|
+
|
|
82
|
+
def positional_embedding(self, q, k, v, kv_cache, layer_idx):
|
|
83
|
+
# Apply the rotary embeddings
|
|
84
|
+
seq_pos_offset = 0
|
|
85
|
+
if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME:
|
|
86
|
+
if kv_cache is not None:
|
|
87
|
+
seq_pos_offset = kv_cache.seq_len(layer_idx)
|
|
88
|
+
|
|
89
|
+
# We need to permute because rotary embeddings expect the sequence dimension to be the second-to-last dimension
|
|
90
|
+
q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset)
|
|
91
|
+
|
|
92
|
+
if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
|
|
93
|
+
# First, we append the current input key and value tensors to the cache.
|
|
94
|
+
# This concatenates the current key and value tensors to the existing key and value tensors
|
|
95
|
+
kv_cache.append(layer_idx, (k, v))
|
|
96
|
+
# Then, we retrieve the key and value tensors from the cache.
|
|
97
|
+
# This includes all the key and value tensors from previous time steps
|
|
98
|
+
# as well as the current time step.
|
|
99
|
+
k, v = kv_cache[layer_idx]
|
|
100
|
+
|
|
101
|
+
q = q.contiguous()
|
|
102
|
+
k = k.contiguous().to(q.dtype) # Ensure k is the same dtype as q; this is necessary when using mixed precision
|
|
103
|
+
v = v.contiguous().to(q.dtype) # Ensure v is the same dtype as q; this is necessary when using mixed precision
|
|
104
|
+
|
|
105
|
+
return q, k, v, seq_pos_offset
|
|
106
|
+
|
|
107
|
+
def rearrange_output(self, output: torch.Tensor, batch: int, variate: int, seq_len: int) -> torch.Tensor:
|
|
108
|
+
if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
|
|
109
|
+
pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
110
|
+
elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
|
|
111
|
+
pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
112
|
+
elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
|
|
113
|
+
pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
114
|
+
elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
|
|
115
|
+
pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
116
|
+
|
|
117
|
+
return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) # type: ignore
|
|
118
|
+
|
|
119
|
+
def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate):
|
|
120
|
+
# Determine dimension ranges for attention
|
|
121
|
+
# Ensure the last query vector index is used from the cache
|
|
122
|
+
q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
|
|
123
|
+
kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
|
|
124
|
+
if self.attention_axis == AttentionAxis.TIME:
|
|
125
|
+
attention_mask = (
|
|
126
|
+
attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end]
|
|
127
|
+
if torch.is_tensor(attention_mask)
|
|
128
|
+
else None
|
|
129
|
+
)
|
|
130
|
+
return scaled_dot_product_attention(
|
|
131
|
+
q,
|
|
132
|
+
k,
|
|
133
|
+
v,
|
|
134
|
+
attn_mask=attention_mask,
|
|
135
|
+
dropout_p=dropout,
|
|
136
|
+
is_causal=(attention_mask is None and seq_pos_offset == 0),
|
|
137
|
+
)
|
|
138
|
+
elif self.attention_axis == AttentionAxis.SPACE:
|
|
139
|
+
# We don't use causal masking for space-wise attention
|
|
140
|
+
attention_mask = (
|
|
141
|
+
attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end]
|
|
142
|
+
if torch.is_tensor(attention_mask)
|
|
143
|
+
else None
|
|
144
|
+
)
|
|
145
|
+
return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False)
|
|
146
|
+
else:
|
|
147
|
+
raise ValueError("Invalid attention axis")
|
|
148
|
+
|
|
149
|
+
def forward(
|
|
150
|
+
self,
|
|
151
|
+
layer_idx: int,
|
|
152
|
+
inputs: torch.Tensor,
|
|
153
|
+
attention_mask: torch.Tensor | None = None,
|
|
154
|
+
kv_cache=None,
|
|
155
|
+
) -> torch.Tensor:
|
|
156
|
+
batch_size, variate, seq_len, _ = inputs.shape
|
|
157
|
+
dropout = self.dropout if self.training else 0.0
|
|
158
|
+
|
|
159
|
+
rearranged_inputs = self.rearrange_inputs(inputs)
|
|
160
|
+
q, k, v = self.get_qkv(rearranged_inputs)
|
|
161
|
+
|
|
162
|
+
q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx)
|
|
163
|
+
|
|
164
|
+
output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate)
|
|
165
|
+
|
|
166
|
+
output = self.rearrange_output(output, batch_size, variate, seq_len)
|
|
167
|
+
return self.wO(output)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class TimeWiseMultiheadAttention(BaseMultiheadAttention):
|
|
171
|
+
"""
|
|
172
|
+
Computes standard multihead causal attention over the time axis.
|
|
173
|
+
It does this by flattening out the variates along the batch dimension.
|
|
174
|
+
It also applies rotary position embeddings to the query and key matrices
|
|
175
|
+
in order to incorporate relative positional information.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
attention_axis = AttentionAxis.TIME
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
|
|
182
|
+
"""
|
|
183
|
+
Computes bidirectional multihead attention over the space axis (i.e. across variates within
|
|
184
|
+
a multi-variate time series). This is done by flattening out the time axis along the batch dimension.
|
|
185
|
+
This allows the model to attend to different variates at the same time point. By alternating
|
|
186
|
+
between time-wise and space-wise attention, the model can learn both temporal and cross-variate
|
|
187
|
+
dependencies in the data.
|
|
188
|
+
|
|
189
|
+
Unlike with time-wise attention, don't apply rotary embeddings here
|
|
190
|
+
because we want cross-variate attention to be invariant to the order of the variates.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
attention_axis = AttentionAxis.SPACE
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention
|