autogluon.timeseries 1.4.1b20251115__py3-none-any.whl → 1.5.0b20251221__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/hyperparameter_presets.py +13 -28
- autogluon/timeseries/configs/predictor_presets.py +23 -39
- autogluon/timeseries/dataset/ts_dataframe.py +32 -34
- autogluon/timeseries/learner.py +67 -33
- autogluon/timeseries/metrics/__init__.py +4 -4
- autogluon/timeseries/metrics/abstract.py +8 -8
- autogluon/timeseries/metrics/point.py +9 -9
- autogluon/timeseries/metrics/quantile.py +4 -4
- autogluon/timeseries/models/__init__.py +2 -1
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -50
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +8 -8
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +30 -26
- autogluon/timeseries/models/autogluon_tabular/per_step.py +13 -11
- autogluon/timeseries/models/autogluon_tabular/transforms.py +2 -2
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +395 -0
- autogluon/timeseries/models/chronos/model.py +30 -25
- autogluon/timeseries/models/chronos/utils.py +5 -5
- autogluon/timeseries/models/ensemble/__init__.py +17 -10
- autogluon/timeseries/models/ensemble/abstract.py +13 -9
- autogluon/timeseries/models/ensemble/array_based/__init__.py +2 -2
- autogluon/timeseries/models/ensemble/array_based/abstract.py +24 -31
- autogluon/timeseries/models/ensemble/array_based/models.py +146 -11
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +2 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +6 -5
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +44 -83
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +21 -55
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +7 -3
- autogluon/timeseries/models/ensemble/weighted/basic.py +26 -13
- autogluon/timeseries/models/ensemble/weighted/greedy.py +21 -144
- autogluon/timeseries/models/gluonts/abstract.py +30 -29
- autogluon/timeseries/models/gluonts/dataset.py +9 -9
- autogluon/timeseries/models/gluonts/models.py +0 -7
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +13 -16
- autogluon/timeseries/models/local/naive.py +2 -2
- autogluon/timeseries/models/local/npts.py +7 -1
- autogluon/timeseries/models/local/statsforecast.py +13 -13
- autogluon/timeseries/models/multi_window/multi_window_model.py +38 -23
- autogluon/timeseries/models/registry.py +3 -4
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +3 -4
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +6 -6
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +4 -9
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +2 -3
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +10 -10
- autogluon/timeseries/models/toto/_internal/dataset.py +2 -2
- autogluon/timeseries/models/toto/_internal/forecaster.py +8 -8
- autogluon/timeseries/models/toto/dataloader.py +4 -4
- autogluon/timeseries/models/toto/hf_pretrained_model.py +97 -16
- autogluon/timeseries/models/toto/model.py +30 -17
- autogluon/timeseries/predictor.py +531 -136
- autogluon/timeseries/regressor.py +18 -23
- autogluon/timeseries/splitter.py +2 -2
- autogluon/timeseries/trainer/ensemble_composer.py +323 -129
- autogluon/timeseries/trainer/model_set_builder.py +9 -9
- autogluon/timeseries/trainer/prediction_cache.py +16 -16
- autogluon/timeseries/trainer/trainer.py +235 -145
- autogluon/timeseries/trainer/utils.py +3 -4
- autogluon/timeseries/transforms/covariate_scaler.py +7 -7
- autogluon/timeseries/transforms/target_scaler.py +8 -8
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +1 -3
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/features.py +22 -9
- autogluon/timeseries/utils/forecast.py +1 -2
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/version.py +1 -1
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/METADATA +23 -21
- autogluon_timeseries-1.5.0b20251221.dist-info/RECORD +103 -0
- autogluon_timeseries-1.4.1b20251115.dist-info/RECORD +0 -96
- /autogluon.timeseries-1.4.1b20251115-py3.9-nspkg.pth → /autogluon.timeseries-1.5.0b20251221-py3.11-nspkg.pth +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/WHEEL +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/licenses/LICENSE +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/licenses/NOTICE +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/namespace_packages.txt +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/top_level.txt +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/zip-safe +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Any,
|
|
2
|
+
from typing import Any, Type
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
@@ -19,7 +19,7 @@ class AbstractStatsForecastModel(AbstractLocalModel):
|
|
|
19
19
|
local_model_args["season_length"] = seasonal_period
|
|
20
20
|
return local_model_args
|
|
21
21
|
|
|
22
|
-
def _get_model_type(self, variant:
|
|
22
|
+
def _get_model_type(self, variant: str | None = None) -> Type:
|
|
23
23
|
raise NotImplementedError
|
|
24
24
|
|
|
25
25
|
def _get_local_model(self, local_model_args: dict):
|
|
@@ -162,7 +162,7 @@ class AutoARIMAModel(AbstractProbabilisticStatsForecastModel):
|
|
|
162
162
|
local_model_args.setdefault("allowmean", True)
|
|
163
163
|
return local_model_args
|
|
164
164
|
|
|
165
|
-
def _get_model_type(self, variant:
|
|
165
|
+
def _get_model_type(self, variant: str | None = None):
|
|
166
166
|
from statsforecast.models import AutoARIMA
|
|
167
167
|
|
|
168
168
|
return AutoARIMA
|
|
@@ -232,7 +232,7 @@ class ARIMAModel(AbstractProbabilisticStatsForecastModel):
|
|
|
232
232
|
local_model_args.setdefault("order", (1, 1, 1))
|
|
233
233
|
return local_model_args
|
|
234
234
|
|
|
235
|
-
def _get_model_type(self, variant:
|
|
235
|
+
def _get_model_type(self, variant: str | None = None):
|
|
236
236
|
from statsforecast.models import ARIMA
|
|
237
237
|
|
|
238
238
|
return ARIMA
|
|
@@ -269,7 +269,7 @@ class AutoETSModel(AbstractProbabilisticStatsForecastModel):
|
|
|
269
269
|
This significantly speeds up fitting and usually leads to no change in accuracy.
|
|
270
270
|
"""
|
|
271
271
|
|
|
272
|
-
ag_priority =
|
|
272
|
+
ag_priority = 60
|
|
273
273
|
init_time_in_seconds = 0 # C++ models require no compilation
|
|
274
274
|
allowed_local_model_args = [
|
|
275
275
|
"damped",
|
|
@@ -277,7 +277,7 @@ class AutoETSModel(AbstractProbabilisticStatsForecastModel):
|
|
|
277
277
|
"seasonal_period",
|
|
278
278
|
]
|
|
279
279
|
|
|
280
|
-
def _get_model_type(self, variant:
|
|
280
|
+
def _get_model_type(self, variant: str | None = None):
|
|
281
281
|
from statsforecast.models import AutoETS
|
|
282
282
|
|
|
283
283
|
return AutoETS
|
|
@@ -380,7 +380,7 @@ class DynamicOptimizedThetaModel(AbstractProbabilisticStatsForecastModel):
|
|
|
380
380
|
"seasonal_period",
|
|
381
381
|
]
|
|
382
382
|
|
|
383
|
-
def _get_model_type(self, variant:
|
|
383
|
+
def _get_model_type(self, variant: str | None = None):
|
|
384
384
|
from statsforecast.models import DynamicOptimizedTheta
|
|
385
385
|
|
|
386
386
|
return DynamicOptimizedTheta
|
|
@@ -425,7 +425,7 @@ class ThetaModel(AbstractProbabilisticStatsForecastModel):
|
|
|
425
425
|
"seasonal_period",
|
|
426
426
|
]
|
|
427
427
|
|
|
428
|
-
def _get_model_type(self, variant:
|
|
428
|
+
def _get_model_type(self, variant: str | None = None):
|
|
429
429
|
from statsforecast.models import Theta
|
|
430
430
|
|
|
431
431
|
return Theta
|
|
@@ -546,7 +546,7 @@ class AutoCESModel(AbstractProbabilisticStatsForecastModel):
|
|
|
546
546
|
"seasonal_period",
|
|
547
547
|
]
|
|
548
548
|
|
|
549
|
-
def _get_model_type(self, variant:
|
|
549
|
+
def _get_model_type(self, variant: str | None = None):
|
|
550
550
|
from statsforecast.models import AutoCES
|
|
551
551
|
|
|
552
552
|
return AutoCES
|
|
@@ -610,7 +610,7 @@ class ADIDAModel(AbstractStatsForecastIntermittentDemandModel):
|
|
|
610
610
|
|
|
611
611
|
ag_priority = 10
|
|
612
612
|
|
|
613
|
-
def _get_model_type(self, variant:
|
|
613
|
+
def _get_model_type(self, variant: str | None = None):
|
|
614
614
|
from statsforecast.models import ADIDA
|
|
615
615
|
|
|
616
616
|
return ADIDA
|
|
@@ -652,7 +652,7 @@ class CrostonModel(AbstractStatsForecastIntermittentDemandModel):
|
|
|
652
652
|
"variant",
|
|
653
653
|
]
|
|
654
654
|
|
|
655
|
-
def _get_model_type(self, variant:
|
|
655
|
+
def _get_model_type(self, variant: str | None = None):
|
|
656
656
|
from statsforecast.models import CrostonClassic, CrostonOptimized, CrostonSBA
|
|
657
657
|
|
|
658
658
|
model_variants = {
|
|
@@ -702,7 +702,7 @@ class IMAPAModel(AbstractStatsForecastIntermittentDemandModel):
|
|
|
702
702
|
|
|
703
703
|
ag_priority = 10
|
|
704
704
|
|
|
705
|
-
def _get_model_type(self, variant:
|
|
705
|
+
def _get_model_type(self, variant: str | None = None):
|
|
706
706
|
from statsforecast.models import IMAPA
|
|
707
707
|
|
|
708
708
|
return IMAPA
|
|
@@ -726,7 +726,7 @@ class ZeroModel(AbstractStatsForecastIntermittentDemandModel):
|
|
|
726
726
|
|
|
727
727
|
ag_priority = 100
|
|
728
728
|
|
|
729
|
-
def _get_model_type(self, variant:
|
|
729
|
+
def _get_model_type(self, variant: str | None = None):
|
|
730
730
|
# ZeroModel does not depend on a StatsForecast implementation
|
|
731
731
|
raise NotImplementedError
|
|
732
732
|
|
|
@@ -4,7 +4,7 @@ import logging
|
|
|
4
4
|
import math
|
|
5
5
|
import os
|
|
6
6
|
import time
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Type
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
from typing_extensions import Self
|
|
@@ -38,8 +38,8 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
38
38
|
|
|
39
39
|
def __init__(
|
|
40
40
|
self,
|
|
41
|
-
model_base:
|
|
42
|
-
model_base_kwargs:
|
|
41
|
+
model_base: AbstractTimeSeriesModel | Type[AbstractTimeSeriesModel],
|
|
42
|
+
model_base_kwargs: dict[str, Any] | None = None,
|
|
43
43
|
**kwargs,
|
|
44
44
|
):
|
|
45
45
|
if inspect.isclass(model_base) and issubclass(model_base, AbstractTimeSeriesModel):
|
|
@@ -58,8 +58,8 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
58
58
|
self.model_base_type = type(self.model_base)
|
|
59
59
|
self.info_per_val_window = []
|
|
60
60
|
|
|
61
|
-
self.most_recent_model:
|
|
62
|
-
self.most_recent_model_folder:
|
|
61
|
+
self.most_recent_model: AbstractTimeSeriesModel | None = None
|
|
62
|
+
self.most_recent_model_folder: str | None = None
|
|
63
63
|
super().__init__(**kwargs)
|
|
64
64
|
|
|
65
65
|
@property
|
|
@@ -83,19 +83,19 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
83
83
|
def _is_gpu_available(self) -> bool:
|
|
84
84
|
return self._get_model_base()._is_gpu_available()
|
|
85
85
|
|
|
86
|
-
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str,
|
|
86
|
+
def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
|
|
87
87
|
return self._get_model_base().get_minimum_resources(is_gpu_available)
|
|
88
88
|
|
|
89
89
|
def _fit(
|
|
90
90
|
self,
|
|
91
91
|
train_data: TimeSeriesDataFrame,
|
|
92
|
-
val_data:
|
|
93
|
-
time_limit:
|
|
94
|
-
num_cpus:
|
|
95
|
-
num_gpus:
|
|
92
|
+
val_data: TimeSeriesDataFrame | None = None,
|
|
93
|
+
time_limit: float | None = None,
|
|
94
|
+
num_cpus: int | None = None,
|
|
95
|
+
num_gpus: int | None = None,
|
|
96
96
|
verbosity: int = 2,
|
|
97
|
-
val_splitter:
|
|
98
|
-
refit_every_n_windows:
|
|
97
|
+
val_splitter: AbstractWindowSplitter | None = None,
|
|
98
|
+
refit_every_n_windows: int | None = 1,
|
|
99
99
|
**kwargs,
|
|
100
100
|
):
|
|
101
101
|
# TODO: use incremental training for GluonTS models?
|
|
@@ -109,9 +109,9 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
109
109
|
if refit_every_n_windows is None:
|
|
110
110
|
refit_every_n_windows = val_splitter.num_val_windows + 1 # only fit model for the first window
|
|
111
111
|
|
|
112
|
-
oof_predictions_per_window = []
|
|
112
|
+
oof_predictions_per_window: list[TimeSeriesDataFrame] = []
|
|
113
113
|
global_fit_start_time = time.time()
|
|
114
|
-
model:
|
|
114
|
+
model: AbstractTimeSeriesModel | None = None
|
|
115
115
|
|
|
116
116
|
for window_index, (train_fold, val_fold) in enumerate(val_splitter.split(train_data)):
|
|
117
117
|
logger.debug(f"\tWindow {window_index}")
|
|
@@ -142,6 +142,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
142
142
|
train_data=train_fold,
|
|
143
143
|
val_data=val_fold,
|
|
144
144
|
time_limit=time_left_for_window,
|
|
145
|
+
verbosity=verbosity,
|
|
145
146
|
**kwargs,
|
|
146
147
|
)
|
|
147
148
|
model.fit_time = time.time() - model_fit_start_time
|
|
@@ -182,8 +183,9 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
182
183
|
self.most_recent_model_folder = most_recent_refit_window # type: ignore
|
|
183
184
|
self.predict_time = self.most_recent_model.predict_time
|
|
184
185
|
self.fit_time = time.time() - global_fit_start_time - self.predict_time # type: ignore
|
|
185
|
-
self.
|
|
186
|
-
|
|
186
|
+
self.cache_oof_predictions(oof_predictions_per_window)
|
|
187
|
+
|
|
188
|
+
self.val_score = float(np.mean([info["val_score"] for info in self.info_per_val_window]))
|
|
187
189
|
|
|
188
190
|
def get_info(self) -> dict:
|
|
189
191
|
info = super().get_info()
|
|
@@ -198,7 +200,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
198
200
|
def _predict(
|
|
199
201
|
self,
|
|
200
202
|
data: TimeSeriesDataFrame,
|
|
201
|
-
known_covariates:
|
|
203
|
+
known_covariates: TimeSeriesDataFrame | None = None,
|
|
202
204
|
**kwargs,
|
|
203
205
|
) -> TimeSeriesDataFrame:
|
|
204
206
|
if self.most_recent_model is None:
|
|
@@ -212,12 +214,25 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
212
214
|
store_predict_time: bool = False,
|
|
213
215
|
**predict_kwargs,
|
|
214
216
|
) -> None:
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
217
|
+
if self._oof_predictions is None or self.most_recent_model is None:
|
|
218
|
+
raise ValueError(f"{self.name} must be fit before calling score_and_cache_oof")
|
|
219
|
+
|
|
220
|
+
# Score on val_data using the most recent model
|
|
221
|
+
past_data, known_covariates = val_data.get_model_inputs_for_scoring(
|
|
222
|
+
prediction_length=self.prediction_length, known_covariates_names=self.covariate_metadata.known_covariates
|
|
223
|
+
)
|
|
224
|
+
predict_start_time = time.time()
|
|
225
|
+
val_predictions = self.most_recent_model.predict(
|
|
226
|
+
past_data, known_covariates=known_covariates, **predict_kwargs
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
self._oof_predictions.append(val_predictions)
|
|
230
|
+
|
|
219
231
|
if store_predict_time:
|
|
220
|
-
|
|
232
|
+
self.predict_time = time.time() - predict_start_time
|
|
233
|
+
|
|
234
|
+
if store_val_score:
|
|
235
|
+
self.val_score = self._score_with_predictions(val_data, val_predictions)
|
|
221
236
|
|
|
222
237
|
def _get_search_space(self):
|
|
223
238
|
return self.model_base._get_search_space()
|
|
@@ -234,7 +249,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
|
|
|
234
249
|
train_fn_kwargs["init_params"]["model_base_kwargs"] = self.get_params()
|
|
235
250
|
return train_fn_kwargs
|
|
236
251
|
|
|
237
|
-
def save(self, path:
|
|
252
|
+
def save(self, path: str | None = None, verbose: bool = True) -> str:
|
|
238
253
|
most_recent_model = self.most_recent_model
|
|
239
254
|
self.most_recent_model = None
|
|
240
255
|
save_path = super().save(path, verbose)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from abc import ABCMeta
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from inspect import isabstract
|
|
4
|
-
from typing import Union
|
|
5
4
|
|
|
6
5
|
|
|
7
6
|
@dataclass
|
|
@@ -44,7 +43,7 @@ class ModelRegistry(ABCMeta):
|
|
|
44
43
|
cls.REGISTRY[alias] = record
|
|
45
44
|
|
|
46
45
|
@classmethod
|
|
47
|
-
def _get_model_record(cls, alias:
|
|
46
|
+
def _get_model_record(cls, alias: str | type) -> ModelRecord:
|
|
48
47
|
if isinstance(alias, type):
|
|
49
48
|
alias = alias.__name__
|
|
50
49
|
alias = alias.removesuffix("Model")
|
|
@@ -53,11 +52,11 @@ class ModelRegistry(ABCMeta):
|
|
|
53
52
|
return cls.REGISTRY[alias]
|
|
54
53
|
|
|
55
54
|
@classmethod
|
|
56
|
-
def get_model_class(cls, alias:
|
|
55
|
+
def get_model_class(cls, alias: str | type) -> type:
|
|
57
56
|
return cls._get_model_record(alias).model_class
|
|
58
57
|
|
|
59
58
|
@classmethod
|
|
60
|
-
def get_model_priority(cls, alias:
|
|
59
|
+
def get_model_priority(cls, alias: str | type) -> int:
|
|
61
60
|
return cls._get_model_record(alias).ag_priority
|
|
62
61
|
|
|
63
62
|
@classmethod
|
|
@@ -5,7 +5,6 @@
|
|
|
5
5
|
|
|
6
6
|
import logging
|
|
7
7
|
from enum import Enum
|
|
8
|
-
from typing import Optional, Union
|
|
9
8
|
|
|
10
9
|
import torch
|
|
11
10
|
from einops import rearrange
|
|
@@ -27,7 +26,7 @@ class BaseMultiheadAttention(torch.nn.Module):
|
|
|
27
26
|
embed_dim: int,
|
|
28
27
|
num_heads: int,
|
|
29
28
|
dropout: float,
|
|
30
|
-
rotary_emb:
|
|
29
|
+
rotary_emb: TimeAwareRotaryEmbedding | None,
|
|
31
30
|
use_memory_efficient_attention: bool,
|
|
32
31
|
):
|
|
33
32
|
super().__init__()
|
|
@@ -151,7 +150,7 @@ class BaseMultiheadAttention(torch.nn.Module):
|
|
|
151
150
|
self,
|
|
152
151
|
layer_idx: int,
|
|
153
152
|
inputs: torch.Tensor,
|
|
154
|
-
attention_mask:
|
|
153
|
+
attention_mask: torch.Tensor | None = None,
|
|
155
154
|
kv_cache=None,
|
|
156
155
|
) -> torch.Tensor:
|
|
157
156
|
batch_size, variate, seq_len, _ = inputs.shape
|
|
@@ -194,4 +193,4 @@ class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
|
|
|
194
193
|
attention_axis = AttentionAxis.SPACE
|
|
195
194
|
|
|
196
195
|
|
|
197
|
-
MultiHeadAttention =
|
|
196
|
+
MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# Copyright 2025 Datadog, Inc.
|
|
5
5
|
|
|
6
6
|
import math
|
|
7
|
-
from typing import NamedTuple
|
|
7
|
+
from typing import NamedTuple
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
@@ -131,7 +131,7 @@ class TotoBackbone(torch.nn.Module):
|
|
|
131
131
|
scaler_cls: str,
|
|
132
132
|
output_distribution_classes: list[str],
|
|
133
133
|
spacewise_first: bool = True,
|
|
134
|
-
output_distribution_kwargs:
|
|
134
|
+
output_distribution_kwargs: dict | None = None,
|
|
135
135
|
use_memory_efficient_attention: bool = True,
|
|
136
136
|
stabilize_with_global: bool = True,
|
|
137
137
|
scale_factor_exponent: float = 10.0,
|
|
@@ -192,8 +192,8 @@ class TotoBackbone(torch.nn.Module):
|
|
|
192
192
|
inputs: torch.Tensor,
|
|
193
193
|
input_padding_mask: torch.Tensor,
|
|
194
194
|
id_mask: torch.Tensor,
|
|
195
|
-
kv_cache:
|
|
196
|
-
scaling_prefix_length:
|
|
195
|
+
kv_cache: KVCache | None = None,
|
|
196
|
+
scaling_prefix_length: int | None = None,
|
|
197
197
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
198
198
|
scaled_inputs: torch.Tensor
|
|
199
199
|
loc: torch.Tensor
|
|
@@ -244,8 +244,8 @@ class TotoBackbone(torch.nn.Module):
|
|
|
244
244
|
inputs: torch.Tensor,
|
|
245
245
|
input_padding_mask: torch.Tensor,
|
|
246
246
|
id_mask: torch.Tensor,
|
|
247
|
-
kv_cache:
|
|
248
|
-
scaling_prefix_length:
|
|
247
|
+
kv_cache: KVCache | None = None,
|
|
248
|
+
scaling_prefix_length: int | None = None,
|
|
249
249
|
) -> TotoOutput:
|
|
250
250
|
flattened, loc, scale = self.backbone(
|
|
251
251
|
inputs,
|
|
@@ -3,16 +3,11 @@
|
|
|
3
3
|
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
4
|
# Copyright 2025 Datadog, Inc.
|
|
5
5
|
|
|
6
|
-
from typing import Optional
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
from einops import rearrange
|
|
10
|
-
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
|
11
|
-
from rotary_embedding_torch.rotary_embedding_torch import default
|
|
12
9
|
|
|
13
|
-
|
|
14
|
-
def exists(val):
|
|
15
|
-
return val is not None
|
|
10
|
+
from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb, default
|
|
16
11
|
|
|
17
12
|
|
|
18
13
|
class TimeAwareRotaryEmbedding(RotaryEmbedding):
|
|
@@ -41,8 +36,8 @@ class TimeAwareRotaryEmbedding(RotaryEmbedding):
|
|
|
41
36
|
self,
|
|
42
37
|
q: torch.Tensor,
|
|
43
38
|
k: torch.Tensor,
|
|
44
|
-
seq_dim:
|
|
45
|
-
seq_pos:
|
|
39
|
+
seq_dim: int | None = None,
|
|
40
|
+
seq_pos: torch.Tensor | None = None,
|
|
46
41
|
seq_pos_offset: int = 0,
|
|
47
42
|
):
|
|
48
43
|
"""
|
|
@@ -78,7 +73,7 @@ class TimeAwareRotaryEmbedding(RotaryEmbedding):
|
|
|
78
73
|
|
|
79
74
|
return rotated_q, rotated_k
|
|
80
75
|
|
|
81
|
-
def get_scale(self, t: torch.Tensor, seq_len:
|
|
76
|
+
def get_scale(self, t: torch.Tensor, seq_len: int | None = None, offset=0):
|
|
82
77
|
"""
|
|
83
78
|
This method is adapted closely from the base class, but it knows how to handle
|
|
84
79
|
when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different
|