autogluon.timeseries 1.4.1b20250902__py3-none-any.whl → 1.4.1b20251003__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. autogluon/timeseries/models/__init__.py +2 -0
  2. autogluon/timeseries/models/chronos/model.py +4 -1
  3. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +50 -0
  4. autogluon/timeseries/models/toto/__init__.py +3 -0
  5. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  6. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  7. autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
  8. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  9. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  10. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  11. autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
  12. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
  13. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  14. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  15. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  16. autogluon/timeseries/models/toto/dataloader.py +108 -0
  17. autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
  18. autogluon/timeseries/models/toto/model.py +234 -0
  19. autogluon/timeseries/version.py +1 -1
  20. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/METADATA +10 -5
  21. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/RECORD +28 -13
  22. /autogluon.timeseries-1.4.1b20250902-py3.9-nspkg.pth → /autogluon.timeseries-1.4.1b20251003-py3.9-nspkg.pth +0 -0
  23. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/LICENSE +0 -0
  24. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/NOTICE +0 -0
  25. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/WHEEL +0 -0
  26. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/namespace_packages.txt +0 -0
  27. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/top_level.txt +0 -0
  28. {autogluon.timeseries-1.4.1b20250902.dist-info → autogluon.timeseries-1.4.1b20251003.dist-info}/zip-safe +0 -0
