autogluon.timeseries 1.4.1b20251115__py3-none-any.whl → 1.5.0b20251221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (82) hide show
  1. autogluon/timeseries/configs/hyperparameter_presets.py +13 -28
  2. autogluon/timeseries/configs/predictor_presets.py +23 -39
  3. autogluon/timeseries/dataset/ts_dataframe.py +32 -34
  4. autogluon/timeseries/learner.py +67 -33
  5. autogluon/timeseries/metrics/__init__.py +4 -4
  6. autogluon/timeseries/metrics/abstract.py +8 -8
  7. autogluon/timeseries/metrics/point.py +9 -9
  8. autogluon/timeseries/metrics/quantile.py +4 -4
  9. autogluon/timeseries/models/__init__.py +2 -1
  10. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -50
  11. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  12. autogluon/timeseries/models/abstract/tunable.py +8 -8
  13. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +30 -26
  14. autogluon/timeseries/models/autogluon_tabular/per_step.py +13 -11
  15. autogluon/timeseries/models/autogluon_tabular/transforms.py +2 -2
  16. autogluon/timeseries/models/chronos/__init__.py +2 -1
  17. autogluon/timeseries/models/chronos/chronos2.py +395 -0
  18. autogluon/timeseries/models/chronos/model.py +30 -25
  19. autogluon/timeseries/models/chronos/utils.py +5 -5
  20. autogluon/timeseries/models/ensemble/__init__.py +17 -10
  21. autogluon/timeseries/models/ensemble/abstract.py +13 -9
  22. autogluon/timeseries/models/ensemble/array_based/__init__.py +2 -2
  23. autogluon/timeseries/models/ensemble/array_based/abstract.py +24 -31
  24. autogluon/timeseries/models/ensemble/array_based/models.py +146 -11
  25. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +2 -0
  26. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +6 -5
  27. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
  28. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +44 -83
  29. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +21 -55
  30. autogluon/timeseries/models/ensemble/ensemble_selection.py +167 -0
  31. autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
  32. autogluon/timeseries/models/ensemble/weighted/abstract.py +7 -3
  33. autogluon/timeseries/models/ensemble/weighted/basic.py +26 -13
  34. autogluon/timeseries/models/ensemble/weighted/greedy.py +21 -144
  35. autogluon/timeseries/models/gluonts/abstract.py +30 -29
  36. autogluon/timeseries/models/gluonts/dataset.py +9 -9
  37. autogluon/timeseries/models/gluonts/models.py +0 -7
  38. autogluon/timeseries/models/local/__init__.py +0 -7
  39. autogluon/timeseries/models/local/abstract_local_model.py +13 -16
  40. autogluon/timeseries/models/local/naive.py +2 -2
  41. autogluon/timeseries/models/local/npts.py +7 -1
  42. autogluon/timeseries/models/local/statsforecast.py +13 -13
  43. autogluon/timeseries/models/multi_window/multi_window_model.py +38 -23
  44. autogluon/timeseries/models/registry.py +3 -4
  45. autogluon/timeseries/models/toto/_internal/backbone/attention.py +3 -4
  46. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +6 -6
  47. autogluon/timeseries/models/toto/_internal/backbone/rope.py +4 -9
  48. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  49. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +2 -3
  50. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +10 -10
  51. autogluon/timeseries/models/toto/_internal/dataset.py +2 -2
  52. autogluon/timeseries/models/toto/_internal/forecaster.py +8 -8
  53. autogluon/timeseries/models/toto/dataloader.py +4 -4
  54. autogluon/timeseries/models/toto/hf_pretrained_model.py +97 -16
  55. autogluon/timeseries/models/toto/model.py +30 -17
  56. autogluon/timeseries/predictor.py +531 -136
  57. autogluon/timeseries/regressor.py +18 -23
  58. autogluon/timeseries/splitter.py +2 -2
  59. autogluon/timeseries/trainer/ensemble_composer.py +323 -129
  60. autogluon/timeseries/trainer/model_set_builder.py +9 -9
  61. autogluon/timeseries/trainer/prediction_cache.py +16 -16
  62. autogluon/timeseries/trainer/trainer.py +235 -145
  63. autogluon/timeseries/trainer/utils.py +3 -4
  64. autogluon/timeseries/transforms/covariate_scaler.py +7 -7
  65. autogluon/timeseries/transforms/target_scaler.py +8 -8
  66. autogluon/timeseries/utils/constants.py +10 -0
  67. autogluon/timeseries/utils/datetime/lags.py +1 -3
  68. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  69. autogluon/timeseries/utils/features.py +22 -9
  70. autogluon/timeseries/utils/forecast.py +1 -2
  71. autogluon/timeseries/utils/timer.py +173 -0
  72. autogluon/timeseries/version.py +1 -1
  73. {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/METADATA +23 -21
  74. autogluon_timeseries-1.5.0b20251221.dist-info/RECORD +103 -0
  75. autogluon_timeseries-1.4.1b20251115.dist-info/RECORD +0 -96
  76. /autogluon.timeseries-1.4.1b20251115-py3.9-nspkg.pth → /autogluon.timeseries-1.5.0b20251221-py3.11-nspkg.pth +0 -0
  77. {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/WHEEL +0 -0
  78. {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/licenses/LICENSE +0 -0
  79. {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/licenses/NOTICE +0 -0
  80. {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/namespace_packages.txt +0 -0
  81. {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/top_level.txt +0 -0
  82. {autogluon_timeseries-1.4.1b20251115.dist-info → autogluon_timeseries-1.5.0b20251221.dist-info}/zip-safe +0 -0
@@ -0,0 +1,342 @@
1
+ # Source: https://github.com/lucidrains/rotary-embedding-torch
2
+ # MIT License
3
+ #
4
+ # Copyright (c) 2021 Phil Wang
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ from __future__ import annotations
25
+
26
+ from math import pi
27
+ from typing import Literal
28
+
29
+ import torch
30
+ from einops import rearrange, repeat
31
+ from torch import Tensor, broadcast_tensors, einsum, is_tensor, nn, tensor
32
+ from torch.amp import autocast
33
+ from torch.nn import Module
34
+
35
+ # helper functions
36
+
37
+
38
+ def exists(val):
39
+ return val is not None
40
+
41
+
42
+ def default(val, d):
43
+ return val if exists(val) else d
44
+
45
+
46
+ def slice_at_dim(t, dim_slice: slice, *, dim):
47
+ dim += t.ndim if dim < 0 else 0
48
+ colons = [slice(None)] * t.ndim
49
+ colons[dim] = dim_slice
50
+ return t[tuple(colons)]
51
+
52
+
53
+ # rotary embedding helper functions
54
+
55
+
56
+ def rotate_half(x):
57
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
58
+ x1, x2 = x.unbind(dim=-1)
59
+ x = torch.stack((-x2, x1), dim=-1)
60
+ return rearrange(x, "... d r -> ... (d r)")
61
+
62
+
63
+ @autocast("cuda", enabled=False)
64
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None):
65
+ dtype = t.dtype
66
+
67
+ if not exists(freqs_seq_dim):
68
+ if freqs.ndim == 2 or t.ndim == 3:
69
+ freqs_seq_dim = 0
70
+
71
+ if t.ndim == 3 or exists(freqs_seq_dim):
72
+ seq_len = t.shape[seq_dim]
73
+ freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
74
+
75
+ rot_dim = freqs.shape[-1]
76
+ end_index = start_index + rot_dim
77
+
78
+ assert rot_dim <= t.shape[-1], (
79
+ f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
80
+ )
81
+
82
+ # Split t into three parts: left, middle (to be transformed), and right
83
+ t_left = t[..., :start_index]
84
+ t_middle = t[..., start_index:end_index]
85
+ t_right = t[..., end_index:]
86
+
87
+ # Apply rotary embeddings without modifying t in place
88
+ t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
89
+
90
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
91
+
92
+ return out.type(dtype)
93
+
94
+
95
+ # learned rotation helpers
96
+
97
+
98
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
99
+ if exists(freq_ranges):
100
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
101
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
102
+
103
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
104
+ return apply_rotary_emb(rotations, t, start_index=start_index)
105
+
106
+
107
+ # classes
108
+
109
+
110
+ class RotaryEmbedding(Module):
111
+ def __init__(
112
+ self,
113
+ dim,
114
+ custom_freqs: Tensor | None = None,
115
+ freqs_for: Literal["lang", "pixel", "constant"] = "lang",
116
+ theta=10000,
117
+ max_freq=10,
118
+ num_freqs=1,
119
+ learned_freq=False,
120
+ use_xpos=False,
121
+ xpos_scale_base=512,
122
+ interpolate_factor=1.0,
123
+ theta_rescale_factor=1.0,
124
+ seq_before_head_dim=False,
125
+ cache_if_possible=True,
126
+ cache_max_seq_len=8192,
127
+ ):
128
+ super().__init__()
129
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
130
+ # has some connection to NTK literature
131
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
132
+
133
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
134
+
135
+ self.freqs_for = freqs_for
136
+
137
+ if exists(custom_freqs):
138
+ freqs = custom_freqs
139
+ elif freqs_for == "lang":
140
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
141
+ elif freqs_for == "pixel":
142
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
143
+ elif freqs_for == "constant":
144
+ freqs = torch.ones(num_freqs).float()
145
+
146
+ self.cache_if_possible = cache_if_possible
147
+ self.cache_max_seq_len = cache_max_seq_len
148
+
149
+ self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
150
+ self.cached_freqs_seq_len = 0
151
+
152
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
153
+
154
+ self.learned_freq = learned_freq
155
+
156
+ # dummy for device
157
+
158
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
159
+
160
+ # default sequence dimension
161
+
162
+ self.seq_before_head_dim = seq_before_head_dim
163
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
164
+
165
+ # interpolation factors
166
+
167
+ assert interpolate_factor >= 1.0
168
+ self.interpolate_factor = interpolate_factor
169
+
170
+ # xpos
171
+
172
+ self.use_xpos = use_xpos
173
+
174
+ if not use_xpos:
175
+ return
176
+
177
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
178
+ self.scale_base = xpos_scale_base
179
+
180
+ self.register_buffer("scale", scale, persistent=False)
181
+ self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
182
+ self.cached_scales_seq_len = 0
183
+
184
+ # add apply_rotary_emb as static method
185
+
186
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
187
+
188
+ @property
189
+ def device(self):
190
+ return self.dummy.device
191
+
192
+ def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
193
+ device = default(device, self.device)
194
+ dtype = default(dtype, self.cached_freqs.dtype)
195
+
196
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
197
+
198
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
199
+ seq_dim = default(seq_dim, self.default_seq_dim)
200
+
201
+ assert not self.use_xpos or exists(scale), (
202
+ "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
203
+ )
204
+
205
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
206
+
207
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
208
+
209
+ freqs = self.forward(seq, seq_len=seq_len, offset=offset)
210
+
211
+ if seq_dim == -3:
212
+ freqs = rearrange(freqs, "n d -> n 1 d")
213
+
214
+ return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
215
+
216
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
217
+ dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
218
+
219
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
220
+ assert q_len <= k_len
221
+
222
+ q_scale = k_scale = 1.0
223
+
224
+ if self.use_xpos:
225
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
226
+
227
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
228
+ k_scale = self.get_scale(seq).type(dtype)
229
+
230
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
231
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
232
+
233
+ rotated_q = rotated_q.type(q.dtype)
234
+ rotated_k = rotated_k.type(k.dtype)
235
+
236
+ return rotated_q, rotated_k
237
+
238
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
239
+ seq_dim = default(seq_dim, self.default_seq_dim)
240
+
241
+ assert self.use_xpos
242
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
243
+
244
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
245
+
246
+ freqs = self.forward(seq, seq_len=seq_len)
247
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
248
+
249
+ if seq_dim == -3:
250
+ freqs = rearrange(freqs, "n d -> n 1 d")
251
+ scale = rearrange(scale, "n d -> n 1 d")
252
+
253
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
254
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
255
+
256
+ rotated_q = rotated_q.type(q.dtype)
257
+ rotated_k = rotated_k.type(k.dtype)
258
+
259
+ return rotated_q, rotated_k
260
+
261
+ def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
262
+ assert self.use_xpos
263
+
264
+ should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
265
+
266
+ if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len:
267
+ return self.cached_scales[offset : (offset + seq_len)]
268
+
269
+ scale = 1.0
270
+ if self.use_xpos:
271
+ power = (t - len(t) // 2) / self.scale_base
272
+ scale = self.scale ** rearrange(power, "n -> n 1")
273
+ scale = repeat(scale, "n d -> n (d r)", r=2)
274
+
275
+ if should_cache and offset == 0:
276
+ self.cached_scales[:seq_len] = scale.detach()
277
+ self.cached_scales_seq_len = seq_len
278
+
279
+ return scale
280
+
281
+ def get_axial_freqs(self, *dims, offsets: (tuple[int | float, ...] | Tensor | None) = None):
282
+ Colon = slice(None)
283
+ all_freqs = []
284
+
285
+ # handle offset
286
+
287
+ if exists(offsets):
288
+ if not is_tensor(offsets):
289
+ offsets = tensor(offsets)
290
+
291
+ assert len(offsets) == len(dims)
292
+
293
+ # get frequencies for each axis
294
+
295
+ for ind, dim in enumerate(dims):
296
+ offset = 0
297
+ if exists(offsets):
298
+ offset = offsets[ind]
299
+
300
+ if self.freqs_for == "pixel":
301
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
302
+ else:
303
+ pos = torch.arange(dim, device=self.device)
304
+
305
+ pos = pos + offset
306
+
307
+ freqs = self.forward(pos, seq_len=dim)
308
+
309
+ all_axis = [None] * len(dims)
310
+ all_axis[ind] = Colon
311
+
312
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
313
+ all_freqs.append(freqs[new_axis_slice])
314
+
315
+ # concat all freqs
316
+
317
+ all_freqs = broadcast_tensors(*all_freqs)
318
+ return torch.cat(all_freqs, dim=-1)
319
+
320
+ @autocast("cuda", enabled=False)
321
+ def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
322
+ should_cache = (
323
+ self.cache_if_possible
324
+ and not self.learned_freq
325
+ and exists(seq_len)
326
+ and self.freqs_for != "pixel"
327
+ and (offset + seq_len) <= self.cache_max_seq_len
328
+ )
329
+
330
+ if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len:
331
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
332
+
333
+ freqs = self.freqs
334
+
335
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
336
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
337
+
338
+ if should_cache and offset == 0:
339
+ self.cached_freqs[:seq_len] = freqs.detach()
340
+ self.cached_freqs_seq_len = seq_len
341
+
342
+ return freqs
@@ -4,7 +4,6 @@
4
4
  # Copyright 2025 Datadog, Inc.
5
5
 
6
6
  import warnings
7
- from typing import Optional
8
7
 
9
8
  import torch
10
9
  from einops import repeat
@@ -21,7 +20,7 @@ def compute_causal_statistics(
21
20
  use_bessel_correction: bool = True,
22
21
  stabilize_with_global: bool = False,
23
22
  scale_factor_exponent: float = 10.0,
24
- prefix_length: Optional[int] = None,
23
+ prefix_length: int | None = None,
25
24
  ) -> tuple[torch.Tensor, torch.Tensor]:
26
25
  """
27
26
  Compute causal mean and scale statistics along a specified dimension using
@@ -260,7 +259,7 @@ class CausalPatchStdMeanScaler(Scaler):
260
259
  data: torch.Tensor,
261
260
  padding_mask: torch.Tensor,
262
261
  weights: torch.Tensor,
263
- prefix_length: Optional[int] = None,
262
+ prefix_length: int | None = None,
264
263
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
265
264
  assert data.shape == weights.shape, "data and weights must have same shape"
266
265
  assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]"
@@ -3,12 +3,11 @@
3
3
  # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
4
  # Copyright 2025 Datadog, Inc.
5
5
 
6
- from typing import Optional, Union, cast
6
+ from typing import cast
7
7
 
8
8
  import torch
9
9
  import torch.nn.functional as F
10
10
  from einops import rearrange
11
- from rotary_embedding_torch import RotaryEmbedding
12
11
 
13
12
  from .attention import (
14
13
  AttentionAxis,
@@ -18,6 +17,7 @@ from .attention import (
18
17
  )
19
18
  from .kvcache import KVCache
20
19
  from .rope import TimeAwareRotaryEmbedding
20
+ from .rotary_embedding_torch import RotaryEmbedding
21
21
 
22
22
 
23
23
  class SwiGLU(torch.nn.Module):
@@ -37,7 +37,7 @@ class RMSNorm(torch.nn.Module):
37
37
  super(RMSNorm, self).__init__()
38
38
  self.eps = eps
39
39
  if include_weight:
40
- self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter(torch.ones(dim))
40
+ self.scale: torch.nn.Parameter | None = torch.nn.Parameter(torch.ones(dim))
41
41
  else:
42
42
  self.scale = None
43
43
 
@@ -86,7 +86,7 @@ class TransformerLayer(torch.nn.Module):
86
86
  num_heads: int,
87
87
  mlp_hidden_dim: int,
88
88
  dropout: float,
89
- rotary_emb: Optional[RotaryEmbedding] = None,
89
+ rotary_emb: RotaryEmbedding | None = None,
90
90
  attention_axis: AttentionAxis = AttentionAxis.TIME,
91
91
  RMS_norm: bool = True,
92
92
  use_memory_efficient_attention: bool = True,
@@ -99,8 +99,8 @@ class TransformerLayer(torch.nn.Module):
99
99
  self.attention_axis = attention_axis
100
100
 
101
101
  if RMS_norm:
102
- self.norm1: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim)
103
- self.norm2: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim)
102
+ self.norm1: RMSNorm | torch.nn.LayerNorm = RMSNorm(embed_dim)
103
+ self.norm2: RMSNorm | torch.nn.LayerNorm = RMSNorm(embed_dim)
104
104
 
105
105
  else:
106
106
  self.norm1 = torch.nn.LayerNorm(embed_dim)
@@ -138,8 +138,8 @@ class TransformerLayer(torch.nn.Module):
138
138
  self,
139
139
  layer_idx: int,
140
140
  inputs: torch.Tensor,
141
- attention_mask: Optional[torch.Tensor] = None,
142
- kv_cache: Optional[KVCache] = None,
141
+ attention_mask: torch.Tensor | None = None,
142
+ kv_cache: KVCache | None = None,
143
143
  ) -> torch.Tensor:
144
144
  pre_norm_1 = self.norm1(inputs)
145
145
  hidden_state = inputs + self.attention(layer_idx, pre_norm_1, attention_mask, kv_cache).contiguous()
@@ -221,7 +221,7 @@ class Transformer(torch.nn.Module):
221
221
  self,
222
222
  num_heads: int,
223
223
  dtype: torch.dtype,
224
- id_mask: Optional[torch.Tensor] = None,
224
+ id_mask: torch.Tensor | None = None,
225
225
  ) -> torch.Tensor:
226
226
  """
227
227
  Unified method to create and process space-wise masks.
@@ -302,7 +302,7 @@ class Transformer(torch.nn.Module):
302
302
  self,
303
303
  inputs: torch.Tensor,
304
304
  id_mask: torch.Tensor,
305
- kv_cache: Optional[KVCache] = None,
305
+ kv_cache: KVCache | None = None,
306
306
  ) -> torch.Tensor:
307
307
  batch, _, seq_len, _ = inputs.shape
308
308
  # Get the sequence length by looking up a timewise layer in the kv cache.
@@ -4,7 +4,7 @@
4
4
  # Copyright 2025 Datadog, Inc.
5
5
 
6
6
  from functools import reduce
7
- from typing import NamedTuple, Union
7
+ from typing import NamedTuple
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
@@ -142,7 +142,7 @@ def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) -> torch.T
142
142
  return torch.where(is_extreme_value(t), torch.tensor(replacement, dtype=t.dtype, device=t.device), t)
143
143
 
144
144
 
145
- def freq_to_seconds(freq: Union[str, pd.offsets.BaseOffset]) -> float:
145
+ def freq_to_seconds(freq: str | pd.offsets.BaseOffset) -> float:
146
146
  # Modified from: https://github.com/DataDog/toto/blob/846d599f4b8d377db3088d5cd1a736d050cef5ac/toto/inference/gluonts_predictor.py#L58
147
147
  if isinstance(freq, str):
148
148
  freq = pd.tseries.frequencies.to_offset(freq)
@@ -4,7 +4,7 @@
4
4
  # Copyright 2025 Datadog, Inc.
5
5
 
6
6
  from dataclasses import dataclass
7
- from typing import Optional, Union, cast
7
+ from typing import cast
8
8
 
9
9
  import numpy as np
10
10
  import torch
@@ -24,9 +24,9 @@ from .dataset import (
24
24
  @dataclass(frozen=True)
25
25
  class Forecast:
26
26
  mean: torch.Tensor
27
- samples: Optional[torch.Tensor]
27
+ samples: torch.Tensor | None
28
28
 
29
- def quantile(self, q: Union[float, torch.Tensor]) -> torch.Tensor:
29
+ def quantile(self, q: float | torch.Tensor) -> torch.Tensor:
30
30
  """
31
31
  Compute the quantile of the forecast samples.
32
32
  """
@@ -88,7 +88,7 @@ class TotoForecaster:
88
88
  self,
89
89
  inputs: MaskedTimeseries,
90
90
  prediction_length: int,
91
- num_samples: Optional[int] = None,
91
+ num_samples: int | None = None,
92
92
  samples_per_batch: int = 10,
93
93
  use_kv_cache: bool = True,
94
94
  ) -> Forecast:
@@ -187,8 +187,8 @@ class TotoForecaster:
187
187
  prediction_length: int,
188
188
  timestamp_seconds: torch.Tensor,
189
189
  time_interval_seconds: torch.Tensor,
190
- input_padding_mask: Optional[torch.Tensor] = None,
191
- id_mask: Optional[torch.Tensor] = None,
190
+ input_padding_mask: torch.Tensor | None = None,
191
+ id_mask: torch.Tensor | None = None,
192
192
  use_kv_cache: bool = False,
193
193
  ) -> torch.Tensor:
194
194
  """
@@ -262,8 +262,8 @@ class TotoForecaster:
262
262
  num_samples: int,
263
263
  timestamp_seconds: torch.Tensor,
264
264
  time_interval_seconds: torch.Tensor,
265
- input_padding_mask: Optional[torch.Tensor] = None,
266
- id_mask: Optional[torch.Tensor] = None,
265
+ input_padding_mask: torch.Tensor | None = None,
266
+ id_mask: torch.Tensor | None = None,
267
267
  sampling_batch_size: int = 10,
268
268
  use_kv_cache: bool = False,
269
269
  ) -> torch.Tensor:
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import time
3
- from typing import Any, Callable, Iterator, Optional, Union
3
+ from typing import Any, Callable, Iterator
4
4
 
5
5
  import numpy as np
6
6
  import torch
@@ -44,9 +44,9 @@ class TotoDataLoader:
44
44
  def __init__(
45
45
  self,
46
46
  dataset: TotoInferenceDataset,
47
- freq: Optional[str] = None,
47
+ freq: str | None = None,
48
48
  batch_size: int = 1,
49
- time_limit: Optional[Union[int, float]] = None,
49
+ time_limit: int | float | None = None,
50
50
  device: Any = None,
51
51
  ):
52
52
  self.device = torch.device(device)
@@ -60,7 +60,7 @@ class TotoDataLoader:
60
60
  self.freq: str = freq or dataset.freq or "h"
61
61
 
62
62
  @staticmethod
63
- def _get_timeout_callback(seconds: Optional[float]) -> Callable:
63
+ def _get_timeout_callback(seconds: float | None) -> Callable:
64
64
  start_time = time.monotonic()
65
65
 
66
66
  def callback() -> None: