autogluon.timeseries 1.2.1b20250224__py3-none-any.whl → 1.4.1b20251215__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/__init__.py +3 -2
- autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
- autogluon/timeseries/configs/predictor_presets.py +106 -0
- autogluon/timeseries/dataset/ts_dataframe.py +256 -141
- autogluon/timeseries/learner.py +86 -52
- autogluon/timeseries/metrics/__init__.py +42 -8
- autogluon/timeseries/metrics/abstract.py +89 -19
- autogluon/timeseries/metrics/point.py +142 -53
- autogluon/timeseries/metrics/quantile.py +46 -21
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +8 -2
- autogluon/timeseries/models/abstract/__init__.py +2 -2
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +361 -592
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +189 -0
- autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +282 -194
- autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
- autogluon/timeseries/models/autogluon_tabular/transforms.py +25 -18
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +219 -138
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +81 -50
- autogluon/timeseries/models/ensemble/__init__.py +37 -2
- autogluon/timeseries/models/ensemble/abstract.py +107 -0
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
- autogluon/timeseries/models/ensemble/weighted/basic.py +91 -0
- autogluon/timeseries/models/ensemble/weighted/greedy.py +62 -0
- autogluon/timeseries/models/gluonts/__init__.py +1 -1
- autogluon/timeseries/models/gluonts/{abstract_gluonts.py → abstract.py} +148 -208
- autogluon/timeseries/models/gluonts/dataset.py +109 -0
- autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +38 -22
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +71 -74
- autogluon/timeseries/models/local/naive.py +13 -9
- autogluon/timeseries/models/local/npts.py +9 -2
- autogluon/timeseries/models/local/statsforecast.py +52 -36
- autogluon/timeseries/models/multi_window/multi_window_model.py +65 -45
- autogluon/timeseries/models/registry.py +64 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
- autogluon/timeseries/models/toto/model.py +249 -0
- autogluon/timeseries/predictor.py +685 -297
- autogluon/timeseries/regressor.py +94 -44
- autogluon/timeseries/splitter.py +8 -32
- autogluon/timeseries/trainer/__init__.py +3 -0
- autogluon/timeseries/trainer/ensemble_composer.py +444 -0
- autogluon/timeseries/trainer/model_set_builder.py +256 -0
- autogluon/timeseries/trainer/prediction_cache.py +149 -0
- autogluon/timeseries/{trainer.py → trainer/trainer.py} +387 -390
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/__init__.py +2 -13
- autogluon/timeseries/transforms/covariate_scaler.py +34 -40
- autogluon/timeseries/transforms/target_scaler.py +37 -20
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +3 -5
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/datetime/time_features.py +2 -2
- autogluon/timeseries/utils/features.py +70 -47
- autogluon/timeseries/utils/forecast.py +19 -14
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/utils/warning_filters.py +4 -2
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.4.1b20251215-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/METADATA +49 -36
- autogluon_timeseries-1.4.1b20251215.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/WHEEL +1 -1
- autogluon/timeseries/configs/presets_configs.py +0 -79
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -11
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -585
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -518
- autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -78
- autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
- autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
- autogluon/timeseries/models/presets.py +0 -360
- autogluon.timeseries-1.2.1b20250224-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.2.1b20250224.dist-info/RECORD +0 -68
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/zip-safe +0 -0
|
@@ -4,12 +4,13 @@ import logging
|
|
|
4
4
|
import math
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Type
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
|
+
from typing_extensions import Self
|
|
10
11
|
|
|
11
12
|
import autogluon.core as ag
|
|
12
|
-
from autogluon.timeseries.dataset
|
|
13
|
+
from autogluon.timeseries.dataset import TimeSeriesDataFrame
|
|
13
14
|
from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
|
|
14
15
|
from autogluon.timeseries.models.local.abstract_local_model import AbstractLocalModel
|
|
15
16
|
from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
|
|
@@ -25,10 +26,10 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
25
26
|
|
|
26
27
|
Parameters
|
|
27
28
|
----------
|
|
28
|
-
model_base
|
|
29
|
+
model_base
|
|
29
30
|
The base model to repeatedly train. If a AbstractTimeSeriesModel class, then also provide model_base_kwargs
|
|
30
31
|
which will be used to initialize the model via model_base(**model_base_kwargs).
|
|
31
|
-
model_base_kwargs
|
|
32
|
+
model_base_kwargs
|
|
32
33
|
kwargs used to initialize model_base if model_base is a class.
|
|
33
34
|
"""
|
|
34
35
|
|
|
@@ -37,8 +38,8 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
37
38
|
|
|
38
39
|
def __init__(
|
|
39
40
|
self,
|
|
40
|
-
model_base:
|
|
41
|
-
model_base_kwargs:
|
|
41
|
+
model_base: AbstractTimeSeriesModel | Type[AbstractTimeSeriesModel],
|
|
42
|
+
model_base_kwargs: dict[str, Any] | None = None,
|
|
42
43
|
**kwargs,
|
|
43
44
|
):
|
|
44
45
|
if inspect.isclass(model_base) and issubclass(model_base, AbstractTimeSeriesModel):
|
|
@@ -57,8 +58,8 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
57
58
|
self.model_base_type = type(self.model_base)
|
|
58
59
|
self.info_per_val_window = []
|
|
59
60
|
|
|
60
|
-
self.most_recent_model:
|
|
61
|
-
self.most_recent_model_folder:
|
|
61
|
+
self.most_recent_model: AbstractTimeSeriesModel | None = None
|
|
62
|
+
self.most_recent_model_folder: str | None = None
|
|
62
63
|
super().__init__(**kwargs)
|
|
63
64
|
|
|
64
65
|
@property
|
|
@@ -73,10 +74,6 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
73
74
|
def supports_past_covariates(self) -> bool:
|
|
74
75
|
return self.model_base.supports_past_covariates
|
|
75
76
|
|
|
76
|
-
@property
|
|
77
|
-
def supports_cat_covariates(self) -> bool:
|
|
78
|
-
return self.model_base.supports_cat_covariates
|
|
79
|
-
|
|
80
77
|
def _get_model_base(self):
|
|
81
78
|
return self.model_base
|
|
82
79
|
|
|
@@ -86,16 +83,19 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
86
83
|
def _is_gpu_available(self) -> bool:
|
|
87
84
|
return self._get_model_base()._is_gpu_available()
|
|
88
85
|
|
|
89
|
-
def get_minimum_resources(self, is_gpu_available: bool = False) ->
|
|
86
|
+
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
|
|
90
87
|
return self._get_model_base().get_minimum_resources(is_gpu_available)
|
|
91
88
|
|
|
92
89
|
def _fit(
|
|
93
90
|
self,
|
|
94
91
|
train_data: TimeSeriesDataFrame,
|
|
95
|
-
val_data:
|
|
96
|
-
time_limit:
|
|
97
|
-
|
|
98
|
-
|
|
92
|
+
val_data: TimeSeriesDataFrame | None = None,
|
|
93
|
+
time_limit: float | None = None,
|
|
94
|
+
num_cpus: int | None = None,
|
|
95
|
+
num_gpus: int | None = None,
|
|
96
|
+
verbosity: int = 2,
|
|
97
|
+
val_splitter: AbstractWindowSplitter | None = None,
|
|
98
|
+
refit_every_n_windows: int | None = 1,
|
|
99
99
|
**kwargs,
|
|
100
100
|
):
|
|
101
101
|
# TODO: use incremental training for GluonTS models?
|
|
@@ -109,13 +109,17 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
109
109
|
if refit_every_n_windows is None:
|
|
110
110
|
refit_every_n_windows = val_splitter.num_val_windows + 1 # only fit model for the first window
|
|
111
111
|
|
|
112
|
-
oof_predictions_per_window = []
|
|
112
|
+
oof_predictions_per_window: list[TimeSeriesDataFrame] = []
|
|
113
113
|
global_fit_start_time = time.time()
|
|
114
|
+
model: AbstractTimeSeriesModel | None = None
|
|
114
115
|
|
|
115
116
|
for window_index, (train_fold, val_fold) in enumerate(val_splitter.split(train_data)):
|
|
116
117
|
logger.debug(f"\tWindow {window_index}")
|
|
118
|
+
|
|
117
119
|
# refit_this_window is always True for the 0th window
|
|
118
120
|
refit_this_window = window_index % refit_every_n_windows == 0
|
|
121
|
+
assert window_index != 0 or refit_this_window
|
|
122
|
+
|
|
119
123
|
if time_limit is None:
|
|
120
124
|
time_left_for_window = None
|
|
121
125
|
else:
|
|
@@ -138,6 +142,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
138
142
|
train_data=train_fold,
|
|
139
143
|
val_data=val_fold,
|
|
140
144
|
time_limit=time_left_for_window,
|
|
145
|
+
verbosity=verbosity,
|
|
141
146
|
**kwargs,
|
|
142
147
|
)
|
|
143
148
|
model.fit_time = time.time() - model_fit_start_time
|
|
@@ -148,6 +153,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
148
153
|
else:
|
|
149
154
|
time_left_for_prediction = time_limit - (time.time() - global_fit_start_time)
|
|
150
155
|
|
|
156
|
+
assert model is not None
|
|
151
157
|
model.score_and_cache_oof(
|
|
152
158
|
val_fold, store_val_score=True, store_predict_time=True, time_limit=time_left_for_prediction
|
|
153
159
|
)
|
|
@@ -172,11 +178,14 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
172
178
|
|
|
173
179
|
# Only the model trained on most recent data is saved & used for prediction
|
|
174
180
|
self.most_recent_model = model
|
|
175
|
-
self.
|
|
181
|
+
assert self.most_recent_model is not None
|
|
182
|
+
|
|
183
|
+
self.most_recent_model_folder = most_recent_refit_window # type: ignore
|
|
176
184
|
self.predict_time = self.most_recent_model.predict_time
|
|
177
|
-
self.fit_time = time.time() - global_fit_start_time - self.predict_time
|
|
178
|
-
self.
|
|
179
|
-
|
|
185
|
+
self.fit_time = time.time() - global_fit_start_time - self.predict_time # type: ignore
|
|
186
|
+
self.cache_oof_predictions(oof_predictions_per_window)
|
|
187
|
+
|
|
188
|
+
self.val_score = float(np.mean([info["val_score"] for info in self.info_per_val_window]))
|
|
180
189
|
|
|
181
190
|
def get_info(self) -> dict:
|
|
182
191
|
info = super().get_info()
|
|
@@ -191,7 +200,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
191
200
|
def _predict(
|
|
192
201
|
self,
|
|
193
202
|
data: TimeSeriesDataFrame,
|
|
194
|
-
known_covariates:
|
|
203
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
195
204
|
**kwargs,
|
|
196
205
|
) -> TimeSeriesDataFrame:
|
|
197
206
|
if self.most_recent_model is None:
|
|
@@ -205,27 +214,34 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
205
214
|
store_predict_time: bool = False,
|
|
206
215
|
**predict_kwargs,
|
|
207
216
|
) -> None:
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
217
|
+
if self._oof_predictions is None or self.most_recent_model is None:
|
|
218
|
+
raise ValueError(f"{self.name} must be fit before calling score_and_cache_oof")
|
|
219
|
+
|
|
220
|
+
# Score on val_data using the most recent model
|
|
221
|
+
past_data, known_covariates = val_data.get_model_inputs_for_scoring(
|
|
222
|
+
prediction_length=self.prediction_length, known_covariates_names=self.covariate_metadata.known_covariates
|
|
223
|
+
)
|
|
224
|
+
predict_start_time = time.time()
|
|
225
|
+
val_predictions = self.most_recent_model.predict(
|
|
226
|
+
past_data, known_covariates=known_covariates, **predict_kwargs
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
self._oof_predictions.append(val_predictions)
|
|
230
|
+
|
|
212
231
|
if store_predict_time:
|
|
213
|
-
|
|
232
|
+
self.predict_time = time.time() - predict_start_time
|
|
214
233
|
|
|
215
|
-
|
|
216
|
-
|
|
234
|
+
if store_val_score:
|
|
235
|
+
self.val_score = self._score_with_predictions(val_data, val_predictions)
|
|
217
236
|
|
|
218
237
|
def _get_search_space(self):
|
|
219
238
|
return self.model_base._get_search_space()
|
|
220
239
|
|
|
221
|
-
def
|
|
240
|
+
def _initialize_transforms_and_regressor(self) -> None:
|
|
222
241
|
# Do not initialize the target_scaler and covariate_regressor in the multi window model!
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
super().initialize(**kwargs)
|
|
227
|
-
self.model_base.initialize(**kwargs)
|
|
228
|
-
return kwargs
|
|
242
|
+
self.target_scaler = None
|
|
243
|
+
self.covariate_scaler = None
|
|
244
|
+
self.covariate_regressor = None
|
|
229
245
|
|
|
230
246
|
def _get_hpo_train_fn_kwargs(self, **train_fn_kwargs) -> dict:
|
|
231
247
|
train_fn_kwargs["is_bagged_model"] = True
|
|
@@ -233,7 +249,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
233
249
|
train_fn_kwargs["init_params"]["model_base_kwargs"] = self.get_params()
|
|
234
250
|
return train_fn_kwargs
|
|
235
251
|
|
|
236
|
-
def save(self, path: str = None, verbose=True) -> str:
|
|
252
|
+
def save(self, path: str | None = None, verbose: bool = True) -> str:
|
|
237
253
|
most_recent_model = self.most_recent_model
|
|
238
254
|
self.most_recent_model = None
|
|
239
255
|
save_path = super().save(path, verbose)
|
|
@@ -244,32 +260,36 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
244
260
|
most_recent_model.save()
|
|
245
261
|
return save_path
|
|
246
262
|
|
|
247
|
-
def persist(self):
|
|
263
|
+
def persist(self) -> Self:
|
|
248
264
|
if self.most_recent_model is None:
|
|
249
265
|
raise ValueError(f"{self.name} must be fit before persisting")
|
|
250
266
|
self.most_recent_model.persist()
|
|
267
|
+
return self
|
|
251
268
|
|
|
252
269
|
@classmethod
|
|
253
270
|
def load(
|
|
254
271
|
cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
|
|
255
272
|
) -> AbstractTimeSeriesModel:
|
|
256
273
|
model = super().load(path=path, reset_paths=reset_paths, load_oof=load_oof, verbose=verbose)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
274
|
+
if model.most_recent_model_folder is not None:
|
|
275
|
+
most_recent_model_path = os.path.join(model.path, model.most_recent_model_folder)
|
|
276
|
+
model.most_recent_model = model.model_base_type.load(
|
|
277
|
+
most_recent_model_path,
|
|
278
|
+
reset_paths=reset_paths,
|
|
279
|
+
verbose=verbose,
|
|
280
|
+
)
|
|
263
281
|
return model
|
|
264
282
|
|
|
265
283
|
def convert_to_refit_full_template(self) -> AbstractTimeSeriesModel:
|
|
266
284
|
# refit_model is an instance of base model type, not MultiWindowBacktestingModel
|
|
285
|
+
assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
|
|
267
286
|
refit_model = self.most_recent_model.convert_to_refit_full_template()
|
|
268
287
|
refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
|
|
269
288
|
return refit_model
|
|
270
289
|
|
|
271
290
|
def convert_to_refit_full_via_copy(self) -> AbstractTimeSeriesModel:
|
|
272
291
|
# refit_model is an instance of base model type, not MultiWindowBacktestingModel
|
|
292
|
+
assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
|
|
273
293
|
refit_model = self.most_recent_model.convert_to_refit_full_via_copy()
|
|
274
294
|
refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
|
|
275
295
|
return refit_model
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from abc import ABCMeta
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from inspect import isabstract
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ModelRecord:
|
|
8
|
+
model_class: type
|
|
9
|
+
ag_priority: int
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelRegistry(ABCMeta):
|
|
13
|
+
"""Registry metaclass for time series models. Ensures that TimeSeriesModel classes
|
|
14
|
+
which implement this metaclass are automatically registered, in order to centralize
|
|
15
|
+
access to model types.
|
|
16
|
+
|
|
17
|
+
See, https://github.com/faif/python-patterns.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
REGISTRY: dict[str, ModelRecord] = {}
|
|
21
|
+
|
|
22
|
+
def __new__(cls, name, bases, attrs):
|
|
23
|
+
new_cls = super().__new__(cls, name, bases, attrs)
|
|
24
|
+
|
|
25
|
+
if name is not None and not isabstract(new_cls):
|
|
26
|
+
record = ModelRecord(
|
|
27
|
+
model_class=new_cls,
|
|
28
|
+
ag_priority=getattr(new_cls, "ag_priority", 0),
|
|
29
|
+
)
|
|
30
|
+
cls._add(name.removesuffix("Model"), record)
|
|
31
|
+
|
|
32
|
+
# if the class provides additional aliases, register them too
|
|
33
|
+
if aliases := attrs.get("ag_model_aliases"):
|
|
34
|
+
for alias in aliases:
|
|
35
|
+
cls._add(alias, record)
|
|
36
|
+
|
|
37
|
+
return new_cls
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def _add(cls, alias: str, record: ModelRecord) -> None:
|
|
41
|
+
if alias in cls.REGISTRY:
|
|
42
|
+
raise ValueError(f"You are trying to define a new model with {alias}, but this model already exists.")
|
|
43
|
+
cls.REGISTRY[alias] = record
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def _get_model_record(cls, alias: str | type) -> ModelRecord:
|
|
47
|
+
if isinstance(alias, type):
|
|
48
|
+
alias = alias.__name__
|
|
49
|
+
alias = alias.removesuffix("Model")
|
|
50
|
+
if alias not in cls.REGISTRY:
|
|
51
|
+
raise ValueError(f"Unknown model: {alias}, available models are: {cls.available_aliases()}")
|
|
52
|
+
return cls.REGISTRY[alias]
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def get_model_class(cls, alias: str | type) -> type:
|
|
56
|
+
return cls._get_model_record(alias).model_class
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def get_model_priority(cls, alias: str | type) -> int:
|
|
60
|
+
return cls._get_model_record(alias).ag_priority
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def available_aliases(cls) -> list[str]:
|
|
64
|
+
return sorted(cls.REGISTRY.keys())
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops import rearrange
|
|
11
|
+
from torch.nn.functional import scaled_dot_product_attention
|
|
12
|
+
|
|
13
|
+
from .rope import TimeAwareRotaryEmbedding
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AttentionAxis(Enum):
|
|
19
|
+
TIME = 1
|
|
20
|
+
SPACE = 2
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseMultiheadAttention(torch.nn.Module):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
embed_dim: int,
|
|
27
|
+
num_heads: int,
|
|
28
|
+
dropout: float,
|
|
29
|
+
rotary_emb: TimeAwareRotaryEmbedding | None,
|
|
30
|
+
use_memory_efficient_attention: bool,
|
|
31
|
+
):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.embed_dim = embed_dim
|
|
34
|
+
self.num_heads = num_heads
|
|
35
|
+
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
|
|
36
|
+
self.head_dim = embed_dim // num_heads
|
|
37
|
+
self.rotary_emb = rotary_emb
|
|
38
|
+
|
|
39
|
+
# We allocate a single tensor for the q, k, and v projection matrices,
|
|
40
|
+
# multiply them with the inputs, and then split the projected tensors into q, k, and v using unbind.
|
|
41
|
+
# This reduces overhead a bit vs. having multiple separate Linear layers,
|
|
42
|
+
# which need to be initialized, tracked by the optimizer, etc.
|
|
43
|
+
self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
|
|
44
|
+
self.dropout = dropout
|
|
45
|
+
self.use_memory_efficient_attention = use_memory_efficient_attention
|
|
46
|
+
self.wO = torch.nn.Linear(embed_dim, embed_dim)
|
|
47
|
+
|
|
48
|
+
assert not self.use_memory_efficient_attention, (
|
|
49
|
+
"xformers is not available, so use_memory_efficient_attention must be False"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE):
|
|
53
|
+
raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.")
|
|
54
|
+
|
|
55
|
+
def rearrange_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
pattern = (
|
|
57
|
+
"batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim"
|
|
58
|
+
if self.attention_axis == AttentionAxis.TIME
|
|
59
|
+
else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return rearrange(inputs, pattern)
|
|
63
|
+
|
|
64
|
+
def get_qkv(
|
|
65
|
+
self,
|
|
66
|
+
inputs: torch.Tensor,
|
|
67
|
+
) -> tuple[torch.Tensor, ...]:
|
|
68
|
+
pattern: str = ""
|
|
69
|
+
if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
|
|
70
|
+
pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim"
|
|
71
|
+
elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
|
|
72
|
+
pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim"
|
|
73
|
+
elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
|
|
74
|
+
pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim"
|
|
75
|
+
elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
|
|
76
|
+
pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim"
|
|
77
|
+
|
|
78
|
+
assert pattern
|
|
79
|
+
qkv = self.wQKV(inputs.contiguous())
|
|
80
|
+
return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0)
|
|
81
|
+
|
|
82
|
+
def positional_embedding(self, q, k, v, kv_cache, layer_idx):
|
|
83
|
+
# Apply the rotary embeddings
|
|
84
|
+
seq_pos_offset = 0
|
|
85
|
+
if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME:
|
|
86
|
+
if kv_cache is not None:
|
|
87
|
+
seq_pos_offset = kv_cache.seq_len(layer_idx)
|
|
88
|
+
|
|
89
|
+
# We need to permute because rotary embeddings expect the sequence dimension to be the second-to-last dimension
|
|
90
|
+
q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset)
|
|
91
|
+
|
|
92
|
+
if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
|
|
93
|
+
# First, we append the current input key and value tensors to the cache.
|
|
94
|
+
# This concatenates the current key and value tensors to the existing key and value tensors
|
|
95
|
+
kv_cache.append(layer_idx, (k, v))
|
|
96
|
+
# Then, we retrieve the key and value tensors from the cache.
|
|
97
|
+
# This includes all the key and value tensors from previous time steps
|
|
98
|
+
# as well as the current time step.
|
|
99
|
+
k, v = kv_cache[layer_idx]
|
|
100
|
+
|
|
101
|
+
q = q.contiguous()
|
|
102
|
+
k = k.contiguous().to(q.dtype) # Ensure k is the same dtype as q; this is necessary when using mixed precision
|
|
103
|
+
v = v.contiguous().to(q.dtype) # Ensure v is the same dtype as q; this is necessary when using mixed precision
|
|
104
|
+
|
|
105
|
+
return q, k, v, seq_pos_offset
|
|
106
|
+
|
|
107
|
+
def rearrange_output(self, output: torch.Tensor, batch: int, variate: int, seq_len: int) -> torch.Tensor:
|
|
108
|
+
if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
|
|
109
|
+
pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
110
|
+
elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
|
|
111
|
+
pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
112
|
+
elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
|
|
113
|
+
pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
114
|
+
elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
|
|
115
|
+
pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
116
|
+
|
|
117
|
+
return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) # type: ignore
|
|
118
|
+
|
|
119
|
+
def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate):
|
|
120
|
+
# Determine dimension ranges for attention
|
|
121
|
+
# Ensure the last query vector index is used from the cache
|
|
122
|
+
q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
|
|
123
|
+
kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
|
|
124
|
+
if self.attention_axis == AttentionAxis.TIME:
|
|
125
|
+
attention_mask = (
|
|
126
|
+
attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end]
|
|
127
|
+
if torch.is_tensor(attention_mask)
|
|
128
|
+
else None
|
|
129
|
+
)
|
|
130
|
+
return scaled_dot_product_attention(
|
|
131
|
+
q,
|
|
132
|
+
k,
|
|
133
|
+
v,
|
|
134
|
+
attn_mask=attention_mask,
|
|
135
|
+
dropout_p=dropout,
|
|
136
|
+
is_causal=(attention_mask is None and seq_pos_offset == 0),
|
|
137
|
+
)
|
|
138
|
+
elif self.attention_axis == AttentionAxis.SPACE:
|
|
139
|
+
# We don't use causal masking for space-wise attention
|
|
140
|
+
attention_mask = (
|
|
141
|
+
attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end]
|
|
142
|
+
if torch.is_tensor(attention_mask)
|
|
143
|
+
else None
|
|
144
|
+
)
|
|
145
|
+
return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False)
|
|
146
|
+
else:
|
|
147
|
+
raise ValueError("Invalid attention axis")
|
|
148
|
+
|
|
149
|
+
def forward(
|
|
150
|
+
self,
|
|
151
|
+
layer_idx: int,
|
|
152
|
+
inputs: torch.Tensor,
|
|
153
|
+
attention_mask: torch.Tensor | None = None,
|
|
154
|
+
kv_cache=None,
|
|
155
|
+
) -> torch.Tensor:
|
|
156
|
+
batch_size, variate, seq_len, _ = inputs.shape
|
|
157
|
+
dropout = self.dropout if self.training else 0.0
|
|
158
|
+
|
|
159
|
+
rearranged_inputs = self.rearrange_inputs(inputs)
|
|
160
|
+
q, k, v = self.get_qkv(rearranged_inputs)
|
|
161
|
+
|
|
162
|
+
q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx)
|
|
163
|
+
|
|
164
|
+
output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate)
|
|
165
|
+
|
|
166
|
+
output = self.rearrange_output(output, batch_size, variate, seq_len)
|
|
167
|
+
return self.wO(output)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class TimeWiseMultiheadAttention(BaseMultiheadAttention):
|
|
171
|
+
"""
|
|
172
|
+
Computes standard multihead causal attention over the time axis.
|
|
173
|
+
It does this by flattening out the variates along the batch dimension.
|
|
174
|
+
It also applies rotary position embeddings to the query and key matrices
|
|
175
|
+
in order to incorporate relative positional information.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
attention_axis = AttentionAxis.TIME
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
|
|
182
|
+
"""
|
|
183
|
+
Computes bidirectional multihead attention over the space axis (i.e. across variates within
|
|
184
|
+
a multi-variate time series). This is done by flattening out the time axis along the batch dimension.
|
|
185
|
+
This allows the model to attend to different variates at the same time point. By alternating
|
|
186
|
+
between time-wise and space-wise attention, the model can learn both temporal and cross-variate
|
|
187
|
+
dependencies in the data.
|
|
188
|
+
|
|
189
|
+
Unlike with time-wise attention, don't apply rotary embeddings here
|
|
190
|
+
because we want cross-variate attention to be invariant to the order of the variates.
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
attention_axis = AttentionAxis.SPACE
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention
|