autogluon.timeseries 1.4.1b20250906__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (93) hide show
  1. autogluon/timeseries/configs/hyperparameter_presets.py +2 -2
  2. autogluon/timeseries/dataset/ts_dataframe.py +97 -86
  3. autogluon/timeseries/learner.py +68 -35
  4. autogluon/timeseries/metrics/__init__.py +4 -4
  5. autogluon/timeseries/metrics/abstract.py +8 -8
  6. autogluon/timeseries/metrics/point.py +9 -9
  7. autogluon/timeseries/metrics/quantile.py +5 -5
  8. autogluon/timeseries/metrics/utils.py +4 -4
  9. autogluon/timeseries/models/__init__.py +4 -1
  10. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -39
  11. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  12. autogluon/timeseries/models/abstract/tunable.py +8 -8
  13. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +58 -62
  14. autogluon/timeseries/models/autogluon_tabular/per_step.py +26 -15
  15. autogluon/timeseries/models/autogluon_tabular/transforms.py +11 -9
  16. autogluon/timeseries/models/chronos/__init__.py +2 -1
  17. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  18. autogluon/timeseries/models/chronos/model.py +125 -87
  19. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +68 -36
  20. autogluon/timeseries/models/ensemble/__init__.py +34 -2
  21. autogluon/timeseries/models/ensemble/abstract.py +5 -42
  22. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  23. autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
  24. autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
  25. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  26. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  27. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
  28. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  30. autogluon/timeseries/models/ensemble/{greedy.py → ensemble_selection.py} +41 -61
  31. autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
  32. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  33. autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
  34. autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +6 -16
  35. autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
  36. autogluon/timeseries/models/gluonts/abstract.py +25 -25
  37. autogluon/timeseries/models/gluonts/dataset.py +11 -11
  38. autogluon/timeseries/models/local/__init__.py +0 -7
  39. autogluon/timeseries/models/local/abstract_local_model.py +15 -18
  40. autogluon/timeseries/models/local/naive.py +2 -2
  41. autogluon/timeseries/models/local/npts.py +1 -1
  42. autogluon/timeseries/models/local/statsforecast.py +12 -12
  43. autogluon/timeseries/models/multi_window/multi_window_model.py +39 -24
  44. autogluon/timeseries/models/registry.py +3 -4
  45. autogluon/timeseries/models/toto/__init__.py +3 -0
  46. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  47. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  48. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  49. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  50. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  51. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  52. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  56. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  57. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  58. autogluon/timeseries/models/toto/dataloader.py +108 -0
  59. autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
  60. autogluon/timeseries/models/toto/model.py +236 -0
  61. autogluon/timeseries/predictor.py +301 -103
  62. autogluon/timeseries/regressor.py +27 -30
  63. autogluon/timeseries/splitter.py +3 -27
  64. autogluon/timeseries/trainer/ensemble_composer.py +439 -0
  65. autogluon/timeseries/trainer/model_set_builder.py +9 -9
  66. autogluon/timeseries/trainer/prediction_cache.py +16 -16
  67. autogluon/timeseries/trainer/trainer.py +300 -275
  68. autogluon/timeseries/trainer/utils.py +17 -0
  69. autogluon/timeseries/transforms/covariate_scaler.py +8 -8
  70. autogluon/timeseries/transforms/target_scaler.py +15 -15
  71. autogluon/timeseries/utils/constants.py +10 -0
  72. autogluon/timeseries/utils/datetime/lags.py +1 -3
  73. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  74. autogluon/timeseries/utils/features.py +18 -14
  75. autogluon/timeseries/utils/forecast.py +6 -7
  76. autogluon/timeseries/utils/timer.py +173 -0
  77. autogluon/timeseries/version.py +1 -1
  78. autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
  79. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +39 -22
  80. autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
  81. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
  82. autogluon/timeseries/evaluator.py +0 -6
  83. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
  84. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  85. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
  86. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -580
  87. autogluon.timeseries-1.4.1b20250906-py3.9-nspkg.pth +0 -1
  88. autogluon.timeseries-1.4.1b20250906.dist-info/RECORD +0 -75
  89. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
  90. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
  91. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
  92. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
  93. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
