autogluon.timeseries 1.2.1b20250224__py3-none-any.whl → 1.4.1b20251215__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +106 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +256 -141
  5. autogluon/timeseries/learner.py +86 -52
  6. autogluon/timeseries/metrics/__init__.py +42 -8
  7. autogluon/timeseries/metrics/abstract.py +89 -19
  8. autogluon/timeseries/metrics/point.py +142 -53
  9. autogluon/timeseries/metrics/quantile.py +46 -21
  10. autogluon/timeseries/metrics/utils.py +4 -4
  11. autogluon/timeseries/models/__init__.py +8 -2
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +361 -592
  14. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +282 -194
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +25 -18
  20. autogluon/timeseries/models/chronos/__init__.py +2 -1
  21. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  22. autogluon/timeseries/models/chronos/model.py +219 -138
  23. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +81 -50
  24. autogluon/timeseries/models/ensemble/__init__.py +37 -2
  25. autogluon/timeseries/models/ensemble/abstract.py +107 -0
  26. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  27. autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
  28. autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  34. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  35. autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
  36. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  37. autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
  38. autogluon/timeseries/models/ensemble/weighted/basic.py +91 -0
  39. autogluon/timeseries/models/ensemble/weighted/greedy.py +62 -0
  40. autogluon/timeseries/models/gluonts/__init__.py +1 -1
  41. autogluon/timeseries/models/gluonts/{abstract_gluonts.py → abstract.py} +148 -208
  42. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  43. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +38 -22
  44. autogluon/timeseries/models/local/__init__.py +0 -7
  45. autogluon/timeseries/models/local/abstract_local_model.py +71 -74
  46. autogluon/timeseries/models/local/naive.py +13 -9
  47. autogluon/timeseries/models/local/npts.py +9 -2
  48. autogluon/timeseries/models/local/statsforecast.py +52 -36
  49. autogluon/timeseries/models/multi_window/multi_window_model.py +65 -45
  50. autogluon/timeseries/models/registry.py +64 -0
  51. autogluon/timeseries/models/toto/__init__.py +3 -0
  52. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  62. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  63. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  64. autogluon/timeseries/models/toto/dataloader.py +108 -0
  65. autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
  66. autogluon/timeseries/models/toto/model.py +249 -0
  67. autogluon/timeseries/predictor.py +685 -297
  68. autogluon/timeseries/regressor.py +94 -44
  69. autogluon/timeseries/splitter.py +8 -32
  70. autogluon/timeseries/trainer/__init__.py +3 -0
  71. autogluon/timeseries/trainer/ensemble_composer.py +444 -0
  72. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  73. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  74. autogluon/timeseries/{trainer.py → trainer/trainer.py} +387 -390
  75. autogluon/timeseries/trainer/utils.py +17 -0
  76. autogluon/timeseries/transforms/__init__.py +2 -13
  77. autogluon/timeseries/transforms/covariate_scaler.py +34 -40
  78. autogluon/timeseries/transforms/target_scaler.py +37 -20
  79. autogluon/timeseries/utils/constants.py +10 -0
  80. autogluon/timeseries/utils/datetime/lags.py +3 -5
  81. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  82. autogluon/timeseries/utils/datetime/time_features.py +2 -2
  83. autogluon/timeseries/utils/features.py +70 -47
  84. autogluon/timeseries/utils/forecast.py +19 -14
  85. autogluon/timeseries/utils/timer.py +173 -0
  86. autogluon/timeseries/utils/warning_filters.py +4 -2
  87. autogluon/timeseries/version.py +1 -1
  88. autogluon.timeseries-1.4.1b20251215-py3.11-nspkg.pth +1 -0
  89. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/METADATA +49 -36
  90. autogluon_timeseries-1.4.1b20251215.dist-info/RECORD +103 -0
  91. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/WHEEL +1 -1
  92. autogluon/timeseries/configs/presets_configs.py +0 -79
  93. autogluon/timeseries/evaluator.py +0 -6
  94. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -11
  95. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  96. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -585
  97. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -518
  98. autogluon/timeseries/models/ensemble/abstract_timeseries_ensemble.py +0 -78
  99. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  100. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  101. autogluon/timeseries/models/presets.py +0 -360
  102. autogluon.timeseries-1.2.1b20250224-py3.9-nspkg.pth +0 -1
  103. autogluon.timeseries-1.2.1b20250224.dist-info/RECORD +0 -68
  104. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.2.1b20250224.dist-info → autogluon_timeseries-1.4.1b20251215.dist-info}/zip-safe +0 -0
