autogluon.timeseries 1.4.1b20250924__tar.gz → 1.4.1b20251011__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/PKG-INFO +2 -1
  2. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/setup.py +5 -1
  3. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/__init__.py +2 -0
  4. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/__init__.py +3 -0
  5. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  6. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  7. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
  8. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  9. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  10. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  11. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
  12. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
  13. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  14. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  15. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  16. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/dataloader.py +108 -0
  17. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
  18. autogluon.timeseries-1.4.1b20251011/src/autogluon/timeseries/models/toto/model.py +234 -0
  19. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/version.py +1 -1
  20. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon.timeseries.egg-info/PKG-INFO +2 -1
  21. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon.timeseries.egg-info/SOURCES.txt +15 -0
  22. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon.timeseries.egg-info/requires.txt +10 -4
  23. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/setup.cfg +0 -0
  24. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/__init__.py +0 -0
  25. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/configs/__init__.py +0 -0
  26. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/configs/hyperparameter_presets.py +0 -0
  27. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/configs/predictor_presets.py +0 -0
  28. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/dataset/__init__.py +0 -0
  29. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/dataset/ts_dataframe.py +0 -0
  30. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/evaluator.py +0 -0
  31. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/learner.py +0 -0
  32. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/metrics/__init__.py +0 -0
  33. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/metrics/abstract.py +0 -0
  34. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/metrics/point.py +0 -0
  35. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/metrics/quantile.py +0 -0
  36. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/metrics/utils.py +0 -0
  37. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/abstract/__init__.py +0 -0
  38. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/abstract/abstract_timeseries_model.py +0 -0
  39. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/abstract/model_trial.py +0 -0
  40. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/abstract/tunable.py +0 -0
  41. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/autogluon_tabular/__init__.py +0 -0
  42. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/autogluon_tabular/mlforecast.py +0 -0
  43. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/autogluon_tabular/per_step.py +0 -0
  44. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/autogluon_tabular/transforms.py +0 -0
  45. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/autogluon_tabular/utils.py +0 -0
  46. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/chronos/__init__.py +0 -0
  47. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/chronos/model.py +0 -0
  48. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -0
  49. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/chronos/pipeline/base.py +0 -0
  50. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -0
  51. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -0
  52. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/chronos/pipeline/utils.py +0 -0
  53. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/ensemble/__init__.py +0 -0
  54. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/ensemble/abstract.py +0 -0
  55. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/ensemble/basic.py +0 -0
  56. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/ensemble/greedy.py +0 -0
  57. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/gluonts/__init__.py +0 -0
  58. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/gluonts/abstract.py +0 -0
  59. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/gluonts/dataset.py +0 -0
  60. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/gluonts/models.py +0 -0
  61. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/local/__init__.py +0 -0
  62. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/local/abstract_local_model.py +0 -0
  63. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/local/naive.py +0 -0
  64. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/local/npts.py +0 -0
  65. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/local/statsforecast.py +0 -0
  66. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/multi_window/__init__.py +0 -0
  67. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/multi_window/multi_window_model.py +0 -0
  68. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/models/registry.py +0 -0
  69. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/predictor.py +0 -0
  70. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/regressor.py +0 -0
  71. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/splitter.py +0 -0
  72. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/trainer/__init__.py +0 -0
  73. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/trainer/model_set_builder.py +0 -0
  74. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/trainer/prediction_cache.py +0 -0
  75. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/trainer/trainer.py +0 -0
  76. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/transforms/__init__.py +0 -0
  77. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/transforms/covariate_scaler.py +0 -0
  78. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/transforms/target_scaler.py +0 -0
  79. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/__init__.py +0 -0
  80. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/datetime/__init__.py +0 -0
  81. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/datetime/base.py +0 -0
  82. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/datetime/lags.py +0 -0
  83. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/datetime/seasonality.py +0 -0
  84. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/datetime/time_features.py +0 -0
  85. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/features.py +0 -0
  86. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/forecast.py +0 -0
  87. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon/timeseries/utils/warning_filters.py +0 -0
  88. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon.timeseries.egg-info/dependency_links.txt +0 -0
  89. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon.timeseries.egg-info/namespace_packages.txt +0 -0
  90. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon.timeseries.egg-info/top_level.txt +0 -0
  91. {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251011}/src/autogluon.timeseries.egg-info/zip-safe +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: autogluon.timeseries
3
- Version: 1.4.1b20250924
3
+ Version: 1.4.1b20251011
4
4
  Summary: Fast and Accurate ML in 3 Lines of Code
5
5
  Home-page: https://github.com/autogluon/autogluon
6
6
  Author: AutoGluon Community
@@ -35,6 +35,7 @@ Classifier: Topic :: Scientific/Engineering :: Image Recognition
35
35
  Requires-Python: >=3.9, <3.13
36
36
  Description-Content-Type: text/markdown
37
37
  Provides-Extra: tests
38
+ Provides-Extra: toto
38
39
  Provides-Extra: all
39
40
  License-File: ../LICENSE
40
41
  License-File: ../NOTICE
@@ -54,10 +54,14 @@ extras_require = {
54
54
  "flaky>=3.7,<4",
55
55
  "pytest-timeout>=2.1,<3",
56
56
  ],
57
+ "toto": [
58
+ "einops>=0.7,<1",
59
+ "rotary-embedding-torch>=0.8,<1",
60
+ ],
57
61
  }
58
62
 
59
63
  # chronos-openvino and chronos-onnx are deprecated, and will be removed in a future version
60
- extras_require["all"] = []
64
+ extras_require["all"] = list(set.union(*(set(extras_require[extra]) for extra in ["toto"])))
61
65
  install_requires = ag.get_dependency_version_ranges(install_requires)
62
66
 
63
67
  if __name__ == "__main__":
@@ -28,6 +28,7 @@ from .local import (
28
28
  ZeroModel,
29
29
  )
30
30
  from .registry import ModelRegistry
31
+ from .toto import TotoModel
31
32
 
32
33
  __all__ = [
33
34
  "ADIDAModel",
@@ -56,6 +57,7 @@ __all__ = [
56
57
  "TemporalFusionTransformerModel",
57
58
  "ThetaModel",
58
59
  "TiDEModel",
60
+ "TotoModel",
59
61
  "WaveNetModel",
60
62
  "ZeroModel",
61
63
  ]
@@ -0,0 +1,3 @@
1
+ from .model import TotoModel
2
+
3
+ __all__ = ["TotoModel"]
@@ -0,0 +1,9 @@
1
+ from .backbone import TotoBackbone
2
+ from .dataset import MaskedTimeseries
3
+ from .forecaster import TotoForecaster
4
+
5
+ __all__ = [
6
+ "MaskedTimeseries",
7
+ "TotoBackbone",
8
+ "TotoForecaster",
9
+ ]
@@ -0,0 +1,3 @@
1
+ from .backbone import TotoBackbone
2
+
3
+ __all__ = ["TotoBackbone"]
@@ -0,0 +1,197 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import logging
7
+ from enum import Enum
8
+ from typing import Optional, Union
9
+
10
+ import torch
11
+ from einops import rearrange
12
+ from torch.nn.functional import scaled_dot_product_attention
13
+
14
+ from .rope import TimeAwareRotaryEmbedding
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ class AttentionAxis(Enum):
20
+ TIME = 1
21
+ SPACE = 2
22
+
23
+
24
+ class BaseMultiheadAttention(torch.nn.Module):
25
+ def __init__(
26
+ self,
27
+ embed_dim: int,
28
+ num_heads: int,
29
+ dropout: float,
30
+ rotary_emb: Optional[TimeAwareRotaryEmbedding],
31
+ use_memory_efficient_attention: bool,
32
+ ):
33
+ super().__init__()
34
+ self.embed_dim = embed_dim
35
+ self.num_heads = num_heads
36
+ assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
37
+ self.head_dim = embed_dim // num_heads
38
+ self.rotary_emb = rotary_emb
39
+
40
+ # We allocate a single tensor for the q, k, and v projection matrices,
41
+ # multiply them with the inputs, and then split the projected tensors into q, k, and v using unbind.
42
+ # This reduces overhead a bit vs. having multiple separate Linear layers,
43
+ # which need to be initialized, tracked by the optimizer, etc.
44
+ self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
45
+ self.dropout = dropout
46
+ self.use_memory_efficient_attention = use_memory_efficient_attention
47
+ self.wO = torch.nn.Linear(embed_dim, embed_dim)
48
+
49
+ assert not self.use_memory_efficient_attention, (
50
+ "xformers is not available, so use_memory_efficient_attention must be False"
51
+ )
52
+
53
+ if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE):
54
+ raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.")
55
+
56
+ def rearrange_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
57
+ pattern = (
58
+ "batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim"
59
+ if self.attention_axis == AttentionAxis.TIME
60
+ else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim"
61
+ )
62
+
63
+ return rearrange(inputs, pattern)
64
+
65
+ def get_qkv(
66
+ self,
67
+ inputs: torch.Tensor,
68
+ ) -> tuple[torch.Tensor, ...]:
69
+ pattern: str = ""
70
+ if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
71
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim"
72
+ elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
73
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim"
74
+ elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
75
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim"
76
+ elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
77
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim"
78
+
79
+ assert pattern
80
+ qkv = self.wQKV(inputs.contiguous())
81
+ return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0)
82
+
83
+ def positional_embedding(self, q, k, v, kv_cache, layer_idx):
84
+ # Apply the rotary embeddings
85
+ seq_pos_offset = 0
86
+ if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME:
87
+ if kv_cache is not None:
88
+ seq_pos_offset = kv_cache.seq_len(layer_idx)
89
+
90
+ # We need to permute because rotary embeddings expect the sequence dimension to be the second-to-last dimension
91
+ q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset)
92
+
93
+ if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
94
+ # First, we append the current input key and value tensors to the cache.
95
+ # This concatenates the current key and value tensors to the existing key and value tensors
96
+ kv_cache.append(layer_idx, (k, v))
97
+ # Then, we retrieve the key and value tensors from the cache.
98
+ # This includes all the key and value tensors from previous time steps
99
+ # as well as the current time step.
100
+ k, v = kv_cache[layer_idx]
101
+
102
+ q = q.contiguous()
103
+ k = k.contiguous().to(q.dtype) # Ensure k is the same dtype as q; this is necessary when using mixed precision
104
+ v = v.contiguous().to(q.dtype) # Ensure v is the same dtype as q; this is necessary when using mixed precision
105
+
106
+ return q, k, v, seq_pos_offset
107
+
108
+ def rearrange_output(self, output: torch.Tensor, batch: int, variate: int, seq_len: int) -> torch.Tensor:
109
+ if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
110
+ pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
111
+ elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
112
+ pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)"
113
+ elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
114
+ pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
115
+ elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
116
+ pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)"
117
+
118
+ return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) # type: ignore
119
+
120
+ def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate):
121
+ # Determine dimension ranges for attention
122
+ # Ensure the last query vector index is used from the cache
123
+ q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
124
+ kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
125
+ if self.attention_axis == AttentionAxis.TIME:
126
+ attention_mask = (
127
+ attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end]
128
+ if torch.is_tensor(attention_mask)
129
+ else None
130
+ )
131
+ return scaled_dot_product_attention(
132
+ q,
133
+ k,
134
+ v,
135
+ attn_mask=attention_mask,
136
+ dropout_p=dropout,
137
+ is_causal=(attention_mask is None and seq_pos_offset == 0),
138
+ )
139
+ elif self.attention_axis == AttentionAxis.SPACE:
140
+ # We don't use causal masking for space-wise attention
141
+ attention_mask = (
142
+ attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end]
143
+ if torch.is_tensor(attention_mask)
144
+ else None
145
+ )
146
+ return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False)
147
+ else:
148
+ raise ValueError("Invalid attention axis")
149
+
150
+ def forward(
151
+ self,
152
+ layer_idx: int,
153
+ inputs: torch.Tensor,
154
+ attention_mask: Optional[torch.Tensor] = None,
155
+ kv_cache=None,
156
+ ) -> torch.Tensor:
157
+ batch_size, variate, seq_len, _ = inputs.shape
158
+ dropout = self.dropout if self.training else 0.0
159
+
160
+ rearranged_inputs = self.rearrange_inputs(inputs)
161
+ q, k, v = self.get_qkv(rearranged_inputs)
162
+
163
+ q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx)
164
+
165
+ output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate)
166
+
167
+ output = self.rearrange_output(output, batch_size, variate, seq_len)
168
+ return self.wO(output)
169
+
170
+
171
+ class TimeWiseMultiheadAttention(BaseMultiheadAttention):
172
+ """
173
+ Computes standard multihead causal attention over the time axis.
174
+ It does this by flattening out the variates along the batch dimension.
175
+ It also applies rotary position embeddings to the query and key matrices
176
+ in order to incorporate relative positional information.
177
+ """
178
+
179
+ attention_axis = AttentionAxis.TIME
180
+
181
+
182
+ class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
183
+ """
184
+ Computes bidirectional multihead attention over the space axis (i.e. across variates within
185
+ a multi-variate time series). This is done by flattening out the time axis along the batch dimension.
186
+ This allows the model to attend to different variates at the same time point. By alternating
187
+ between time-wise and space-wise attention, the model can learn both temporal and cross-variate
188
+ dependencies in the data.
189
+
190
+ Unlike with time-wise attention, don't apply rotary embeddings here
191
+ because we want cross-variate attention to be invariant to the order of the variates.
192
+ """
193
+
194
+ attention_axis = AttentionAxis.SPACE
195
+
196
+
197
+ MultiHeadAttention = Union[TimeWiseMultiheadAttention, SpaceWiseMultiheadAttention]
@@ -0,0 +1,262 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import math
7
+ from typing import NamedTuple, Optional
8
+
9
+ import torch
10
+
11
+ from .distribution import MixtureOfStudentTsOutput
12
+ from .kvcache import KVCache
13
+ from .scaler import CausalPatchStdMeanScaler
14
+ from .transformer import Transformer
15
+
16
+
17
+ class TotoOutput(NamedTuple):
18
+ """
19
+ Output of the Toto model. Contains the output distribution, the location parameters,
20
+ and the scale parameters.
21
+ """
22
+
23
+ distribution: torch.distributions.Distribution
24
+ loc: torch.Tensor
25
+ scale: torch.Tensor
26
+
27
+
28
+ def patchify_id_mask(id_mask: torch.Tensor, patch_size: int) -> torch.Tensor:
29
+ patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size)
30
+ patched_id_mask_min = patched_id_mask.min(-1).values
31
+ patched_id_mask_max = patched_id_mask.max(-1).values
32
+ assert torch.eq(patched_id_mask_min, patched_id_mask_max).all(), "Patches cannot span multiple datasets"
33
+ return patched_id_mask_min
34
+
35
+
36
+ class PatchEmbedding(torch.nn.Module):
37
+ """
38
+ Multivariate time series patch embedding.
39
+ Patchifies each variate separately.
40
+ """
41
+
42
+ def __init__(self, patch_size: int, stride: int, embed_dim: int):
43
+ super().__init__()
44
+ self.patch_size = patch_size
45
+ self.embed_dim = embed_dim
46
+ self.stride = stride
47
+ self.projection = torch.nn.Linear(self.patch_size, self.embed_dim)
48
+
49
+ def _patchify(self, x: torch.Tensor) -> torch.Tensor:
50
+ return x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
51
+
52
+ def forward(
53
+ self,
54
+ x: torch.Tensor,
55
+ id_mask: torch.Tensor,
56
+ ) -> tuple[torch.Tensor, torch.Tensor]:
57
+ assert x.shape[-1] % self.patch_size == 0, (
58
+ f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})"
59
+ )
60
+ x_patched: torch.Tensor = self._patchify(x)
61
+ id_mask_patched: torch.Tensor = self._patchify(id_mask)
62
+
63
+ assert torch.eq(id_mask_patched.min(-1).values, id_mask_patched.max(-1).values).all(), (
64
+ "Patches cannot span multiple datasets"
65
+ )
66
+
67
+ return (
68
+ self.projection(x_patched),
69
+ id_mask_patched.min(-1).values,
70
+ )
71
+
72
+
73
+ class TotoBackbone(torch.nn.Module):
74
+ """
75
+ Toto (Timeseries-Optimized Transformer for Observability) is a transformer-based model for multivariate
76
+ time series forecasting. It applies a patch embedding to the input data, followed by a transformer
77
+ that alternates between time-wise and space-wise attention. The transformer is followed by a linear projection
78
+ that maps the transformer output to the output distribution.
79
+
80
+ The output distribution can be a single distribution (e.g. Gaussian) or a mixture of distributions.
81
+ If a mixture of distributions is used, the model will learn to predict the mixture weights
82
+ as well as the parameters of the individual distributions.
83
+
84
+ Parameters
85
+ ----------
86
+ patch_size
87
+ Size of the patch to use for the patch embedding.
88
+ stride
89
+ Stride to use for the patch embedding.
90
+ embed_dim
91
+ Dimension of the model's latent space.
92
+ num_layers
93
+ Number of transformer layers to use.
94
+ num_heads
95
+ Number of attention heads to use in each self-attention layer.
96
+ mlp_hidden_dim
97
+ Dimension of the hidden layer in the feedforward network.
98
+ dropout
99
+ Dropout rate to use in the model.
100
+ spacewise_every_n_layers
101
+ How many time-wise transformer layers to apply between each space-wise transformer layer.
102
+ spacewise_first
103
+ Whether to apply space-wise attention before time-wise attention.
104
+ scaler_cls
105
+ Class to use for scaling the input data.
106
+ output_distribution_classes
107
+ List of classes to use for the output distribution. If a single class is provided, the model
108
+ will output a single distribution. If multiple classes are provided, the model will output a
109
+ learned mixture of distributions.
110
+ output_distribution_kwargs
111
+ Keyword arguments to pass to the output distribution class. Note: this currently only works
112
+ with a single output distribution class.
113
+ use_memory_efficient_attention:
114
+ Whether to use memory-efficient attention. If True, the model will use the memory-efficient from xFormers.
115
+ stabilize_with_global:
116
+ Whether to use global statistics to stabilize causal statistics by clamping extreme values. Only applies to causal scalers.
117
+ scale_factor_exponent:
118
+ Exponent that controls the allowed range of deviation from global scale for causal scalers.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ patch_size: int,
124
+ stride: int,
125
+ embed_dim: int,
126
+ num_layers: int,
127
+ num_heads: int,
128
+ mlp_hidden_dim: int,
129
+ dropout: float,
130
+ spacewise_every_n_layers: int,
131
+ scaler_cls: str,
132
+ output_distribution_classes: list[str],
133
+ spacewise_first: bool = True,
134
+ output_distribution_kwargs: Optional[dict] = None,
135
+ use_memory_efficient_attention: bool = True,
136
+ stabilize_with_global: bool = True,
137
+ scale_factor_exponent: float = 10.0,
138
+ ):
139
+ super().__init__()
140
+ self.embed_dim = embed_dim
141
+ # strings are used when loading a safetensors checkpoint
142
+ # Initialize patch-based scalers with the correct patch_size
143
+
144
+ self.scaler = CausalPatchStdMeanScaler(
145
+ patch_size=patch_size,
146
+ stabilize_with_global=stabilize_with_global,
147
+ scale_factor_exponent=scale_factor_exponent,
148
+ )
149
+ self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim)
150
+ self.dropout = dropout
151
+ self.num_layers = num_layers
152
+ self.use_memory_efficient_attention = use_memory_efficient_attention
153
+ self.transformer = Transformer(
154
+ embed_dim=embed_dim,
155
+ num_heads=num_heads,
156
+ num_layers=self.num_layers,
157
+ mlp_hidden_dim=mlp_hidden_dim,
158
+ dropout=dropout,
159
+ spacewise_every_n_layers=spacewise_every_n_layers,
160
+ spacewise_first=spacewise_first,
161
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
162
+ )
163
+ self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size)
164
+
165
+ # TODO[BEN] this doesn't need to be a list
166
+ output_distribution_classes_ = [MixtureOfStudentTsOutput]
167
+ self.output_distribution = output_distribution_classes_[0](embed_dim, **(output_distribution_kwargs or {}))
168
+
169
+ def allocate_kv_cache(
170
+ self,
171
+ batch_size: int,
172
+ num_variates: int,
173
+ max_time_steps: int,
174
+ device: torch.device,
175
+ dtype: torch.dtype,
176
+ ) -> KVCache:
177
+ return KVCache(
178
+ batch_size=batch_size,
179
+ num_variates=num_variates,
180
+ transformer_layers=list(self.transformer.layers),
181
+ num_layers=self.num_layers,
182
+ embed_dim=self.embed_dim,
183
+ num_heads=self.transformer.layers[0].num_heads, # type: ignore
184
+ max_seq_len=math.ceil(max_time_steps / self.patch_embed.stride),
185
+ device=device,
186
+ dtype=dtype,
187
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
188
+ )
189
+
190
+ def backbone(
191
+ self,
192
+ inputs: torch.Tensor,
193
+ input_padding_mask: torch.Tensor,
194
+ id_mask: torch.Tensor,
195
+ kv_cache: Optional[KVCache] = None,
196
+ scaling_prefix_length: Optional[int] = None,
197
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
198
+ scaled_inputs: torch.Tensor
199
+ loc: torch.Tensor
200
+ scale: torch.Tensor
201
+
202
+ # Standard scaling operation, same API but without ID mask.
203
+ scaled_inputs, loc, scale = self.scaler(
204
+ inputs,
205
+ weights=torch.ones_like(inputs, device=inputs.device),
206
+ padding_mask=input_padding_mask,
207
+ prefix_length=scaling_prefix_length,
208
+ )
209
+
210
+ if kv_cache is not None:
211
+ prefix_len = self.patch_embed.stride * kv_cache.current_len(0)
212
+
213
+ # Truncate inputs so that the transformer only processes
214
+ # the last patch in the sequence. We'll use the KVCache
215
+ # for the earlier patches.
216
+ scaled_inputs = scaled_inputs[:, :, prefix_len:]
217
+
218
+ # As a simplification, when using kv cache we only allow decoding
219
+ # one step at a time after the initial forward pass.
220
+ assert (prefix_len == 0) or (scaled_inputs.shape[-1] == self.patch_embed.stride), (
221
+ "Must decode one step at a time."
222
+ )
223
+
224
+ input_padding_mask = input_padding_mask[:, :, prefix_len:]
225
+ id_mask = id_mask[:, :, prefix_len:]
226
+
227
+ embeddings: torch.Tensor
228
+ reduced_id_mask: torch.Tensor
229
+
230
+ embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask)
231
+
232
+ # Apply the transformer on the embeddings
233
+ transformed: torch.Tensor = self.transformer(embeddings, reduced_id_mask, kv_cache)
234
+
235
+ # Unembed and flatten the sequence
236
+ unembedded = self.unembed(transformed)
237
+ batch_size, num_variates, seq_len = unembedded.shape[:3]
238
+ patch_size = unembedded.shape[-1] // self.embed_dim
239
+ flattened = unembedded.view(batch_size, num_variates, seq_len * patch_size, self.embed_dim)
240
+ return flattened, loc, scale
241
+
242
+ def forward(
243
+ self,
244
+ inputs: torch.Tensor,
245
+ input_padding_mask: torch.Tensor,
246
+ id_mask: torch.Tensor,
247
+ kv_cache: Optional[KVCache] = None,
248
+ scaling_prefix_length: Optional[int] = None,
249
+ ) -> TotoOutput:
250
+ flattened, loc, scale = self.backbone(
251
+ inputs,
252
+ input_padding_mask,
253
+ id_mask,
254
+ kv_cache,
255
+ scaling_prefix_length,
256
+ )
257
+
258
+ return TotoOutput(self.output_distribution(flattened), loc, scale)
259
+
260
+ @property
261
+ def device(self):
262
+ return next(self.parameters()).device
@@ -0,0 +1,70 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from abc import ABC
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from gluonts.torch.distributions import AffineTransformed
11
+ from gluonts.torch.distributions.studentT import StudentT
12
+
13
+
14
+ class DistributionOutput(ABC, torch.nn.Module):
15
+ pass
16
+
17
+
18
+ class StudentTOutput(DistributionOutput):
19
+ def __init__(self, embed_dim):
20
+ super().__init__()
21
+ self.embed_dim = embed_dim
22
+ self.df = torch.nn.Linear(embed_dim, 1)
23
+ self.loc_proj = torch.nn.Linear(embed_dim, 1)
24
+ self.scale_proj = torch.nn.Linear(embed_dim, 1)
25
+
26
+ def forward(self, inputs, loc=None, scale=None):
27
+ eps = torch.finfo(inputs.dtype).eps
28
+ df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
29
+ base_loc = self.loc_proj(inputs).squeeze(-1)
30
+ base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
31
+
32
+ base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) # type: ignore
33
+
34
+ if loc is not None and scale is not None:
35
+ return AffineTransformed(
36
+ base_dist,
37
+ loc=loc,
38
+ scale=scale,
39
+ )
40
+ return base_dist
41
+
42
+
43
+ class MixtureOfStudentTsOutput(DistributionOutput):
44
+ def __init__(
45
+ self,
46
+ embed_dim,
47
+ k_components,
48
+ ):
49
+ super().__init__()
50
+ self.embed_dim = embed_dim
51
+ self.k_components = k_components
52
+
53
+ self.df = torch.nn.Linear(embed_dim, k_components)
54
+ self.loc_proj = torch.nn.Linear(embed_dim, k_components)
55
+ self.scale_proj = torch.nn.Linear(embed_dim, k_components)
56
+ self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
57
+
58
+ def forward(self, inputs, loc=None, scale=None):
59
+ df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
60
+ loc = self.loc_proj(inputs)
61
+ scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
62
+ logits = self.mixture_weights(inputs)
63
+ probs = F.softmax(logits, dim=-1)
64
+ components = StudentT(df, loc, scale)
65
+ mixture_distribution = torch.distributions.Categorical(probs=probs)
66
+
67
+ return torch.distributions.MixtureSameFamily(
68
+ mixture_distribution,
69
+ components,
70
+ )