autogluon.timeseries 1.0.1b20240304__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (108) hide show
  1. autogluon/timeseries/configs/__init__.py +3 -2
  2. autogluon/timeseries/configs/hyperparameter_presets.py +62 -0
  3. autogluon/timeseries/configs/predictor_presets.py +84 -0
  4. autogluon/timeseries/dataset/ts_dataframe.py +339 -186
  5. autogluon/timeseries/learner.py +192 -60
  6. autogluon/timeseries/metrics/__init__.py +55 -11
  7. autogluon/timeseries/metrics/abstract.py +96 -25
  8. autogluon/timeseries/metrics/point.py +186 -39
  9. autogluon/timeseries/metrics/quantile.py +47 -20
  10. autogluon/timeseries/metrics/utils.py +6 -6
  11. autogluon/timeseries/models/__init__.py +13 -7
  12. autogluon/timeseries/models/abstract/__init__.py +2 -2
  13. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +533 -273
  14. autogluon/timeseries/models/abstract/model_trial.py +10 -10
  15. autogluon/timeseries/models/abstract/tunable.py +189 -0
  16. autogluon/timeseries/models/autogluon_tabular/__init__.py +2 -0
  17. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +369 -215
  18. autogluon/timeseries/models/autogluon_tabular/per_step.py +513 -0
  19. autogluon/timeseries/models/autogluon_tabular/transforms.py +67 -0
  20. autogluon/timeseries/models/autogluon_tabular/utils.py +3 -51
  21. autogluon/timeseries/models/chronos/__init__.py +4 -0
  22. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  23. autogluon/timeseries/models/chronos/model.py +738 -0
  24. autogluon/timeseries/models/chronos/utils.py +369 -0
  25. autogluon/timeseries/models/ensemble/__init__.py +35 -2
  26. autogluon/timeseries/models/ensemble/{abstract_timeseries_ensemble.py → abstract.py} +50 -26
  27. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  28. autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
  29. autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  31. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  32. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
  33. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  34. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  35. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  36. autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
  37. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  38. autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
  39. autogluon/timeseries/models/ensemble/weighted/basic.py +78 -0
  40. autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
  41. autogluon/timeseries/models/gluonts/__init__.py +3 -1
  42. autogluon/timeseries/models/gluonts/abstract.py +583 -0
  43. autogluon/timeseries/models/gluonts/dataset.py +109 -0
  44. autogluon/timeseries/models/gluonts/{torch/models.py → models.py} +185 -44
  45. autogluon/timeseries/models/local/__init__.py +1 -10
  46. autogluon/timeseries/models/local/abstract_local_model.py +150 -97
  47. autogluon/timeseries/models/local/naive.py +31 -23
  48. autogluon/timeseries/models/local/npts.py +6 -2
  49. autogluon/timeseries/models/local/statsforecast.py +99 -112
  50. autogluon/timeseries/models/multi_window/multi_window_model.py +99 -40
  51. autogluon/timeseries/models/registry.py +64 -0
  52. autogluon/timeseries/models/toto/__init__.py +3 -0
  53. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  58. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  59. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  60. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  61. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  62. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  63. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  64. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  65. autogluon/timeseries/models/toto/dataloader.py +108 -0
  66. autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
  67. autogluon/timeseries/models/toto/model.py +236 -0
  68. autogluon/timeseries/predictor.py +826 -305
  69. autogluon/timeseries/regressor.py +253 -0
  70. autogluon/timeseries/splitter.py +10 -31
  71. autogluon/timeseries/trainer/__init__.py +2 -3
  72. autogluon/timeseries/trainer/ensemble_composer.py +439 -0
  73. autogluon/timeseries/trainer/model_set_builder.py +256 -0
  74. autogluon/timeseries/trainer/prediction_cache.py +149 -0
  75. autogluon/timeseries/trainer/trainer.py +1298 -0
  76. autogluon/timeseries/trainer/utils.py +17 -0
  77. autogluon/timeseries/transforms/__init__.py +2 -0
  78. autogluon/timeseries/transforms/covariate_scaler.py +164 -0
  79. autogluon/timeseries/transforms/target_scaler.py +149 -0
  80. autogluon/timeseries/utils/constants.py +10 -0
  81. autogluon/timeseries/utils/datetime/base.py +38 -20
  82. autogluon/timeseries/utils/datetime/lags.py +18 -16
  83. autogluon/timeseries/utils/datetime/seasonality.py +14 -14
  84. autogluon/timeseries/utils/datetime/time_features.py +17 -14
  85. autogluon/timeseries/utils/features.py +317 -53
  86. autogluon/timeseries/utils/forecast.py +31 -17
  87. autogluon/timeseries/utils/timer.py +173 -0
  88. autogluon/timeseries/utils/warning_filters.py +44 -6
  89. autogluon/timeseries/version.py +2 -1
  90. autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
  91. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +71 -47
  92. autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
  93. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
  94. autogluon/timeseries/configs/presets_configs.py +0 -11
  95. autogluon/timeseries/evaluator.py +0 -6
  96. autogluon/timeseries/models/ensemble/greedy_ensemble.py +0 -170
  97. autogluon/timeseries/models/gluonts/abstract_gluonts.py +0 -550
  98. autogluon/timeseries/models/gluonts/torch/__init__.py +0 -0
  99. autogluon/timeseries/models/presets.py +0 -325
  100. autogluon/timeseries/trainer/abstract_trainer.py +0 -1144
  101. autogluon/timeseries/trainer/auto_trainer.py +0 -74
  102. autogluon.timeseries-1.0.1b20240304-py3.8-nspkg.pth +0 -1
  103. autogluon.timeseries-1.0.1b20240304.dist-info/RECORD +0 -58
  104. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
  105. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
  106. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
  107. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
  108. {autogluon.timeseries-1.0.1b20240304.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
@@ -4,12 +4,13 @@ import logging
4
4
  import math
5
5
  import os
6
6
  import time
7
- from typing import Dict, Optional, Type, Union
7
+ from typing import Any, Type
8
8
 
9
9
  import numpy as np
10
+ from typing_extensions import Self
10
11
 
11
12
  import autogluon.core as ag
12
- from autogluon.timeseries.dataset.ts_dataframe import TimeSeriesDataFrame
13
+ from autogluon.timeseries.dataset import TimeSeriesDataFrame
13
14
  from autogluon.timeseries.models.abstract import AbstractTimeSeriesModel
14
15
  from autogluon.timeseries.models.local.abstract_local_model import AbstractLocalModel
15
16
  from autogluon.timeseries.splitter import AbstractWindowSplitter, ExpandingWindowSplitter
@@ -25,17 +26,20 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
25
26
 
26
27
  Parameters
27
28
  ----------
28
- model_base : Union[AbstractTimeSeriesModel, Type[AbstractTimeSeriesModel]]
29
+ model_base
29
30
  The base model to repeatedly train. If a AbstractTimeSeriesModel class, then also provide model_base_kwargs
30
31
  which will be used to initialize the model via model_base(**model_base_kwargs).
31
- model_base_kwargs : Optional[Dict[str, any]], default = None
32
+ model_base_kwargs
32
33
  kwargs used to initialize model_base if model_base is a class.
33
34
  """
34
35
 
36
+ # TODO: Remove the MultiWindowBacktestingModel class, move the logic to TimeSeriesTrainer
37
+ default_max_time_limit_ratio = 1.0
38
+
35
39
  def __init__(
36
40
  self,
37
- model_base: Union[AbstractTimeSeriesModel, Type[AbstractTimeSeriesModel]],
38
- model_base_kwargs: Optional[Dict[str, any]] = None,
41
+ model_base: AbstractTimeSeriesModel | Type[AbstractTimeSeriesModel],
42
+ model_base_kwargs: dict[str, Any] | None = None,
39
43
  **kwargs,
40
44
  ):
41
45
  if inspect.isclass(model_base) and issubclass(model_base, AbstractTimeSeriesModel):
@@ -54,10 +58,22 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
54
58
  self.model_base_type = type(self.model_base)
55
59
  self.info_per_val_window = []
56
60
 
57
- self.most_recent_model: AbstractTimeSeriesModel = None
58
- self.most_recent_model_folder: Optional[str] = None
61
+ self.most_recent_model: AbstractTimeSeriesModel | None = None
62
+ self.most_recent_model_folder: str | None = None
59
63
  super().__init__(**kwargs)
60
64
 
65
+ @property
66
+ def supports_static_features(self) -> bool:
67
+ return self.model_base.supports_static_features
68
+
69
+ @property
70
+ def supports_known_covariates(self) -> bool:
71
+ return self.model_base.supports_known_covariates
72
+
73
+ @property
74
+ def supports_past_covariates(self) -> bool:
75
+ return self.model_base.supports_past_covariates
76
+
61
77
  def _get_model_base(self):
62
78
  return self.model_base
63
79
 
@@ -67,16 +83,19 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
67
83
  def _is_gpu_available(self) -> bool:
68
84
  return self._get_model_base()._is_gpu_available()
69
85
 
70
- def get_minimum_resources(self, is_gpu_available: bool = False) -> bool:
86
+ def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
71
87
  return self._get_model_base().get_minimum_resources(is_gpu_available)
72
88
 
73
89
  def _fit(
74
90
  self,
75
91
  train_data: TimeSeriesDataFrame,
76
- val_data: Optional[TimeSeriesDataFrame] = None,
77
- time_limit: Optional[int] = None,
78
- val_splitter: AbstractWindowSplitter = None,
79
- refit_every_n_windows: Optional[int] = 1,
92
+ val_data: TimeSeriesDataFrame | None = None,
93
+ time_limit: float | None = None,
94
+ num_cpus: int | None = None,
95
+ num_gpus: int | None = None,
96
+ verbosity: int = 2,
97
+ val_splitter: AbstractWindowSplitter | None = None,
98
+ refit_every_n_windows: int | None = 1,
80
99
  **kwargs,
81
100
  ):
82
101
  # TODO: use incremental training for GluonTS models?
@@ -90,13 +109,17 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
90
109
  if refit_every_n_windows is None:
91
110
  refit_every_n_windows = val_splitter.num_val_windows + 1 # only fit model for the first window
92
111
 
93
- oof_predictions_per_window = []
112
+ oof_predictions_per_window: list[TimeSeriesDataFrame] = []
94
113
  global_fit_start_time = time.time()
114
+ model: AbstractTimeSeriesModel | None = None
95
115
 
96
116
  for window_index, (train_fold, val_fold) in enumerate(val_splitter.split(train_data)):
97
117
  logger.debug(f"\tWindow {window_index}")
118
+
98
119
  # refit_this_window is always True for the 0th window
99
120
  refit_this_window = window_index % refit_every_n_windows == 0
121
+ assert window_index != 0 or refit_this_window
122
+
100
123
  if time_limit is None:
101
124
  time_left_for_window = None
102
125
  else:
@@ -110,8 +133,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
110
133
  num_refits_remaining = math.ceil(
111
134
  (val_splitter.num_val_windows - window_index) / refit_every_n_windows
112
135
  )
113
- # Reserve 10% of the remaining time for prediction, use 90% of time for training
114
- time_left_for_window = 0.9 * time_left / num_refits_remaining
136
+ time_left_for_window = time_left / num_refits_remaining
115
137
 
116
138
  if refit_this_window:
117
139
  model = self.get_child_model(window_index)
@@ -120,11 +142,21 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
120
142
  train_data=train_fold,
121
143
  val_data=val_fold,
122
144
  time_limit=time_left_for_window,
145
+ verbosity=verbosity,
123
146
  **kwargs,
124
147
  )
125
148
  model.fit_time = time.time() - model_fit_start_time
126
149
  most_recent_refit_window = f"W{window_index}"
127
- model.score_and_cache_oof(val_fold, store_val_score=True, store_predict_time=True)
150
+
151
+ if time_limit is None:
152
+ time_left_for_prediction = None
153
+ else:
154
+ time_left_for_prediction = time_limit - (time.time() - global_fit_start_time)
155
+
156
+ assert model is not None
157
+ model.score_and_cache_oof(
158
+ val_fold, store_val_score=True, store_predict_time=True, time_limit=time_left_for_prediction
159
+ )
128
160
 
129
161
  oof_predictions_per_window.append(model.get_oof_predictions()[0])
130
162
 
@@ -146,11 +178,14 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
146
178
 
147
179
  # Only the model trained on most recent data is saved & used for prediction
148
180
  self.most_recent_model = model
149
- self.most_recent_model_folder = most_recent_refit_window
181
+ assert self.most_recent_model is not None
182
+
183
+ self.most_recent_model_folder = most_recent_refit_window # type: ignore
150
184
  self.predict_time = self.most_recent_model.predict_time
151
- self.fit_time = time.time() - global_fit_start_time - self.predict_time
152
- self._oof_predictions = oof_predictions_per_window
153
- self.val_score = np.mean([info["val_score"] for info in self.info_per_val_window])
185
+ self.fit_time = time.time() - global_fit_start_time - self.predict_time # type: ignore
186
+ self.cache_oof_predictions(oof_predictions_per_window)
187
+
188
+ self.val_score = float(np.mean([info["val_score"] for info in self.info_per_val_window]))
154
189
 
155
190
  def get_info(self) -> dict:
156
191
  info = super().get_info()
@@ -165,7 +200,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
165
200
  def _predict(
166
201
  self,
167
202
  data: TimeSeriesDataFrame,
168
- known_covariates: Optional[TimeSeriesDataFrame] = None,
203
+ known_covariates: TimeSeriesDataFrame | None = None,
169
204
  **kwargs,
170
205
  ) -> TimeSeriesDataFrame:
171
206
  if self.most_recent_model is None:
@@ -177,23 +212,36 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
177
212
  val_data: TimeSeriesDataFrame,
178
213
  store_val_score: bool = False,
179
214
  store_predict_time: bool = False,
215
+ **predict_kwargs,
180
216
  ) -> None:
181
- # self.val_score, self.predict_time, self._oof_predictions already saved during _fit()
182
- assert self._oof_predictions is not None
183
- if store_val_score:
184
- assert self.val_score is not None
217
+ if self._oof_predictions is None or self.most_recent_model is None:
218
+ raise ValueError(f"{self.name} must be fit before calling score_and_cache_oof")
219
+
220
+ # Score on val_data using the most recent model
221
+ past_data, known_covariates = val_data.get_model_inputs_for_scoring(
222
+ prediction_length=self.prediction_length, known_covariates_names=self.covariate_metadata.known_covariates
223
+ )
224
+ predict_start_time = time.time()
225
+ val_predictions = self.most_recent_model.predict(
226
+ past_data, known_covariates=known_covariates, **predict_kwargs
227
+ )
228
+
229
+ self._oof_predictions.append(val_predictions)
230
+
185
231
  if store_predict_time:
186
- assert self.predict_time is not None
232
+ self.predict_time = time.time() - predict_start_time
187
233
 
188
- def get_user_params(self) -> dict:
189
- return self.model_base.get_user_params()
234
+ if store_val_score:
235
+ self.val_score = self._score_with_predictions(val_data, val_predictions)
190
236
 
191
237
  def _get_search_space(self):
192
238
  return self.model_base._get_search_space()
193
239
 
194
- def _initialize(self, **kwargs) -> None:
195
- super()._initialize(**kwargs)
196
- self.model_base.initialize(**kwargs)
240
+ def _initialize_transforms_and_regressor(self) -> None:
241
+ # Do not initialize the target_scaler and covariate_regressor in the multi window model!
242
+ self.target_scaler = None
243
+ self.covariate_scaler = None
244
+ self.covariate_regressor = None
197
245
 
198
246
  def _get_hpo_train_fn_kwargs(self, **train_fn_kwargs) -> dict:
199
247
  train_fn_kwargs["is_bagged_model"] = True
@@ -201,7 +249,7 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
201
249
  train_fn_kwargs["init_params"]["model_base_kwargs"] = self.get_params()
202
250
  return train_fn_kwargs
203
251
 
204
- def save(self, path: str = None, verbose=True) -> str:
252
+ def save(self, path: str | None = None, verbose: bool = True) -> str:
205
253
  most_recent_model = self.most_recent_model
206
254
  self.most_recent_model = None
207
255
  save_path = super().save(path, verbose)
@@ -212,30 +260,41 @@ class MultiWindowBacktestingModel(AbstractTimeSeriesModel):
212
260
  most_recent_model.save()
213
261
  return save_path
214
262
 
263
+ def persist(self) -> Self:
264
+ if self.most_recent_model is None:
265
+ raise ValueError(f"{self.name} must be fit before persisting")
266
+ self.most_recent_model.persist()
267
+ return self
268
+
215
269
  @classmethod
216
270
  def load(
217
271
  cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
218
272
  ) -> AbstractTimeSeriesModel:
219
273
  model = super().load(path=path, reset_paths=reset_paths, load_oof=load_oof, verbose=verbose)
220
- most_recent_model_path = os.path.join(model.path, model.most_recent_model_folder)
221
- model.most_recent_model = model.model_base_type.load(
222
- most_recent_model_path,
223
- reset_paths=reset_paths,
224
- verbose=verbose,
225
- )
274
+ if model.most_recent_model_folder is not None:
275
+ most_recent_model_path = os.path.join(model.path, model.most_recent_model_folder)
276
+ model.most_recent_model = model.model_base_type.load(
277
+ most_recent_model_path,
278
+ reset_paths=reset_paths,
279
+ verbose=verbose,
280
+ )
226
281
  return model
227
282
 
228
283
  def convert_to_refit_full_template(self) -> AbstractTimeSeriesModel:
229
284
  # refit_model is an instance of base model type, not MultiWindowBacktestingModel
285
+ assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
230
286
  refit_model = self.most_recent_model.convert_to_refit_full_template()
231
287
  refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
232
288
  return refit_model
233
289
 
234
290
  def convert_to_refit_full_via_copy(self) -> AbstractTimeSeriesModel:
235
291
  # refit_model is an instance of base model type, not MultiWindowBacktestingModel
292
+ assert self.most_recent_model is not None, "Most recent model is None. Model must be fit first."
236
293
  refit_model = self.most_recent_model.convert_to_refit_full_via_copy()
237
294
  refit_model.rename(self.name + ag.constants.REFIT_FULL_SUFFIX)
238
295
  return refit_model
239
296
 
240
297
  def _more_tags(self) -> dict:
241
- return self.most_recent_model._get_tags()
298
+ tags = self.model_base._get_tags()
299
+ tags["can_use_val_data"] = False
300
+ return tags
@@ -0,0 +1,64 @@
1
+ from abc import ABCMeta
2
+ from dataclasses import dataclass
3
+ from inspect import isabstract
4
+
5
+
6
+ @dataclass
7
+ class ModelRecord:
8
+ model_class: type
9
+ ag_priority: int
10
+
11
+
12
+ class ModelRegistry(ABCMeta):
13
+ """Registry metaclass for time series models. Ensures that TimeSeriesModel classes
14
+ which implement this metaclass are automatically registered, in order to centralize
15
+ access to model types.
16
+
17
+ See, https://github.com/faif/python-patterns.
18
+ """
19
+
20
+ REGISTRY: dict[str, ModelRecord] = {}
21
+
22
+ def __new__(cls, name, bases, attrs):
23
+ new_cls = super().__new__(cls, name, bases, attrs)
24
+
25
+ if name is not None and not isabstract(new_cls):
26
+ record = ModelRecord(
27
+ model_class=new_cls,
28
+ ag_priority=getattr(new_cls, "ag_priority", 0),
29
+ )
30
+ cls._add(name.removesuffix("Model"), record)
31
+
32
+ # if the class provides additional aliases, register them too
33
+ if aliases := attrs.get("ag_model_aliases"):
34
+ for alias in aliases:
35
+ cls._add(alias, record)
36
+
37
+ return new_cls
38
+
39
+ @classmethod
40
+ def _add(cls, alias: str, record: ModelRecord) -> None:
41
+ if alias in cls.REGISTRY:
42
+ raise ValueError(f"You are trying to define a new model with {alias}, but this model already exists.")
43
+ cls.REGISTRY[alias] = record
44
+
45
+ @classmethod
46
+ def _get_model_record(cls, alias: str | type) -> ModelRecord:
47
+ if isinstance(alias, type):
48
+ alias = alias.__name__
49
+ alias = alias.removesuffix("Model")
50
+ if alias not in cls.REGISTRY:
51
+ raise ValueError(f"Unknown model: {alias}, available models are: {cls.available_aliases()}")
52
+ return cls.REGISTRY[alias]
53
+
54
+ @classmethod
55
+ def get_model_class(cls, alias: str | type) -> type:
56
+ return cls._get_model_record(alias).model_class
57
+
58
+ @classmethod
59
+ def get_model_priority(cls, alias: str | type) -> int:
60
+ return cls._get_model_record(alias).ag_priority
61
+
62
+ @classmethod
63
+ def available_aliases(cls) -> list[str]:
64
+ return sorted(cls.REGISTRY.keys())
@@ -0,0 +1,3 @@
1
+ from .model import TotoModel
2
+
3
+ __all__ = ["TotoModel"]
@@ -0,0 +1,9 @@
1
+ from .backbone import TotoBackbone
2
+ from .dataset import MaskedTimeseries
3
+ from .forecaster import TotoForecaster
4
+
5
+ __all__ = [
6
+ "MaskedTimeseries",
7
+ "TotoBackbone",
8
+ "TotoForecaster",
9
+ ]
@@ -0,0 +1,3 @@
1
+ from .backbone import TotoBackbone
2
+
3
+ __all__ = ["TotoBackbone"]
@@ -0,0 +1,196 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import logging
7
+ from enum import Enum
8
+
9
+ import torch
10
+ from einops import rearrange
11
+ from torch.nn.functional import scaled_dot_product_attention
12
+
13
+ from .rope import TimeAwareRotaryEmbedding
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class AttentionAxis(Enum):
19
+ TIME = 1
20
+ SPACE = 2
21
+
22
+
23
+ class BaseMultiheadAttention(torch.nn.Module):
24
+ def __init__(
25
+ self,
26
+ embed_dim: int,
27
+ num_heads: int,
28
+ dropout: float,
29
+ rotary_emb: TimeAwareRotaryEmbedding | None,
30
+ use_memory_efficient_attention: bool,
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.num_heads = num_heads
35
+ assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
36
+ self.head_dim = embed_dim // num_heads
37
+ self.rotary_emb = rotary_emb
38
+
39
+ # We allocate a single tensor for the q, k, and v projection matrices,
40
+ # multiply them with the inputs, and then split the projected tensors into q, k, and v using unbind.
41
+ # This reduces overhead a bit vs. having multiple separate Linear layers,
42
+ # which need to be initialized, tracked by the optimizer, etc.
43
+ self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
44
+ self.dropout = dropout
45
+ self.use_memory_efficient_attention = use_memory_efficient_attention
46
+ self.wO = torch.nn.Linear(embed_dim, embed_dim)
47
+
48
+ assert not self.use_memory_efficient_attention, (
49
+ "xformers is not available, so use_memory_efficient_attention must be False"
50
+ )
51
+
52
+ if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE):
53
+ raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.")
54
+
55
+ def rearrange_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
56
+ pattern = (
57
+ "batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim"
58
+ if self.attention_axis == AttentionAxis.TIME
59
+ else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim"
60
+ )
61
+
62
+ return rearrange(inputs, pattern)
63
+
64
+ def get_qkv(
65
+ self,
66
+ inputs: torch.Tensor,
67
+ ) -> tuple[torch.Tensor, ...]:
68
+ pattern: str = ""
69
+ if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
70
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim"
71
+ elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
72
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim"
73
+ elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
74
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim"
75
+ elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
76
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim"
77
+
78
+ assert pattern
79
+ qkv = self.wQKV(inputs.contiguous())
80
+ return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0)
81
+
82
+ def positional_embedding(self, q, k, v, kv_cache, layer_idx):
83
+ # Apply the rotary embeddings
84
+ seq_pos_offset = 0
85
+ if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME:
86
+ if kv_cache is not None:
87
+ seq_pos_offset = kv_cache.seq_len(layer_idx)
88
+
89
+ # We need to permute because rotary embeddings expect the sequence dimension to be the second-to-last dimension
90
+ q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset)
91
+
92
+ if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
93
+ # First, we append the current input key and value tensors to the cache.
94
+ # This concatenates the current key and value tensors to the existing key and value tensors
95
+ kv_cache.append(layer_idx, (k, v))
96
+ # Then, we retrieve the key and value tensors from the cache.
97
+ # This includes all the key and value tensors from previous time steps
98
+ # as well as the current time step.
99
+ k, v = kv_cache[layer_idx]
100
+
101
+ q = q.contiguous()
102
+ k = k.contiguous().to(q.dtype) # Ensure k is the same dtype as q; this is necessary when using mixed precision
103
+ v = v.contiguous().to(q.dtype) # Ensure v is the same dtype as q; this is necessary when using mixed precision
104
+
105
+ return q, k, v, seq_pos_offset
106
+
107
+ def rearrange_output(self, output: torch.Tensor, batch: int, variate: int, seq_len: int) -> torch.Tensor:
108
+ if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
109
+ pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
110
+ elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
111
+ pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)"
112
+ elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
113
+ pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
114
+ elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
115
+ pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)"
116
+
117
+ return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) # type: ignore
118
+
119
+ def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate):
120
+ # Determine dimension ranges for attention
121
+ # Ensure the last query vector index is used from the cache
122
+ q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
123
+ kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
124
+ if self.attention_axis == AttentionAxis.TIME:
125
+ attention_mask = (
126
+ attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end]
127
+ if torch.is_tensor(attention_mask)
128
+ else None
129
+ )
130
+ return scaled_dot_product_attention(
131
+ q,
132
+ k,
133
+ v,
134
+ attn_mask=attention_mask,
135
+ dropout_p=dropout,
136
+ is_causal=(attention_mask is None and seq_pos_offset == 0),
137
+ )
138
+ elif self.attention_axis == AttentionAxis.SPACE:
139
+ # We don't use causal masking for space-wise attention
140
+ attention_mask = (
141
+ attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end]
142
+ if torch.is_tensor(attention_mask)
143
+ else None
144
+ )
145
+ return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False)
146
+ else:
147
+ raise ValueError("Invalid attention axis")
148
+
149
+ def forward(
150
+ self,
151
+ layer_idx: int,
152
+ inputs: torch.Tensor,
153
+ attention_mask: torch.Tensor | None = None,
154
+ kv_cache=None,
155
+ ) -> torch.Tensor:
156
+ batch_size, variate, seq_len, _ = inputs.shape
157
+ dropout = self.dropout if self.training else 0.0
158
+
159
+ rearranged_inputs = self.rearrange_inputs(inputs)
160
+ q, k, v = self.get_qkv(rearranged_inputs)
161
+
162
+ q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx)
163
+
164
+ output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate)
165
+
166
+ output = self.rearrange_output(output, batch_size, variate, seq_len)
167
+ return self.wO(output)
168
+
169
+
170
+ class TimeWiseMultiheadAttention(BaseMultiheadAttention):
171
+ """
172
+ Computes standard multihead causal attention over the time axis.
173
+ It does this by flattening out the variates along the batch dimension.
174
+ It also applies rotary position embeddings to the query and key matrices
175
+ in order to incorporate relative positional information.
176
+ """
177
+
178
+ attention_axis = AttentionAxis.TIME
179
+
180
+
181
+ class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
182
+ """
183
+ Computes bidirectional multihead attention over the space axis (i.e. across variates within
184
+ a multi-variate time series). This is done by flattening out the time axis along the batch dimension.
185
+ This allows the model to attend to different variates at the same time point. By alternating
186
+ between time-wise and space-wise attention, the model can learn both temporal and cross-variate
187
+ dependencies in the data.
188
+
189
+ Unlike with time-wise attention, don't apply rotary embeddings here
190
+ because we want cross-variate attention to be invariant to the order of the variates.
191
+ """
192
+
193
+ attention_axis = AttentionAxis.SPACE
194
+
195
+
196
+ MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention