autogluon.timeseries 1.4.1b20250907__py3-none-any.whl → 1.5.1b20260122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (95) hide show
  1. autogluon/timeseries/configs/hyperparameter_presets.py +13 -28
  2. autogluon/timeseries/configs/predictor_presets.py +23 -39
  3. autogluon/timeseries/dataset/ts_dataframe.py +97 -86
  4. autogluon/timeseries/learner.py +70 -35
  5. autogluon/timeseries/metrics/__init__.py +4 -4
  6. autogluon/timeseries/metrics/abstract.py +8 -8
  7. autogluon/timeseries/metrics/point.py +9 -9
  8. autogluon/timeseries/metrics/quantile.py +5 -5
  9. autogluon/timeseries/metrics/utils.py +4 -4
  10. autogluon/timeseries/models/__init__.py +4 -1
  11. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -50
  12. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  13. autogluon/timeseries/models/abstract/tunable.py +8 -8
  14. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +58 -62
  15. autogluon/timeseries/models/autogluon_tabular/per_step.py +27 -16
  16. autogluon/timeseries/models/autogluon_tabular/transforms.py +11 -9
  17. autogluon/timeseries/models/chronos/__init__.py +2 -1
  18. autogluon/timeseries/models/chronos/chronos2.py +395 -0
  19. autogluon/timeseries/models/chronos/model.py +127 -89
  20. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +69 -37
  21. autogluon/timeseries/models/ensemble/__init__.py +36 -2
  22. autogluon/timeseries/models/ensemble/abstract.py +14 -46
  23. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  24. autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
  25. autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
  26. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  27. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  28. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  31. autogluon/timeseries/models/ensemble/{greedy.py → ensemble_selection.py} +41 -61
  32. autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
  33. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  34. autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
  35. autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +25 -22
  36. autogluon/timeseries/models/ensemble/weighted/greedy.py +64 -0
  37. autogluon/timeseries/models/gluonts/abstract.py +32 -31
  38. autogluon/timeseries/models/gluonts/dataset.py +11 -11
  39. autogluon/timeseries/models/gluonts/models.py +0 -7
  40. autogluon/timeseries/models/local/__init__.py +0 -7
  41. autogluon/timeseries/models/local/abstract_local_model.py +15 -18
  42. autogluon/timeseries/models/local/naive.py +2 -2
  43. autogluon/timeseries/models/local/npts.py +7 -1
  44. autogluon/timeseries/models/local/statsforecast.py +13 -13
  45. autogluon/timeseries/models/multi_window/multi_window_model.py +39 -24
  46. autogluon/timeseries/models/registry.py +3 -4
  47. autogluon/timeseries/models/toto/__init__.py +3 -0
  48. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  49. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  50. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  51. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  52. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  58. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  59. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  60. autogluon/timeseries/models/toto/dataloader.py +108 -0
  61. autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
  62. autogluon/timeseries/models/toto/model.py +249 -0
  63. autogluon/timeseries/predictor.py +541 -162
  64. autogluon/timeseries/regressor.py +27 -30
  65. autogluon/timeseries/splitter.py +3 -27
  66. autogluon/timeseries/trainer/ensemble_composer.py +444 -0
  67. autogluon/timeseries/trainer/model_set_builder.py +9 -9
  68. autogluon/timeseries/trainer/prediction_cache.py +16 -16
  69. autogluon/timeseries/trainer/trainer.py +300 -279
  70. autogluon/timeseries/trainer/utils.py +17 -0
  71. autogluon/timeseries/transforms/covariate_scaler.py +8 -8
  72. autogluon/timeseries/transforms/target_scaler.py +15 -15
  73. autogluon/timeseries/utils/constants.py +10 -0
  74. autogluon/timeseries/utils/datetime/lags.py +1 -3
  75. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  76. autogluon/timeseries/utils/features.py +31 -14
  77. autogluon/timeseries/utils/forecast.py +6 -7
  78. autogluon/timeseries/utils/timer.py +173 -0
  79. autogluon/timeseries/version.py +1 -1
  80. autogluon.timeseries-1.5.1b20260122-py3.11-nspkg.pth +1 -0
  81. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/METADATA +39 -22
  82. autogluon_timeseries-1.5.1b20260122.dist-info/RECORD +103 -0
  83. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/WHEEL +1 -1
  84. autogluon/timeseries/evaluator.py +0 -6
  85. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
  86. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  87. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
  88. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -580
  89. autogluon.timeseries-1.4.1b20250907-py3.9-nspkg.pth +0 -1
  90. autogluon.timeseries-1.4.1b20250907.dist-info/RECORD +0 -75
  91. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/LICENSE +0 -0
  92. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/NOTICE +0 -0
  93. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/namespace_packages.txt +0 -0
  94. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/top_level.txt +0 -0
  95. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/zip-safe +0 -0