@@ -4,12 +4,13 @@ import logging
4
4
  import math
5
5
  import os
6
6
  import time
7
- from typing import Any, Dict, Optional, Type, Union
7
+ from typing import Any, Type
8
8
 
9
9
  import numpy as np
10
+ from typing_extensions import Self
10
11
 
11
12
  import autogluon.core as ag
12
- from autogluon.timeseries.dataset.ts_dataframe import TimeSeriesDataFrame
13
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
13
14
  from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
14
15
  from autogluon.timeseries.models.local.abstract_local_model import AbstractLocalModel
15
16
  from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
@@ -25,10 +26,10 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
25
26
 
26
27
  Parameters
27
28
  ----------
28
- model_base : Union[AbstractTimeSeriesModel, Type[AbstractTimeSeriesModel]]
29
+ model_base
29
30
  The base model to repeatedly train. If a AbstractTimeSeriesModel class, then also provide model_base_kwargs
30
31
  which will be used to initialize the model via model_base(**model_base_kwargs).
31
- model_base_kwargs : Optional[Dict[str, any]], default = None
32
+ model_base_kwargs
32
33
  kwargs used to initialize model_base if model_base is a class.
33
34
  """
34
35
 
@@ -37,8 +38,8 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
37
38
 
38
39
  def __init__(
39
40
  self,
40
- model_base: Union[AbstractTimeSeriesModel, Type[AbstractTimeSeriesModel]],
41
- model_base_kwargs: Optional[Dict[str, Any]] = None,
41
+ model_base: AbstractTimeSeriesModel | Type[AbstractTimeSeriesModel],
42
+ model_base_kwargs: dict[str, Any] | None = None,
42
43
  **kwargs,
43
44
  ):
44
45
  if inspect.isclass(model_base) and issubclass(model_base, AbstractTimeSeriesModel):
@@ -57,8 +58,8 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
57
58
  self.model_base_type = type(self.model_base)
58
59
  self.info_per_val_window = []
59
60
 
60
- self.most_recent_model: Optional[AbstractTimeSeriesModel] = None
61
- self.most_recent_model_folder: Optional[str] = None
61
+ self.most_recent_model: AbstractTimeSeriesModel | None = None
62
+ self.most_recent_model_folder: str | None = None
62
63
  super().__init__(**kwargs)
63
64
 
64
65
  @property
@@ -73,10 +74,6 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
73
74
  def supports_past_covariates(self) -> bool:
74
75
  return self.model_base.supports_past_covariates
75
76
 
76
- @property
77
- def supports_cat_covariates(self) -> bool:
78
- return self.model_base.supports_cat_covariates
79
-
80
77
  def _get_model_base(self):
81
78
  return self.model_base
82
79
 
@@ -86,16 +83,19 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
86
83
  def _is_gpu_available(self) -> bool:
87
84
  return self._get_model_base()._is_gpu_available()
88
85
 
89
- def get_minimum_resources(self, is_gpu_available: bool = False) -> bool:
86
+ def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
90
87
  return self._get_model_base().get_minimum_resources(is_gpu_available)
91
88
 
92
89
  def _fit(
93
90
  self,
94
91
  train_data: TimeSeriesDataFrame,
95
- val_data: Optional[TimeSeriesDataFrame] = None,
96
- time_limit: Optional[int] = None,
97
- val_splitter: AbstractWindowSplitter = None,
98
- refit_every_n_windows: Optional[int] = 1,
92
+ val_data: TimeSeriesDataFrame | None = None,
93
+ time_limit: float | None = None,
94
+ num_cpus: int | None = None,
95
+ num_gpus: int | None = None,
96
+ verbosity: int = 2,
97
+ val_splitter: AbstractWindowSplitter | None = None,
98
+ refit_every_n_windows: int | None = 1,
99
99
  **kwargs,
100
100
  ):
101
101
  # TODO: use incremental training for GluonTS models?
@@ -109,13 +109,17 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
109
109
  if refit_every_n_windows is None:
110
110
  refit_every_n_windows = val_splitter.num_val_windows + 1 # only fit model for the first window
111
111
 
112
- oof_predictions_per_window = []
112
+ oof_predictions_per_window: list[TimeSeriesDataFrame] = []
113
113
  global_fit_start_time = time.time()
114
+ model: AbstractTimeSeriesModel | None = None
114
115
 
115
116
  for window_index, (train_fold, val_fold) in enumerate(val_splitter.split(train_data)):
116
117
  logger.debug(f"\tWindow {window_index}")
118
+
117
119
  # refit_this_window is always True for the 0th window
118
120
  refit_this_window = window_index % refit_every_n_windows == 0
121
+ assert window_index != 0 or refit_this_window
122
+
119
123
  if time_limit is None:
120
124
  time_left_for_window = None
121
125
  else:
@@ -138,6 +142,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
138
142
  train_data=train_fold,
139
143
  val_data=val_fold,
140
144
  time_limit=time_left_for_window,
145
+ verbosity=verbosity,
141
146
  **kwargs,
142
147
  )
143
148
  model.fit_time = time.time() - model_fit_start_time
@@ -148,6 +153,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
148
153
  else:
149
154
  time_left_for_prediction = time_limit - (time.time() - global_fit_start_time)
150
155
 
156
+ assert model is not None
151
157
  model.score_and_cache_oof(
152
158
  val_fold, store_val_score=True, store_predict_time=True, time_limit=time_left_for_prediction
153
159
  )
@@ -172,11 +178,14 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
172
178
 
173
179
  # Only the model trained on most recent data is saved & used for prediction
174
180
  self.most_recent_model = model
175
- self.most_recent_model_folder = most_recent_refit_window
181
+ assert self.most_recent_model is not None
182
+
183
+ self.most_recent_model_folder = most_recent_refit_window # type: ignore
176
184
  self.predict_time = self.most_recent_model.predict_time
177
- self.fit_time = time.time() - global_fit_start_time - self.predict_time
178
- self._oof_predictions = oof_predictions_per_window
179
- self.val_score = np.mean([info["val_score"] for info in self.info_per_val_window])
185
+ self.fit_time = time.time() - global_fit_start_time - self.predict_time # type: ignore
186
+ self.cache_oof_predictions(oof_predictions_per_window)
187
+
188
+ self.val_score = float(np.mean([info["val_score"] for info in self.info_per_val_window]))
180
189
 
181
190
  def get_info(self) -> dict:
182
191
  info = super().get_info()
@@ -191,7 +200,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
191
200
  def _predict(
192
201
  self,
193
202
  data: TimeSeriesDataFrame,
194
- known_covariates: Optional[TimeSeriesDataFrame] = None,
203
+ known_covariates: TimeSeriesDataFrame | None = None,
195
204
  **kwargs,
196
205
  ) -> TimeSeriesDataFrame:
197
206
  if self.most_recent_model is None:
@@ -205,27 +214,34 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
205
214
  store_predict_time: bool = False,
206
215
  **predict_kwargs,
207
216
  ) -> None:
208
- # self.val_score, self.predict_time, self._oof_predictions already saved during _fit()
209
- assert self._oof_predictions is not None
210
- if store_val_score:
211
- assert self.val_score is not None
217
+ if self._oof_predictions is None or self.most_recent_model is None:
218
+ raise ValueError(f"{self.name} must be fit before calling score_and_cache_oof")
219
+
220
+ # Score on val_data using the most recent model
221
+ past_data, known_covariates = val_data.get_model_inputs_for_scoring(
222
+ prediction_length=self.prediction_length, known_covariates_names=self.covariate_metadata.known_covariates
223
+ )
224
+ predict_start_time = time.time()
225
+ val_predictions = self.most_recent_model.predict(
226
+ past_data, known_covariates=known_covariates, **predict_kwargs
227
+ )
228
+
229
+ self._oof_predictions.append(val_predictions)
230
+
212
231
  if store_predict_time:
213
- assert self.predict_time is not None
232
+ self.predict_time = time.time() - predict_start_time
214
233
 
215
- def get_user_params(self) -> dict:
216
- return self.model_base.get_user_params()
234
+ if store_val_score:
235
+ self.val_score = self._score_with_predictions(val_data, val_predictions)
217
236
 
218
237
  def _get_search_space(self):
219
238
  return self.model_base._get_search_space()
220
239
 
221
- def _initialize_covariate_regressor_scaler(self, **kwargs) -> None:
240
+ def _initialize_transforms_and_regressor(self) -> None:
222
241
  # Do not initialize the target_scaler and covariate_regressor in the multi window model!
223
- pass
224
-
225
- def initialize(self, **kwargs) -> dict:
226
- super().initialize(**kwargs)
227
- self.model_base.initialize(**kwargs)
228
- return kwargs
242
+ self.target_scaler = None
243
+ self.covariate_scaler = None
244
+ self.covariate_regressor = None
229
245
 
230
246
  def _get_hpo_train_fn_kwargs(self, **train_fn_kwargs) -> dict:
231
247
  train_fn_kwargs["is_bagged_model"] = True
@@ -233,7 +249,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
233
249
  train_fn_kwargs["init_params"]["model_base_kwargs"] = self.get_params()
234
250
  return train_fn_kwargs
235
251
 
236
- def save(self, path: str = None, verbose=True) -> str:
252
+ def save(self, path: str | None = None, verbose: bool = True) -> str:
237
253
  most_recent_model = self.most_recent_model
238
254
  self.most_recent_model = None
239
255
  save_path = super().save(path, verbose)
@@ -244,32 +260,36 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
244
260
  most_recent_model.save()
245
261
  return save_path
246
262
 
247
- def persist(self):
263
+ def persist(self) -> Self:
248
264
  if self.most_recent_model is None:
249
265
  raise ValueError(f"{self.name} must be fit before persisting")
250
266
  self.most_recent_model.persist()
267
+ return self
251
268
 
252
269
  @classmethod
253
270
  def load(
254
271
  cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
255
272
  ) -> AbstractTimeSeriesModel:
256
273
  model = super().load(path=path, reset_paths=reset_paths, load_oof=load_oof, verbose=verbose)
257
- most_recent_model_path = os.path.join(model.path, model.most_recent_model_folder)
258
- model.most_recent_model = model.model_base_type.load(
259
- most_recent_model_path,
260
- reset_paths=reset_paths,
261
- verbose=verbose,
262
- )
274
+ if model.most_recent_model_folder is not None:
275
+ most_recent_model_path = os.path.join(model.path, model.most_recent_model_folder)
276
+ model.most_recent_model = model.model_base_type.load(
277
+ most_recent_model_path,
278
+ reset_paths=reset_paths,
279
+ verbose=verbose,
280
+ )
263
281
  return model
264
282
 
265
283
  def convert_to_refit_full_template(self) -> AbstractTimeSeriesModel:
266
284
  # refit_model is an instance of base model type, not MultiWindowBacktestingModel
285
+ assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
267
286
  refit_model = self.most_recent_model.convert_to_refit_full_template()
268
287
  refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
269
288
  return refit_model
270
289
 
271
290
  def convert_to_refit_full_via_copy(self) -> AbstractTimeSeriesModel:
272
291
  # refit_model is an instance of base model type, not MultiWindowBacktestingModel
292
+ assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
273
293
  refit_model = self.most_recent_model.convert_to_refit_full_via_copy()
274
294
  refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
275
295
  return refit_model
@@ -0,0 +1,64 @@
1
+ from abc import ABCMeta
2
+ from dataclasses import dataclass
3
+ from inspect import isabstract
4
+
5
+
6
+ @dataclass
7
+ class ModelRecord:
8
+ model_class: type
9
+ ag_priority: int
10
+
11
+
12
+ class ModelRegistry(ABCMeta):
13
+ """Registry metaclass for time series models. Ensures that TimeSeriesModel classes
14
+ which implement this metaclass are automatically registered, in order to centralize
15
+ access to model types.
16
+
17
+ See, https://github.com/faif/python-patterns.
18
+ """
19
+
20
+ REGISTRY: dict[str, ModelRecord] = {}
21
+
22
+ def __new__(cls, name, bases, attrs):
23
+ new_cls = super().__new__(cls, name, bases, attrs)
24
+
25
+ if name is not None and not isabstract(new_cls):
26
+ record = ModelRecord(
27
+ model_class=new_cls,
28
+ ag_priority=getattr(new_cls, "ag_priority", 0),
29
+ )
30
+ cls._add(name.removesuffix("Model"), record)
31
+
32
+ # if the class provides additional aliases, register them too
33
+ if aliases := attrs.get("ag_model_aliases"):
34
+ for alias in aliases:
35
+ cls._add(alias, record)
36
+
37
+ return new_cls
38
+
39
+ @classmethod
40
+ def _add(cls, alias: str, record: ModelRecord) -> None:
41
+ if alias in cls.REGISTRY:
42
+ raise ValueError(f"You are trying to define a new model with {alias}, but this model already exists.")
43
+ cls.REGISTRY[alias] = record
44
+
45
+ @classmethod
46
+ def _get_model_record(cls, alias: str | type) -> ModelRecord:
47
+ if isinstance(alias, type):
48
+ alias = alias.__name__
49
+ alias = alias.removesuffix("Model")
50
+ if alias not in cls.REGISTRY:
51
+ raise ValueError(f"Unknown model: {alias}, available models are: {cls.available_aliases()}")
52
+ return cls.REGISTRY[alias]
53
+
54
+ @classmethod
55
+ def get_model_class(cls, alias: str | type) -> type:
56
+ return cls._get_model_record(alias).model_class
57
+
58
+ @classmethod
59
+ def get_model_priority(cls, alias: str | type) -> int:
60
+ return cls._get_model_record(alias).ag_priority
61
+
62
+ @classmethod
63
+ def available_aliases(cls) -> list[str]:
64
+ return sorted(cls.REGISTRY.keys())
@@ -0,0 +1,3 @@
1
+ from .model import TotoModel
2
+
3
+ __all__ = ["TotoModel"]
@@ -0,0 +1,9 @@
1
+ from .backbone import TotoBackbone
2
+ from .dataset import MaskedTimeseries
3
+ from .forecaster import TotoForecaster
4
+
5
+ __all__ = [
6
+ "MaskedTimeseries",
7
+ "TotoBackbone",
8
+ "TotoForecaster",
9
+ ]
@@ -0,0 +1,3 @@
1
+ from .backbone import TotoBackbone
2
+
3
+ __all__ = ["TotoBackbone"]
@@ -0,0 +1,196 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import logging
7
+ from enum import Enum
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from torch.nn.functional import scaled_dot_product_attention
12
+
13
+ from .rope import TimeAwareRotaryEmbedding
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class AttentionAxis(Enum):
19
+ TIME = 1
20
+ SPACE = 2
21
+
22
+
23
+ class BaseMultiheadAttention(torch.nn.Module):
24
+ def __init__(
25
+ self,
26
+ embed_dim: int,
27
+ num_heads: int,
28
+ dropout: float,
29
+ rotary_emb: TimeAwareRotaryEmbedding | None,
30
+ use_memory_efficient_attention: bool,
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.num_heads = num_heads
35
+ assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
36
+ self.head_dim = embed_dim // num_heads
37
+ self.rotary_emb = rotary_emb
38
+
39
+ # We allocate a single tensor for the q, k, and v projection matrices,
40
+ # multiply them with the inputs, and then split the projected tensors into q, k, and v using unbind.
41
+ # This reduces overhead a bit vs. having multiple separate Linear layers,
42
+ # which need to be initialized, tracked by the optimizer, etc.
43
+ self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
44
+ self.dropout = dropout
45
+ self.use_memory_efficient_attention = use_memory_efficient_attention
46
+ self.wO = torch.nn.Linear(embed_dim, embed_dim)
47
+
48
+ assert not self.use_memory_efficient_attention, (
49
+ "xformers is not available, so use_memory_efficient_attention must be False"
50
+ )
51
+
52
+ if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE):
53
+ raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.")
54
+
55
+ def rearrange_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
56
+ pattern = (
57
+ "batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim"
58
+ if self.attention_axis == AttentionAxis.TIME
59
+ else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim"
60
+ )
61
+
62
+ return rearrange(inputs, pattern)
63
+
64
+ def get_qkv(
65
+ self,
66
+ inputs: torch.Tensor,
67
+ ) -> tuple[torch.Tensor, ...]:
68
+ pattern: str = ""
69
+ if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
70
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim"
71
+ elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
72
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim"
73
+ elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
74
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim"
75
+ elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
76
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim"
77
+
78
+ assert pattern
79
+ qkv = self.wQKV(inputs.contiguous())
80
+ return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0)
81
+
82
+ def positional_embedding(self, q, k, v, kv_cache, layer_idx):
83
+ # Apply the rotary embeddings
84
+ seq_pos_offset = 0
85
+ if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME:
86
+ if kv_cache is not None:
87
+ seq_pos_offset = kv_cache.seq_len(layer_idx)
88
+
89
+ # We need to permute because rotary embeddings expect the sequence dimension to be the second-to-last dimension
90
+ q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset)
91
+
92
+ if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
93
+ # First, we append the current input key and value tensors to the cache.
94
+ # This concatenates the current key and value tensors to the existing key and value tensors
95
+ kv_cache.append(layer_idx, (k, v))
96
+ # Then, we retrieve the key and value tensors from the cache.
97
+ # This includes all the key and value tensors from previous time steps
98
+ # as well as the current time step.
99
+ k, v = kv_cache[layer_idx]
100
+
101
+ q = q.contiguous()
102
+ k = k.contiguous().to(q.dtype) # Ensure k is the same dtype as q; this is necessary when using mixed precision
103
+ v = v.contiguous().to(q.dtype) # Ensure v is the same dtype as q; this is necessary when using mixed precision
104
+
105
+ return q, k, v, seq_pos_offset
106
+
107
+ def rearrange_output(self, output: torch.Tensor, batch: int, variate: int, seq_len: int) -> torch.Tensor:
108
+ if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
109
+ pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
110
+ elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
111
+ pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)"
112
+ elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
113
+ pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
114
+ elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
115
+ pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)"
116
+
117
+ return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) # type: ignore
118
+
119
+ def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate):
120
+ # Determine dimension ranges for attention
121
+ # Ensure the last query vector index is used from the cache
122
+ q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
123
+ kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
124
+ if self.attention_axis == AttentionAxis.TIME:
125
+ attention_mask = (
126
+ attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end]
127
+ if torch.is_tensor(attention_mask)
128
+ else None
129
+ )
130
+ return scaled_dot_product_attention(
131
+ q,
132
+ k,
133
+ v,
134
+ attn_mask=attention_mask,
135
+ dropout_p=dropout,
136
+ is_causal=(attention_mask is None and seq_pos_offset == 0),
137
+ )
138
+ elif self.attention_axis == AttentionAxis.SPACE:
139
+ # We don't use causal masking for space-wise attention
140
+ attention_mask = (
141
+ attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end]
142
+ if torch.is_tensor(attention_mask)
143
+ else None
144
+ )
145
+ return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False)
146
+ else:
147
+ raise ValueError("Invalid attention axis")
148
+
149
+ def forward(
150
+ self,
151
+ layer_idx: int,
152
+ inputs: torch.Tensor,
153
+ attention_mask: torch.Tensor | None = None,
154
+ kv_cache=None,
155
+ ) -> torch.Tensor:
156
+ batch_size, variate, seq_len, _ = inputs.shape
157
+ dropout = self.dropout if self.training else 0.0
158
+
159
+ rearranged_inputs = self.rearrange_inputs(inputs)
160
+ q, k, v = self.get_qkv(rearranged_inputs)
161
+
162
+ q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx)
163
+
164
+ output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate)
165
+
166
+ output = self.rearrange_output(output, batch_size, variate, seq_len)
167
+ return self.wO(output)
168
+
169
+
170
+ class TimeWiseMultiheadAttention(BaseMultiheadAttention):
171
+ """
172
+ Computes standard multihead causal attention over the time axis.
173
+ It does this by flattening out the variates along the batch dimension.
174
+ It also applies rotary position embeddings to the query and key matrices
175
+ in order to incorporate relative positional information.
176
+ """
177
+
178
+ attention_axis = AttentionAxis.TIME
179
+
180
+
181
+ class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
182
+ """
183
+ Computes bidirectional multihead attention over the space axis (i.e. across variates within
184
+ a multi-variate time series). This is done by flattening out the time axis along the batch dimension.
185
+ This allows the model to attend to different variates at the same time point. By alternating
186
+ between time-wise and space-wise attention, the model can learn both temporal and cross-variate
187
+ dependencies in the data.
188
+
189
+ Unlike with time-wise attention, don't apply rotary embeddings here
190
+ because we want cross-variate attention to be invariant to the order of the variates.
191
+ """
192
+
193
+ attention_axis = AttentionAxis.SPACE
194
+
195
+
196
+ MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention