autogluon.timeseries 1.4.1b20250906__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.timeseries might be problematic. Click here for more details.

Files changed (93) hide show
  1. autogluon/timeseries/configs/hyperparameter_presets.py +2 -2
  2. autogluon/timeseries/dataset/ts_dataframe.py +97 -86
  3. autogluon/timeseries/learner.py +68 -35
  4. autogluon/timeseries/metrics/__init__.py +4 -4
  5. autogluon/timeseries/metrics/abstract.py +8 -8
  6. autogluon/timeseries/metrics/point.py +9 -9
  7. autogluon/timeseries/metrics/quantile.py +5 -5
  8. autogluon/timeseries/metrics/utils.py +4 -4
  9. autogluon/timeseries/models/__init__.py +4 -1
  10. autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -39
  11. autogluon/timeseries/models/abstract/model_trial.py +2 -1
  12. autogluon/timeseries/models/abstract/tunable.py +8 -8
  13. autogluon/timeseries/models/autogluon_tabular/mlforecast.py +58 -62
  14. autogluon/timeseries/models/autogluon_tabular/per_step.py +26 -15
  15. autogluon/timeseries/models/autogluon_tabular/transforms.py +11 -9
  16. autogluon/timeseries/models/chronos/__init__.py +2 -1
  17. autogluon/timeseries/models/chronos/chronos2.py +361 -0
  18. autogluon/timeseries/models/chronos/model.py +125 -87
  19. autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +68 -36
  20. autogluon/timeseries/models/ensemble/__init__.py +34 -2
  21. autogluon/timeseries/models/ensemble/abstract.py +5 -42
  22. autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
  23. autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
  24. autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
  25. autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
  26. autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
  27. autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
  28. autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
  29. autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
  30. autogluon/timeseries/models/ensemble/{greedy.py → ensemble_selection.py} +41 -61
  31. autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
  32. autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
  33. autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
  34. autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +6 -16
  35. autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
  36. autogluon/timeseries/models/gluonts/abstract.py +25 -25
  37. autogluon/timeseries/models/gluonts/dataset.py +11 -11
  38. autogluon/timeseries/models/local/__init__.py +0 -7
  39. autogluon/timeseries/models/local/abstract_local_model.py +15 -18
  40. autogluon/timeseries/models/local/naive.py +2 -2
  41. autogluon/timeseries/models/local/npts.py +1 -1
  42. autogluon/timeseries/models/local/statsforecast.py +12 -12
  43. autogluon/timeseries/models/multi_window/multi_window_model.py +39 -24
  44. autogluon/timeseries/models/registry.py +3 -4
  45. autogluon/timeseries/models/toto/__init__.py +3 -0
  46. autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
  47. autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
  48. autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
  49. autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
  50. autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
  51. autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
  52. autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
  53. autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
  54. autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
  55. autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
  56. autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
  57. autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
  58. autogluon/timeseries/models/toto/dataloader.py +108 -0
  59. autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
  60. autogluon/timeseries/models/toto/model.py +236 -0
  61. autogluon/timeseries/predictor.py +301 -103
  62. autogluon/timeseries/regressor.py +27 -30
  63. autogluon/timeseries/splitter.py +3 -27
  64. autogluon/timeseries/trainer/ensemble_composer.py +439 -0
  65. autogluon/timeseries/trainer/model_set_builder.py +9 -9
  66. autogluon/timeseries/trainer/prediction_cache.py +16 -16
  67. autogluon/timeseries/trainer/trainer.py +300 -275
  68. autogluon/timeseries/trainer/utils.py +17 -0
  69. autogluon/timeseries/transforms/covariate_scaler.py +8 -8
  70. autogluon/timeseries/transforms/target_scaler.py +15 -15
  71. autogluon/timeseries/utils/constants.py +10 -0
  72. autogluon/timeseries/utils/datetime/lags.py +1 -3
  73. autogluon/timeseries/utils/datetime/seasonality.py +1 -3
  74. autogluon/timeseries/utils/features.py +18 -14
  75. autogluon/timeseries/utils/forecast.py +6 -7
  76. autogluon/timeseries/utils/timer.py +173 -0
  77. autogluon/timeseries/version.py +1 -1
  78. autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
  79. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +39 -22
  80. autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
  81. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
  82. autogluon/timeseries/evaluator.py +0 -6
  83. autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
  84. autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
  85. autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
  86. autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -580
  87. autogluon.timeseries-1.4.1b20250906-py3.9-nspkg.pth +0 -1
  88. autogluon.timeseries-1.4.1b20250906.dist-info/RECORD +0 -75
  89. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
  90. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
  91. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
  92. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
  93. {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
@@ -0,0 +1,70 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from abc import ABC
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from gluonts.torch.distributions import AffineTransformed
11
+ from gluonts.torch.distributions.studentT import StudentT
12
+
13
+
14
+ class DistributionOutput(ABC, torch.nn.Module):
15
+ pass
16
+
17
+
18
+ class StudentTOutput(DistributionOutput):
19
+ def __init__(self, embed_dim):
20
+ super().__init__()
21
+ self.embed_dim = embed_dim
22
+ self.df = torch.nn.Linear(embed_dim, 1)
23
+ self.loc_proj = torch.nn.Linear(embed_dim, 1)
24
+ self.scale_proj = torch.nn.Linear(embed_dim, 1)
25
+
26
+ def forward(self, inputs, loc=None, scale=None):
27
+ eps = torch.finfo(inputs.dtype).eps
28
+ df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
29
+ base_loc = self.loc_proj(inputs).squeeze(-1)
30
+ base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
31
+
32
+ base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) # type: ignore
33
+
34
+ if loc is not None and scale is not None:
35
+ return AffineTransformed(
36
+ base_dist,
37
+ loc=loc,
38
+ scale=scale,
39
+ )
40
+ return base_dist
41
+
42
+
43
+ class MixtureOfStudentTsOutput(DistributionOutput):
44
+ def __init__(
45
+ self,
46
+ embed_dim,
47
+ k_components,
48
+ ):
49
+ super().__init__()
50
+ self.embed_dim = embed_dim
51
+ self.k_components = k_components
52
+
53
+ self.df = torch.nn.Linear(embed_dim, k_components)
54
+ self.loc_proj = torch.nn.Linear(embed_dim, k_components)
55
+ self.scale_proj = torch.nn.Linear(embed_dim, k_components)
56
+ self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
57
+
58
+ def forward(self, inputs, loc=None, scale=None):
59
+ df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
60
+ loc = self.loc_proj(inputs)
61
+ scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
62
+ logits = self.mixture_weights(inputs)
63
+ probs = F.softmax(logits, dim=-1)
64
+ components = StudentT(df, loc, scale)
65
+ mixture_distribution = torch.distributions.Categorical(probs=probs)
66
+
67
+ return torch.distributions.MixtureSameFamily(
68
+ mixture_distribution,
69
+ components,
70
+ )
@@ -0,0 +1,136 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+
10
+ from .attention import TimeWiseMultiheadAttention
11
+
12
+ K = torch.Tensor
13
+ V = torch.Tensor
14
+ KV = tuple[torch.Tensor, torch.Tensor]
15
+
16
+
17
+ @dataclass
18
+ class KVCache:
19
+ """
20
+ Key/Value cache for storing intermediate attention values
21
+ during multistep inference. Only stores KV cache for timewise layers, skipping spacewise layers.
22
+ """
23
+
24
+ batch_size: int
25
+ num_variates: int
26
+ transformer_layers: list
27
+ num_layers: int
28
+ embed_dim: int
29
+ num_heads: int
30
+ max_seq_len: int
31
+ device: torch.device = torch.device("cpu")
32
+ dtype: torch.dtype = torch.float32
33
+ use_memory_efficient_attention: bool = True
34
+
35
+ _keys: torch.Tensor = field(init=False)
36
+ _values: torch.Tensor = field(init=False)
37
+ _current_idx: torch.Tensor = field(init=False)
38
+ _layer_cache_map: torch.Tensor = field(init=False)
39
+
40
+ def __post_init__(self):
41
+ """
42
+ - Determine timewise vs. spacewise layers and allocate KV only for timewise.
43
+ - Create a fast tensor-based mapping from global layer_idx -> timewise layer_idx.
44
+ """
45
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
46
+ head_dim = self.embed_dim // self.num_heads
47
+
48
+ # Compute which layers are timewise
49
+ time_layer_indices = [
50
+ i
51
+ for i in range(self.num_layers)
52
+ if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention)
53
+ ]
54
+
55
+ time_layer_count = max(1, len(time_layer_indices)) # handle edge case for no timewise layers
56
+ # Allocate for only the timewise layers
57
+ if self.use_memory_efficient_attention:
58
+ shape = (
59
+ time_layer_count,
60
+ self.batch_size * self.num_variates,
61
+ self.max_seq_len,
62
+ self.num_heads,
63
+ head_dim,
64
+ )
65
+ else:
66
+ shape = (
67
+ time_layer_count,
68
+ self.batch_size * self.num_variates,
69
+ self.num_heads,
70
+ self.max_seq_len,
71
+ head_dim,
72
+ )
73
+ self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
74
+ self._values = torch.zeros_like(self._keys)
75
+ self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int)
76
+ # Build a tensor lookup for global -> timewise layer index (default to 0)
77
+ self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device)
78
+ for cache_idx, layer_idx in enumerate(time_layer_indices):
79
+ self._layer_cache_map[layer_idx] = int(cache_idx) # Assign correct indices
80
+
81
+ def __getitem__(self, layer_idx: int) -> KV:
82
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
83
+ end_idx = int(self._current_idx[cache_idx].item())
84
+
85
+ if self.use_memory_efficient_attention:
86
+ return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :]
87
+ else:
88
+ return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :]
89
+
90
+ def current_len(self, cache_idx: int) -> int:
91
+ return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0
92
+
93
+ def seq_len(self, layer_idx: int) -> int:
94
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
95
+ return self.current_len(cache_idx)
96
+
97
+ def append(self, layer_idx: int, kv: KV):
98
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
99
+ keys, values = kv
100
+
101
+ # Validate dimensions
102
+ assert keys.shape == values.shape, "keys and values must have the same shape"
103
+ assert keys.shape[0] == self.batch_size * self.num_variates, (
104
+ "keys and values must have batch_size * num_variates as their first dimension"
105
+ )
106
+
107
+ if self.use_memory_efficient_attention:
108
+ assert keys.shape[2] == self.num_heads, "keys and values must have num_heads as their third dimension"
109
+ else:
110
+ assert keys.shape[1] == self.num_heads, "keys and values must have num_heads as their second dimension"
111
+ assert keys.shape[3] == self.embed_dim // self.num_heads, (
112
+ "keys and values must have head_dim as their fourth dimension"
113
+ )
114
+
115
+ start_idx = self._current_idx[cache_idx]
116
+ if self.use_memory_efficient_attention:
117
+ end_idx = start_idx + keys.shape[1]
118
+ else:
119
+ end_idx = start_idx + keys.shape[2]
120
+ assert end_idx <= self.max_seq_len, (
121
+ f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}"
122
+ )
123
+
124
+ if self.use_memory_efficient_attention:
125
+ self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
126
+ self._values[cache_idx, :, start_idx:end_idx, :, :] = values
127
+ else:
128
+ self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
129
+ self._values[cache_idx, :, :, start_idx:end_idx, :] = values
130
+
131
+ self._current_idx[cache_idx] = end_idx
132
+
133
+ def reset(self):
134
+ self._keys.zero_()
135
+ self._values.zero_()
136
+ self._current_idx.zero_()
@@ -0,0 +1,89 @@
1
+ # Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
2
+ #
3
+ # This product includes software developed at Datadog (https://www.datadoghq.com/)
4
+ # Copyright 2025 Datadog, Inc.
5
+
6
+
7
+ import torch
8
+ from einops import rearrange
9
+
10
+ from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb, default
11
+
12
+
13
+ class TimeAwareRotaryEmbedding(RotaryEmbedding):
14
+ """
15
+ A variant of the rotary position embedding that (optionally) uses the time index
16
+ to compute the sinusoidal and cosine embeddings. This is useful for
17
+ time series data, where the time index is the most important positional
18
+ information.
19
+ """
20
+
21
+ def __init__(self, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ # If the parent stored `freqs` as a Parameter, remove it and register as a buffer
24
+ # Register buffer is needed for sharding with FSDP
25
+ if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter):
26
+ # Extract the underlying Tensor
27
+ freqs_data = self.freqs.data
28
+
29
+ # Remove `freqs` from the module's parameters
30
+ self._parameters.pop("freqs")
31
+
32
+ # Register as non-persistent buffer
33
+ self.register_buffer("freqs", freqs_data, persistent=False)
34
+
35
+ def rotate_queries_and_keys(
36
+ self,
37
+ q: torch.Tensor,
38
+ k: torch.Tensor,
39
+ seq_dim: int | None = None,
40
+ seq_pos: torch.Tensor | None = None,
41
+ seq_pos_offset: int = 0,
42
+ ):
43
+ """
44
+ This method is the same as the one on the base class, except it allows you to override
45
+ the sequence position tensor with a custom one. It also removes the ability
46
+ to cache the position encodings, since we have to compute them dynamically
47
+ based on the timesteps in the input data.
48
+ """
49
+ if seq_dim is None:
50
+ seq_dim = self.default_seq_dim
51
+
52
+ assert self.use_xpos
53
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
54
+
55
+ seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device))
56
+ seq = seq + seq_pos_offset # type: ignore
57
+
58
+ freqs = self.forward(seq)
59
+
60
+ scale = self.get_scale(seq).to(dtype)
61
+
62
+ # used for xformers
63
+ if seq_dim == -3:
64
+ num_heads = q.shape[-2]
65
+ freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
66
+ scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
67
+
68
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) # type: ignore
69
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) # type: ignore
70
+
71
+ rotated_q = rotated_q.type(q.dtype)
72
+ rotated_k = rotated_k.type(k.dtype)
73
+
74
+ return rotated_q, rotated_k
75
+
76
+ def get_scale(self, t: torch.Tensor, seq_len: int | None = None, offset=0):
77
+ """
78
+ This method is adapted closely from the base class, but it knows how to handle
79
+ when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different
80
+ sequence position vector for each time series).
81
+ """
82
+ assert self.use_xpos
83
+
84
+ power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
85
+
86
+ scale = self.scale ** rearrange(power, "... n -> ... n 1") # type: ignore
87
+ scale = torch.cat((scale, scale), dim=-1)
88
+
89
+ return scale
@@ -0,0 +1,342 @@
1
+ # Source: https://github.com/lucidrains/rotary-embedding-torch
2
+ # MIT License
3
+ #
4
+ # Copyright (c) 2021 Phil Wang
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ from __future__ import annotations
25
+
26
+ from math import pi
27
+ from typing import Literal
28
+
29
+ import torch
30
+ from einops import rearrange, repeat
31
+ from torch import Tensor, broadcast_tensors, einsum, is_tensor, nn, tensor
32
+ from torch.amp import autocast
33
+ from torch.nn import Module
34
+
35
+ # helper functions
36
+
37
+
38
+ def exists(val):
39
+ return val is not None
40
+
41
+
42
+ def default(val, d):
43
+ return val if exists(val) else d
44
+
45
+
46
+ def slice_at_dim(t, dim_slice: slice, *, dim):
47
+ dim += t.ndim if dim < 0 else 0
48
+ colons = [slice(None)] * t.ndim
49
+ colons[dim] = dim_slice
50
+ return t[tuple(colons)]
51
+
52
+
53
+ # rotary embedding helper functions
54
+
55
+
56
+ def rotate_half(x):
57
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
58
+ x1, x2 = x.unbind(dim=-1)
59
+ x = torch.stack((-x2, x1), dim=-1)
60
+ return rearrange(x, "... d r -> ... (d r)")
61
+
62
+
63
+ @autocast("cuda", enabled=False)
64
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None):
65
+ dtype = t.dtype
66
+
67
+ if not exists(freqs_seq_dim):
68
+ if freqs.ndim == 2 or t.ndim == 3:
69
+ freqs_seq_dim = 0
70
+
71
+ if t.ndim == 3 or exists(freqs_seq_dim):
72
+ seq_len = t.shape[seq_dim]
73
+ freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
74
+
75
+ rot_dim = freqs.shape[-1]
76
+ end_index = start_index + rot_dim
77
+
78
+ assert rot_dim <= t.shape[-1], (
79
+ f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
80
+ )
81
+
82
+ # Split t into three parts: left, middle (to be transformed), and right
83
+ t_left = t[..., :start_index]
84
+ t_middle = t[..., start_index:end_index]
85
+ t_right = t[..., end_index:]
86
+
87
+ # Apply rotary embeddings without modifying t in place
88
+ t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
89
+
90
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
91
+
92
+ return out.type(dtype)
93
+
94
+
95
+ # learned rotation helpers
96
+
97
+
98
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
99
+ if exists(freq_ranges):
100
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
101
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
102
+
103
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
104
+ return apply_rotary_emb(rotations, t, start_index=start_index)
105
+
106
+
107
+ # classes
108
+
109
+
110
+ class RotaryEmbedding(Module):
111
+ def __init__(
112
+ self,
113
+ dim,
114
+ custom_freqs: Tensor | None = None,
115
+ freqs_for: Literal["lang", "pixel", "constant"] = "lang",
116
+ theta=10000,
117
+ max_freq=10,
118
+ num_freqs=1,
119
+ learned_freq=False,
120
+ use_xpos=False,
121
+ xpos_scale_base=512,
122
+ interpolate_factor=1.0,
123
+ theta_rescale_factor=1.0,
124
+ seq_before_head_dim=False,
125
+ cache_if_possible=True,
126
+ cache_max_seq_len=8192,
127
+ ):
128
+ super().__init__()
129
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
130
+ # has some connection to NTK literature
131
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
132
+
133
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
134
+
135
+ self.freqs_for = freqs_for
136
+
137
+ if exists(custom_freqs):
138
+ freqs = custom_freqs
139
+ elif freqs_for == "lang":
140
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
141
+ elif freqs_for == "pixel":
142
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
143
+ elif freqs_for == "constant":
144
+ freqs = torch.ones(num_freqs).float()
145
+
146
+ self.cache_if_possible = cache_if_possible
147
+ self.cache_max_seq_len = cache_max_seq_len
148
+
149
+ self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
150
+ self.cached_freqs_seq_len = 0
151
+
152
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
153
+
154
+ self.learned_freq = learned_freq
155
+
156
+ # dummy for device
157
+
158
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
159
+
160
+ # default sequence dimension
161
+
162
+ self.seq_before_head_dim = seq_before_head_dim
163
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
164
+
165
+ # interpolation factors
166
+
167
+ assert interpolate_factor >= 1.0
168
+ self.interpolate_factor = interpolate_factor
169
+
170
+ # xpos
171
+
172
+ self.use_xpos = use_xpos
173
+
174
+ if not use_xpos:
175
+ return
176
+
177
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
178
+ self.scale_base = xpos_scale_base
179
+
180
+ self.register_buffer("scale", scale, persistent=False)
181
+ self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
182
+ self.cached_scales_seq_len = 0
183
+
184
+ # add apply_rotary_emb as static method
185
+
186
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
187
+
188
+ @property
189
+ def device(self):
190
+ return self.dummy.device
191
+
192
+ def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
193
+ device = default(device, self.device)
194
+ dtype = default(dtype, self.cached_freqs.dtype)
195
+
196
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
197
+
198
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
199
+ seq_dim = default(seq_dim, self.default_seq_dim)
200
+
201
+ assert not self.use_xpos or exists(scale), (
202
+ "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
203
+ )
204
+
205
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
206
+
207
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
208
+
209
+ freqs = self.forward(seq, seq_len=seq_len, offset=offset)
210
+
211
+ if seq_dim == -3:
212
+ freqs = rearrange(freqs, "n d -> n 1 d")
213
+
214
+ return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
215
+
216
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
217
+ dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
218
+
219
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
220
+ assert q_len <= k_len
221
+
222
+ q_scale = k_scale = 1.0
223
+
224
+ if self.use_xpos:
225
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
226
+
227
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
228
+ k_scale = self.get_scale(seq).type(dtype)
229
+
230
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
231
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
232
+
233
+ rotated_q = rotated_q.type(q.dtype)
234
+ rotated_k = rotated_k.type(k.dtype)
235
+
236
+ return rotated_q, rotated_k
237
+
238
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
239
+ seq_dim = default(seq_dim, self.default_seq_dim)
240
+
241
+ assert self.use_xpos
242
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
243
+
244
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
245
+
246
+ freqs = self.forward(seq, seq_len=seq_len)
247
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
248
+
249
+ if seq_dim == -3:
250
+ freqs = rearrange(freqs, "n d -> n 1 d")
251
+ scale = rearrange(scale, "n d -> n 1 d")
252
+
253
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
254
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
255
+
256
+ rotated_q = rotated_q.type(q.dtype)
257
+ rotated_k = rotated_k.type(k.dtype)
258
+
259
+ return rotated_q, rotated_k
260
+
261
+ def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
262
+ assert self.use_xpos
263
+
264
+ should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
265
+
266
+ if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len:
267
+ return self.cached_scales[offset : (offset + seq_len)]
268
+
269
+ scale = 1.0
270
+ if self.use_xpos:
271
+ power = (t - len(t) // 2) / self.scale_base
272
+ scale = self.scale ** rearrange(power, "n -> n 1")
273
+ scale = repeat(scale, "n d -> n (d r)", r=2)
274
+
275
+ if should_cache and offset == 0:
276
+ self.cached_scales[:seq_len] = scale.detach()
277
+ self.cached_scales_seq_len = seq_len
278
+
279
+ return scale
280
+
281
+ def get_axial_freqs(self, *dims, offsets: (tuple[int | float, ...] | Tensor | None) = None):
282
+ Colon = slice(None)
283
+ all_freqs = []
284
+
285
+ # handle offset
286
+
287
+ if exists(offsets):
288
+ if not is_tensor(offsets):
289
+ offsets = tensor(offsets)
290
+
291
+ assert len(offsets) == len(dims)
292
+
293
+ # get frequencies for each axis
294
+
295
+ for ind, dim in enumerate(dims):
296
+ offset = 0
297
+ if exists(offsets):
298
+ offset = offsets[ind]
299
+
300
+ if self.freqs_for == "pixel":
301
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
302
+ else:
303
+ pos = torch.arange(dim, device=self.device)
304
+
305
+ pos = pos + offset
306
+
307
+ freqs = self.forward(pos, seq_len=dim)
308
+
309
+ all_axis = [None] * len(dims)
310
+ all_axis[ind] = Colon
311
+
312
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
313
+ all_freqs.append(freqs[new_axis_slice])
314
+
315
+ # concat all freqs
316
+
317
+ all_freqs = broadcast_tensors(*all_freqs)
318
+ return torch.cat(all_freqs, dim=-1)
319
+
320
+ @autocast("cuda", enabled=False)
321
+ def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
322
+ should_cache = (
323
+ self.cache_if_possible
324
+ and not self.learned_freq
325
+ and exists(seq_len)
326
+ and self.freqs_for != "pixel"
327
+ and (offset + seq_len) <= self.cache_max_seq_len
328
+ )
329
+
330
+ if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len:
331
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
332
+
333
+ freqs = self.freqs
334
+
335
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
336
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
337
+
338
+ if should_cache and offset == 0:
339
+ self.cached_freqs[:seq_len] = freqs.detach()
340
+ self.cached_freqs_seq_len = seq_len
341
+
342
+ return freqs