autogluon.timeseries 1.4.1b20251115__py3-none-any.whl → 1.5.0b20251221__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/hyperparameter_presets.py +13 -28
- autogluon/timeseries/configs/predictor_presets.py +23 -39
- autogluon/timeseries/dataset/ts_dataframe.py +32 -34
- autogluon/timeseries/learner.py +67 -33
- autogluon/timeseries/metrics/__init__.py +4 -4
- autogluon/timeseries/metrics/abstract.py +8 -8
- autogluon/timeseries/metrics/point.py +9 -9
- autogluon/timeseries/metrics/quantile.py +4 -4
- autogluon/timeseries/models/__init__.py +2 -1
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -50
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +8 -8
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +30 -26
- autogluon/timeseries/models/autogluon_tabular/per_step.py +13 -11
- autogluon/timeseries/models/autogluon_tabular/transforms.py +2 -2
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +395 -0
- autogluon/timeseries/models/chronos/model.py +30 -25
- autogluon/timeseries/models/chronos/utils.py +5 -5
- autogluon/timeseries/models/ensemble/__init__.py +17 -10
- autogluon/timeseries/models/ensemble/abstract.py +13 -9
- autogluon/timeseries/models/ensemble/array_based/__init__.py +2 -2
- autogluon/timeseries/models/ensemble/array_based/abstract.py +24 -31
- autogluon/timeseries/models/ensemble/array_based/models.py +146 -11
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +2 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +6 -5
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +44 -83
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +21 -55
- autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
- autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +7 -3
- autogluon/timeseries/models/ensemble/weighted/basic.py +26 -13
- autogluon/timeseries/models/ensemble/weighted/greedy.py +21 -144
- autogluon/timeseries/models/gluonts/abstract.py +30 -29
- autogluon/timeseries/models/gluonts/dataset.py +9 -9
- autogluon/timeseries/models/gluonts/models.py +0 -7
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +13 -16
- autogluon/timeseries/models/local/naive.py +2 -2
- autogluon/timeseries/models/local/npts.py +7 -1
- autogluon/timeseries/models/local/statsforecast.py +13 -13
- autogluon/timeseries/models/multi_window/multi_window_model.py +38 -23
- autogluon/timeseries/models/registry.py +3 -4
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +3 -4
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +6 -6
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +4 -9
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +2 -3
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +10 -10
- autogluon/timeseries/models/toto/_internal/dataset.py +2 -2
- autogluon/timeseries/models/toto/_internal/forecaster.py +8 -8
- autogluon/timeseries/models/toto/dataloader.py +4 -4
- autogluon/timeseries/models/toto/hf_pretrained_model.py +97 -16
- autogluon/timeseries/models/toto/model.py +30 -17
- autogluon/timeseries/predictor.py +531 -136
- autogluon/timeseries/regressor.py +18 -23
- autogluon/timeseries/splitter.py +2 -2
- autogluon/timeseries/trainer/ensemble_composer.py +323 -129
- autogluon/timeseries/trainer/model_set_builder.py +9 -9
- autogluon/timeseries/trainer/prediction_cache.py +16 -16
- autogluon/timeseries/trainer/trainer.py +235 -145
- autogluon/timeseries/trainer/utils.py +3 -4
- autogluon/timeseries/transforms/covariate_scaler.py +7 -7
- autogluon/timeseries/transforms/target_scaler.py +8 -8
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +1 -3
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/features.py +22 -9
- autogluon/timeseries/utils/forecast.py +1 -2
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/version.py +1 -1
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/METADATA +23 -21
- autogluon_timeseries-1.5.0b20251221.dist-info/RECORD +103 -0
- autogluon_timeseries-1.4.1b20251115.dist-info/RECORD +0 -96
- /autogluon.timeseries-1.4.1b20251115-py3.9-nspkg.pth → /autogluon.timeseries-1.5.0b20251221-py3.11-nspkg.pth +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/WHEEL +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/licenses/LICENSE +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/licenses/NOTICE +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/namespace_packages.txt +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/top_level.txt +0 -0
- {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
# Source: https://github.com/lucidrains/rotary-embedding-torch
|
|
2
|
+
# MIT License
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2021 Phil Wang
|
|
5
|
+
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from math import pi
|
|
27
|
+
from typing import Literal
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
from einops import rearrange, repeat
|
|
31
|
+
from torch import Tensor, broadcast_tensors, einsum, is_tensor, nn, tensor
|
|
32
|
+
from torch.amp import autocast
|
|
33
|
+
from torch.nn import Module
|
|
34
|
+
|
|
35
|
+
# helper functions
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def exists(val):
|
|
39
|
+
return val is not None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def default(val, d):
|
|
43
|
+
return val if exists(val) else d
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def slice_at_dim(t, dim_slice: slice, *, dim):
|
|
47
|
+
dim += t.ndim if dim < 0 else 0
|
|
48
|
+
colons = [slice(None)] * t.ndim
|
|
49
|
+
colons[dim] = dim_slice
|
|
50
|
+
return t[tuple(colons)]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# rotary embedding helper functions
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def rotate_half(x):
|
|
57
|
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
|
58
|
+
x1, x2 = x.unbind(dim=-1)
|
|
59
|
+
x = torch.stack((-x2, x1), dim=-1)
|
|
60
|
+
return rearrange(x, "... d r -> ... (d r)")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@autocast("cuda", enabled=False)
|
|
64
|
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None):
|
|
65
|
+
dtype = t.dtype
|
|
66
|
+
|
|
67
|
+
if not exists(freqs_seq_dim):
|
|
68
|
+
if freqs.ndim == 2 or t.ndim == 3:
|
|
69
|
+
freqs_seq_dim = 0
|
|
70
|
+
|
|
71
|
+
if t.ndim == 3 or exists(freqs_seq_dim):
|
|
72
|
+
seq_len = t.shape[seq_dim]
|
|
73
|
+
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
|
|
74
|
+
|
|
75
|
+
rot_dim = freqs.shape[-1]
|
|
76
|
+
end_index = start_index + rot_dim
|
|
77
|
+
|
|
78
|
+
assert rot_dim <= t.shape[-1], (
|
|
79
|
+
f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Split t into three parts: left, middle (to be transformed), and right
|
|
83
|
+
t_left = t[..., :start_index]
|
|
84
|
+
t_middle = t[..., start_index:end_index]
|
|
85
|
+
t_right = t[..., end_index:]
|
|
86
|
+
|
|
87
|
+
# Apply rotary embeddings without modifying t in place
|
|
88
|
+
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
|
89
|
+
|
|
90
|
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
|
91
|
+
|
|
92
|
+
return out.type(dtype)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# learned rotation helpers
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
|
99
|
+
if exists(freq_ranges):
|
|
100
|
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
|
101
|
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
|
102
|
+
|
|
103
|
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
|
104
|
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# classes
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RotaryEmbedding(Module):
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
dim,
|
|
114
|
+
custom_freqs: Tensor | None = None,
|
|
115
|
+
freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
|
116
|
+
theta=10000,
|
|
117
|
+
max_freq=10,
|
|
118
|
+
num_freqs=1,
|
|
119
|
+
learned_freq=False,
|
|
120
|
+
use_xpos=False,
|
|
121
|
+
xpos_scale_base=512,
|
|
122
|
+
interpolate_factor=1.0,
|
|
123
|
+
theta_rescale_factor=1.0,
|
|
124
|
+
seq_before_head_dim=False,
|
|
125
|
+
cache_if_possible=True,
|
|
126
|
+
cache_max_seq_len=8192,
|
|
127
|
+
):
|
|
128
|
+
super().__init__()
|
|
129
|
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
130
|
+
# has some connection to NTK literature
|
|
131
|
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
132
|
+
|
|
133
|
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
134
|
+
|
|
135
|
+
self.freqs_for = freqs_for
|
|
136
|
+
|
|
137
|
+
if exists(custom_freqs):
|
|
138
|
+
freqs = custom_freqs
|
|
139
|
+
elif freqs_for == "lang":
|
|
140
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
141
|
+
elif freqs_for == "pixel":
|
|
142
|
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
|
143
|
+
elif freqs_for == "constant":
|
|
144
|
+
freqs = torch.ones(num_freqs).float()
|
|
145
|
+
|
|
146
|
+
self.cache_if_possible = cache_if_possible
|
|
147
|
+
self.cache_max_seq_len = cache_max_seq_len
|
|
148
|
+
|
|
149
|
+
self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
|
150
|
+
self.cached_freqs_seq_len = 0
|
|
151
|
+
|
|
152
|
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
|
153
|
+
|
|
154
|
+
self.learned_freq = learned_freq
|
|
155
|
+
|
|
156
|
+
# dummy for device
|
|
157
|
+
|
|
158
|
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
|
159
|
+
|
|
160
|
+
# default sequence dimension
|
|
161
|
+
|
|
162
|
+
self.seq_before_head_dim = seq_before_head_dim
|
|
163
|
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
|
164
|
+
|
|
165
|
+
# interpolation factors
|
|
166
|
+
|
|
167
|
+
assert interpolate_factor >= 1.0
|
|
168
|
+
self.interpolate_factor = interpolate_factor
|
|
169
|
+
|
|
170
|
+
# xpos
|
|
171
|
+
|
|
172
|
+
self.use_xpos = use_xpos
|
|
173
|
+
|
|
174
|
+
if not use_xpos:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
178
|
+
self.scale_base = xpos_scale_base
|
|
179
|
+
|
|
180
|
+
self.register_buffer("scale", scale, persistent=False)
|
|
181
|
+
self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
|
182
|
+
self.cached_scales_seq_len = 0
|
|
183
|
+
|
|
184
|
+
# add apply_rotary_emb as static method
|
|
185
|
+
|
|
186
|
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def device(self):
|
|
190
|
+
return self.dummy.device
|
|
191
|
+
|
|
192
|
+
def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
|
|
193
|
+
device = default(device, self.device)
|
|
194
|
+
dtype = default(dtype, self.cached_freqs.dtype)
|
|
195
|
+
|
|
196
|
+
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
|
197
|
+
|
|
198
|
+
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
|
|
199
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
200
|
+
|
|
201
|
+
assert not self.use_xpos or exists(scale), (
|
|
202
|
+
"you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
|
206
|
+
|
|
207
|
+
seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
|
|
208
|
+
|
|
209
|
+
freqs = self.forward(seq, seq_len=seq_len, offset=offset)
|
|
210
|
+
|
|
211
|
+
if seq_dim == -3:
|
|
212
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
213
|
+
|
|
214
|
+
return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
|
|
215
|
+
|
|
216
|
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
|
217
|
+
dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
|
|
218
|
+
|
|
219
|
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
|
220
|
+
assert q_len <= k_len
|
|
221
|
+
|
|
222
|
+
q_scale = k_scale = 1.0
|
|
223
|
+
|
|
224
|
+
if self.use_xpos:
|
|
225
|
+
seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
|
|
226
|
+
|
|
227
|
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
|
228
|
+
k_scale = self.get_scale(seq).type(dtype)
|
|
229
|
+
|
|
230
|
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
|
|
231
|
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
|
|
232
|
+
|
|
233
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
234
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
235
|
+
|
|
236
|
+
return rotated_q, rotated_k
|
|
237
|
+
|
|
238
|
+
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
|
239
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
240
|
+
|
|
241
|
+
assert self.use_xpos
|
|
242
|
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
|
243
|
+
|
|
244
|
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
|
245
|
+
|
|
246
|
+
freqs = self.forward(seq, seq_len=seq_len)
|
|
247
|
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
|
248
|
+
|
|
249
|
+
if seq_dim == -3:
|
|
250
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
251
|
+
scale = rearrange(scale, "n d -> n 1 d")
|
|
252
|
+
|
|
253
|
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
|
254
|
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
|
255
|
+
|
|
256
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
257
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
258
|
+
|
|
259
|
+
return rotated_q, rotated_k
|
|
260
|
+
|
|
261
|
+
def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
262
|
+
assert self.use_xpos
|
|
263
|
+
|
|
264
|
+
should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
|
|
265
|
+
|
|
266
|
+
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len:
|
|
267
|
+
return self.cached_scales[offset : (offset + seq_len)]
|
|
268
|
+
|
|
269
|
+
scale = 1.0
|
|
270
|
+
if self.use_xpos:
|
|
271
|
+
power = (t - len(t) // 2) / self.scale_base
|
|
272
|
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
|
273
|
+
scale = repeat(scale, "n d -> n (d r)", r=2)
|
|
274
|
+
|
|
275
|
+
if should_cache and offset == 0:
|
|
276
|
+
self.cached_scales[:seq_len] = scale.detach()
|
|
277
|
+
self.cached_scales_seq_len = seq_len
|
|
278
|
+
|
|
279
|
+
return scale
|
|
280
|
+
|
|
281
|
+
def get_axial_freqs(self, *dims, offsets: (tuple[int | float, ...] | Tensor | None) = None):
|
|
282
|
+
Colon = slice(None)
|
|
283
|
+
all_freqs = []
|
|
284
|
+
|
|
285
|
+
# handle offset
|
|
286
|
+
|
|
287
|
+
if exists(offsets):
|
|
288
|
+
if not is_tensor(offsets):
|
|
289
|
+
offsets = tensor(offsets)
|
|
290
|
+
|
|
291
|
+
assert len(offsets) == len(dims)
|
|
292
|
+
|
|
293
|
+
# get frequencies for each axis
|
|
294
|
+
|
|
295
|
+
for ind, dim in enumerate(dims):
|
|
296
|
+
offset = 0
|
|
297
|
+
if exists(offsets):
|
|
298
|
+
offset = offsets[ind]
|
|
299
|
+
|
|
300
|
+
if self.freqs_for == "pixel":
|
|
301
|
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
|
302
|
+
else:
|
|
303
|
+
pos = torch.arange(dim, device=self.device)
|
|
304
|
+
|
|
305
|
+
pos = pos + offset
|
|
306
|
+
|
|
307
|
+
freqs = self.forward(pos, seq_len=dim)
|
|
308
|
+
|
|
309
|
+
all_axis = [None] * len(dims)
|
|
310
|
+
all_axis[ind] = Colon
|
|
311
|
+
|
|
312
|
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
|
313
|
+
all_freqs.append(freqs[new_axis_slice])
|
|
314
|
+
|
|
315
|
+
# concat all freqs
|
|
316
|
+
|
|
317
|
+
all_freqs = broadcast_tensors(*all_freqs)
|
|
318
|
+
return torch.cat(all_freqs, dim=-1)
|
|
319
|
+
|
|
320
|
+
@autocast("cuda", enabled=False)
|
|
321
|
+
def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
322
|
+
should_cache = (
|
|
323
|
+
self.cache_if_possible
|
|
324
|
+
and not self.learned_freq
|
|
325
|
+
and exists(seq_len)
|
|
326
|
+
and self.freqs_for != "pixel"
|
|
327
|
+
and (offset + seq_len) <= self.cache_max_seq_len
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len:
|
|
331
|
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
|
332
|
+
|
|
333
|
+
freqs = self.freqs
|
|
334
|
+
|
|
335
|
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
|
336
|
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
|
337
|
+
|
|
338
|
+
if should_cache and offset == 0:
|
|
339
|
+
self.cached_freqs[:seq_len] = freqs.detach()
|
|
340
|
+
self.cached_freqs_seq_len = seq_len
|
|
341
|
+
|
|
342
|
+
return freqs
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
# Copyright 2025 Datadog, Inc.
|
|
5
5
|
|
|
6
6
|
import warnings
|
|
7
|
-
from typing import Optional
|
|
8
7
|
|
|
9
8
|
import torch
|
|
10
9
|
from einops import repeat
|
|
@@ -21,7 +20,7 @@ def compute_causal_statistics(
|
|
|
21
20
|
use_bessel_correction: bool = True,
|
|
22
21
|
stabilize_with_global: bool = False,
|
|
23
22
|
scale_factor_exponent: float = 10.0,
|
|
24
|
-
prefix_length:
|
|
23
|
+
prefix_length: int | None = None,
|
|
25
24
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
26
25
|
"""
|
|
27
26
|
Compute causal mean and scale statistics along a specified dimension using
|
|
@@ -260,7 +259,7 @@ class CausalPatchStdMeanScaler(Scaler):
|
|
|
260
259
|
data: torch.Tensor,
|
|
261
260
|
padding_mask: torch.Tensor,
|
|
262
261
|
weights: torch.Tensor,
|
|
263
|
-
prefix_length:
|
|
262
|
+
prefix_length: int | None = None,
|
|
264
263
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
265
264
|
assert data.shape == weights.shape, "data and weights must have same shape"
|
|
266
265
|
assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]"
|
|
@@ -3,12 +3,11 @@
|
|
|
3
3
|
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
4
|
# Copyright 2025 Datadog, Inc.
|
|
5
5
|
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import cast
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn.functional as F
|
|
10
10
|
from einops import rearrange
|
|
11
|
-
from rotary_embedding_torch import RotaryEmbedding
|
|
12
11
|
|
|
13
12
|
from .attention import (
|
|
14
13
|
AttentionAxis,
|
|
@@ -18,6 +17,7 @@ from .attention import (
|
|
|
18
17
|
)
|
|
19
18
|
from .kvcache import KVCache
|
|
20
19
|
from .rope import TimeAwareRotaryEmbedding
|
|
20
|
+
from .rotary_embedding_torch import RotaryEmbedding
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class SwiGLU(torch.nn.Module):
|
|
@@ -37,7 +37,7 @@ class RMSNorm(torch.nn.Module):
|
|
|
37
37
|
super(RMSNorm, self).__init__()
|
|
38
38
|
self.eps = eps
|
|
39
39
|
if include_weight:
|
|
40
|
-
self.scale:
|
|
40
|
+
self.scale: torch.nn.Parameter | None = torch.nn.Parameter(torch.ones(dim))
|
|
41
41
|
else:
|
|
42
42
|
self.scale = None
|
|
43
43
|
|
|
@@ -86,7 +86,7 @@ class TransformerLayer(torch.nn.Module):
|
|
|
86
86
|
num_heads: int,
|
|
87
87
|
mlp_hidden_dim: int,
|
|
88
88
|
dropout: float,
|
|
89
|
-
rotary_emb:
|
|
89
|
+
rotary_emb: RotaryEmbedding | None = None,
|
|
90
90
|
attention_axis: AttentionAxis = AttentionAxis.TIME,
|
|
91
91
|
RMS_norm: bool = True,
|
|
92
92
|
use_memory_efficient_attention: bool = True,
|
|
@@ -99,8 +99,8 @@ class TransformerLayer(torch.nn.Module):
|
|
|
99
99
|
self.attention_axis = attention_axis
|
|
100
100
|
|
|
101
101
|
if RMS_norm:
|
|
102
|
-
self.norm1:
|
|
103
|
-
self.norm2:
|
|
102
|
+
self.norm1: RMSNorm | torch.nn.LayerNorm = RMSNorm(embed_dim)
|
|
103
|
+
self.norm2: RMSNorm | torch.nn.LayerNorm = RMSNorm(embed_dim)
|
|
104
104
|
|
|
105
105
|
else:
|
|
106
106
|
self.norm1 = torch.nn.LayerNorm(embed_dim)
|
|
@@ -138,8 +138,8 @@ class TransformerLayer(torch.nn.Module):
|
|
|
138
138
|
self,
|
|
139
139
|
layer_idx: int,
|
|
140
140
|
inputs: torch.Tensor,
|
|
141
|
-
attention_mask:
|
|
142
|
-
kv_cache:
|
|
141
|
+
attention_mask: torch.Tensor | None = None,
|
|
142
|
+
kv_cache: KVCache | None = None,
|
|
143
143
|
) -> torch.Tensor:
|
|
144
144
|
pre_norm_1 = self.norm1(inputs)
|
|
145
145
|
hidden_state = inputs + self.attention(layer_idx, pre_norm_1, attention_mask, kv_cache).contiguous()
|
|
@@ -221,7 +221,7 @@ class Transformer(torch.nn.Module):
|
|
|
221
221
|
self,
|
|
222
222
|
num_heads: int,
|
|
223
223
|
dtype: torch.dtype,
|
|
224
|
-
id_mask:
|
|
224
|
+
id_mask: torch.Tensor | None = None,
|
|
225
225
|
) -> torch.Tensor:
|
|
226
226
|
"""
|
|
227
227
|
Unified method to create and process space-wise masks.
|
|
@@ -302,7 +302,7 @@ class Transformer(torch.nn.Module):
|
|
|
302
302
|
self,
|
|
303
303
|
inputs: torch.Tensor,
|
|
304
304
|
id_mask: torch.Tensor,
|
|
305
|
-
kv_cache:
|
|
305
|
+
kv_cache: KVCache | None = None,
|
|
306
306
|
) -> torch.Tensor:
|
|
307
307
|
batch, _, seq_len, _ = inputs.shape
|
|
308
308
|
# Get the sequence length by looking up a timewise layer in the kv cache.
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# Copyright 2025 Datadog, Inc.
|
|
5
5
|
|
|
6
6
|
from functools import reduce
|
|
7
|
-
from typing import NamedTuple
|
|
7
|
+
from typing import NamedTuple
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import pandas as pd
|
|
@@ -142,7 +142,7 @@ def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) -> torch.T
|
|
|
142
142
|
return torch.where(is_extreme_value(t), torch.tensor(replacement, dtype=t.dtype, device=t.device), t)
|
|
143
143
|
|
|
144
144
|
|
|
145
|
-
def freq_to_seconds(freq:
|
|
145
|
+
def freq_to_seconds(freq: str | pd.offsets.BaseOffset) -> float:
|
|
146
146
|
# Modified from: https://github.com/DataDog/toto/blob/846d599f4b8d377db3088d5cd1a736d050cef5ac/toto/inference/gluonts_predictor.py#L58
|
|
147
147
|
if isinstance(freq, str):
|
|
148
148
|
freq = pd.tseries.frequencies.to_offset(freq)
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
# Copyright 2025 Datadog, Inc.
|
|
5
5
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import cast
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
10
|
import torch
|
|
@@ -24,9 +24,9 @@ from .dataset import (
|
|
|
24
24
|
@dataclass(frozen=True)
|
|
25
25
|
class Forecast:
|
|
26
26
|
mean: torch.Tensor
|
|
27
|
-
samples:
|
|
27
|
+
samples: torch.Tensor | None
|
|
28
28
|
|
|
29
|
-
def quantile(self, q:
|
|
29
|
+
def quantile(self, q: float | torch.Tensor) -> torch.Tensor:
|
|
30
30
|
"""
|
|
31
31
|
Compute the quantile of the forecast samples.
|
|
32
32
|
"""
|
|
@@ -88,7 +88,7 @@ class TotoForecaster:
|
|
|
88
88
|
self,
|
|
89
89
|
inputs: MaskedTimeseries,
|
|
90
90
|
prediction_length: int,
|
|
91
|
-
num_samples:
|
|
91
|
+
num_samples: int | None = None,
|
|
92
92
|
samples_per_batch: int = 10,
|
|
93
93
|
use_kv_cache: bool = True,
|
|
94
94
|
) -> Forecast:
|
|
@@ -187,8 +187,8 @@ class TotoForecaster:
|
|
|
187
187
|
prediction_length: int,
|
|
188
188
|
timestamp_seconds: torch.Tensor,
|
|
189
189
|
time_interval_seconds: torch.Tensor,
|
|
190
|
-
input_padding_mask:
|
|
191
|
-
id_mask:
|
|
190
|
+
input_padding_mask: torch.Tensor | None = None,
|
|
191
|
+
id_mask: torch.Tensor | None = None,
|
|
192
192
|
use_kv_cache: bool = False,
|
|
193
193
|
) -> torch.Tensor:
|
|
194
194
|
"""
|
|
@@ -262,8 +262,8 @@ class TotoForecaster:
|
|
|
262
262
|
num_samples: int,
|
|
263
263
|
timestamp_seconds: torch.Tensor,
|
|
264
264
|
time_interval_seconds: torch.Tensor,
|
|
265
|
-
input_padding_mask:
|
|
266
|
-
id_mask:
|
|
265
|
+
input_padding_mask: torch.Tensor | None = None,
|
|
266
|
+
id_mask: torch.Tensor | None = None,
|
|
267
267
|
sampling_batch_size: int = 10,
|
|
268
268
|
use_kv_cache: bool = False,
|
|
269
269
|
) -> torch.Tensor:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import time
|
|
3
|
-
from typing import Any, Callable, Iterator
|
|
3
|
+
from typing import Any, Callable, Iterator
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import torch
|
|
@@ -44,9 +44,9 @@ class TotoDataLoader:
|
|
|
44
44
|
def __init__(
|
|
45
45
|
self,
|
|
46
46
|
dataset: TotoInferenceDataset,
|
|
47
|
-
freq:
|
|
47
|
+
freq: str | None = None,
|
|
48
48
|
batch_size: int = 1,
|
|
49
|
-
time_limit:
|
|
49
|
+
time_limit: int | float | None = None,
|
|
50
50
|
device: Any = None,
|
|
51
51
|
):
|
|
52
52
|
self.device = torch.device(device)
|
|
@@ -60,7 +60,7 @@ class TotoDataLoader:
|
|
|
60
60
|
self.freq: str = freq or dataset.freq or "h"
|
|
61
61
|
|
|
62
62
|
@staticmethod
|
|
63
|
-
def _get_timeout_callback(seconds:
|
|
63
|
+
def _get_timeout_callback(seconds: float | None) -> Callable:
|
|
64
64
|
start_time = time.monotonic()
|
|
65
65
|
|
|
66
66
|
def callback() -> None:
|