@@ -0,0 +1,262 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import math
7
+ from typing import NamedTuple
8
+
9
+ import torch
10
+
11
+ from .distribution import MixtureOfStudentTsOutput
12
+ from .kvcache import KVCache
13
+ from .scaler import CausalPatchStdMeanScaler
14
+ from .transformer import Transformer
15
+
16
+
17
+ class TotoOutput(NamedTuple):
18
+ """
19
+ Output of the Toto model. Contains the output distribution, the location parameters,
20
+ and the scale parameters.
21
+ """
22
+
23
+ distribution: torch.distributions.Distribution
24
+ loc: torch.Tensor
25
+ scale: torch.Tensor
26
+
27
+
28
+ def patchify_id_mask(id_mask: torch.Tensor, patch_size: int) -> torch.Tensor:
29
+ patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size)
30
+ patched_id_mask_min = patched_id_mask.min(-1).values
31
+ patched_id_mask_max = patched_id_mask.max(-1).values
32
+ assert torch.eq(patched_id_mask_min, patched_id_mask_max).all(), "Patches cannot span multiple datasets"
33
+ return patched_id_mask_min
34
+
35
+
36
+ class PatchEmbedding(torch.nn.Module):
37
+ """
38
+ Multivariate time series patch embedding.
39
+ Patchifies each variate separately.
40
+ """
41
+
42
+ def __init__(self, patch_size: int, stride: int, embed_dim: int):
43
+ super().__init__()
44
+ self.patch_size = patch_size
45
+ self.embed_dim = embed_dim
46
+ self.stride = stride
47
+ self.projection = torch.nn.Linear(self.patch_size, self.embed_dim)
48
+
49
+ def _patchify(self, x: torch.Tensor) -> torch.Tensor:
50
+ return x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
51
+
52
+ def forward(
53
+ self,
54
+ x: torch.Tensor,
55
+ id_mask: torch.Tensor,
56
+ ) -> tuple[torch.Tensor, torch.Tensor]:
57
+ assert x.shape[-1] % self.patch_size == 0, (
58
+ f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})"
59
+ )
60
+ x_patched: torch.Tensor = self._patchify(x)
61
+ id_mask_patched: torch.Tensor = self._patchify(id_mask)
62
+
63
+ assert torch.eq(id_mask_patched.min(-1).values, id_mask_patched.max(-1).values).all(), (
64
+ "Patches cannot span multiple datasets"
65
+ )
66
+
67
+ return (
68
+ self.projection(x_patched),
69
+ id_mask_patched.min(-1).values,
70
+ )
71
+
72
+
73
+ class TotoBackbone(torch.nn.Module):
74
+ """
75
+ Toto (Timeseries-Optimized Transformer for Observability) is a transformer-based model for multivariate
76
+ time series forecasting. It applies a patch embedding to the input data, followed by a transformer
77
+ that alternates between time-wise and space-wise attention. The transformer is followed by a linear projection
78
+ that maps the transformer output to the output distribution.
79
+
80
+ The output distribution can be a single distribution (e.g. Gaussian) or a mixture of distributions.
81
+ If a mixture of distributions is used, the model will learn to predict the mixture weights
82
+ as well as the parameters of the individual distributions.
83
+
84
+ Parameters
85
+ ----------
86
+ patch_size
87
+ Size of the patch to use for the patch embedding.
88
+ stride
89
+ Stride to use for the patch embedding.
90
+ embed_dim
91
+ Dimension of the model's latent space.
92
+ num_layers
93
+ Number of transformer layers to use.
94
+ num_heads
95
+ Number of attention heads to use in each self-attention layer.
96
+ mlp_hidden_dim
97
+ Dimension of the hidden layer in the feedforward network.
98
+ dropout
99
+ Dropout rate to use in the model.
100
+ spacewise_every_n_layers
101
+ How many time-wise transformer layers to apply between each space-wise transformer layer.
102
+ spacewise_first
103
+ Whether to apply space-wise attention before time-wise attention.
104
+ scaler_cls
105
+ Class to use for scaling the input data.
106
+ output_distribution_classes
107
+ List of classes to use for the output distribution. If a single class is provided, the model
108
+ will output a single distribution. If multiple classes are provided, the model will output a
109
+ learned mixture of distributions.
110
+ output_distribution_kwargs
111
+ Keyword arguments to pass to the output distribution class. Note: this currently only works
112
+ with a single output distribution class.
113
+ use_memory_efficient_attention:
114
+ Whether to use memory-efficient attention. If True, the model will use the memory-efficient from xFormers.
115
+ stabilize_with_global:
116
+ Whether to use global statistics to stabilize causal statistics by clamping extreme values. Only applies to causal scalers.
117
+ scale_factor_exponent:
118
+ Exponent that controls the allowed range of deviation from global scale for causal scalers.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ patch_size: int,
124
+ stride: int,
125
+ embed_dim: int,
126
+ num_layers: int,
127
+ num_heads: int,
128
+ mlp_hidden_dim: int,
129
+ dropout: float,
130
+ spacewise_every_n_layers: int,
131
+ scaler_cls: str,
132
+ output_distribution_classes: list[str],
133
+ spacewise_first: bool = True,
134
+ output_distribution_kwargs: dict | None = None,
135
+ use_memory_efficient_attention: bool = True,
136
+ stabilize_with_global: bool = True,
137
+ scale_factor_exponent: float = 10.0,
138
+ ):
139
+ super().__init__()
140
+ self.embed_dim = embed_dim
141
+ # strings are used when loading a safetensors checkpoint
142
+ # Initialize patch-based scalers with the correct patch_size
143
+
144
+ self.scaler = CausalPatchStdMeanScaler(
145
+ patch_size=patch_size,
146
+ stabilize_with_global=stabilize_with_global,
147
+ scale_factor_exponent=scale_factor_exponent,
148
+ )
149
+ self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim)
150
+ self.dropout = dropout
151
+ self.num_layers = num_layers
152
+ self.use_memory_efficient_attention = use_memory_efficient_attention
153
+ self.transformer = Transformer(
154
+ embed_dim=embed_dim,
155
+ num_heads=num_heads,
156
+ num_layers=self.num_layers,
157
+ mlp_hidden_dim=mlp_hidden_dim,
158
+ dropout=dropout,
159
+ spacewise_every_n_layers=spacewise_every_n_layers,
160
+ spacewise_first=spacewise_first,
161
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
162
+ )
163
+ self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size)
164
+
165
+ # TODO[BEN] this doesn't need to be a list
166
+ output_distribution_classes_ = [MixtureOfStudentTsOutput]
167
+ self.output_distribution = output_distribution_classes_[0](embed_dim, **(output_distribution_kwargs or {}))
168
+
169
+ def allocate_kv_cache(
170
+ self,
171
+ batch_size: int,
172
+ num_variates: int,
173
+ max_time_steps: int,
174
+ device: torch.device,
175
+ dtype: torch.dtype,
176
+ ) -> KVCache:
177
+ return KVCache(
178
+ batch_size=batch_size,
179
+ num_variates=num_variates,
180
+ transformer_layers=list(self.transformer.layers),
181
+ num_layers=self.num_layers,
182
+ embed_dim=self.embed_dim,
183
+ num_heads=self.transformer.layers[0].num_heads, # type: ignore
184
+ max_seq_len=math.ceil(max_time_steps / self.patch_embed.stride),
185
+ device=device,
186
+ dtype=dtype,
187
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
188
+ )
189
+
190
+ def backbone(
191
+ self,
192
+ inputs: torch.Tensor,
193
+ input_padding_mask: torch.Tensor,
194
+ id_mask: torch.Tensor,
195
+ kv_cache: KVCache | None = None,
196
+ scaling_prefix_length: int | None = None,
197
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
198
+ scaled_inputs: torch.Tensor
199
+ loc: torch.Tensor
200
+ scale: torch.Tensor
201
+
202
+ # Standard scaling operation, same API but without ID mask.
203
+ scaled_inputs, loc, scale = self.scaler(
204
+ inputs,
205
+ weights=torch.ones_like(inputs, device=inputs.device),
206
+ padding_mask=input_padding_mask,
207
+ prefix_length=scaling_prefix_length,
208
+ )
209
+
210
+ if kv_cache is not None:
211
+ prefix_len = self.patch_embed.stride * kv_cache.current_len(0)
212
+
213
+ # Truncate inputs so that the transformer only processes
214
+ # the last patch in the sequence. We'll use the KVCache
215
+ # for the earlier patches.
216
+ scaled_inputs = scaled_inputs[:, :, prefix_len:]
217
+
218
+ # As a simplification, when using kv cache we only allow decoding
219
+ # one step at a time after the initial forward pass.
220
+ assert (prefix_len == 0) or (scaled_inputs.shape[-1] == self.patch_embed.stride), (
221
+ "Must decode one step at a time."
222
+ )
223
+
224
+ input_padding_mask = input_padding_mask[:, :, prefix_len:]
225
+ id_mask = id_mask[:, :, prefix_len:]
226
+
227
+ embeddings: torch.Tensor
228
+ reduced_id_mask: torch.Tensor
229
+
230
+ embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask)
231
+
232
+ # Apply the transformer on the embeddings
233
+ transformed: torch.Tensor = self.transformer(embeddings, reduced_id_mask, kv_cache)
234
+
235
+ # Unembed and flatten the sequence
236
+ unembedded = self.unembed(transformed)
237
+ batch_size, num_variates, seq_len = unembedded.shape[:3]
238
+ patch_size = unembedded.shape[-1] // self.embed_dim
239
+ flattened = unembedded.view(batch_size, num_variates, seq_len * patch_size, self.embed_dim)
240
+ return flattened, loc, scale
241
+
242
+ def forward(
243
+ self,
244
+ inputs: torch.Tensor,
245
+ input_padding_mask: torch.Tensor,
246
+ id_mask: torch.Tensor,
247
+ kv_cache: KVCache | None = None,
248
+ scaling_prefix_length: int | None = None,
249
+ ) -> TotoOutput:
250
+ flattened, loc, scale = self.backbone(
251
+ inputs,
252
+ input_padding_mask,
253
+ id_mask,
254
+ kv_cache,
255
+ scaling_prefix_length,
256
+ )
257
+
258
+ return TotoOutput(self.output_distribution(flattened), loc, scale)
259
+
260
+ @property
261
+ def device(self):
262
+ return next(self.parameters()).device
@@ -0,0 +1,70 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from abc import ABC
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from gluonts.torch.distributions import AffineTransformed
11
+ from gluonts.torch.distributions.studentT import StudentT
12
+
13
+
14
+ class DistributionOutput(ABC, torch.nn.Module):
15
+ pass
16
+
17
+
18
+ class StudentTOutput(DistributionOutput):
19
+ def __init__(self, embed_dim):
20
+ super().__init__()
21
+ self.embed_dim = embed_dim
22
+ self.df = torch.nn.Linear(embed_dim, 1)
23
+ self.loc_proj = torch.nn.Linear(embed_dim, 1)
24
+ self.scale_proj = torch.nn.Linear(embed_dim, 1)
25
+
26
+ def forward(self, inputs, loc=None, scale=None):
27
+ eps = torch.finfo(inputs.dtype).eps
28
+ df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
29
+ base_loc = self.loc_proj(inputs).squeeze(-1)
30
+ base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
31
+
32
+ base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) # type: ignore
33
+
34
+ if loc is not None and scale is not None:
35
+ return AffineTransformed(
36
+ base_dist,
37
+ loc=loc,
38
+ scale=scale,
39
+ )
40
+ return base_dist
41
+
42
+
43
+ class MixtureOfStudentTsOutput(DistributionOutput):
44
+ def __init__(
45
+ self,
46
+ embed_dim,
47
+ k_components,
48
+ ):
49
+ super().__init__()
50
+ self.embed_dim = embed_dim
51
+ self.k_components = k_components
52
+
53
+ self.df = torch.nn.Linear(embed_dim, k_components)
54
+ self.loc_proj = torch.nn.Linear(embed_dim, k_components)
55
+ self.scale_proj = torch.nn.Linear(embed_dim, k_components)
56
+ self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
57
+
58
+ def forward(self, inputs, loc=None, scale=None):
59
+ df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
60
+ loc = self.loc_proj(inputs)
61
+ scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
62
+ logits = self.mixture_weights(inputs)
63
+ probs = F.softmax(logits, dim=-1)
64
+ components = StudentT(df, loc, scale)
65
+ mixture_distribution = torch.distributions.Categorical(probs=probs)
66
+
67
+ return torch.distributions.MixtureSameFamily(
68
+ mixture_distribution,
69
+ components,
70
+ )
@@ -0,0 +1,136 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+
10
+ from .attention import TimeWiseMultiheadAttention
11
+
12
+ K = torch.Tensor
13
+ V = torch.Tensor
14
+ KV = tuple[torch.Tensor, torch.Tensor]
15
+
16
+
17
+ @dataclass
18
+ class KVCache:
19
+ """
20
+ Key/Value cache for storing intermediate attention values
21
+ during multistep inference. Only stores KV cache for timewise layers, skipping spacewise layers.
22
+ """
23
+
24
+ batch_size: int
25
+ num_variates: int
26
+ transformer_layers: list
27
+ num_layers: int
28
+ embed_dim: int
29
+ num_heads: int
30
+ max_seq_len: int
31
+ device: torch.device = torch.device("cpu")
32
+ dtype: torch.dtype = torch.float32
33
+ use_memory_efficient_attention: bool = True
34
+
35
+ _keys: torch.Tensor = field(init=False)
36
+ _values: torch.Tensor = field(init=False)
37
+ _current_idx: torch.Tensor = field(init=False)
38
+ _layer_cache_map: torch.Tensor = field(init=False)
39
+
40
+ def __post_init__(self):
41
+ """
42
+ - Determine timewise vs. spacewise layers and allocate KV only for timewise.
43
+ - Create a fast tensor-based mapping from global layer_idx -> timewise layer_idx.
44
+ """
45
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
46
+ head_dim = self.embed_dim // self.num_heads
47
+
48
+ # Compute which layers are timewise
49
+ time_layer_indices = [
50
+ i
51
+ for i in range(self.num_layers)
52
+ if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention)
53
+ ]
54
+
55
+ time_layer_count = max(1, len(time_layer_indices)) # handle edge case for no timewise layers
56
+ # Allocate for only the timewise layers
57
+ if self.use_memory_efficient_attention:
58
+ shape = (
59
+ time_layer_count,
60
+ self.batch_size * self.num_variates,
61
+ self.max_seq_len,
62
+ self.num_heads,
63
+ head_dim,
64
+ )
65
+ else:
66
+ shape = (
67
+ time_layer_count,
68
+ self.batch_size * self.num_variates,
69
+ self.num_heads,
70
+ self.max_seq_len,
71
+ head_dim,
72
+ )
73
+ self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
74
+ self._values = torch.zeros_like(self._keys)
75
+ self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int)
76
+ # Build a tensor lookup for global -> timewise layer index (default to 0)
77
+ self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device)
78
+ for cache_idx, layer_idx in enumerate(time_layer_indices):
79
+ self._layer_cache_map[layer_idx] = int(cache_idx) # Assign correct indices
80
+
81
+ def __getitem__(self, layer_idx: int) -> KV:
82
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
83
+ end_idx = int(self._current_idx[cache_idx].item())
84
+
85
+ if self.use_memory_efficient_attention:
86
+ return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :]
87
+ else:
88
+ return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :]
89
+
90
+ def current_len(self, cache_idx: int) -> int:
91
+ return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0
92
+
93
+ def seq_len(self, layer_idx: int) -> int:
94
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
95
+ return self.current_len(cache_idx)
96
+
97
+ def append(self, layer_idx: int, kv: KV):
98
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
99
+ keys, values = kv
100
+
101
+ # Validate dimensions
102
+ assert keys.shape == values.shape, "keys and values must have the same shape"
103
+ assert keys.shape[0] == self.batch_size * self.num_variates, (
104
+ "keys and values must have batch_size * num_variates as their first dimension"
105
+ )
106
+
107
+ if self.use_memory_efficient_attention:
108
+ assert keys.shape[2] == self.num_heads, "keys and values must have num_heads as their third dimension"
109
+ else:
110
+ assert keys.shape[1] == self.num_heads, "keys and values must have num_heads as their second dimension"
111
+ assert keys.shape[3] == self.embed_dim // self.num_heads, (
112
+ "keys and values must have head_dim as their fourth dimension"
113
+ )
114
+
115
+ start_idx = self._current_idx[cache_idx]
116
+ if self.use_memory_efficient_attention:
117
+ end_idx = start_idx + keys.shape[1]
118
+ else:
119
+ end_idx = start_idx + keys.shape[2]
120
+ assert end_idx <= self.max_seq_len, (
121
+ f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}"
122
+ )
123
+
124
+ if self.use_memory_efficient_attention:
125
+ self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
126
+ self._values[cache_idx, :, start_idx:end_idx, :, :] = values
127
+ else:
128
+ self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
129
+ self._values[cache_idx, :, :, start_idx:end_idx, :] = values
130
+
131
+ self._current_idx[cache_idx] = end_idx
132
+
133
+ def reset(self):
134
+ self._keys.zero_()
135
+ self._values.zero_()
136
+ self._current_idx.zero_()
@@ -0,0 +1,89 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+
7
+ import torch
8
+ from einops import rearrange
9
+
10
+ from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb, default
11
+
12
+
13
+ class TimeAwareRotaryEmbedding(RotaryEmbedding):
14
+ """
15
+ A variant of the rotary position embedding that (optionally) uses the time index
16
+ to compute the sinusoidal and cosine embeddings. This is useful for
17
+ time series data, where the time index is the most important positional
18
+ information.
19
+ """
20
+
21
+ def __init__(self, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ # If the parent stored `freqs` as a Parameter, remove it and register as a buffer
24
+ # Register buffer is needed for sharding with FSDP
25
+ if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter):
26
+ # Extract the underlying Tensor
27
+ freqs_data = self.freqs.data
28
+
29
+ # Remove `freqs` from the module's parameters
30
+ self._parameters.pop("freqs")
31
+
32
+ # Register as non-persistent buffer
33
+ self.register_buffer("freqs", freqs_data, persistent=False)
34
+
35
+ def rotate_queries_and_keys(
36
+ self,
37
+ q: torch.Tensor,
38
+ k: torch.Tensor,
39
+ seq_dim: int | None = None,
40
+ seq_pos: torch.Tensor | None = None,
41
+ seq_pos_offset: int = 0,
42
+ ):
43
+ """
44
+ This method is the same as the one on the base class, except it allows you to override
45
+ the sequence position tensor with a custom one. It also removes the ability
46
+ to cache the position encodings, since we have to compute them dynamically
47
+ based on the timesteps in the input data.
48
+ """
49
+ if seq_dim is None:
50
+ seq_dim = self.default_seq_dim
51
+
52
+ assert self.use_xpos
53
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
54
+
55
+ seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device))
56
+ seq = seq + seq_pos_offset # type: ignore
57
+
58
+ freqs = self.forward(seq)
59
+
60
+ scale = self.get_scale(seq).to(dtype)
61
+
62
+ # used for xformers
63
+ if seq_dim == -3:
64
+ num_heads = q.shape[-2]
65
+ freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
66
+ scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
67
+
68
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) # type: ignore
69
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) # type: ignore
70
+
71
+ rotated_q = rotated_q.type(q.dtype)
72
+ rotated_k = rotated_k.type(k.dtype)
73
+
74
+ return rotated_q, rotated_k
75
+
76
+ def get_scale(self, t: torch.Tensor, seq_len: int | None = None, offset=0):
77
+ """
78
+ This method is adapted closely from the base class, but it knows how to handle
79
+ when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different
80
+ sequence position vector for each time series).
81
+ """
82
+ assert self.use_xpos
83
+
84
+ power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
85
+
86
+ scale = self.scale ** rearrange(power, "... n -> ... n 1") # type: ignore
87
+ scale = torch.cat((scale, scale), dim=-1)
88
+
89
+ return scale