@@ -0,0 +1,136 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+
10
+ from .attention import TimeWiseMultiheadAttention
11
+
12
+ K = torch.Tensor
13
+ V = torch.Tensor
14
+ KV = tuple[torch.Tensor, torch.Tensor]
15
+
16
+
17
+ @dataclass
18
+ class KVCache:
19
+ """
20
+ Key/Value cache for storing intermediate attention values
21
+ during multistep inference. Only stores KV cache for timewise layers, skipping spacewise layers.
22
+ """
23
+
24
+ batch_size: int
25
+ num_variates: int
26
+ transformer_layers: list
27
+ num_layers: int
28
+ embed_dim: int
29
+ num_heads: int
30
+ max_seq_len: int
31
+ device: torch.device = torch.device("cpu")
32
+ dtype: torch.dtype = torch.float32
33
+ use_memory_efficient_attention: bool = True
34
+
35
+ _keys: torch.Tensor = field(init=False)
36
+ _values: torch.Tensor = field(init=False)
37
+ _current_idx: torch.Tensor = field(init=False)
38
+ _layer_cache_map: torch.Tensor = field(init=False)
39
+
40
+ def __post_init__(self):
41
+ """
42
+ - Determine timewise vs. spacewise layers and allocate KV only for timewise.
43
+ - Create a fast tensor-based mapping from global layer_idx -> timewise layer_idx.
44
+ """
45
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
46
+ head_dim = self.embed_dim // self.num_heads
47
+
48
+ # Compute which layers are timewise
49
+ time_layer_indices = [
50
+ i
51
+ for i in range(self.num_layers)
52
+ if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention)
53
+ ]
54
+
55
+ time_layer_count = max(1, len(time_layer_indices)) # handle edge case for no timewise layers
56
+ # Allocate for only the timewise layers
57
+ if self.use_memory_efficient_attention:
58
+ shape = (
59
+ time_layer_count,
60
+ self.batch_size * self.num_variates,
61
+ self.max_seq_len,
62
+ self.num_heads,
63
+ head_dim,
64
+ )
65
+ else:
66
+ shape = (
67
+ time_layer_count,
68
+ self.batch_size * self.num_variates,
69
+ self.num_heads,
70
+ self.max_seq_len,
71
+ head_dim,
72
+ )
73
+ self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
74
+ self._values = torch.zeros_like(self._keys)
75
+ self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int)
76
+ # Build a tensor lookup for global -> timewise layer index (default to 0)
77
+ self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device)
78
+ for cache_idx, layer_idx in enumerate(time_layer_indices):
79
+ self._layer_cache_map[layer_idx] = int(cache_idx) # Assign correct indices
80
+
81
+ def __getitem__(self, layer_idx: int) -> KV:
82
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
83
+ end_idx = int(self._current_idx[cache_idx].item())
84
+
85
+ if self.use_memory_efficient_attention:
86
+ return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :]
87
+ else:
88
+ return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :]
89
+
90
+ def current_len(self, cache_idx: int) -> int:
91
+ return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0
92
+
93
+ def seq_len(self, layer_idx: int) -> int:
94
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
95
+ return self.current_len(cache_idx)
96
+
97
+ def append(self, layer_idx: int, kv: KV):
98
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
99
+ keys, values = kv
100
+
101
+ # Validate dimensions
102
+ assert keys.shape == values.shape, "keys and values must have the same shape"
103
+ assert keys.shape[0] == self.batch_size * self.num_variates, (
104
+ "keys and values must have batch_size * num_variates as their first dimension"
105
+ )
106
+
107
+ if self.use_memory_efficient_attention:
108
+ assert keys.shape[2] == self.num_heads, "keys and values must have num_heads as their third dimension"
109
+ else:
110
+ assert keys.shape[1] == self.num_heads, "keys and values must have num_heads as their second dimension"
111
+ assert keys.shape[3] == self.embed_dim // self.num_heads, (
112
+ "keys and values must have head_dim as their fourth dimension"
113
+ )
114
+
115
+ start_idx = self._current_idx[cache_idx]
116
+ if self.use_memory_efficient_attention:
117
+ end_idx = start_idx + keys.shape[1]
118
+ else:
119
+ end_idx = start_idx + keys.shape[2]
120
+ assert end_idx <= self.max_seq_len, (
121
+ f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}"
122
+ )
123
+
124
+ if self.use_memory_efficient_attention:
125
+ self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
126
+ self._values[cache_idx, :, start_idx:end_idx, :, :] = values
127
+ else:
128
+ self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
129
+ self._values[cache_idx, :, :, start_idx:end_idx, :] = values
130
+
131
+ self._current_idx[cache_idx] = end_idx
132
+
133
+ def reset(self):
134
+ self._keys.zero_()
135
+ self._values.zero_()
136
+ self._current_idx.zero_()
@@ -0,0 +1,94 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from typing import Optional
7
+
8
+ import torch
9
+ from einops import rearrange
10
+ from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
11
+ from rotary_embedding_torch.rotary_embedding_torch import default
12
+
13
+
14
+ def exists(val):
15
+ return val is not None
16
+
17
+
18
+ class TimeAwareRotaryEmbedding(RotaryEmbedding):
19
+ """
20
+ A variant of the rotary position embedding that (optionally) uses the time index
21
+ to compute the sinusoidal and cosine embeddings. This is useful for
22
+ time series data, where the time index is the most important positional
23
+ information.
24
+ """
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ # If the parent stored `freqs` as a Parameter, remove it and register as a buffer
29
+ # Register buffer is needed for sharding with FSDP
30
+ if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter):
31
+ # Extract the underlying Tensor
32
+ freqs_data = self.freqs.data
33
+
34
+ # Remove `freqs` from the module's parameters
35
+ self._parameters.pop("freqs")
36
+
37
+ # Register as non-persistent buffer
38
+ self.register_buffer("freqs", freqs_data, persistent=False)
39
+
40
+ def rotate_queries_and_keys(
41
+ self,
42
+ q: torch.Tensor,
43
+ k: torch.Tensor,
44
+ seq_dim: Optional[int] = None,
45
+ seq_pos: Optional[torch.Tensor] = None,
46
+ seq_pos_offset: int = 0,
47
+ ):
48
+ """
49
+ This method is the same as the one on the base class, except it allows you to override
50
+ the sequence position tensor with a custom one. It also removes the ability
51
+ to cache the position encodings, since we have to compute them dynamically
52
+ based on the timesteps in the input data.
53
+ """
54
+ if seq_dim is None:
55
+ seq_dim = self.default_seq_dim
56
+
57
+ assert self.use_xpos
58
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
59
+
60
+ seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device))
61
+ seq = seq + seq_pos_offset # type: ignore
62
+
63
+ freqs = self.forward(seq)
64
+
65
+ scale = self.get_scale(seq).to(dtype)
66
+
67
+ # used for xformers
68
+ if seq_dim == -3:
69
+ num_heads = q.shape[-2]
70
+ freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
71
+ scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
72
+
73
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) # type: ignore
74
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) # type: ignore
75
+
76
+ rotated_q = rotated_q.type(q.dtype)
77
+ rotated_k = rotated_k.type(k.dtype)
78
+
79
+ return rotated_q, rotated_k
80
+
81
+ def get_scale(self, t: torch.Tensor, seq_len: Optional[int] = None, offset=0):
82
+ """
83
+ This method is adapted closely from the base class, but it knows how to handle
84
+ when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different
85
+ sequence position vector for each time series).
86
+ """
87
+ assert self.use_xpos
88
+
89
+ power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
90
+
91
+ scale = self.scale ** rearrange(power, "... n -> ... n 1") # type: ignore
92
+ scale = torch.cat((scale, scale), dim=-1)
93
+
94
+ return scale
@@ -0,0 +1,306 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import warnings
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from einops import repeat
11
+ from gluonts.core.component import validated
12
+ from gluonts.torch.scaler import Scaler
13
+
14
+
15
+ def compute_causal_statistics(
16
+ data: torch.Tensor,
17
+ weights: torch.Tensor,
18
+ padding_mask: torch.Tensor,
19
+ dim: int,
20
+ minimum_scale: float,
21
+ use_bessel_correction: bool = True,
22
+ stabilize_with_global: bool = False,
23
+ scale_factor_exponent: float = 10.0,
24
+ prefix_length: Optional[int] = None,
25
+ ) -> tuple[torch.Tensor, torch.Tensor]:
26
+ """
27
+ Compute causal mean and scale statistics along a specified dimension using
28
+ a vectorized implementation of Welford's algorithm for numerical stability.
29
+
30
+ This implementation avoids explicit loops while maintaining the numerical stability
31
+ of Welford's algorithm, achieving better performance with the same robustness
32
+ against overflow issues.
33
+
34
+
35
+ Can optionally use global statistics to stabilize causal statistics by clamping
36
+ extreme values, preventing instability while preserving a relaxed version of the
37
+ causal property. This allows a controlled amount of future information leakage,
38
+ introducing an explicit tradeoff between causality and stability.
39
+ extreme values, preventing instability while preserving the causal property.
40
+
41
+ Parameters
42
+ ----------
43
+ data
44
+ The input data tensor
45
+ weights
46
+ The weight tensor (same shape as data)
47
+ padding_mask
48
+ The padding mask tensor (same shape as data)
49
+ dim
50
+ The dimension along which to compute statistics (must be -1, the time dimension)
51
+ minimum_scale
52
+ Minimum scale value to use
53
+ use_bessel_correction
54
+ Whether to use Bessel's correction to get an unbiased estimator
55
+ stabilize_with_global
56
+ Whether to use global statistics to stabilize the causal statistics by clamping
57
+ extreme values
58
+ scale_factor_exponent
59
+ Exponent that controls the allowed range of deviation from global scale.
60
+ For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
61
+ With exponent=2.0, the range would be 0.01x to 100x.
62
+ prefix_length
63
+ If specified, the global statistics will be computed using only the prefix length
64
+ requested. This is used for multistep decoding, where we only want to use the
65
+ initial historical data to compute the global statistics. If stabilize_with_global
66
+ is False, this parameter is ignored.
67
+
68
+ Returns
69
+ -------
70
+ tuple[torch.Tensor, torch.Tensor]
71
+ Causal mean and scale tensors, potentially stabilized with global statistics
72
+ """
73
+ # Assert that dim is -1 (last dimension)
74
+ assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)"
75
+
76
+ with torch.no_grad():
77
+ # Apply padding mask to weights
78
+ weights = weights * padding_mask
79
+
80
+ # Try to use higher precision for numerical stability
81
+ try:
82
+ high_precision_data = data.to(torch.float64)
83
+ high_precision_weights = weights.to(torch.float64)
84
+ except TypeError:
85
+ # Fallback for devices that don't support float64
86
+ warnings.warn(
87
+ f"Float64 is not supported by device {data.device}. "
88
+ "Using float32 instead for causal scaler calculations. "
89
+ "This may lead to numerical issues if the data contains extreme values.",
90
+ RuntimeWarning,
91
+ )
92
+ high_precision_data = data.to(torch.float32)
93
+ high_precision_weights = weights.to(torch.float32)
94
+
95
+ # Check if deterministic algorithms are enabled and we're using CUDA.
96
+ # Cumsum operations do not support deterministic mode in CUDA,
97
+ # so we need to disable it for just this section.
98
+ prev_deterministic = torch.are_deterministic_algorithms_enabled()
99
+ if prev_deterministic and data.device.type == "cuda":
100
+ # Disable deterministic algorithms for operations
101
+ torch.use_deterministic_algorithms(False)
102
+
103
+ try:
104
+ # Create weighted data
105
+ weighted_data = high_precision_weights * high_precision_data
106
+
107
+ # Compute cumulative sum of weights and weighted data along time dimension
108
+ cum_weights = torch.cumsum(high_precision_weights, dim=dim)
109
+ cum_values = torch.cumsum(weighted_data, dim=dim)
110
+
111
+ # Avoid division by zero for the first time step or when no valid values
112
+ denominator = cum_weights.clamp_min(1.0)
113
+
114
+ # Compute causal means at each time step
115
+ causal_means = cum_values / denominator
116
+
117
+ # For Welford's algorithm, we need to compute the correction term
118
+ # using the difference between the current value and the current mean
119
+
120
+ # Create shifted version of causal means to compute delta efficiently
121
+ # First item in shifted_means will be zero
122
+ shifted_means = torch.zeros_like(causal_means)
123
+ shifted_means[..., 1:] = causal_means[..., :-1]
124
+
125
+ # Compute delta between current data point and previous mean
126
+ # For t=0, this is just the first data point
127
+ delta = high_precision_data - shifted_means
128
+
129
+ # Compute the increment term for Welford's algorithm.
130
+ # This is defined as the product of the delta and the difference between the current data point and the causal mean.
131
+ # This is where we avoid the traditional E[X²] - E[X]² computation
132
+ increment = delta * (high_precision_data - causal_means) * high_precision_weights
133
+
134
+ # The Welford algorithm uses the term m_2, which is the cumulative sum of the increment term.
135
+ # This is an accumulator that helps us compute the second moment (hence m_2) of the distribution.
136
+ # Compute cumulative sum of the increment term
137
+ m_2 = torch.cumsum(increment, dim=dim)
138
+
139
+ # Compute variance according to Welford's algorithm
140
+ if use_bessel_correction:
141
+ causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0)
142
+ else:
143
+ causal_variance = m_2 / denominator
144
+
145
+ # Add minimum scale but keep in high precision for now
146
+ causal_scale = torch.sqrt(causal_variance + minimum_scale)
147
+
148
+ # Apply stabilization with global statistics if requested
149
+ if stabilize_with_global:
150
+ if prefix_length is not None:
151
+ # Create a prefix mask for global statistics computation
152
+ prefix_mask = torch.zeros_like(weights)
153
+ prefix_mask[..., :prefix_length] = 1.0
154
+
155
+ # Apply prefix mask to restrict computation to prefix
156
+ weighted_data = weighted_data * prefix_mask
157
+ weights = weights * prefix_mask
158
+ padding_mask = padding_mask * prefix_mask
159
+
160
+ # Calculate scale factors from the exponent
161
+ scale_factor_min = 10.0 ** (-scale_factor_exponent)
162
+ scale_factor_max = 10.0**scale_factor_exponent
163
+
164
+ global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0)
165
+ global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator
166
+ global_means = torch.nan_to_num(global_means)
167
+
168
+ global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum(
169
+ dim, keepdim=True
170
+ ) / global_denominator
171
+ global_scale = torch.sqrt(global_variance + minimum_scale)
172
+
173
+ # Expand global statistics to match the time dimension
174
+ expanded_global_scale = global_scale.expand_as(causal_scale)
175
+
176
+ # Define bounds using scale factors
177
+ min_allowed_scale = expanded_global_scale * scale_factor_min
178
+ max_allowed_scale = expanded_global_scale * scale_factor_max
179
+
180
+ # Clamp the causal scale between min_allowed_scale and max_allowed_scale
181
+ causal_scale = torch.clamp(
182
+ causal_scale,
183
+ min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale),
184
+ max=max_allowed_scale,
185
+ )
186
+
187
+ # Now convert means and scale to original dtype after all numerical operations
188
+ causal_means = causal_means.to(data.dtype)
189
+ causal_scale = causal_scale.to(data.dtype)
190
+
191
+ finally:
192
+ # Restore original deterministic setting if it was changed
193
+ if prev_deterministic and data.device.type == "cuda":
194
+ torch.use_deterministic_algorithms(True)
195
+
196
+ return causal_means, causal_scale
197
+
198
+
199
+ class CausalPatchStdMeanScaler(Scaler):
200
+ """
201
+ Causally scales data in patches, where each patch uses statistics computed
202
+ from all data up to and including that patch. Within each patch, all timesteps
203
+ use the same scaling values.
204
+
205
+ This approach provides more stability than per-timestep causal scaling while
206
+ still maintaining the causal property (not using future data).
207
+
208
+ Can optionally stabilize causal statistics using global statistics to prevent
209
+ extreme values, while preserving the causal property.
210
+
211
+ The statistics are computed using Welford's algorithm, which provides better
212
+ numerical stability compared to the direct computation of variance, especially
213
+ when dealing with large values or a large number of data points.
214
+
215
+ Note: This scaler only works with the following constraints:
216
+ - The input must have shape [batch, variates, time_steps]
217
+ - It only operates on the last dimension (-1)
218
+ - The time_steps must be divisible by patch_size
219
+
220
+ Parameters
221
+ ----------
222
+ dim
223
+ dimension along which to compute the causal scale. Must be -1 (the last dimension).
224
+ patch_size
225
+ number of timesteps in each patch
226
+ minimum_scale
227
+ default scale that is used for elements that are constantly zero
228
+ along dimension `dim` or for the first patch.
229
+ use_bessel_correction
230
+ whether to use Bessel's correction to get an unbiased estimator
231
+ stabilize_with_global
232
+ whether to use global statistics to stabilize extreme causal statistics
233
+ scale_factor_exponent
234
+ exponent that controls the allowed range of deviation from global scale.
235
+ For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
236
+ With exponent=2.0, the range would be 0.01x to 100x.
237
+ """
238
+
239
+ @validated()
240
+ def __init__(
241
+ self,
242
+ dim: int = -1,
243
+ patch_size: int = 32,
244
+ minimum_scale: float = 0.1,
245
+ use_bessel_correction: bool = True,
246
+ stabilize_with_global: bool = False,
247
+ scale_factor_exponent: float = 10.0,
248
+ ) -> None:
249
+ super().__init__()
250
+ assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)"
251
+ self.dim = dim
252
+ self.patch_size = patch_size
253
+ self.minimum_scale = minimum_scale
254
+ self.use_bessel_correction = use_bessel_correction
255
+ self.stabilize_with_global = stabilize_with_global
256
+ self.scale_factor_exponent = scale_factor_exponent
257
+
258
+ def __call__( # type: ignore[override]
259
+ self,
260
+ data: torch.Tensor,
261
+ padding_mask: torch.Tensor,
262
+ weights: torch.Tensor,
263
+ prefix_length: Optional[int] = None,
264
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
265
+ assert data.shape == weights.shape, "data and weights must have same shape"
266
+ assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]"
267
+
268
+ with torch.no_grad():
269
+ # Get the number of time steps (last dimension)
270
+ time_steps = data.shape[-1]
271
+
272
+ # Assert that time_steps is divisible by patch_size
273
+ assert time_steps % self.patch_size == 0, (
274
+ f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})"
275
+ )
276
+
277
+ # First compute causal statistics with optional stabilization
278
+ causal_means, causal_scale = compute_causal_statistics(
279
+ data,
280
+ weights,
281
+ padding_mask,
282
+ -1,
283
+ self.minimum_scale,
284
+ self.use_bessel_correction,
285
+ self.stabilize_with_global,
286
+ self.scale_factor_exponent,
287
+ prefix_length,
288
+ )
289
+
290
+ # Unfold the causal means and scales to get the patches
291
+ means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size)
292
+ scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size)
293
+
294
+ # Get the last element of each patch (the most recent statistic)
295
+ patch_stats_means = means_unfolded[..., -1]
296
+ patch_stats_scales = scales_unfolded[..., -1]
297
+
298
+ # Tile the patch statistics across time dimension using einops.repeat
299
+ # With our fixed [batch, variates, num_patches] shape this is much simpler
300
+ patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size)
301
+ patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size)
302
+
303
+ # Apply normalization
304
+ scaled_data = (data - patch_means) / patch_scales
305
+
306
+ return scaled_data, patch_means, patch_scales