autogluon.timeseries 1.4.1b20250907__py3-none-any.whl → 1.5.1b20260122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/hyperparameter_presets.py +13 -28
- autogluon/timeseries/configs/predictor_presets.py +23 -39
- autogluon/timeseries/dataset/ts_dataframe.py +97 -86
- autogluon/timeseries/learner.py +70 -35
- autogluon/timeseries/metrics/__init__.py +4 -4
- autogluon/timeseries/metrics/abstract.py +8 -8
- autogluon/timeseries/metrics/point.py +9 -9
- autogluon/timeseries/metrics/quantile.py +5 -5
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +4 -1
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -50
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +8 -8
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +58 -62
- autogluon/timeseries/models/autogluon_tabular/per_step.py +27 -16
- autogluon/timeseries/models/autogluon_tabular/transforms.py +11 -9
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +395 -0
- autogluon/timeseries/models/chronos/model.py +127 -89
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +69 -37
- autogluon/timeseries/models/ensemble/__init__.py +36 -2
- autogluon/timeseries/models/ensemble/abstract.py +14 -46
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/{greedy.py → ensemble_selection.py} +41 -61
- autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
- autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +25 -22
- autogluon/timeseries/models/ensemble/weighted/greedy.py +64 -0
- autogluon/timeseries/models/gluonts/abstract.py +32 -31
- autogluon/timeseries/models/gluonts/dataset.py +11 -11
- autogluon/timeseries/models/gluonts/models.py +0 -7
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +15 -18
- autogluon/timeseries/models/local/naive.py +2 -2
- autogluon/timeseries/models/local/npts.py +7 -1
- autogluon/timeseries/models/local/statsforecast.py +13 -13
- autogluon/timeseries/models/multi_window/multi_window_model.py +39 -24
- autogluon/timeseries/models/registry.py +3 -4
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
- autogluon/timeseries/models/toto/model.py +249 -0
- autogluon/timeseries/predictor.py +541 -162
- autogluon/timeseries/regressor.py +27 -30
- autogluon/timeseries/splitter.py +3 -27
- autogluon/timeseries/trainer/ensemble_composer.py +444 -0
- autogluon/timeseries/trainer/model_set_builder.py +9 -9
- autogluon/timeseries/trainer/prediction_cache.py +16 -16
- autogluon/timeseries/trainer/trainer.py +300 -279
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/covariate_scaler.py +8 -8
- autogluon/timeseries/transforms/target_scaler.py +15 -15
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +1 -3
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/features.py +31 -14
- autogluon/timeseries/utils/forecast.py +6 -7
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.5.1b20260122-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/METADATA +39 -22
- autogluon_timeseries-1.5.1b20260122.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/WHEEL +1 -1
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -580
- autogluon.timeseries-1.4.1b20250907-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.4.1b20250907.dist-info/RECORD +0 -75
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
# Source: https://github.com/lucidrains/rotary-embedding-torch
|
|
2
|
+
# MIT License
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2021 Phil Wang
|
|
5
|
+
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from math import pi
|
|
27
|
+
from typing import Literal
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
from einops import rearrange, repeat
|
|
31
|
+
from torch import Tensor, broadcast_tensors, einsum, is_tensor, nn, tensor
|
|
32
|
+
from torch.amp import autocast
|
|
33
|
+
from torch.nn import Module
|
|
34
|
+
|
|
35
|
+
# helper functions
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def exists(val):
|
|
39
|
+
return val is not None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def default(val, d):
|
|
43
|
+
return val if exists(val) else d
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def slice_at_dim(t, dim_slice: slice, *, dim):
|
|
47
|
+
dim += t.ndim if dim < 0 else 0
|
|
48
|
+
colons = [slice(None)] * t.ndim
|
|
49
|
+
colons[dim] = dim_slice
|
|
50
|
+
return t[tuple(colons)]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# rotary embedding helper functions
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def rotate_half(x):
|
|
57
|
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
|
58
|
+
x1, x2 = x.unbind(dim=-1)
|
|
59
|
+
x = torch.stack((-x2, x1), dim=-1)
|
|
60
|
+
return rearrange(x, "... d r -> ... (d r)")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@autocast("cuda", enabled=False)
|
|
64
|
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None):
|
|
65
|
+
dtype = t.dtype
|
|
66
|
+
|
|
67
|
+
if not exists(freqs_seq_dim):
|
|
68
|
+
if freqs.ndim == 2 or t.ndim == 3:
|
|
69
|
+
freqs_seq_dim = 0
|
|
70
|
+
|
|
71
|
+
if t.ndim == 3 or exists(freqs_seq_dim):
|
|
72
|
+
seq_len = t.shape[seq_dim]
|
|
73
|
+
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
|
|
74
|
+
|
|
75
|
+
rot_dim = freqs.shape[-1]
|
|
76
|
+
end_index = start_index + rot_dim
|
|
77
|
+
|
|
78
|
+
assert rot_dim <= t.shape[-1], (
|
|
79
|
+
f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Split t into three parts: left, middle (to be transformed), and right
|
|
83
|
+
t_left = t[..., :start_index]
|
|
84
|
+
t_middle = t[..., start_index:end_index]
|
|
85
|
+
t_right = t[..., end_index:]
|
|
86
|
+
|
|
87
|
+
# Apply rotary embeddings without modifying t in place
|
|
88
|
+
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
|
89
|
+
|
|
90
|
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
|
91
|
+
|
|
92
|
+
return out.type(dtype)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# learned rotation helpers
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
|
99
|
+
if exists(freq_ranges):
|
|
100
|
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
|
101
|
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
|
102
|
+
|
|
103
|
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
|
104
|
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# classes
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RotaryEmbedding(Module):
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
dim,
|
|
114
|
+
custom_freqs: Tensor | None = None,
|
|
115
|
+
freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
|
116
|
+
theta=10000,
|
|
117
|
+
max_freq=10,
|
|
118
|
+
num_freqs=1,
|
|
119
|
+
learned_freq=False,
|
|
120
|
+
use_xpos=False,
|
|
121
|
+
xpos_scale_base=512,
|
|
122
|
+
interpolate_factor=1.0,
|
|
123
|
+
theta_rescale_factor=1.0,
|
|
124
|
+
seq_before_head_dim=False,
|
|
125
|
+
cache_if_possible=True,
|
|
126
|
+
cache_max_seq_len=8192,
|
|
127
|
+
):
|
|
128
|
+
super().__init__()
|
|
129
|
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
130
|
+
# has some connection to NTK literature
|
|
131
|
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
132
|
+
|
|
133
|
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
134
|
+
|
|
135
|
+
self.freqs_for = freqs_for
|
|
136
|
+
|
|
137
|
+
if exists(custom_freqs):
|
|
138
|
+
freqs = custom_freqs
|
|
139
|
+
elif freqs_for == "lang":
|
|
140
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
141
|
+
elif freqs_for == "pixel":
|
|
142
|
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
|
143
|
+
elif freqs_for == "constant":
|
|
144
|
+
freqs = torch.ones(num_freqs).float()
|
|
145
|
+
|
|
146
|
+
self.cache_if_possible = cache_if_possible
|
|
147
|
+
self.cache_max_seq_len = cache_max_seq_len
|
|
148
|
+
|
|
149
|
+
self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
|
150
|
+
self.cached_freqs_seq_len = 0
|
|
151
|
+
|
|
152
|
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
|
153
|
+
|
|
154
|
+
self.learned_freq = learned_freq
|
|
155
|
+
|
|
156
|
+
# dummy for device
|
|
157
|
+
|
|
158
|
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
|
159
|
+
|
|
160
|
+
# default sequence dimension
|
|
161
|
+
|
|
162
|
+
self.seq_before_head_dim = seq_before_head_dim
|
|
163
|
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
|
164
|
+
|
|
165
|
+
# interpolation factors
|
|
166
|
+
|
|
167
|
+
assert interpolate_factor >= 1.0
|
|
168
|
+
self.interpolate_factor = interpolate_factor
|
|
169
|
+
|
|
170
|
+
# xpos
|
|
171
|
+
|
|
172
|
+
self.use_xpos = use_xpos
|
|
173
|
+
|
|
174
|
+
if not use_xpos:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
178
|
+
self.scale_base = xpos_scale_base
|
|
179
|
+
|
|
180
|
+
self.register_buffer("scale", scale, persistent=False)
|
|
181
|
+
self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
|
182
|
+
self.cached_scales_seq_len = 0
|
|
183
|
+
|
|
184
|
+
# add apply_rotary_emb as static method
|
|
185
|
+
|
|
186
|
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def device(self):
|
|
190
|
+
return self.dummy.device
|
|
191
|
+
|
|
192
|
+
def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
|
|
193
|
+
device = default(device, self.device)
|
|
194
|
+
dtype = default(dtype, self.cached_freqs.dtype)
|
|
195
|
+
|
|
196
|
+
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
|
197
|
+
|
|
198
|
+
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
|
|
199
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
200
|
+
|
|
201
|
+
assert not self.use_xpos or exists(scale), (
|
|
202
|
+
"you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
|
206
|
+
|
|
207
|
+
seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
|
|
208
|
+
|
|
209
|
+
freqs = self.forward(seq, seq_len=seq_len, offset=offset)
|
|
210
|
+
|
|
211
|
+
if seq_dim == -3:
|
|
212
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
213
|
+
|
|
214
|
+
return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
|
|
215
|
+
|
|
216
|
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
|
217
|
+
dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
|
|
218
|
+
|
|
219
|
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
|
220
|
+
assert q_len <= k_len
|
|
221
|
+
|
|
222
|
+
q_scale = k_scale = 1.0
|
|
223
|
+
|
|
224
|
+
if self.use_xpos:
|
|
225
|
+
seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
|
|
226
|
+
|
|
227
|
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
|
228
|
+
k_scale = self.get_scale(seq).type(dtype)
|
|
229
|
+
|
|
230
|
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
|
|
231
|
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
|
|
232
|
+
|
|
233
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
234
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
235
|
+
|
|
236
|
+
return rotated_q, rotated_k
|
|
237
|
+
|
|
238
|
+
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
|
239
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
240
|
+
|
|
241
|
+
assert self.use_xpos
|
|
242
|
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
|
243
|
+
|
|
244
|
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
|
245
|
+
|
|
246
|
+
freqs = self.forward(seq, seq_len=seq_len)
|
|
247
|
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
|
248
|
+
|
|
249
|
+
if seq_dim == -3:
|
|
250
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
251
|
+
scale = rearrange(scale, "n d -> n 1 d")
|
|
252
|
+
|
|
253
|
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
|
254
|
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
|
255
|
+
|
|
256
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
257
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
258
|
+
|
|
259
|
+
return rotated_q, rotated_k
|
|
260
|
+
|
|
261
|
+
def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
262
|
+
assert self.use_xpos
|
|
263
|
+
|
|
264
|
+
should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
|
|
265
|
+
|
|
266
|
+
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len:
|
|
267
|
+
return self.cached_scales[offset : (offset + seq_len)]
|
|
268
|
+
|
|
269
|
+
scale = 1.0
|
|
270
|
+
if self.use_xpos:
|
|
271
|
+
power = (t - len(t) // 2) / self.scale_base
|
|
272
|
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
|
273
|
+
scale = repeat(scale, "n d -> n (d r)", r=2)
|
|
274
|
+
|
|
275
|
+
if should_cache and offset == 0:
|
|
276
|
+
self.cached_scales[:seq_len] = scale.detach()
|
|
277
|
+
self.cached_scales_seq_len = seq_len
|
|
278
|
+
|
|
279
|
+
return scale
|
|
280
|
+
|
|
281
|
+
def get_axial_freqs(self, *dims, offsets: (tuple[int | float, ...] | Tensor | None) = None):
|
|
282
|
+
Colon = slice(None)
|
|
283
|
+
all_freqs = []
|
|
284
|
+
|
|
285
|
+
# handle offset
|
|
286
|
+
|
|
287
|
+
if exists(offsets):
|
|
288
|
+
if not is_tensor(offsets):
|
|
289
|
+
offsets = tensor(offsets)
|
|
290
|
+
|
|
291
|
+
assert len(offsets) == len(dims)
|
|
292
|
+
|
|
293
|
+
# get frequencies for each axis
|
|
294
|
+
|
|
295
|
+
for ind, dim in enumerate(dims):
|
|
296
|
+
offset = 0
|
|
297
|
+
if exists(offsets):
|
|
298
|
+
offset = offsets[ind]
|
|
299
|
+
|
|
300
|
+
if self.freqs_for == "pixel":
|
|
301
|
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
|
302
|
+
else:
|
|
303
|
+
pos = torch.arange(dim, device=self.device)
|
|
304
|
+
|
|
305
|
+
pos = pos + offset
|
|
306
|
+
|
|
307
|
+
freqs = self.forward(pos, seq_len=dim)
|
|
308
|
+
|
|
309
|
+
all_axis = [None] * len(dims)
|
|
310
|
+
all_axis[ind] = Colon
|
|
311
|
+
|
|
312
|
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
|
313
|
+
all_freqs.append(freqs[new_axis_slice])
|
|
314
|
+
|
|
315
|
+
# concat all freqs
|
|
316
|
+
|
|
317
|
+
all_freqs = broadcast_tensors(*all_freqs)
|
|
318
|
+
return torch.cat(all_freqs, dim=-1)
|
|
319
|
+
|
|
320
|
+
@autocast("cuda", enabled=False)
|
|
321
|
+
def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
322
|
+
should_cache = (
|
|
323
|
+
self.cache_if_possible
|
|
324
|
+
and not self.learned_freq
|
|
325
|
+
and exists(seq_len)
|
|
326
|
+
and self.freqs_for != "pixel"
|
|
327
|
+
and (offset + seq_len) <= self.cache_max_seq_len
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len:
|
|
331
|
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
|
332
|
+
|
|
333
|
+
freqs = self.freqs
|
|
334
|
+
|
|
335
|
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
|
336
|
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
|
337
|
+
|
|
338
|
+
if should_cache and offset == 0:
|
|
339
|
+
self.cached_freqs[:seq_len] = freqs.detach()
|
|
340
|
+
self.cached_freqs_seq_len = seq_len
|
|
341
|
+
|
|
342
|
+
return freqs
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
import warnings
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from einops import repeat
|
|
10
|
+
from gluonts.core.component import validated
|
|
11
|
+
from gluonts.torch.scaler import Scaler
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def compute_causal_statistics(
|
|
15
|
+
data: torch.Tensor,
|
|
16
|
+
weights: torch.Tensor,
|
|
17
|
+
padding_mask: torch.Tensor,
|
|
18
|
+
dim: int,
|
|
19
|
+
minimum_scale: float,
|
|
20
|
+
use_bessel_correction: bool = True,
|
|
21
|
+
stabilize_with_global: bool = False,
|
|
22
|
+
scale_factor_exponent: float = 10.0,
|
|
23
|
+
prefix_length: int | None = None,
|
|
24
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
25
|
+
"""
|
|
26
|
+
Compute causal mean and scale statistics along a specified dimension using
|
|
27
|
+
a vectorized implementation of Welford's algorithm for numerical stability.
|
|
28
|
+
|
|
29
|
+
This implementation avoids explicit loops while maintaining the numerical stability
|
|
30
|
+
of Welford's algorithm, achieving better performance with the same robustness
|
|
31
|
+
against overflow issues.
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
Can optionally use global statistics to stabilize causal statistics by clamping
|
|
35
|
+
extreme values, preventing instability while preserving a relaxed version of the
|
|
36
|
+
causal property. This allows a controlled amount of future information leakage,
|
|
37
|
+
introducing an explicit tradeoff between causality and stability.
|
|
38
|
+
extreme values, preventing instability while preserving the causal property.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
data
|
|
43
|
+
The input data tensor
|
|
44
|
+
weights
|
|
45
|
+
The weight tensor (same shape as data)
|
|
46
|
+
padding_mask
|
|
47
|
+
The padding mask tensor (same shape as data)
|
|
48
|
+
dim
|
|
49
|
+
The dimension along which to compute statistics (must be -1, the time dimension)
|
|
50
|
+
minimum_scale
|
|
51
|
+
Minimum scale value to use
|
|
52
|
+
use_bessel_correction
|
|
53
|
+
Whether to use Bessel's correction to get an unbiased estimator
|
|
54
|
+
stabilize_with_global
|
|
55
|
+
Whether to use global statistics to stabilize the causal statistics by clamping
|
|
56
|
+
extreme values
|
|
57
|
+
scale_factor_exponent
|
|
58
|
+
Exponent that controls the allowed range of deviation from global scale.
|
|
59
|
+
For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
|
|
60
|
+
With exponent=2.0, the range would be 0.01x to 100x.
|
|
61
|
+
prefix_length
|
|
62
|
+
If specified, the global statistics will be computed using only the prefix length
|
|
63
|
+
requested. This is used for multistep decoding, where we only want to use the
|
|
64
|
+
initial historical data to compute the global statistics. If stabilize_with_global
|
|
65
|
+
is False, this parameter is ignored.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
70
|
+
Causal mean and scale tensors, potentially stabilized with global statistics
|
|
71
|
+
"""
|
|
72
|
+
# Assert that dim is -1 (last dimension)
|
|
73
|
+
assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)"
|
|
74
|
+
|
|
75
|
+
with torch.no_grad():
|
|
76
|
+
# Apply padding mask to weights
|
|
77
|
+
weights = weights * padding_mask
|
|
78
|
+
|
|
79
|
+
# Try to use higher precision for numerical stability
|
|
80
|
+
try:
|
|
81
|
+
high_precision_data = data.to(torch.float64)
|
|
82
|
+
high_precision_weights = weights.to(torch.float64)
|
|
83
|
+
except TypeError:
|
|
84
|
+
# Fallback for devices that don't support float64
|
|
85
|
+
warnings.warn(
|
|
86
|
+
f"Float64 is not supported by device {data.device}. "
|
|
87
|
+
"Using float32 instead for causal scaler calculations. "
|
|
88
|
+
"This may lead to numerical issues if the data contains extreme values.",
|
|
89
|
+
RuntimeWarning,
|
|
90
|
+
)
|
|
91
|
+
high_precision_data = data.to(torch.float32)
|
|
92
|
+
high_precision_weights = weights.to(torch.float32)
|
|
93
|
+
|
|
94
|
+
# Check if deterministic algorithms are enabled and we're using CUDA.
|
|
95
|
+
# Cumsum operations do not support deterministic mode in CUDA,
|
|
96
|
+
# so we need to disable it for just this section.
|
|
97
|
+
prev_deterministic = torch.are_deterministic_algorithms_enabled()
|
|
98
|
+
if prev_deterministic and data.device.type == "cuda":
|
|
99
|
+
# Disable deterministic algorithms for operations
|
|
100
|
+
torch.use_deterministic_algorithms(False)
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
# Create weighted data
|
|
104
|
+
weighted_data = high_precision_weights * high_precision_data
|
|
105
|
+
|
|
106
|
+
# Compute cumulative sum of weights and weighted data along time dimension
|
|
107
|
+
cum_weights = torch.cumsum(high_precision_weights, dim=dim)
|
|
108
|
+
cum_values = torch.cumsum(weighted_data, dim=dim)
|
|
109
|
+
|
|
110
|
+
# Avoid division by zero for the first time step or when no valid values
|
|
111
|
+
denominator = cum_weights.clamp_min(1.0)
|
|
112
|
+
|
|
113
|
+
# Compute causal means at each time step
|
|
114
|
+
causal_means = cum_values / denominator
|
|
115
|
+
|
|
116
|
+
# For Welford's algorithm, we need to compute the correction term
|
|
117
|
+
# using the difference between the current value and the current mean
|
|
118
|
+
|
|
119
|
+
# Create shifted version of causal means to compute delta efficiently
|
|
120
|
+
# First item in shifted_means will be zero
|
|
121
|
+
shifted_means = torch.zeros_like(causal_means)
|
|
122
|
+
shifted_means[..., 1:] = causal_means[..., :-1]
|
|
123
|
+
|
|
124
|
+
# Compute delta between current data point and previous mean
|
|
125
|
+
# For t=0, this is just the first data point
|
|
126
|
+
delta = high_precision_data - shifted_means
|
|
127
|
+
|
|
128
|
+
# Compute the increment term for Welford's algorithm.
|
|
129
|
+
# This is defined as the product of the delta and the difference between the current data point and the causal mean.
|
|
130
|
+
# This is where we avoid the traditional E[X²] - E[X]² computation
|
|
131
|
+
increment = delta * (high_precision_data - causal_means) * high_precision_weights
|
|
132
|
+
|
|
133
|
+
# The Welford algorithm uses the term m_2, which is the cumulative sum of the increment term.
|
|
134
|
+
# This is an accumulator that helps us compute the second moment (hence m_2) of the distribution.
|
|
135
|
+
# Compute cumulative sum of the increment term
|
|
136
|
+
m_2 = torch.cumsum(increment, dim=dim)
|
|
137
|
+
|
|
138
|
+
# Compute variance according to Welford's algorithm
|
|
139
|
+
if use_bessel_correction:
|
|
140
|
+
causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0)
|
|
141
|
+
else:
|
|
142
|
+
causal_variance = m_2 / denominator
|
|
143
|
+
|
|
144
|
+
# Add minimum scale but keep in high precision for now
|
|
145
|
+
causal_scale = torch.sqrt(causal_variance + minimum_scale)
|
|
146
|
+
|
|
147
|
+
# Apply stabilization with global statistics if requested
|
|
148
|
+
if stabilize_with_global:
|
|
149
|
+
if prefix_length is not None:
|
|
150
|
+
# Create a prefix mask for global statistics computation
|
|
151
|
+
prefix_mask = torch.zeros_like(weights)
|
|
152
|
+
prefix_mask[..., :prefix_length] = 1.0
|
|
153
|
+
|
|
154
|
+
# Apply prefix mask to restrict computation to prefix
|
|
155
|
+
weighted_data = weighted_data * prefix_mask
|
|
156
|
+
weights = weights * prefix_mask
|
|
157
|
+
padding_mask = padding_mask * prefix_mask
|
|
158
|
+
|
|
159
|
+
# Calculate scale factors from the exponent
|
|
160
|
+
scale_factor_min = 10.0 ** (-scale_factor_exponent)
|
|
161
|
+
scale_factor_max = 10.0**scale_factor_exponent
|
|
162
|
+
|
|
163
|
+
global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0)
|
|
164
|
+
global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator
|
|
165
|
+
global_means = torch.nan_to_num(global_means)
|
|
166
|
+
|
|
167
|
+
global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum(
|
|
168
|
+
dim, keepdim=True
|
|
169
|
+
) / global_denominator
|
|
170
|
+
global_scale = torch.sqrt(global_variance + minimum_scale)
|
|
171
|
+
|
|
172
|
+
# Expand global statistics to match the time dimension
|
|
173
|
+
expanded_global_scale = global_scale.expand_as(causal_scale)
|
|
174
|
+
|
|
175
|
+
# Define bounds using scale factors
|
|
176
|
+
min_allowed_scale = expanded_global_scale * scale_factor_min
|
|
177
|
+
max_allowed_scale = expanded_global_scale * scale_factor_max
|
|
178
|
+
|
|
179
|
+
# Clamp the causal scale between min_allowed_scale and max_allowed_scale
|
|
180
|
+
causal_scale = torch.clamp(
|
|
181
|
+
causal_scale,
|
|
182
|
+
min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale),
|
|
183
|
+
max=max_allowed_scale,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Now convert means and scale to original dtype after all numerical operations
|
|
187
|
+
causal_means = causal_means.to(data.dtype)
|
|
188
|
+
causal_scale = causal_scale.to(data.dtype)
|
|
189
|
+
|
|
190
|
+
finally:
|
|
191
|
+
# Restore original deterministic setting if it was changed
|
|
192
|
+
if prev_deterministic and data.device.type == "cuda":
|
|
193
|
+
torch.use_deterministic_algorithms(True)
|
|
194
|
+
|
|
195
|
+
return causal_means, causal_scale
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class CausalPatchStdMeanScaler(Scaler):
|
|
199
|
+
"""
|
|
200
|
+
Causally scales data in patches, where each patch uses statistics computed
|
|
201
|
+
from all data up to and including that patch. Within each patch, all timesteps
|
|
202
|
+
use the same scaling values.
|
|
203
|
+
|
|
204
|
+
This approach provides more stability than per-timestep causal scaling while
|
|
205
|
+
still maintaining the causal property (not using future data).
|
|
206
|
+
|
|
207
|
+
Can optionally stabilize causal statistics using global statistics to prevent
|
|
208
|
+
extreme values, while preserving the causal property.
|
|
209
|
+
|
|
210
|
+
The statistics are computed using Welford's algorithm, which provides better
|
|
211
|
+
numerical stability compared to the direct computation of variance, especially
|
|
212
|
+
when dealing with large values or a large number of data points.
|
|
213
|
+
|
|
214
|
+
Note: This scaler only works with the following constraints:
|
|
215
|
+
- The input must have shape [batch, variates, time_steps]
|
|
216
|
+
- It only operates on the last dimension (-1)
|
|
217
|
+
- The time_steps must be divisible by patch_size
|
|
218
|
+
|
|
219
|
+
Parameters
|
|
220
|
+
----------
|
|
221
|
+
dim
|
|
222
|
+
dimension along which to compute the causal scale. Must be -1 (the last dimension).
|
|
223
|
+
patch_size
|
|
224
|
+
number of timesteps in each patch
|
|
225
|
+
minimum_scale
|
|
226
|
+
default scale that is used for elements that are constantly zero
|
|
227
|
+
along dimension `dim` or for the first patch.
|
|
228
|
+
use_bessel_correction
|
|
229
|
+
whether to use Bessel's correction to get an unbiased estimator
|
|
230
|
+
stabilize_with_global
|
|
231
|
+
whether to use global statistics to stabilize extreme causal statistics
|
|
232
|
+
scale_factor_exponent
|
|
233
|
+
exponent that controls the allowed range of deviation from global scale.
|
|
234
|
+
For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
|
|
235
|
+
With exponent=2.0, the range would be 0.01x to 100x.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
@validated()
|
|
239
|
+
def __init__(
|
|
240
|
+
self,
|
|
241
|
+
dim: int = -1,
|
|
242
|
+
patch_size: int = 32,
|
|
243
|
+
minimum_scale: float = 0.1,
|
|
244
|
+
use_bessel_correction: bool = True,
|
|
245
|
+
stabilize_with_global: bool = False,
|
|
246
|
+
scale_factor_exponent: float = 10.0,
|
|
247
|
+
) -> None:
|
|
248
|
+
super().__init__()
|
|
249
|
+
assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)"
|
|
250
|
+
self.dim = dim
|
|
251
|
+
self.patch_size = patch_size
|
|
252
|
+
self.minimum_scale = minimum_scale
|
|
253
|
+
self.use_bessel_correction = use_bessel_correction
|
|
254
|
+
self.stabilize_with_global = stabilize_with_global
|
|
255
|
+
self.scale_factor_exponent = scale_factor_exponent
|
|
256
|
+
|
|
257
|
+
def __call__( # type: ignore[override]
|
|
258
|
+
self,
|
|
259
|
+
data: torch.Tensor,
|
|
260
|
+
padding_mask: torch.Tensor,
|
|
261
|
+
weights: torch.Tensor,
|
|
262
|
+
prefix_length: int | None = None,
|
|
263
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
264
|
+
assert data.shape == weights.shape, "data and weights must have same shape"
|
|
265
|
+
assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]"
|
|
266
|
+
|
|
267
|
+
with torch.no_grad():
|
|
268
|
+
# Get the number of time steps (last dimension)
|
|
269
|
+
time_steps = data.shape[-1]
|
|
270
|
+
|
|
271
|
+
# Assert that time_steps is divisible by patch_size
|
|
272
|
+
assert time_steps % self.patch_size == 0, (
|
|
273
|
+
f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# First compute causal statistics with optional stabilization
|
|
277
|
+
causal_means, causal_scale = compute_causal_statistics(
|
|
278
|
+
data,
|
|
279
|
+
weights,
|
|
280
|
+
padding_mask,
|
|
281
|
+
-1,
|
|
282
|
+
self.minimum_scale,
|
|
283
|
+
self.use_bessel_correction,
|
|
284
|
+
self.stabilize_with_global,
|
|
285
|
+
self.scale_factor_exponent,
|
|
286
|
+
prefix_length,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Unfold the causal means and scales to get the patches
|
|
290
|
+
means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size)
|
|
291
|
+
scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size)
|
|
292
|
+
|
|
293
|
+
# Get the last element of each patch (the most recent statistic)
|
|
294
|
+
patch_stats_means = means_unfolded[..., -1]
|
|
295
|
+
patch_stats_scales = scales_unfolded[..., -1]
|
|
296
|
+
|
|
297
|
+
# Tile the patch statistics across time dimension using einops.repeat
|
|
298
|
+
# With our fixed [batch, variates, num_patches] shape this is much simpler
|
|
299
|
+
patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size)
|
|
300
|
+
patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size)
|
|
301
|
+
|
|
302
|
+
# Apply normalization
|
|
303
|
+
scaled_data = (data - patch_means) / patch_scales
|
|
304
|
+
|
|
305
|
+
return scaled_data, patch_means, patch_scales
|