autogluon.timeseries 1.4.1b20250907__py3-none-any.whl → 1.5.1b20260122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (95) hide show
  1. autogluon/timeseries/configs/hyperparameter_presets.py +13 -28
  2. autogluon/timeseries/configs/predictor_presets.py +23 -39
  3. autogluon/timeseries/dataset/ts_dataframe.py +97 -86
  4. autogluon/timeseries/learner.py +70 -35
  5. autogluon/timeseries/metrics/__init__.py +4 -4
  6. autogluon/timeseries/metrics/abstract.py +8 -8
  7. autogluon/timeseries/metrics/point.py +9 -9
  8. autogluon/timeseries/metrics/quantile.py +5 -5
  9. autogluon/timeseries/metrics/utils.py +4 -4
  10. autogluon/timeseries/models/__init__.py +4 -1
  11. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -50
  12. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  13. autogluon/timeseries/models/abstract/tunable.py +8 -8
  14. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +58 -62
  15. autogluon/timeseries/models/autogluon_tabular/per_step.py +27 -16
  16. autogluon/timeseries/models/autogluon_tabular/transforms.py +11 -9
  17. autogluon/timeseries/models/chronos/__init__.py +2 -1
  18. autogluon/timeseries/models/chronos/chronos2.py +395 -0
  19. autogluon/timeseries/models/chronos/model.py +127 -89
  20. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +69 -37
  21. autogluon/timeseries/models/ensemble/__init__.py +36 -2
  22. autogluon/timeseries/models/ensemble/abstract.py +14 -46
  23. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  24. autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
  25. autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
  26. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  27. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  28. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  30. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  31. autogluon/timeseries/models/ensemble/{greedy.py → ensemble_selection.py} +41 -61
  32. autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
  33. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  34. autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
  35. autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +25 -22
  36. autogluon/timeseries/models/ensemble/weighted/greedy.py +64 -0
  37. autogluon/timeseries/models/gluonts/abstract.py +32 -31
  38. autogluon/timeseries/models/gluonts/dataset.py +11 -11
  39. autogluon/timeseries/models/gluonts/models.py +0 -7
  40. autogluon/timeseries/models/local/__init__.py +0 -7
  41. autogluon/timeseries/models/local/abstract_local_model.py +15 -18
  42. autogluon/timeseries/models/local/naive.py +2 -2
  43. autogluon/timeseries/models/local/npts.py +7 -1
  44. autogluon/timeseries/models/local/statsforecast.py +13 -13
  45. autogluon/timeseries/models/multi_window/multi_window_model.py +39 -24
  46. autogluon/timeseries/models/registry.py +3 -4
  47. autogluon/timeseries/models/toto/__init__.py +3 -0
  48. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  49. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  50. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  51. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  52. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  56. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  57. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  58. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  59. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  60. autogluon/timeseries/models/toto/dataloader.py +108 -0
  61. autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
  62. autogluon/timeseries/models/toto/model.py +249 -0
  63. autogluon/timeseries/predictor.py +541 -162
  64. autogluon/timeseries/regressor.py +27 -30
  65. autogluon/timeseries/splitter.py +3 -27
  66. autogluon/timeseries/trainer/ensemble_composer.py +444 -0
  67. autogluon/timeseries/trainer/model_set_builder.py +9 -9
  68. autogluon/timeseries/trainer/prediction_cache.py +16 -16
  69. autogluon/timeseries/trainer/trainer.py +300 -279
  70. autogluon/timeseries/trainer/utils.py +17 -0
  71. autogluon/timeseries/transforms/covariate_scaler.py +8 -8
  72. autogluon/timeseries/transforms/target_scaler.py +15 -15
  73. autogluon/timeseries/utils/constants.py +10 -0
  74. autogluon/timeseries/utils/datetime/lags.py +1 -3
  75. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  76. autogluon/timeseries/utils/features.py +31 -14
  77. autogluon/timeseries/utils/forecast.py +6 -7
  78. autogluon/timeseries/utils/timer.py +173 -0
  79. autogluon/timeseries/version.py +1 -1
  80. autogluon.timeseries-1.5.1b20260122-py3.11-nspkg.pth +1 -0
  81. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/METADATA +39 -22
  82. autogluon_timeseries-1.5.1b20260122.dist-info/RECORD +103 -0
  83. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/WHEEL +1 -1
  84. autogluon/timeseries/evaluator.py +0 -6
  85. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
  86. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  87. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
  88. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -580
  89. autogluon.timeseries-1.4.1b20250907-py3.9-nspkg.pth +0 -1
  90. autogluon.timeseries-1.4.1b20250907.dist-info/RECORD +0 -75
  91. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/LICENSE +0 -0
  92. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/NOTICE +0 -0
  93. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/namespace_packages.txt +0 -0
  94. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/top_level.txt +0 -0
  95. {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/zip-safe +0 -0
@@ -0,0 +1,342 @@
1
+ # Source: https://github.com/lucidrains/rotary-embedding-torch
2
+ # MIT License
3
+ #
4
+ # Copyright (c) 2021 Phil Wang
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ from __future__ import annotations
25
+
26
+ from math import pi
27
+ from typing import Literal
28
+
29
+ import torch
30
+ from einops import rearrange, repeat
31
+ from torch import Tensor, broadcast_tensors, einsum, is_tensor, nn, tensor
32
+ from torch.amp import autocast
33
+ from torch.nn import Module
34
+
35
+ # helper functions
36
+
37
+
38
+ def exists(val):
39
+ return val is not None
40
+
41
+
42
+ def default(val, d):
43
+ return val if exists(val) else d
44
+
45
+
46
+ def slice_at_dim(t, dim_slice: slice, *, dim):
47
+ dim += t.ndim if dim < 0 else 0
48
+ colons = [slice(None)] * t.ndim
49
+ colons[dim] = dim_slice
50
+ return t[tuple(colons)]
51
+
52
+
53
+ # rotary embedding helper functions
54
+
55
+
56
+ def rotate_half(x):
57
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
58
+ x1, x2 = x.unbind(dim=-1)
59
+ x = torch.stack((-x2, x1), dim=-1)
60
+ return rearrange(x, "... d r -> ... (d r)")
61
+
62
+
63
+ @autocast("cuda", enabled=False)
64
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None):
65
+ dtype = t.dtype
66
+
67
+ if not exists(freqs_seq_dim):
68
+ if freqs.ndim == 2 or t.ndim == 3:
69
+ freqs_seq_dim = 0
70
+
71
+ if t.ndim == 3 or exists(freqs_seq_dim):
72
+ seq_len = t.shape[seq_dim]
73
+ freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
74
+
75
+ rot_dim = freqs.shape[-1]
76
+ end_index = start_index + rot_dim
77
+
78
+ assert rot_dim <= t.shape[-1], (
79
+ f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
80
+ )
81
+
82
+ # Split t into three parts: left, middle (to be transformed), and right
83
+ t_left = t[..., :start_index]
84
+ t_middle = t[..., start_index:end_index]
85
+ t_right = t[..., end_index:]
86
+
87
+ # Apply rotary embeddings without modifying t in place
88
+ t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
89
+
90
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
91
+
92
+ return out.type(dtype)
93
+
94
+
95
+ # learned rotation helpers
96
+
97
+
98
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
99
+ if exists(freq_ranges):
100
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
101
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
102
+
103
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
104
+ return apply_rotary_emb(rotations, t, start_index=start_index)
105
+
106
+
107
+ # classes
108
+
109
+
110
+ class RotaryEmbedding(Module):
111
+ def __init__(
112
+ self,
113
+ dim,
114
+ custom_freqs: Tensor | None = None,
115
+ freqs_for: Literal["lang", "pixel", "constant"] = "lang",
116
+ theta=10000,
117
+ max_freq=10,
118
+ num_freqs=1,
119
+ learned_freq=False,
120
+ use_xpos=False,
121
+ xpos_scale_base=512,
122
+ interpolate_factor=1.0,
123
+ theta_rescale_factor=1.0,
124
+ seq_before_head_dim=False,
125
+ cache_if_possible=True,
126
+ cache_max_seq_len=8192,
127
+ ):
128
+ super().__init__()
129
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
130
+ # has some connection to NTK literature
131
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
132
+
133
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
134
+
135
+ self.freqs_for = freqs_for
136
+
137
+ if exists(custom_freqs):
138
+ freqs = custom_freqs
139
+ elif freqs_for == "lang":
140
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
141
+ elif freqs_for == "pixel":
142
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
143
+ elif freqs_for == "constant":
144
+ freqs = torch.ones(num_freqs).float()
145
+
146
+ self.cache_if_possible = cache_if_possible
147
+ self.cache_max_seq_len = cache_max_seq_len
148
+
149
+ self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
150
+ self.cached_freqs_seq_len = 0
151
+
152
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
153
+
154
+ self.learned_freq = learned_freq
155
+
156
+ # dummy for device
157
+
158
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
159
+
160
+ # default sequence dimension
161
+
162
+ self.seq_before_head_dim = seq_before_head_dim
163
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
164
+
165
+ # interpolation factors
166
+
167
+ assert interpolate_factor >= 1.0
168
+ self.interpolate_factor = interpolate_factor
169
+
170
+ # xpos
171
+
172
+ self.use_xpos = use_xpos
173
+
174
+ if not use_xpos:
175
+ return
176
+
177
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
178
+ self.scale_base = xpos_scale_base
179
+
180
+ self.register_buffer("scale", scale, persistent=False)
181
+ self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
182
+ self.cached_scales_seq_len = 0
183
+
184
+ # add apply_rotary_emb as static method
185
+
186
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
187
+
188
+ @property
189
+ def device(self):
190
+ return self.dummy.device
191
+
192
+ def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
193
+ device = default(device, self.device)
194
+ dtype = default(dtype, self.cached_freqs.dtype)
195
+
196
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
197
+
198
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
199
+ seq_dim = default(seq_dim, self.default_seq_dim)
200
+
201
+ assert not self.use_xpos or exists(scale), (
202
+ "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
203
+ )
204
+
205
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
206
+
207
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
208
+
209
+ freqs = self.forward(seq, seq_len=seq_len, offset=offset)
210
+
211
+ if seq_dim == -3:
212
+ freqs = rearrange(freqs, "n d -> n 1 d")
213
+
214
+ return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
215
+
216
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
217
+ dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
218
+
219
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
220
+ assert q_len <= k_len
221
+
222
+ q_scale = k_scale = 1.0
223
+
224
+ if self.use_xpos:
225
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
226
+
227
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
228
+ k_scale = self.get_scale(seq).type(dtype)
229
+
230
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
231
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
232
+
233
+ rotated_q = rotated_q.type(q.dtype)
234
+ rotated_k = rotated_k.type(k.dtype)
235
+
236
+ return rotated_q, rotated_k
237
+
238
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
239
+ seq_dim = default(seq_dim, self.default_seq_dim)
240
+
241
+ assert self.use_xpos
242
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
243
+
244
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
245
+
246
+ freqs = self.forward(seq, seq_len=seq_len)
247
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
248
+
249
+ if seq_dim == -3:
250
+ freqs = rearrange(freqs, "n d -> n 1 d")
251
+ scale = rearrange(scale, "n d -> n 1 d")
252
+
253
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
254
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
255
+
256
+ rotated_q = rotated_q.type(q.dtype)
257
+ rotated_k = rotated_k.type(k.dtype)
258
+
259
+ return rotated_q, rotated_k
260
+
261
+ def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
262
+ assert self.use_xpos
263
+
264
+ should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
265
+
266
+ if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len:
267
+ return self.cached_scales[offset : (offset + seq_len)]
268
+
269
+ scale = 1.0
270
+ if self.use_xpos:
271
+ power = (t - len(t) // 2) / self.scale_base
272
+ scale = self.scale ** rearrange(power, "n -> n 1")
273
+ scale = repeat(scale, "n d -> n (d r)", r=2)
274
+
275
+ if should_cache and offset == 0:
276
+ self.cached_scales[:seq_len] = scale.detach()
277
+ self.cached_scales_seq_len = seq_len
278
+
279
+ return scale
280
+
281
+ def get_axial_freqs(self, *dims, offsets: (tuple[int | float, ...] | Tensor | None) = None):
282
+ Colon = slice(None)
283
+ all_freqs = []
284
+
285
+ # handle offset
286
+
287
+ if exists(offsets):
288
+ if not is_tensor(offsets):
289
+ offsets = tensor(offsets)
290
+
291
+ assert len(offsets) == len(dims)
292
+
293
+ # get frequencies for each axis
294
+
295
+ for ind, dim in enumerate(dims):
296
+ offset = 0
297
+ if exists(offsets):
298
+ offset = offsets[ind]
299
+
300
+ if self.freqs_for == "pixel":
301
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
302
+ else:
303
+ pos = torch.arange(dim, device=self.device)
304
+
305
+ pos = pos + offset
306
+
307
+ freqs = self.forward(pos, seq_len=dim)
308
+
309
+ all_axis = [None] * len(dims)
310
+ all_axis[ind] = Colon
311
+
312
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
313
+ all_freqs.append(freqs[new_axis_slice])
314
+
315
+ # concat all freqs
316
+
317
+ all_freqs = broadcast_tensors(*all_freqs)
318
+ return torch.cat(all_freqs, dim=-1)
319
+
320
+ @autocast("cuda", enabled=False)
321
+ def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
322
+ should_cache = (
323
+ self.cache_if_possible
324
+ and not self.learned_freq
325
+ and exists(seq_len)
326
+ and self.freqs_for != "pixel"
327
+ and (offset + seq_len) <= self.cache_max_seq_len
328
+ )
329
+
330
+ if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len:
331
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
332
+
333
+ freqs = self.freqs
334
+
335
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
336
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
337
+
338
+ if should_cache and offset == 0:
339
+ self.cached_freqs[:seq_len] = freqs.detach()
340
+ self.cached_freqs_seq_len = seq_len
341
+
342
+ return freqs
@@ -0,0 +1,305 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ import warnings
7
+
8
+ import torch
9
+ from einops import repeat
10
+ from gluonts.core.component import validated
11
+ from gluonts.torch.scaler import Scaler
12
+
13
+
14
+ def compute_causal_statistics(
15
+ data: torch.Tensor,
16
+ weights: torch.Tensor,
17
+ padding_mask: torch.Tensor,
18
+ dim: int,
19
+ minimum_scale: float,
20
+ use_bessel_correction: bool = True,
21
+ stabilize_with_global: bool = False,
22
+ scale_factor_exponent: float = 10.0,
23
+ prefix_length: int | None = None,
24
+ ) -> tuple[torch.Tensor, torch.Tensor]:
25
+ """
26
+ Compute causal mean and scale statistics along a specified dimension using
27
+ a vectorized implementation of Welford's algorithm for numerical stability.
28
+
29
+ This implementation avoids explicit loops while maintaining the numerical stability
30
+ of Welford's algorithm, achieving better performance with the same robustness
31
+ against overflow issues.
32
+
33
+
34
+ Can optionally use global statistics to stabilize causal statistics by clamping
35
+ extreme values, preventing instability while preserving a relaxed version of the
36
+ causal property. This allows a controlled amount of future information leakage,
37
+ introducing an explicit tradeoff between causality and stability.
38
+ extreme values, preventing instability while preserving the causal property.
39
+
40
+ Parameters
41
+ ----------
42
+ data
43
+ The input data tensor
44
+ weights
45
+ The weight tensor (same shape as data)
46
+ padding_mask
47
+ The padding mask tensor (same shape as data)
48
+ dim
49
+ The dimension along which to compute statistics (must be -1, the time dimension)
50
+ minimum_scale
51
+ Minimum scale value to use
52
+ use_bessel_correction
53
+ Whether to use Bessel's correction to get an unbiased estimator
54
+ stabilize_with_global
55
+ Whether to use global statistics to stabilize the causal statistics by clamping
56
+ extreme values
57
+ scale_factor_exponent
58
+ Exponent that controls the allowed range of deviation from global scale.
59
+ For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
60
+ With exponent=2.0, the range would be 0.01x to 100x.
61
+ prefix_length
62
+ If specified, the global statistics will be computed using only the prefix length
63
+ requested. This is used for multistep decoding, where we only want to use the
64
+ initial historical data to compute the global statistics. If stabilize_with_global
65
+ is False, this parameter is ignored.
66
+
67
+ Returns
68
+ -------
69
+ tuple[torch.Tensor, torch.Tensor]
70
+ Causal mean and scale tensors, potentially stabilized with global statistics
71
+ """
72
+ # Assert that dim is -1 (last dimension)
73
+ assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)"
74
+
75
+ with torch.no_grad():
76
+ # Apply padding mask to weights
77
+ weights = weights * padding_mask
78
+
79
+ # Try to use higher precision for numerical stability
80
+ try:
81
+ high_precision_data = data.to(torch.float64)
82
+ high_precision_weights = weights.to(torch.float64)
83
+ except TypeError:
84
+ # Fallback for devices that don't support float64
85
+ warnings.warn(
86
+ f"Float64 is not supported by device {data.device}. "
87
+ "Using float32 instead for causal scaler calculations. "
88
+ "This may lead to numerical issues if the data contains extreme values.",
89
+ RuntimeWarning,
90
+ )
91
+ high_precision_data = data.to(torch.float32)
92
+ high_precision_weights = weights.to(torch.float32)
93
+
94
+ # Check if deterministic algorithms are enabled and we're using CUDA.
95
+ # Cumsum operations do not support deterministic mode in CUDA,
96
+ # so we need to disable it for just this section.
97
+ prev_deterministic = torch.are_deterministic_algorithms_enabled()
98
+ if prev_deterministic and data.device.type == "cuda":
99
+ # Disable deterministic algorithms for operations
100
+ torch.use_deterministic_algorithms(False)
101
+
102
+ try:
103
+ # Create weighted data
104
+ weighted_data = high_precision_weights * high_precision_data
105
+
106
+ # Compute cumulative sum of weights and weighted data along time dimension
107
+ cum_weights = torch.cumsum(high_precision_weights, dim=dim)
108
+ cum_values = torch.cumsum(weighted_data, dim=dim)
109
+
110
+ # Avoid division by zero for the first time step or when no valid values
111
+ denominator = cum_weights.clamp_min(1.0)
112
+
113
+ # Compute causal means at each time step
114
+ causal_means = cum_values / denominator
115
+
116
+ # For Welford's algorithm, we need to compute the correction term
117
+ # using the difference between the current value and the current mean
118
+
119
+ # Create shifted version of causal means to compute delta efficiently
120
+ # First item in shifted_means will be zero
121
+ shifted_means = torch.zeros_like(causal_means)
122
+ shifted_means[..., 1:] = causal_means[..., :-1]
123
+
124
+ # Compute delta between current data point and previous mean
125
+ # For t=0, this is just the first data point
126
+ delta = high_precision_data - shifted_means
127
+
128
+ # Compute the increment term for Welford's algorithm.
129
+ # This is defined as the product of the delta and the difference between the current data point and the causal mean.
130
+ # This is where we avoid the traditional E[X²] - E[X]² computation
131
+ increment = delta * (high_precision_data - causal_means) * high_precision_weights
132
+
133
+ # The Welford algorithm uses the term m_2, which is the cumulative sum of the increment term.
134
+ # This is an accumulator that helps us compute the second moment (hence m_2) of the distribution.
135
+ # Compute cumulative sum of the increment term
136
+ m_2 = torch.cumsum(increment, dim=dim)
137
+
138
+ # Compute variance according to Welford's algorithm
139
+ if use_bessel_correction:
140
+ causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0)
141
+ else:
142
+ causal_variance = m_2 / denominator
143
+
144
+ # Add minimum scale but keep in high precision for now
145
+ causal_scale = torch.sqrt(causal_variance + minimum_scale)
146
+
147
+ # Apply stabilization with global statistics if requested
148
+ if stabilize_with_global:
149
+ if prefix_length is not None:
150
+ # Create a prefix mask for global statistics computation
151
+ prefix_mask = torch.zeros_like(weights)
152
+ prefix_mask[..., :prefix_length] = 1.0
153
+
154
+ # Apply prefix mask to restrict computation to prefix
155
+ weighted_data = weighted_data * prefix_mask
156
+ weights = weights * prefix_mask
157
+ padding_mask = padding_mask * prefix_mask
158
+
159
+ # Calculate scale factors from the exponent
160
+ scale_factor_min = 10.0 ** (-scale_factor_exponent)
161
+ scale_factor_max = 10.0**scale_factor_exponent
162
+
163
+ global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0)
164
+ global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator
165
+ global_means = torch.nan_to_num(global_means)
166
+
167
+ global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum(
168
+ dim, keepdim=True
169
+ ) / global_denominator
170
+ global_scale = torch.sqrt(global_variance + minimum_scale)
171
+
172
+ # Expand global statistics to match the time dimension
173
+ expanded_global_scale = global_scale.expand_as(causal_scale)
174
+
175
+ # Define bounds using scale factors
176
+ min_allowed_scale = expanded_global_scale * scale_factor_min
177
+ max_allowed_scale = expanded_global_scale * scale_factor_max
178
+
179
+ # Clamp the causal scale between min_allowed_scale and max_allowed_scale
180
+ causal_scale = torch.clamp(
181
+ causal_scale,
182
+ min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale),
183
+ max=max_allowed_scale,
184
+ )
185
+
186
+ # Now convert means and scale to original dtype after all numerical operations
187
+ causal_means = causal_means.to(data.dtype)
188
+ causal_scale = causal_scale.to(data.dtype)
189
+
190
+ finally:
191
+ # Restore original deterministic setting if it was changed
192
+ if prev_deterministic and data.device.type == "cuda":
193
+ torch.use_deterministic_algorithms(True)
194
+
195
+ return causal_means, causal_scale
196
+
197
+
198
+ class CausalPatchStdMeanScaler(Scaler):
199
+ """
200
+ Causally scales data in patches, where each patch uses statistics computed
201
+ from all data up to and including that patch. Within each patch, all timesteps
202
+ use the same scaling values.
203
+
204
+ This approach provides more stability than per-timestep causal scaling while
205
+ still maintaining the causal property (not using future data).
206
+
207
+ Can optionally stabilize causal statistics using global statistics to prevent
208
+ extreme values, while preserving the causal property.
209
+
210
+ The statistics are computed using Welford's algorithm, which provides better
211
+ numerical stability compared to the direct computation of variance, especially
212
+ when dealing with large values or a large number of data points.
213
+
214
+ Note: This scaler only works with the following constraints:
215
+ - The input must have shape [batch, variates, time_steps]
216
+ - It only operates on the last dimension (-1)
217
+ - The time_steps must be divisible by patch_size
218
+
219
+ Parameters
220
+ ----------
221
+ dim
222
+ dimension along which to compute the causal scale. Must be -1 (the last dimension).
223
+ patch_size
224
+ number of timesteps in each patch
225
+ minimum_scale
226
+ default scale that is used for elements that are constantly zero
227
+ along dimension `dim` or for the first patch.
228
+ use_bessel_correction
229
+ whether to use Bessel's correction to get an unbiased estimator
230
+ stabilize_with_global
231
+ whether to use global statistics to stabilize extreme causal statistics
232
+ scale_factor_exponent
233
+ exponent that controls the allowed range of deviation from global scale.
234
+ For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
235
+ With exponent=2.0, the range would be 0.01x to 100x.
236
+ """
237
+
238
+ @validated()
239
+ def __init__(
240
+ self,
241
+ dim: int = -1,
242
+ patch_size: int = 32,
243
+ minimum_scale: float = 0.1,
244
+ use_bessel_correction: bool = True,
245
+ stabilize_with_global: bool = False,
246
+ scale_factor_exponent: float = 10.0,
247
+ ) -> None:
248
+ super().__init__()
249
+ assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)"
250
+ self.dim = dim
251
+ self.patch_size = patch_size
252
+ self.minimum_scale = minimum_scale
253
+ self.use_bessel_correction = use_bessel_correction
254
+ self.stabilize_with_global = stabilize_with_global
255
+ self.scale_factor_exponent = scale_factor_exponent
256
+
257
+ def __call__( # type: ignore[override]
258
+ self,
259
+ data: torch.Tensor,
260
+ padding_mask: torch.Tensor,
261
+ weights: torch.Tensor,
262
+ prefix_length: int | None = None,
263
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
264
+ assert data.shape == weights.shape, "data and weights must have same shape"
265
+ assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]"
266
+
267
+ with torch.no_grad():
268
+ # Get the number of time steps (last dimension)
269
+ time_steps = data.shape[-1]
270
+
271
+ # Assert that time_steps is divisible by patch_size
272
+ assert time_steps % self.patch_size == 0, (
273
+ f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})"
274
+ )
275
+
276
+ # First compute causal statistics with optional stabilization
277
+ causal_means, causal_scale = compute_causal_statistics(
278
+ data,
279
+ weights,
280
+ padding_mask,
281
+ -1,
282
+ self.minimum_scale,
283
+ self.use_bessel_correction,
284
+ self.stabilize_with_global,
285
+ self.scale_factor_exponent,
286
+ prefix_length,
287
+ )
288
+
289
+ # Unfold the causal means and scales to get the patches
290
+ means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size)
291
+ scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size)
292
+
293
+ # Get the last element of each patch (the most recent statistic)
294
+ patch_stats_means = means_unfolded[..., -1]
295
+ patch_stats_scales = scales_unfolded[..., -1]
296
+
297
+ # Tile the patch statistics across time dimension using einops.repeat
298
+ # With our fixed [batch, variates, num_patches] shape this is much simpler
299
+ patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size)
300
+ patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size)
301
+
302
+ # Apply normalization
303
+ scaled_data = (data - patch_means) / patch_scales
304
+
305
+ return scaled_data, patch_means, patch_scales