@@ -0,0 +1,305 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import warnings
7
+
8
+ import torch
9
+ from einops import repeat
10
+ from gluonts.core.component import validated
11
+ from gluonts.torch.scaler import Scaler
12
+
13
+
14
+ def compute_causal_statistics(
15
+ data: torch.Tensor,
16
+ weights: torch.Tensor,
17
+ padding_mask: torch.Tensor,
18
+ dim: int,
19
+ minimum_scale: float,
20
+ use_bessel_correction: bool = True,
21
+ stabilize_with_global: bool = False,
22
+ scale_factor_exponent: float = 10.0,
23
+ prefix_length: int | None = None,
24
+ ) -> tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Compute causal mean and scale statistics along a specified dimension using
27
+ a vectorized implementation of Welford's algorithm for numerical stability.
28
+
29
+ This implementation avoids explicit loops while maintaining the numerical stability
30
+ of Welford's algorithm, achieving better performance with the same robustness
31
+ against overflow issues.
32
+
33
+
34
+ Can optionally use global statistics to stabilize causal statistics by clamping
35
+ extreme values, preventing instability while preserving a relaxed version of the
36
+ causal property. This allows a controlled amount of future information leakage,
37
+ introducing an explicit tradeoff between causality and stability.
38
+ extreme values, preventing instability while preserving the causal property.
39
+
40
+ Parameters
41
+ ----------
42
+ data
43
+ The input data tensor
44
+ weights
45
+ The weight tensor (same shape as data)
46
+ padding_mask
47
+ The padding mask tensor (same shape as data)
48
+ dim
49
+ The dimension along which to compute statistics (must be -1, the time dimension)
50
+ minimum_scale
51
+ Minimum scale value to use
52
+ use_bessel_correction
53
+ Whether to use Bessel's correction to get an unbiased estimator
54
+ stabilize_with_global
55
+ Whether to use global statistics to stabilize the causal statistics by clamping
56
+ extreme values
57
+ scale_factor_exponent
58
+ Exponent that controls the allowed range of deviation from global scale.
59
+ For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
60
+ With exponent=2.0, the range would be 0.01x to 100x.
61
+ prefix_length
62
+ If specified, the global statistics will be computed using only the prefix length
63
+ requested. This is used for multistep decoding, where we only want to use the
64
+ initial historical data to compute the global statistics. If stabilize_with_global
65
+ is False, this parameter is ignored.
66
+
67
+ Returns
68
+ -------
69
+ tuple[torch.Tensor, torch.Tensor]
70
+ Causal mean and scale tensors, potentially stabilized with global statistics
71
+ """
72
+ # Assert that dim is -1 (last dimension)
73
+ assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)"
74
+
75
+ with torch.no_grad():
76
+ # Apply padding mask to weights
77
+ weights = weights * padding_mask
78
+
79
+ # Try to use higher precision for numerical stability
80
+ try:
81
+ high_precision_data = data.to(torch.float64)
82
+ high_precision_weights = weights.to(torch.float64)
83
+ except TypeError:
84
+ # Fallback for devices that don't support float64
85
+ warnings.warn(
86
+ f"Float64 is not supported by device {data.device}. "
87
+ "Using float32 instead for causal scaler calculations. "
88
+ "This may lead to numerical issues if the data contains extreme values.",
89
+ RuntimeWarning,
90
+ )
91
+ high_precision_data = data.to(torch.float32)
92
+ high_precision_weights = weights.to(torch.float32)
93
+
94
+ # Check if deterministic algorithms are enabled and we're using CUDA.
95
+ # Cumsum operations do not support deterministic mode in CUDA,
96
+ # so we need to disable it for just this section.
97
+ prev_deterministic = torch.are_deterministic_algorithms_enabled()
98
+ if prev_deterministic and data.device.type == "cuda":
99
+ # Disable deterministic algorithms for operations
100
+ torch.use_deterministic_algorithms(False)
101
+
102
+ try:
103
+ # Create weighted data
104
+ weighted_data = high_precision_weights * high_precision_data
105
+
106
+ # Compute cumulative sum of weights and weighted data along time dimension
107
+ cum_weights = torch.cumsum(high_precision_weights, dim=dim)
108
+ cum_values = torch.cumsum(weighted_data, dim=dim)
109
+
110
+ # Avoid division by zero for the first time step or when no valid values
111
+ denominator = cum_weights.clamp_min(1.0)
112
+
113
+ # Compute causal means at each time step
114
+ causal_means = cum_values / denominator
115
+
116
+ # For Welford's algorithm, we need to compute the correction term
117
+ # using the difference between the current value and the current mean
118
+
119
+ # Create shifted version of causal means to compute delta efficiently
120
+ # First item in shifted_means will be zero
121
+ shifted_means = torch.zeros_like(causal_means)
122
+ shifted_means[..., 1:] = causal_means[..., :-1]
123
+
124
+ # Compute delta between current data point and previous mean
125
+ # For t=0, this is just the first data point
126
+ delta = high_precision_data - shifted_means
127
+
128
+ # Compute the increment term for Welford's algorithm.
129
+ # This is defined as the product of the delta and the difference between the current data point and the causal mean.
130
+ # This is where we avoid the traditional E[X²] - E[X]² computation
131
+ increment = delta * (high_precision_data - causal_means) * high_precision_weights
132
+
133
+ # The Welford algorithm uses the term m_2, which is the cumulative sum of the increment term.
134
+ # This is an accumulator that helps us compute the second moment (hence m_2) of the distribution.
135
+ # Compute cumulative sum of the increment term
136
+ m_2 = torch.cumsum(increment, dim=dim)
137
+
138
+ # Compute variance according to Welford's algorithm
139
+ if use_bessel_correction:
140
+ causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0)
141
+ else:
142
+ causal_variance = m_2 / denominator
143
+
144
+ # Add minimum scale but keep in high precision for now
145
+ causal_scale = torch.sqrt(causal_variance + minimum_scale)
146
+
147
+ # Apply stabilization with global statistics if requested
148
+ if stabilize_with_global:
149
+ if prefix_length is not None:
150
+ # Create a prefix mask for global statistics computation
151
+ prefix_mask = torch.zeros_like(weights)
152
+ prefix_mask[..., :prefix_length] = 1.0
153
+
154
+ # Apply prefix mask to restrict computation to prefix
155
+ weighted_data = weighted_data * prefix_mask
156
+ weights = weights * prefix_mask
157
+ padding_mask = padding_mask * prefix_mask
158
+
159
+ # Calculate scale factors from the exponent
160
+ scale_factor_min = 10.0 ** (-scale_factor_exponent)
161
+ scale_factor_max = 10.0**scale_factor_exponent
162
+
163
+ global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0)
164
+ global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator
165
+ global_means = torch.nan_to_num(global_means)
166
+
167
+ global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum(
168
+ dim, keepdim=True
169
+ ) / global_denominator
170
+ global_scale = torch.sqrt(global_variance + minimum_scale)
171
+
172
+ # Expand global statistics to match the time dimension
173
+ expanded_global_scale = global_scale.expand_as(causal_scale)
174
+
175
+ # Define bounds using scale factors
176
+ min_allowed_scale = expanded_global_scale * scale_factor_min
177
+ max_allowed_scale = expanded_global_scale * scale_factor_max
178
+
179
+ # Clamp the causal scale between min_allowed_scale and max_allowed_scale
180
+ causal_scale = torch.clamp(
181
+ causal_scale,
182
+ min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale),
183
+ max=max_allowed_scale,
184
+ )
185
+
186
+ # Now convert means and scale to original dtype after all numerical operations
187
+ causal_means = causal_means.to(data.dtype)
188
+ causal_scale = causal_scale.to(data.dtype)
189
+
190
+ finally:
191
+ # Restore original deterministic setting if it was changed
192
+ if prev_deterministic and data.device.type == "cuda":
193
+ torch.use_deterministic_algorithms(True)
194
+
195
+ return causal_means, causal_scale
196
+
197
+
198
+ class CausalPatchStdMeanScaler(Scaler):
199
+ """
200
+ Causally scales data in patches, where each patch uses statistics computed
201
+ from all data up to and including that patch. Within each patch, all timesteps
202
+ use the same scaling values.
203
+
204
+ This approach provides more stability than per-timestep causal scaling while
205
+ still maintaining the causal property (not using future data).
206
+
207
+ Can optionally stabilize causal statistics using global statistics to prevent
208
+ extreme values, while preserving the causal property.
209
+
210
+ The statistics are computed using Welford's algorithm, which provides better
211
+ numerical stability compared to the direct computation of variance, especially
212
+ when dealing with large values or a large number of data points.
213
+
214
+ Note: This scaler only works with the following constraints:
215
+ - The input must have shape [batch, variates, time_steps]
216
+ - It only operates on the last dimension (-1)
217
+ - The time_steps must be divisible by patch_size
218
+
219
+ Parameters
220
+ ----------
221
+ dim
222
+ dimension along which to compute the causal scale. Must be -1 (the last dimension).
223
+ patch_size
224
+ number of timesteps in each patch
225
+ minimum_scale
226
+ default scale that is used for elements that are constantly zero
227
+ along dimension `dim` or for the first patch.
228
+ use_bessel_correction
229
+ whether to use Bessel's correction to get an unbiased estimator
230
+ stabilize_with_global
231
+ whether to use global statistics to stabilize extreme causal statistics
232
+ scale_factor_exponent
233
+ exponent that controls the allowed range of deviation from global scale.
234
+ For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
235
+ With exponent=2.0, the range would be 0.01x to 100x.
236
+ """
237
+
238
+ @validated()
239
+ def __init__(
240
+ self,
241
+ dim: int = -1,
242
+ patch_size: int = 32,
243
+ minimum_scale: float = 0.1,
244
+ use_bessel_correction: bool = True,
245
+ stabilize_with_global: bool = False,
246
+ scale_factor_exponent: float = 10.0,
247
+ ) -> None:
248
+ super().__init__()
249
+ assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)"
250
+ self.dim = dim
251
+ self.patch_size = patch_size
252
+ self.minimum_scale = minimum_scale
253
+ self.use_bessel_correction = use_bessel_correction
254
+ self.stabilize_with_global = stabilize_with_global
255
+ self.scale_factor_exponent = scale_factor_exponent
256
+
257
+ def __call__( # type: ignore[override]
258
+ self,
259
+ data: torch.Tensor,
260
+ padding_mask: torch.Tensor,
261
+ weights: torch.Tensor,
262
+ prefix_length: int | None = None,
263
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
264
+ assert data.shape == weights.shape, "data and weights must have same shape"
265
+ assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]"
266
+
267
+ with torch.no_grad():
268
+ # Get the number of time steps (last dimension)
269
+ time_steps = data.shape[-1]
270
+
271
+ # Assert that time_steps is divisible by patch_size
272
+ assert time_steps % self.patch_size == 0, (
273
+ f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})"
274
+ )
275
+
276
+ # First compute causal statistics with optional stabilization
277
+ causal_means, causal_scale = compute_causal_statistics(
278
+ data,
279
+ weights,
280
+ padding_mask,
281
+ -1,
282
+ self.minimum_scale,
283
+ self.use_bessel_correction,
284
+ self.stabilize_with_global,
285
+ self.scale_factor_exponent,
286
+ prefix_length,
287
+ )
288
+
289
+ # Unfold the causal means and scales to get the patches
290
+ means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size)
291
+ scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size)
292
+
293
+ # Get the last element of each patch (the most recent statistic)
294
+ patch_stats_means = means_unfolded[..., -1]
295
+ patch_stats_scales = scales_unfolded[..., -1]
296
+
297
+ # Tile the patch statistics across time dimension using einops.repeat
298
+ # With our fixed [batch, variates, num_patches] shape this is much simpler
299
+ patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size)
300
+ patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size)
301
+
302
+ # Apply normalization
303
+ scaled_data = (data - patch_means) / patch_scales
304
+
305
+ return scaled_data, patch_means, patch_scales
@@ -0,0 +1,333 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from typing import cast
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+
12
+ from .attention import (
13
+ AttentionAxis,
14
+ MultiHeadAttention,
15
+ SpaceWiseMultiheadAttention,
16
+ TimeWiseMultiheadAttention,
17
+ )
18
+ from .kvcache import KVCache
19
+ from .rope import TimeAwareRotaryEmbedding
20
+ from .rotary_embedding_torch import RotaryEmbedding
21
+
22
+
23
+ class SwiGLU(torch.nn.Module):
24
+ """
25
+ https://arxiv.org/abs/2002.05202
26
+ NOTE: x should be 2x the size you want
27
+ """
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ # Note this ordering is unusual, but is done so to match xFormers
31
+ gate, x = x.chunk(2, dim=-1)
32
+ return F.silu(gate) * x
33
+
34
+
35
+ class RMSNorm(torch.nn.Module):
36
+ def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-8):
37
+ super(RMSNorm, self).__init__()
38
+ self.eps = eps
39
+ if include_weight:
40
+ self.scale: torch.nn.Parameter | None = torch.nn.Parameter(torch.ones(dim))
41
+ else:
42
+ self.scale = None
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ x_normed = x / torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
46
+ return x_normed if self.scale is None else x_normed * self.scale
47
+
48
+ def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor):
49
+ """
50
+ If you need the fused addition with RMS norm, do the same check here.
51
+ """
52
+ return self.forward(x + y)
53
+
54
+
55
+ def make_batched_block_mask(t: torch.Tensor) -> torch.Tensor:
56
+ unsqueezed = rearrange(t, "... d -> ... 1 d")
57
+ return unsqueezed == unsqueezed.transpose(-1, -2)
58
+
59
+
60
+ class TransformerLayer(torch.nn.Module):
61
+ """
62
+ A transformer block that applies multihead attention followed by a feedforward network.
63
+
64
+ The transformer can be configured to apply time-wise attention (i.e. attention over the time axis)
65
+ or space-wise attention (i.e. attention over the variate axis).
66
+
67
+ The transformer block uses pre-norm, which is a variant of the transformer architecture where
68
+ LayerNorm is applied before each sublayer, rather than after. This is the approach taken in
69
+ LLaMA and other recent transformer-based models.
70
+
71
+ The transformer block also uses SwiGLU, which is a variant of the Gated Linear Unit (GLU) activation
72
+ function. SwiGLU is a variant of the GLU activation that uses the Swish activation function. This
73
+ activation function has been used extensively in recent transformer-based models and has been shown
74
+ to improve performance.
75
+ """
76
+
77
+ embed_dim: int
78
+ num_heads: int
79
+ mlp_hidden_dim: int
80
+ dropout: float
81
+ attention_axis: AttentionAxis
82
+
83
+ def __init__(
84
+ self,
85
+ embed_dim: int,
86
+ num_heads: int,
87
+ mlp_hidden_dim: int,
88
+ dropout: float,
89
+ rotary_emb: RotaryEmbedding | None = None,
90
+ attention_axis: AttentionAxis = AttentionAxis.TIME,
91
+ RMS_norm: bool = True,
92
+ use_memory_efficient_attention: bool = True,
93
+ ):
94
+ super().__init__()
95
+ self.embed_dim = embed_dim
96
+ self.num_heads = num_heads
97
+ self.mlp_hidden_dim = mlp_hidden_dim
98
+ self.dropout = dropout
99
+ self.attention_axis = attention_axis
100
+
101
+ if RMS_norm:
102
+ self.norm1: RMSNorm | torch.nn.LayerNorm = RMSNorm(embed_dim)
103
+ self.norm2: RMSNorm | torch.nn.LayerNorm = RMSNorm(embed_dim)
104
+
105
+ else:
106
+ self.norm1 = torch.nn.LayerNorm(embed_dim)
107
+ self.norm2 = torch.nn.LayerNorm(embed_dim)
108
+
109
+ self.attention: MultiHeadAttention
110
+
111
+ if attention_axis == AttentionAxis.TIME:
112
+ self.attention = TimeWiseMultiheadAttention(
113
+ embed_dim=embed_dim,
114
+ num_heads=num_heads,
115
+ dropout=dropout,
116
+ rotary_emb=rotary_emb, # type: ignore
117
+ use_memory_efficient_attention=use_memory_efficient_attention,
118
+ )
119
+ elif attention_axis == AttentionAxis.SPACE:
120
+ self.attention = SpaceWiseMultiheadAttention(
121
+ embed_dim=embed_dim,
122
+ num_heads=num_heads,
123
+ dropout=dropout,
124
+ rotary_emb=None,
125
+ use_memory_efficient_attention=use_memory_efficient_attention,
126
+ )
127
+ else:
128
+ raise ValueError("Invalid attention axis")
129
+
130
+ self.mlp = torch.nn.Sequential(
131
+ torch.nn.Linear(embed_dim, 2 * mlp_hidden_dim),
132
+ SwiGLU(),
133
+ torch.nn.Linear(mlp_hidden_dim, embed_dim),
134
+ torch.nn.Dropout(dropout),
135
+ )
136
+
137
+ def forward(
138
+ self,
139
+ layer_idx: int,
140
+ inputs: torch.Tensor,
141
+ attention_mask: torch.Tensor | None = None,
142
+ kv_cache: KVCache | None = None,
143
+ ) -> torch.Tensor:
144
+ pre_norm_1 = self.norm1(inputs)
145
+ hidden_state = inputs + self.attention(layer_idx, pre_norm_1, attention_mask, kv_cache).contiguous()
146
+
147
+ pre_norm_2 = self.norm2(hidden_state)
148
+ return hidden_state + self.mlp(pre_norm_2)
149
+
150
+
151
+ class Transformer(torch.nn.Module):
152
+ """
153
+ A stack of transformer layers. The transformer alternates between time-wise and space-wise attention
154
+ to learn both temporal and cross-variate dependencies in the data.
155
+
156
+ Based on the intuition that time-wise attention is more important overall than space-wise attention
157
+ (because an individual variate is more likely to be correlated with itself across time than with other variates),
158
+ the transformer can be configured to apply space-wise attention less frequently than time-wise attention.
159
+ This is controlled by the `spacewise_every_n_layers` parameter, which specifies how many time-wise transformer
160
+ layers to apply between every space-wise transformer layer.
161
+
162
+ Parameters
163
+ ----------
164
+ num_layers
165
+ Number of transformer layers to use.
166
+ num_heads
167
+ Number of attention heads to use in each self-attention layer.
168
+ mlp_hidden_dim
169
+ Dimension of the hidden layer in the feedforward network.
170
+ dropout
171
+ Dropout rate to use in the model.
172
+ spacewise_every_n_layers
173
+ How many time-wise transformer layers to apply between each space-wise transformer layer.
174
+ spacewise_first
175
+ Whether to apply space-wise attention before time-wise attention.
176
+ use_memory_efficient_attention
177
+ Whether to use memory-efficient attention. If True, the model will use the memory-efficient from xFormers.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ num_layers: int,
183
+ embed_dim: int,
184
+ num_heads: int,
185
+ mlp_hidden_dim: int,
186
+ dropout: float,
187
+ spacewise_every_n_layers: int,
188
+ spacewise_first: bool,
189
+ use_memory_efficient_attention: bool = True,
190
+ ):
191
+ super().__init__()
192
+
193
+ assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
194
+
195
+ self.rotary_emb = TimeAwareRotaryEmbedding(
196
+ embed_dim // num_heads,
197
+ use_xpos=True,
198
+ cache_if_possible=True,
199
+ seq_before_head_dim=use_memory_efficient_attention,
200
+ )
201
+ attention_axes = self._get_layer_types(num_layers, spacewise_every_n_layers, spacewise_first)
202
+
203
+ self.use_memory_efficient_attention = use_memory_efficient_attention
204
+
205
+ self.layers = torch.nn.ModuleList(
206
+ [
207
+ TransformerLayer(
208
+ embed_dim=embed_dim,
209
+ num_heads=num_heads,
210
+ mlp_hidden_dim=mlp_hidden_dim,
211
+ dropout=dropout,
212
+ rotary_emb=self.rotary_emb,
213
+ attention_axis=attention_axes[i],
214
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
215
+ )
216
+ for i in range(num_layers)
217
+ ]
218
+ )
219
+
220
+ def _get_mask(
221
+ self,
222
+ num_heads: int,
223
+ dtype: torch.dtype,
224
+ id_mask: torch.Tensor | None = None,
225
+ ) -> torch.Tensor:
226
+ """
227
+ Unified method to create and process space-wise masks.
228
+
229
+ Args:
230
+ mask_type: Type of mask to create ('spacewise').
231
+ seq_len: Total sequence length.
232
+ num_heads: Number of attention heads.
233
+ device: Device where the mask should be created.
234
+ dtype: Desired dtype for the bias tensor.
235
+ id_mask: Mask for variates (used for spacewise masks).
236
+
237
+ Returns:
238
+ Processed attention mask tensor with the correct shape for the given mask type.
239
+ """
240
+
241
+ if id_mask is None:
242
+ raise ValueError("id_mask must be provided for spacewise masks.")
243
+
244
+ # Create spacewise mask
245
+ mask = make_batched_block_mask(id_mask.transpose(-1, -2))
246
+
247
+ if self.use_memory_efficient_attention:
248
+ mask = self._pad_to_multiple(mask)
249
+ mask = mask.float().masked_fill(~mask, float("-inf")).masked_fill(mask, 0.0).to(dtype)
250
+
251
+ # Rearrange for space-wise attention
252
+ mask = rearrange(mask, "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1 variate2")
253
+ # Stack along num_heads dimension
254
+ return mask.expand(-1, num_heads, -1, -1).contiguous()
255
+
256
+ def _pad_to_multiple(
257
+ self,
258
+ tensor: torch.Tensor,
259
+ multiple: int = 8,
260
+ causal: bool = False, # New flag to indicate causal mask extension
261
+ ) -> torch.Tensor:
262
+ """
263
+ Pads the last two dimensions of a tensor to be divisible by `multiple`.
264
+ For causal masks, the padded area is filled with the continued lower-triangular pattern,
265
+ rather than with zeros.
266
+ """
267
+ pad_amount = (multiple - tensor.shape[-1] % multiple) % multiple
268
+ if pad_amount > 0:
269
+ new_size = tensor.shape[-1] + pad_amount
270
+ if causal:
271
+ # Create a full causal mask for the new size.
272
+ full_mask = torch.tril(torch.ones((new_size, new_size), dtype=tensor.dtype, device=tensor.device))
273
+ # Preserve any modifications from the original mask (e.g., condition tokens in top-left)
274
+ full_mask[: tensor.shape[-1], : tensor.shape[-1]] = tensor
275
+ tensor = full_mask
276
+ else:
277
+ tensor = F.pad(tensor, (0, pad_amount, 0, pad_amount))
278
+ return tensor
279
+
280
+ def _get_layer_types(
281
+ self,
282
+ num_layers: int,
283
+ spacewise_every_n_layers: int,
284
+ spacewise_first: bool,
285
+ ) -> list[AttentionAxis]:
286
+ if spacewise_every_n_layers == -1:
287
+ return [AttentionAxis.TIME] * num_layers
288
+ assert num_layers % spacewise_every_n_layers == 0
289
+
290
+ block = [AttentionAxis.TIME] * (spacewise_every_n_layers - 1)
291
+
292
+ if spacewise_first:
293
+ block = [AttentionAxis.SPACE] + block
294
+ else:
295
+ block = block + [AttentionAxis.SPACE]
296
+
297
+ layer_types = block * (num_layers // spacewise_every_n_layers)
298
+
299
+ return layer_types
300
+
301
+ def forward(
302
+ self,
303
+ inputs: torch.Tensor,
304
+ id_mask: torch.Tensor,
305
+ kv_cache: KVCache | None = None,
306
+ ) -> torch.Tensor:
307
+ batch, _, seq_len, _ = inputs.shape
308
+ # Get the sequence length by looking up a timewise layer in the kv cache.
309
+ # Regardless of whether spacewise is first in the stack, the layer
310
+ # at index 1 is always a timewise layer.
311
+ seq_len = (kv_cache.seq_len(1) if kv_cache else 0) + seq_len
312
+
313
+ num_heads: int = cast(int, self.layers[0].num_heads)
314
+
315
+ timewise_attention_mask = None
316
+
317
+ # We create a space-wise ID mask by creating a block triangular mask from the ID mask
318
+ # in the space-wise direction. This ensures that the model can only attend to
319
+ # variates in the same group.
320
+ spacewise_attention_mask = self._get_mask(
321
+ num_heads=num_heads,
322
+ dtype=inputs.dtype,
323
+ id_mask=id_mask,
324
+ )
325
+
326
+ for layer_idx, layer in enumerate(self.layers):
327
+ inputs = layer(
328
+ layer_idx,
329
+ inputs,
330
+ (timewise_attention_mask if layer.attention_axis == AttentionAxis.TIME else spacewise_attention_mask),
331
+ kv_cache,
332
+ )
333
+ return inputs