autogluon.timeseries 1.4.1b20250906__py3-none-any.whl → 1.4.1b20251210__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/hyperparameter_presets.py +2 -2
- autogluon/timeseries/dataset/ts_dataframe.py +97 -86
- autogluon/timeseries/learner.py +68 -35
- autogluon/timeseries/metrics/__init__.py +4 -4
- autogluon/timeseries/metrics/abstract.py +8 -8
- autogluon/timeseries/metrics/point.py +9 -9
- autogluon/timeseries/metrics/quantile.py +5 -5
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +4 -1
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -39
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +8 -8
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +58 -62
- autogluon/timeseries/models/autogluon_tabular/per_step.py +26 -15
- autogluon/timeseries/models/autogluon_tabular/transforms.py +11 -9
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +361 -0
- autogluon/timeseries/models/chronos/model.py +125 -87
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +68 -36
- autogluon/timeseries/models/ensemble/__init__.py +34 -2
- autogluon/timeseries/models/ensemble/abstract.py +5 -42
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +236 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +73 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +167 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/{greedy.py → ensemble_selection.py} +41 -61
- autogluon/timeseries/models/ensemble/per_item_greedy.py +162 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +40 -0
- autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +6 -16
- autogluon/timeseries/models/ensemble/weighted/greedy.py +57 -0
- autogluon/timeseries/models/gluonts/abstract.py +25 -25
- autogluon/timeseries/models/gluonts/dataset.py +11 -11
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +15 -18
- autogluon/timeseries/models/local/naive.py +2 -2
- autogluon/timeseries/models/local/npts.py +1 -1
- autogluon/timeseries/models/local/statsforecast.py +12 -12
- autogluon/timeseries/models/multi_window/multi_window_model.py +39 -24
- autogluon/timeseries/models/registry.py +3 -4
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +118 -0
- autogluon/timeseries/models/toto/model.py +236 -0
- autogluon/timeseries/predictor.py +301 -103
- autogluon/timeseries/regressor.py +27 -30
- autogluon/timeseries/splitter.py +3 -27
- autogluon/timeseries/trainer/ensemble_composer.py +439 -0
- autogluon/timeseries/trainer/model_set_builder.py +9 -9
- autogluon/timeseries/trainer/prediction_cache.py +16 -16
- autogluon/timeseries/trainer/trainer.py +300 -275
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/covariate_scaler.py +8 -8
- autogluon/timeseries/transforms/target_scaler.py +15 -15
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +1 -3
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/features.py +18 -14
- autogluon/timeseries/utils/forecast.py +6 -7
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.4.1b20251210-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/METADATA +39 -22
- autogluon_timeseries-1.4.1b20251210.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/WHEEL +1 -1
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -580
- autogluon.timeseries-1.4.1b20250906-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.4.1b20250906.dist-info/RECORD +0 -75
- {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.4.1b20250906.dist-info → autogluon_timeseries-1.4.1b20251210.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from abc import ABC
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from gluonts.torch.distributions import AffineTransformed
|
|
11
|
+
from gluonts.torch.distributions.studentT import StudentT
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DistributionOutput(ABC, torch.nn.Module):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StudentTOutput(DistributionOutput):
|
|
19
|
+
def __init__(self, embed_dim):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.embed_dim = embed_dim
|
|
22
|
+
self.df = torch.nn.Linear(embed_dim, 1)
|
|
23
|
+
self.loc_proj = torch.nn.Linear(embed_dim, 1)
|
|
24
|
+
self.scale_proj = torch.nn.Linear(embed_dim, 1)
|
|
25
|
+
|
|
26
|
+
def forward(self, inputs, loc=None, scale=None):
|
|
27
|
+
eps = torch.finfo(inputs.dtype).eps
|
|
28
|
+
df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
|
|
29
|
+
base_loc = self.loc_proj(inputs).squeeze(-1)
|
|
30
|
+
base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
|
|
31
|
+
|
|
32
|
+
base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) # type: ignore
|
|
33
|
+
|
|
34
|
+
if loc is not None and scale is not None:
|
|
35
|
+
return AffineTransformed(
|
|
36
|
+
base_dist,
|
|
37
|
+
loc=loc,
|
|
38
|
+
scale=scale,
|
|
39
|
+
)
|
|
40
|
+
return base_dist
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MixtureOfStudentTsOutput(DistributionOutput):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
embed_dim,
|
|
47
|
+
k_components,
|
|
48
|
+
):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.embed_dim = embed_dim
|
|
51
|
+
self.k_components = k_components
|
|
52
|
+
|
|
53
|
+
self.df = torch.nn.Linear(embed_dim, k_components)
|
|
54
|
+
self.loc_proj = torch.nn.Linear(embed_dim, k_components)
|
|
55
|
+
self.scale_proj = torch.nn.Linear(embed_dim, k_components)
|
|
56
|
+
self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
|
|
57
|
+
|
|
58
|
+
def forward(self, inputs, loc=None, scale=None):
|
|
59
|
+
df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
|
|
60
|
+
loc = self.loc_proj(inputs)
|
|
61
|
+
scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
|
|
62
|
+
logits = self.mixture_weights(inputs)
|
|
63
|
+
probs = F.softmax(logits, dim=-1)
|
|
64
|
+
components = StudentT(df, loc, scale)
|
|
65
|
+
mixture_distribution = torch.distributions.Categorical(probs=probs)
|
|
66
|
+
|
|
67
|
+
return torch.distributions.MixtureSameFamily(
|
|
68
|
+
mixture_distribution,
|
|
69
|
+
components,
|
|
70
|
+
)
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .attention import TimeWiseMultiheadAttention
|
|
11
|
+
|
|
12
|
+
K = torch.Tensor
|
|
13
|
+
V = torch.Tensor
|
|
14
|
+
KV = tuple[torch.Tensor, torch.Tensor]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class KVCache:
|
|
19
|
+
"""
|
|
20
|
+
Key/Value cache for storing intermediate attention values
|
|
21
|
+
during multistep inference. Only stores KV cache for timewise layers, skipping spacewise layers.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
batch_size: int
|
|
25
|
+
num_variates: int
|
|
26
|
+
transformer_layers: list
|
|
27
|
+
num_layers: int
|
|
28
|
+
embed_dim: int
|
|
29
|
+
num_heads: int
|
|
30
|
+
max_seq_len: int
|
|
31
|
+
device: torch.device = torch.device("cpu")
|
|
32
|
+
dtype: torch.dtype = torch.float32
|
|
33
|
+
use_memory_efficient_attention: bool = True
|
|
34
|
+
|
|
35
|
+
_keys: torch.Tensor = field(init=False)
|
|
36
|
+
_values: torch.Tensor = field(init=False)
|
|
37
|
+
_current_idx: torch.Tensor = field(init=False)
|
|
38
|
+
_layer_cache_map: torch.Tensor = field(init=False)
|
|
39
|
+
|
|
40
|
+
def __post_init__(self):
|
|
41
|
+
"""
|
|
42
|
+
- Determine timewise vs. spacewise layers and allocate KV only for timewise.
|
|
43
|
+
- Create a fast tensor-based mapping from global layer_idx -> timewise layer_idx.
|
|
44
|
+
"""
|
|
45
|
+
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
46
|
+
head_dim = self.embed_dim // self.num_heads
|
|
47
|
+
|
|
48
|
+
# Compute which layers are timewise
|
|
49
|
+
time_layer_indices = [
|
|
50
|
+
i
|
|
51
|
+
for i in range(self.num_layers)
|
|
52
|
+
if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention)
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
time_layer_count = max(1, len(time_layer_indices)) # handle edge case for no timewise layers
|
|
56
|
+
# Allocate for only the timewise layers
|
|
57
|
+
if self.use_memory_efficient_attention:
|
|
58
|
+
shape = (
|
|
59
|
+
time_layer_count,
|
|
60
|
+
self.batch_size * self.num_variates,
|
|
61
|
+
self.max_seq_len,
|
|
62
|
+
self.num_heads,
|
|
63
|
+
head_dim,
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
shape = (
|
|
67
|
+
time_layer_count,
|
|
68
|
+
self.batch_size * self.num_variates,
|
|
69
|
+
self.num_heads,
|
|
70
|
+
self.max_seq_len,
|
|
71
|
+
head_dim,
|
|
72
|
+
)
|
|
73
|
+
self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
|
74
|
+
self._values = torch.zeros_like(self._keys)
|
|
75
|
+
self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int)
|
|
76
|
+
# Build a tensor lookup for global -> timewise layer index (default to 0)
|
|
77
|
+
self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device)
|
|
78
|
+
for cache_idx, layer_idx in enumerate(time_layer_indices):
|
|
79
|
+
self._layer_cache_map[layer_idx] = int(cache_idx) # Assign correct indices
|
|
80
|
+
|
|
81
|
+
def __getitem__(self, layer_idx: int) -> KV:
|
|
82
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
83
|
+
end_idx = int(self._current_idx[cache_idx].item())
|
|
84
|
+
|
|
85
|
+
if self.use_memory_efficient_attention:
|
|
86
|
+
return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :]
|
|
87
|
+
else:
|
|
88
|
+
return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :]
|
|
89
|
+
|
|
90
|
+
def current_len(self, cache_idx: int) -> int:
|
|
91
|
+
return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0
|
|
92
|
+
|
|
93
|
+
def seq_len(self, layer_idx: int) -> int:
|
|
94
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
95
|
+
return self.current_len(cache_idx)
|
|
96
|
+
|
|
97
|
+
def append(self, layer_idx: int, kv: KV):
|
|
98
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
99
|
+
keys, values = kv
|
|
100
|
+
|
|
101
|
+
# Validate dimensions
|
|
102
|
+
assert keys.shape == values.shape, "keys and values must have the same shape"
|
|
103
|
+
assert keys.shape[0] == self.batch_size * self.num_variates, (
|
|
104
|
+
"keys and values must have batch_size * num_variates as their first dimension"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if self.use_memory_efficient_attention:
|
|
108
|
+
assert keys.shape[2] == self.num_heads, "keys and values must have num_heads as their third dimension"
|
|
109
|
+
else:
|
|
110
|
+
assert keys.shape[1] == self.num_heads, "keys and values must have num_heads as their second dimension"
|
|
111
|
+
assert keys.shape[3] == self.embed_dim // self.num_heads, (
|
|
112
|
+
"keys and values must have head_dim as their fourth dimension"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
start_idx = self._current_idx[cache_idx]
|
|
116
|
+
if self.use_memory_efficient_attention:
|
|
117
|
+
end_idx = start_idx + keys.shape[1]
|
|
118
|
+
else:
|
|
119
|
+
end_idx = start_idx + keys.shape[2]
|
|
120
|
+
assert end_idx <= self.max_seq_len, (
|
|
121
|
+
f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if self.use_memory_efficient_attention:
|
|
125
|
+
self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
|
|
126
|
+
self._values[cache_idx, :, start_idx:end_idx, :, :] = values
|
|
127
|
+
else:
|
|
128
|
+
self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
|
|
129
|
+
self._values[cache_idx, :, :, start_idx:end_idx, :] = values
|
|
130
|
+
|
|
131
|
+
self._current_idx[cache_idx] = end_idx
|
|
132
|
+
|
|
133
|
+
def reset(self):
|
|
134
|
+
self._keys.zero_()
|
|
135
|
+
self._values.zero_()
|
|
136
|
+
self._current_idx.zero_()
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
|
|
10
|
+
from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb, default
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TimeAwareRotaryEmbedding(RotaryEmbedding):
|
|
14
|
+
"""
|
|
15
|
+
A variant of the rotary position embedding that (optionally) uses the time index
|
|
16
|
+
to compute the sinusoidal and cosine embeddings. This is useful for
|
|
17
|
+
time series data, where the time index is the most important positional
|
|
18
|
+
information.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, *args, **kwargs):
|
|
22
|
+
super().__init__(*args, **kwargs)
|
|
23
|
+
# If the parent stored `freqs` as a Parameter, remove it and register as a buffer
|
|
24
|
+
# Register buffer is needed for sharding with FSDP
|
|
25
|
+
if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter):
|
|
26
|
+
# Extract the underlying Tensor
|
|
27
|
+
freqs_data = self.freqs.data
|
|
28
|
+
|
|
29
|
+
# Remove `freqs` from the module's parameters
|
|
30
|
+
self._parameters.pop("freqs")
|
|
31
|
+
|
|
32
|
+
# Register as non-persistent buffer
|
|
33
|
+
self.register_buffer("freqs", freqs_data, persistent=False)
|
|
34
|
+
|
|
35
|
+
def rotate_queries_and_keys(
|
|
36
|
+
self,
|
|
37
|
+
q: torch.Tensor,
|
|
38
|
+
k: torch.Tensor,
|
|
39
|
+
seq_dim: int | None = None,
|
|
40
|
+
seq_pos: torch.Tensor | None = None,
|
|
41
|
+
seq_pos_offset: int = 0,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
This method is the same as the one on the base class, except it allows you to override
|
|
45
|
+
the sequence position tensor with a custom one. It also removes the ability
|
|
46
|
+
to cache the position encodings, since we have to compute them dynamically
|
|
47
|
+
based on the timesteps in the input data.
|
|
48
|
+
"""
|
|
49
|
+
if seq_dim is None:
|
|
50
|
+
seq_dim = self.default_seq_dim
|
|
51
|
+
|
|
52
|
+
assert self.use_xpos
|
|
53
|
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
|
54
|
+
|
|
55
|
+
seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device))
|
|
56
|
+
seq = seq + seq_pos_offset # type: ignore
|
|
57
|
+
|
|
58
|
+
freqs = self.forward(seq)
|
|
59
|
+
|
|
60
|
+
scale = self.get_scale(seq).to(dtype)
|
|
61
|
+
|
|
62
|
+
# used for xformers
|
|
63
|
+
if seq_dim == -3:
|
|
64
|
+
num_heads = q.shape[-2]
|
|
65
|
+
freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
|
|
66
|
+
scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
|
|
67
|
+
|
|
68
|
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) # type: ignore
|
|
69
|
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) # type: ignore
|
|
70
|
+
|
|
71
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
72
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
73
|
+
|
|
74
|
+
return rotated_q, rotated_k
|
|
75
|
+
|
|
76
|
+
def get_scale(self, t: torch.Tensor, seq_len: int | None = None, offset=0):
|
|
77
|
+
"""
|
|
78
|
+
This method is adapted closely from the base class, but it knows how to handle
|
|
79
|
+
when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different
|
|
80
|
+
sequence position vector for each time series).
|
|
81
|
+
"""
|
|
82
|
+
assert self.use_xpos
|
|
83
|
+
|
|
84
|
+
power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
|
|
85
|
+
|
|
86
|
+
scale = self.scale ** rearrange(power, "... n -> ... n 1") # type: ignore
|
|
87
|
+
scale = torch.cat((scale, scale), dim=-1)
|
|
88
|
+
|
|
89
|
+
return scale
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
# Source: https://github.com/lucidrains/rotary-embedding-torch
|
|
2
|
+
# MIT License
|
|
3
|
+
#
|
|
4
|
+
# Copyright (c) 2021 Phil Wang
|
|
5
|
+
|
|
6
|
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
7
|
+
# of this software and associated documentation files (the "Software"), to deal
|
|
8
|
+
# in the Software without restriction, including without limitation the rights
|
|
9
|
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
10
|
+
# copies of the Software, and to permit persons to whom the Software is
|
|
11
|
+
# furnished to do so, subject to the following conditions:
|
|
12
|
+
|
|
13
|
+
# The above copyright notice and this permission notice shall be included in all
|
|
14
|
+
# copies or substantial portions of the Software.
|
|
15
|
+
|
|
16
|
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
17
|
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
18
|
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
19
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
20
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
21
|
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
22
|
+
# SOFTWARE.
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
from math import pi
|
|
27
|
+
from typing import Literal
|
|
28
|
+
|
|
29
|
+
import torch
|
|
30
|
+
from einops import rearrange, repeat
|
|
31
|
+
from torch import Tensor, broadcast_tensors, einsum, is_tensor, nn, tensor
|
|
32
|
+
from torch.amp import autocast
|
|
33
|
+
from torch.nn import Module
|
|
34
|
+
|
|
35
|
+
# helper functions
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def exists(val):
|
|
39
|
+
return val is not None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def default(val, d):
|
|
43
|
+
return val if exists(val) else d
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def slice_at_dim(t, dim_slice: slice, *, dim):
|
|
47
|
+
dim += t.ndim if dim < 0 else 0
|
|
48
|
+
colons = [slice(None)] * t.ndim
|
|
49
|
+
colons[dim] = dim_slice
|
|
50
|
+
return t[tuple(colons)]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# rotary embedding helper functions
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def rotate_half(x):
|
|
57
|
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
|
58
|
+
x1, x2 = x.unbind(dim=-1)
|
|
59
|
+
x = torch.stack((-x2, x1), dim=-1)
|
|
60
|
+
return rearrange(x, "... d r -> ... (d r)")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@autocast("cuda", enabled=False)
|
|
64
|
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None):
|
|
65
|
+
dtype = t.dtype
|
|
66
|
+
|
|
67
|
+
if not exists(freqs_seq_dim):
|
|
68
|
+
if freqs.ndim == 2 or t.ndim == 3:
|
|
69
|
+
freqs_seq_dim = 0
|
|
70
|
+
|
|
71
|
+
if t.ndim == 3 or exists(freqs_seq_dim):
|
|
72
|
+
seq_len = t.shape[seq_dim]
|
|
73
|
+
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
|
|
74
|
+
|
|
75
|
+
rot_dim = freqs.shape[-1]
|
|
76
|
+
end_index = start_index + rot_dim
|
|
77
|
+
|
|
78
|
+
assert rot_dim <= t.shape[-1], (
|
|
79
|
+
f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Split t into three parts: left, middle (to be transformed), and right
|
|
83
|
+
t_left = t[..., :start_index]
|
|
84
|
+
t_middle = t[..., start_index:end_index]
|
|
85
|
+
t_right = t[..., end_index:]
|
|
86
|
+
|
|
87
|
+
# Apply rotary embeddings without modifying t in place
|
|
88
|
+
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
|
89
|
+
|
|
90
|
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
|
91
|
+
|
|
92
|
+
return out.type(dtype)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# learned rotation helpers
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
|
99
|
+
if exists(freq_ranges):
|
|
100
|
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
|
101
|
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
|
102
|
+
|
|
103
|
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
|
104
|
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# classes
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class RotaryEmbedding(Module):
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
dim,
|
|
114
|
+
custom_freqs: Tensor | None = None,
|
|
115
|
+
freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
|
116
|
+
theta=10000,
|
|
117
|
+
max_freq=10,
|
|
118
|
+
num_freqs=1,
|
|
119
|
+
learned_freq=False,
|
|
120
|
+
use_xpos=False,
|
|
121
|
+
xpos_scale_base=512,
|
|
122
|
+
interpolate_factor=1.0,
|
|
123
|
+
theta_rescale_factor=1.0,
|
|
124
|
+
seq_before_head_dim=False,
|
|
125
|
+
cache_if_possible=True,
|
|
126
|
+
cache_max_seq_len=8192,
|
|
127
|
+
):
|
|
128
|
+
super().__init__()
|
|
129
|
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
130
|
+
# has some connection to NTK literature
|
|
131
|
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
132
|
+
|
|
133
|
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
134
|
+
|
|
135
|
+
self.freqs_for = freqs_for
|
|
136
|
+
|
|
137
|
+
if exists(custom_freqs):
|
|
138
|
+
freqs = custom_freqs
|
|
139
|
+
elif freqs_for == "lang":
|
|
140
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
141
|
+
elif freqs_for == "pixel":
|
|
142
|
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
|
143
|
+
elif freqs_for == "constant":
|
|
144
|
+
freqs = torch.ones(num_freqs).float()
|
|
145
|
+
|
|
146
|
+
self.cache_if_possible = cache_if_possible
|
|
147
|
+
self.cache_max_seq_len = cache_max_seq_len
|
|
148
|
+
|
|
149
|
+
self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
|
150
|
+
self.cached_freqs_seq_len = 0
|
|
151
|
+
|
|
152
|
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
|
153
|
+
|
|
154
|
+
self.learned_freq = learned_freq
|
|
155
|
+
|
|
156
|
+
# dummy for device
|
|
157
|
+
|
|
158
|
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
|
159
|
+
|
|
160
|
+
# default sequence dimension
|
|
161
|
+
|
|
162
|
+
self.seq_before_head_dim = seq_before_head_dim
|
|
163
|
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
|
164
|
+
|
|
165
|
+
# interpolation factors
|
|
166
|
+
|
|
167
|
+
assert interpolate_factor >= 1.0
|
|
168
|
+
self.interpolate_factor = interpolate_factor
|
|
169
|
+
|
|
170
|
+
# xpos
|
|
171
|
+
|
|
172
|
+
self.use_xpos = use_xpos
|
|
173
|
+
|
|
174
|
+
if not use_xpos:
|
|
175
|
+
return
|
|
176
|
+
|
|
177
|
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
178
|
+
self.scale_base = xpos_scale_base
|
|
179
|
+
|
|
180
|
+
self.register_buffer("scale", scale, persistent=False)
|
|
181
|
+
self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
|
182
|
+
self.cached_scales_seq_len = 0
|
|
183
|
+
|
|
184
|
+
# add apply_rotary_emb as static method
|
|
185
|
+
|
|
186
|
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def device(self):
|
|
190
|
+
return self.dummy.device
|
|
191
|
+
|
|
192
|
+
def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
|
|
193
|
+
device = default(device, self.device)
|
|
194
|
+
dtype = default(dtype, self.cached_freqs.dtype)
|
|
195
|
+
|
|
196
|
+
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
|
197
|
+
|
|
198
|
+
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
|
|
199
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
200
|
+
|
|
201
|
+
assert not self.use_xpos or exists(scale), (
|
|
202
|
+
"you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
|
206
|
+
|
|
207
|
+
seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
|
|
208
|
+
|
|
209
|
+
freqs = self.forward(seq, seq_len=seq_len, offset=offset)
|
|
210
|
+
|
|
211
|
+
if seq_dim == -3:
|
|
212
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
213
|
+
|
|
214
|
+
return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
|
|
215
|
+
|
|
216
|
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
|
217
|
+
dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
|
|
218
|
+
|
|
219
|
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
|
220
|
+
assert q_len <= k_len
|
|
221
|
+
|
|
222
|
+
q_scale = k_scale = 1.0
|
|
223
|
+
|
|
224
|
+
if self.use_xpos:
|
|
225
|
+
seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
|
|
226
|
+
|
|
227
|
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
|
228
|
+
k_scale = self.get_scale(seq).type(dtype)
|
|
229
|
+
|
|
230
|
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
|
|
231
|
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
|
|
232
|
+
|
|
233
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
234
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
235
|
+
|
|
236
|
+
return rotated_q, rotated_k
|
|
237
|
+
|
|
238
|
+
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
|
239
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
240
|
+
|
|
241
|
+
assert self.use_xpos
|
|
242
|
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
|
243
|
+
|
|
244
|
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
|
245
|
+
|
|
246
|
+
freqs = self.forward(seq, seq_len=seq_len)
|
|
247
|
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
|
248
|
+
|
|
249
|
+
if seq_dim == -3:
|
|
250
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
251
|
+
scale = rearrange(scale, "n d -> n 1 d")
|
|
252
|
+
|
|
253
|
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
|
254
|
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
|
255
|
+
|
|
256
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
257
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
258
|
+
|
|
259
|
+
return rotated_q, rotated_k
|
|
260
|
+
|
|
261
|
+
def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
262
|
+
assert self.use_xpos
|
|
263
|
+
|
|
264
|
+
should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
|
|
265
|
+
|
|
266
|
+
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len:
|
|
267
|
+
return self.cached_scales[offset : (offset + seq_len)]
|
|
268
|
+
|
|
269
|
+
scale = 1.0
|
|
270
|
+
if self.use_xpos:
|
|
271
|
+
power = (t - len(t) // 2) / self.scale_base
|
|
272
|
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
|
273
|
+
scale = repeat(scale, "n d -> n (d r)", r=2)
|
|
274
|
+
|
|
275
|
+
if should_cache and offset == 0:
|
|
276
|
+
self.cached_scales[:seq_len] = scale.detach()
|
|
277
|
+
self.cached_scales_seq_len = seq_len
|
|
278
|
+
|
|
279
|
+
return scale
|
|
280
|
+
|
|
281
|
+
def get_axial_freqs(self, *dims, offsets: (tuple[int | float, ...] | Tensor | None) = None):
|
|
282
|
+
Colon = slice(None)
|
|
283
|
+
all_freqs = []
|
|
284
|
+
|
|
285
|
+
# handle offset
|
|
286
|
+
|
|
287
|
+
if exists(offsets):
|
|
288
|
+
if not is_tensor(offsets):
|
|
289
|
+
offsets = tensor(offsets)
|
|
290
|
+
|
|
291
|
+
assert len(offsets) == len(dims)
|
|
292
|
+
|
|
293
|
+
# get frequencies for each axis
|
|
294
|
+
|
|
295
|
+
for ind, dim in enumerate(dims):
|
|
296
|
+
offset = 0
|
|
297
|
+
if exists(offsets):
|
|
298
|
+
offset = offsets[ind]
|
|
299
|
+
|
|
300
|
+
if self.freqs_for == "pixel":
|
|
301
|
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
|
302
|
+
else:
|
|
303
|
+
pos = torch.arange(dim, device=self.device)
|
|
304
|
+
|
|
305
|
+
pos = pos + offset
|
|
306
|
+
|
|
307
|
+
freqs = self.forward(pos, seq_len=dim)
|
|
308
|
+
|
|
309
|
+
all_axis = [None] * len(dims)
|
|
310
|
+
all_axis[ind] = Colon
|
|
311
|
+
|
|
312
|
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
|
313
|
+
all_freqs.append(freqs[new_axis_slice])
|
|
314
|
+
|
|
315
|
+
# concat all freqs
|
|
316
|
+
|
|
317
|
+
all_freqs = broadcast_tensors(*all_freqs)
|
|
318
|
+
return torch.cat(all_freqs, dim=-1)
|
|
319
|
+
|
|
320
|
+
@autocast("cuda", enabled=False)
|
|
321
|
+
def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
322
|
+
should_cache = (
|
|
323
|
+
self.cache_if_possible
|
|
324
|
+
and not self.learned_freq
|
|
325
|
+
and exists(seq_len)
|
|
326
|
+
and self.freqs_for != "pixel"
|
|
327
|
+
and (offset + seq_len) <= self.cache_max_seq_len
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len:
|
|
331
|
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
|
332
|
+
|
|
333
|
+
freqs = self.freqs
|
|
334
|
+
|
|
335
|
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
|
336
|
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
|
337
|
+
|
|
338
|
+
if should_cache and offset == 0:
|
|
339
|
+
self.cached_freqs[:seq_len] = freqs.detach()
|
|
340
|
+
self.cached_freqs_seq_len = seq_len
|
|
341
|
+
|
|
342
|
+
return freqs
|