autogluon.timeseries 1.4.1b20250907__py3-none-any.whl → 1.5.1b20260122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- autogluon/timeseries/configs/hyperparameter_presets.py +13 -28
- autogluon/timeseries/configs/predictor_presets.py +23 -39
- autogluon/timeseries/dataset/ts_dataframe.py +97 -86
- autogluon/timeseries/learner.py +70 -35
- autogluon/timeseries/metrics/__init__.py +4 -4
- autogluon/timeseries/metrics/abstract.py +8 -8
- autogluon/timeseries/metrics/point.py +9 -9
- autogluon/timeseries/metrics/quantile.py +5 -5
- autogluon/timeseries/metrics/utils.py +4 -4
- autogluon/timeseries/models/__init__.py +4 -1
- autogluon/timeseries/models/abstract/abstract_timeseries_model.py +52 -50
- autogluon/timeseries/models/abstract/model_trial.py +2 -1
- autogluon/timeseries/models/abstract/tunable.py +8 -8
- autogluon/timeseries/models/autogluon_tabular/mlforecast.py +58 -62
- autogluon/timeseries/models/autogluon_tabular/per_step.py +27 -16
- autogluon/timeseries/models/autogluon_tabular/transforms.py +11 -9
- autogluon/timeseries/models/chronos/__init__.py +2 -1
- autogluon/timeseries/models/chronos/chronos2.py +395 -0
- autogluon/timeseries/models/chronos/model.py +127 -89
- autogluon/timeseries/models/chronos/{pipeline/utils.py → utils.py} +69 -37
- autogluon/timeseries/models/ensemble/__init__.py +36 -2
- autogluon/timeseries/models/ensemble/abstract.py +14 -46
- autogluon/timeseries/models/ensemble/array_based/__init__.py +3 -0
- autogluon/timeseries/models/ensemble/array_based/abstract.py +240 -0
- autogluon/timeseries/models/ensemble/array_based/models.py +185 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/__init__.py +12 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/abstract.py +88 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/linear_stacker.py +186 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/per_quantile_tabular.py +94 -0
- autogluon/timeseries/models/ensemble/array_based/regressor/tabular.py +107 -0
- autogluon/timeseries/models/ensemble/{greedy.py → ensemble_selection.py} +41 -61
- autogluon/timeseries/models/ensemble/per_item_greedy.py +172 -0
- autogluon/timeseries/models/ensemble/weighted/__init__.py +8 -0
- autogluon/timeseries/models/ensemble/weighted/abstract.py +45 -0
- autogluon/timeseries/models/ensemble/{basic.py → weighted/basic.py} +25 -22
- autogluon/timeseries/models/ensemble/weighted/greedy.py +64 -0
- autogluon/timeseries/models/gluonts/abstract.py +32 -31
- autogluon/timeseries/models/gluonts/dataset.py +11 -11
- autogluon/timeseries/models/gluonts/models.py +0 -7
- autogluon/timeseries/models/local/__init__.py +0 -7
- autogluon/timeseries/models/local/abstract_local_model.py +15 -18
- autogluon/timeseries/models/local/naive.py +2 -2
- autogluon/timeseries/models/local/npts.py +7 -1
- autogluon/timeseries/models/local/statsforecast.py +13 -13
- autogluon/timeseries/models/multi_window/multi_window_model.py +39 -24
- autogluon/timeseries/models/registry.py +3 -4
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +196 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +89 -0
- autogluon/timeseries/models/toto/_internal/backbone/rotary_embedding_torch.py +342 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +305 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +200 -0
- autogluon/timeseries/models/toto/model.py +249 -0
- autogluon/timeseries/predictor.py +541 -162
- autogluon/timeseries/regressor.py +27 -30
- autogluon/timeseries/splitter.py +3 -27
- autogluon/timeseries/trainer/ensemble_composer.py +444 -0
- autogluon/timeseries/trainer/model_set_builder.py +9 -9
- autogluon/timeseries/trainer/prediction_cache.py +16 -16
- autogluon/timeseries/trainer/trainer.py +300 -279
- autogluon/timeseries/trainer/utils.py +17 -0
- autogluon/timeseries/transforms/covariate_scaler.py +8 -8
- autogluon/timeseries/transforms/target_scaler.py +15 -15
- autogluon/timeseries/utils/constants.py +10 -0
- autogluon/timeseries/utils/datetime/lags.py +1 -3
- autogluon/timeseries/utils/datetime/seasonality.py +1 -3
- autogluon/timeseries/utils/features.py +31 -14
- autogluon/timeseries/utils/forecast.py +6 -7
- autogluon/timeseries/utils/timer.py +173 -0
- autogluon/timeseries/version.py +1 -1
- autogluon.timeseries-1.5.1b20260122-py3.11-nspkg.pth +1 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/METADATA +39 -22
- autogluon_timeseries-1.5.1b20260122.dist-info/RECORD +103 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/WHEEL +1 -1
- autogluon/timeseries/evaluator.py +0 -6
- autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -10
- autogluon/timeseries/models/chronos/pipeline/base.py +0 -160
- autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -544
- autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -580
- autogluon.timeseries-1.4.1b20250907-py3.9-nspkg.pth +0 -1
- autogluon.timeseries-1.4.1b20250907.dist-info/RECORD +0 -75
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/LICENSE +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info/licenses}/NOTICE +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.4.1b20250907.dist-info → autogluon_timeseries-1.5.1b20260122.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from typing import NamedTuple
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .distribution import MixtureOfStudentTsOutput
|
|
12
|
+
from .kvcache import KVCache
|
|
13
|
+
from .scaler import CausalPatchStdMeanScaler
|
|
14
|
+
from .transformer import Transformer
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TotoOutput(NamedTuple):
|
|
18
|
+
"""
|
|
19
|
+
Output of the Toto model. Contains the output distribution, the location parameters,
|
|
20
|
+
and the scale parameters.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
distribution: torch.distributions.Distribution
|
|
24
|
+
loc: torch.Tensor
|
|
25
|
+
scale: torch.Tensor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def patchify_id_mask(id_mask: torch.Tensor, patch_size: int) -> torch.Tensor:
|
|
29
|
+
patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size)
|
|
30
|
+
patched_id_mask_min = patched_id_mask.min(-1).values
|
|
31
|
+
patched_id_mask_max = patched_id_mask.max(-1).values
|
|
32
|
+
assert torch.eq(patched_id_mask_min, patched_id_mask_max).all(), "Patches cannot span multiple datasets"
|
|
33
|
+
return patched_id_mask_min
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PatchEmbedding(torch.nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Multivariate time series patch embedding.
|
|
39
|
+
Patchifies each variate separately.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, patch_size: int, stride: int, embed_dim: int):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.patch_size = patch_size
|
|
45
|
+
self.embed_dim = embed_dim
|
|
46
|
+
self.stride = stride
|
|
47
|
+
self.projection = torch.nn.Linear(self.patch_size, self.embed_dim)
|
|
48
|
+
|
|
49
|
+
def _patchify(self, x: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
return x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
|
|
51
|
+
|
|
52
|
+
def forward(
|
|
53
|
+
self,
|
|
54
|
+
x: torch.Tensor,
|
|
55
|
+
id_mask: torch.Tensor,
|
|
56
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
57
|
+
assert x.shape[-1] % self.patch_size == 0, (
|
|
58
|
+
f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})"
|
|
59
|
+
)
|
|
60
|
+
x_patched: torch.Tensor = self._patchify(x)
|
|
61
|
+
id_mask_patched: torch.Tensor = self._patchify(id_mask)
|
|
62
|
+
|
|
63
|
+
assert torch.eq(id_mask_patched.min(-1).values, id_mask_patched.max(-1).values).all(), (
|
|
64
|
+
"Patches cannot span multiple datasets"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return (
|
|
68
|
+
self.projection(x_patched),
|
|
69
|
+
id_mask_patched.min(-1).values,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TotoBackbone(torch.nn.Module):
|
|
74
|
+
"""
|
|
75
|
+
Toto (Timeseries-Optimized Transformer for Observability) is a transformer-based model for multivariate
|
|
76
|
+
time series forecasting. It applies a patch embedding to the input data, followed by a transformer
|
|
77
|
+
that alternates between time-wise and space-wise attention. The transformer is followed by a linear projection
|
|
78
|
+
that maps the transformer output to the output distribution.
|
|
79
|
+
|
|
80
|
+
The output distribution can be a single distribution (e.g. Gaussian) or a mixture of distributions.
|
|
81
|
+
If a mixture of distributions is used, the model will learn to predict the mixture weights
|
|
82
|
+
as well as the parameters of the individual distributions.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
patch_size
|
|
87
|
+
Size of the patch to use for the patch embedding.
|
|
88
|
+
stride
|
|
89
|
+
Stride to use for the patch embedding.
|
|
90
|
+
embed_dim
|
|
91
|
+
Dimension of the model's latent space.
|
|
92
|
+
num_layers
|
|
93
|
+
Number of transformer layers to use.
|
|
94
|
+
num_heads
|
|
95
|
+
Number of attention heads to use in each self-attention layer.
|
|
96
|
+
mlp_hidden_dim
|
|
97
|
+
Dimension of the hidden layer in the feedforward network.
|
|
98
|
+
dropout
|
|
99
|
+
Dropout rate to use in the model.
|
|
100
|
+
spacewise_every_n_layers
|
|
101
|
+
How many time-wise transformer layers to apply between each space-wise transformer layer.
|
|
102
|
+
spacewise_first
|
|
103
|
+
Whether to apply space-wise attention before time-wise attention.
|
|
104
|
+
scaler_cls
|
|
105
|
+
Class to use for scaling the input data.
|
|
106
|
+
output_distribution_classes
|
|
107
|
+
List of classes to use for the output distribution. If a single class is provided, the model
|
|
108
|
+
will output a single distribution. If multiple classes are provided, the model will output a
|
|
109
|
+
learned mixture of distributions.
|
|
110
|
+
output_distribution_kwargs
|
|
111
|
+
Keyword arguments to pass to the output distribution class. Note: this currently only works
|
|
112
|
+
with a single output distribution class.
|
|
113
|
+
use_memory_efficient_attention:
|
|
114
|
+
Whether to use memory-efficient attention. If True, the model will use the memory-efficient from xFormers.
|
|
115
|
+
stabilize_with_global:
|
|
116
|
+
Whether to use global statistics to stabilize causal statistics by clamping extreme values. Only applies to causal scalers.
|
|
117
|
+
scale_factor_exponent:
|
|
118
|
+
Exponent that controls the allowed range of deviation from global scale for causal scalers.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
patch_size: int,
|
|
124
|
+
stride: int,
|
|
125
|
+
embed_dim: int,
|
|
126
|
+
num_layers: int,
|
|
127
|
+
num_heads: int,
|
|
128
|
+
mlp_hidden_dim: int,
|
|
129
|
+
dropout: float,
|
|
130
|
+
spacewise_every_n_layers: int,
|
|
131
|
+
scaler_cls: str,
|
|
132
|
+
output_distribution_classes: list[str],
|
|
133
|
+
spacewise_first: bool = True,
|
|
134
|
+
output_distribution_kwargs: dict | None = None,
|
|
135
|
+
use_memory_efficient_attention: bool = True,
|
|
136
|
+
stabilize_with_global: bool = True,
|
|
137
|
+
scale_factor_exponent: float = 10.0,
|
|
138
|
+
):
|
|
139
|
+
super().__init__()
|
|
140
|
+
self.embed_dim = embed_dim
|
|
141
|
+
# strings are used when loading a safetensors checkpoint
|
|
142
|
+
# Initialize patch-based scalers with the correct patch_size
|
|
143
|
+
|
|
144
|
+
self.scaler = CausalPatchStdMeanScaler(
|
|
145
|
+
patch_size=patch_size,
|
|
146
|
+
stabilize_with_global=stabilize_with_global,
|
|
147
|
+
scale_factor_exponent=scale_factor_exponent,
|
|
148
|
+
)
|
|
149
|
+
self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim)
|
|
150
|
+
self.dropout = dropout
|
|
151
|
+
self.num_layers = num_layers
|
|
152
|
+
self.use_memory_efficient_attention = use_memory_efficient_attention
|
|
153
|
+
self.transformer = Transformer(
|
|
154
|
+
embed_dim=embed_dim,
|
|
155
|
+
num_heads=num_heads,
|
|
156
|
+
num_layers=self.num_layers,
|
|
157
|
+
mlp_hidden_dim=mlp_hidden_dim,
|
|
158
|
+
dropout=dropout,
|
|
159
|
+
spacewise_every_n_layers=spacewise_every_n_layers,
|
|
160
|
+
spacewise_first=spacewise_first,
|
|
161
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
|
162
|
+
)
|
|
163
|
+
self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size)
|
|
164
|
+
|
|
165
|
+
# TODO[BEN] this doesn't need to be a list
|
|
166
|
+
output_distribution_classes_ = [MixtureOfStudentTsOutput]
|
|
167
|
+
self.output_distribution = output_distribution_classes_[0](embed_dim, **(output_distribution_kwargs or {}))
|
|
168
|
+
|
|
169
|
+
def allocate_kv_cache(
|
|
170
|
+
self,
|
|
171
|
+
batch_size: int,
|
|
172
|
+
num_variates: int,
|
|
173
|
+
max_time_steps: int,
|
|
174
|
+
device: torch.device,
|
|
175
|
+
dtype: torch.dtype,
|
|
176
|
+
) -> KVCache:
|
|
177
|
+
return KVCache(
|
|
178
|
+
batch_size=batch_size,
|
|
179
|
+
num_variates=num_variates,
|
|
180
|
+
transformer_layers=list(self.transformer.layers),
|
|
181
|
+
num_layers=self.num_layers,
|
|
182
|
+
embed_dim=self.embed_dim,
|
|
183
|
+
num_heads=self.transformer.layers[0].num_heads, # type: ignore
|
|
184
|
+
max_seq_len=math.ceil(max_time_steps / self.patch_embed.stride),
|
|
185
|
+
device=device,
|
|
186
|
+
dtype=dtype,
|
|
187
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def backbone(
|
|
191
|
+
self,
|
|
192
|
+
inputs: torch.Tensor,
|
|
193
|
+
input_padding_mask: torch.Tensor,
|
|
194
|
+
id_mask: torch.Tensor,
|
|
195
|
+
kv_cache: KVCache | None = None,
|
|
196
|
+
scaling_prefix_length: int | None = None,
|
|
197
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
198
|
+
scaled_inputs: torch.Tensor
|
|
199
|
+
loc: torch.Tensor
|
|
200
|
+
scale: torch.Tensor
|
|
201
|
+
|
|
202
|
+
# Standard scaling operation, same API but without ID mask.
|
|
203
|
+
scaled_inputs, loc, scale = self.scaler(
|
|
204
|
+
inputs,
|
|
205
|
+
weights=torch.ones_like(inputs, device=inputs.device),
|
|
206
|
+
padding_mask=input_padding_mask,
|
|
207
|
+
prefix_length=scaling_prefix_length,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if kv_cache is not None:
|
|
211
|
+
prefix_len = self.patch_embed.stride * kv_cache.current_len(0)
|
|
212
|
+
|
|
213
|
+
# Truncate inputs so that the transformer only processes
|
|
214
|
+
# the last patch in the sequence. We'll use the KVCache
|
|
215
|
+
# for the earlier patches.
|
|
216
|
+
scaled_inputs = scaled_inputs[:, :, prefix_len:]
|
|
217
|
+
|
|
218
|
+
# As a simplification, when using kv cache we only allow decoding
|
|
219
|
+
# one step at a time after the initial forward pass.
|
|
220
|
+
assert (prefix_len == 0) or (scaled_inputs.shape[-1] == self.patch_embed.stride), (
|
|
221
|
+
"Must decode one step at a time."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
input_padding_mask = input_padding_mask[:, :, prefix_len:]
|
|
225
|
+
id_mask = id_mask[:, :, prefix_len:]
|
|
226
|
+
|
|
227
|
+
embeddings: torch.Tensor
|
|
228
|
+
reduced_id_mask: torch.Tensor
|
|
229
|
+
|
|
230
|
+
embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask)
|
|
231
|
+
|
|
232
|
+
# Apply the transformer on the embeddings
|
|
233
|
+
transformed: torch.Tensor = self.transformer(embeddings, reduced_id_mask, kv_cache)
|
|
234
|
+
|
|
235
|
+
# Unembed and flatten the sequence
|
|
236
|
+
unembedded = self.unembed(transformed)
|
|
237
|
+
batch_size, num_variates, seq_len = unembedded.shape[:3]
|
|
238
|
+
patch_size = unembedded.shape[-1] // self.embed_dim
|
|
239
|
+
flattened = unembedded.view(batch_size, num_variates, seq_len * patch_size, self.embed_dim)
|
|
240
|
+
return flattened, loc, scale
|
|
241
|
+
|
|
242
|
+
def forward(
|
|
243
|
+
self,
|
|
244
|
+
inputs: torch.Tensor,
|
|
245
|
+
input_padding_mask: torch.Tensor,
|
|
246
|
+
id_mask: torch.Tensor,
|
|
247
|
+
kv_cache: KVCache | None = None,
|
|
248
|
+
scaling_prefix_length: int | None = None,
|
|
249
|
+
) -> TotoOutput:
|
|
250
|
+
flattened, loc, scale = self.backbone(
|
|
251
|
+
inputs,
|
|
252
|
+
input_padding_mask,
|
|
253
|
+
id_mask,
|
|
254
|
+
kv_cache,
|
|
255
|
+
scaling_prefix_length,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return TotoOutput(self.output_distribution(flattened), loc, scale)
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def device(self):
|
|
262
|
+
return next(self.parameters()).device
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from abc import ABC
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from gluonts.torch.distributions import AffineTransformed
|
|
11
|
+
from gluonts.torch.distributions.studentT import StudentT
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DistributionOutput(ABC, torch.nn.Module):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StudentTOutput(DistributionOutput):
|
|
19
|
+
def __init__(self, embed_dim):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.embed_dim = embed_dim
|
|
22
|
+
self.df = torch.nn.Linear(embed_dim, 1)
|
|
23
|
+
self.loc_proj = torch.nn.Linear(embed_dim, 1)
|
|
24
|
+
self.scale_proj = torch.nn.Linear(embed_dim, 1)
|
|
25
|
+
|
|
26
|
+
def forward(self, inputs, loc=None, scale=None):
|
|
27
|
+
eps = torch.finfo(inputs.dtype).eps
|
|
28
|
+
df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
|
|
29
|
+
base_loc = self.loc_proj(inputs).squeeze(-1)
|
|
30
|
+
base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
|
|
31
|
+
|
|
32
|
+
base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) # type: ignore
|
|
33
|
+
|
|
34
|
+
if loc is not None and scale is not None:
|
|
35
|
+
return AffineTransformed(
|
|
36
|
+
base_dist,
|
|
37
|
+
loc=loc,
|
|
38
|
+
scale=scale,
|
|
39
|
+
)
|
|
40
|
+
return base_dist
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MixtureOfStudentTsOutput(DistributionOutput):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
embed_dim,
|
|
47
|
+
k_components,
|
|
48
|
+
):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.embed_dim = embed_dim
|
|
51
|
+
self.k_components = k_components
|
|
52
|
+
|
|
53
|
+
self.df = torch.nn.Linear(embed_dim, k_components)
|
|
54
|
+
self.loc_proj = torch.nn.Linear(embed_dim, k_components)
|
|
55
|
+
self.scale_proj = torch.nn.Linear(embed_dim, k_components)
|
|
56
|
+
self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
|
|
57
|
+
|
|
58
|
+
def forward(self, inputs, loc=None, scale=None):
|
|
59
|
+
df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
|
|
60
|
+
loc = self.loc_proj(inputs)
|
|
61
|
+
scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
|
|
62
|
+
logits = self.mixture_weights(inputs)
|
|
63
|
+
probs = F.softmax(logits, dim=-1)
|
|
64
|
+
components = StudentT(df, loc, scale)
|
|
65
|
+
mixture_distribution = torch.distributions.Categorical(probs=probs)
|
|
66
|
+
|
|
67
|
+
return torch.distributions.MixtureSameFamily(
|
|
68
|
+
mixture_distribution,
|
|
69
|
+
components,
|
|
70
|
+
)
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .attention import TimeWiseMultiheadAttention
|
|
11
|
+
|
|
12
|
+
K = torch.Tensor
|
|
13
|
+
V = torch.Tensor
|
|
14
|
+
KV = tuple[torch.Tensor, torch.Tensor]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class KVCache:
|
|
19
|
+
"""
|
|
20
|
+
Key/Value cache for storing intermediate attention values
|
|
21
|
+
during multistep inference. Only stores KV cache for timewise layers, skipping spacewise layers.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
batch_size: int
|
|
25
|
+
num_variates: int
|
|
26
|
+
transformer_layers: list
|
|
27
|
+
num_layers: int
|
|
28
|
+
embed_dim: int
|
|
29
|
+
num_heads: int
|
|
30
|
+
max_seq_len: int
|
|
31
|
+
device: torch.device = torch.device("cpu")
|
|
32
|
+
dtype: torch.dtype = torch.float32
|
|
33
|
+
use_memory_efficient_attention: bool = True
|
|
34
|
+
|
|
35
|
+
_keys: torch.Tensor = field(init=False)
|
|
36
|
+
_values: torch.Tensor = field(init=False)
|
|
37
|
+
_current_idx: torch.Tensor = field(init=False)
|
|
38
|
+
_layer_cache_map: torch.Tensor = field(init=False)
|
|
39
|
+
|
|
40
|
+
def __post_init__(self):
|
|
41
|
+
"""
|
|
42
|
+
- Determine timewise vs. spacewise layers and allocate KV only for timewise.
|
|
43
|
+
- Create a fast tensor-based mapping from global layer_idx -> timewise layer_idx.
|
|
44
|
+
"""
|
|
45
|
+
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
46
|
+
head_dim = self.embed_dim // self.num_heads
|
|
47
|
+
|
|
48
|
+
# Compute which layers are timewise
|
|
49
|
+
time_layer_indices = [
|
|
50
|
+
i
|
|
51
|
+
for i in range(self.num_layers)
|
|
52
|
+
if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention)
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
time_layer_count = max(1, len(time_layer_indices)) # handle edge case for no timewise layers
|
|
56
|
+
# Allocate for only the timewise layers
|
|
57
|
+
if self.use_memory_efficient_attention:
|
|
58
|
+
shape = (
|
|
59
|
+
time_layer_count,
|
|
60
|
+
self.batch_size * self.num_variates,
|
|
61
|
+
self.max_seq_len,
|
|
62
|
+
self.num_heads,
|
|
63
|
+
head_dim,
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
shape = (
|
|
67
|
+
time_layer_count,
|
|
68
|
+
self.batch_size * self.num_variates,
|
|
69
|
+
self.num_heads,
|
|
70
|
+
self.max_seq_len,
|
|
71
|
+
head_dim,
|
|
72
|
+
)
|
|
73
|
+
self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
|
74
|
+
self._values = torch.zeros_like(self._keys)
|
|
75
|
+
self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int)
|
|
76
|
+
# Build a tensor lookup for global -> timewise layer index (default to 0)
|
|
77
|
+
self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device)
|
|
78
|
+
for cache_idx, layer_idx in enumerate(time_layer_indices):
|
|
79
|
+
self._layer_cache_map[layer_idx] = int(cache_idx) # Assign correct indices
|
|
80
|
+
|
|
81
|
+
def __getitem__(self, layer_idx: int) -> KV:
|
|
82
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
83
|
+
end_idx = int(self._current_idx[cache_idx].item())
|
|
84
|
+
|
|
85
|
+
if self.use_memory_efficient_attention:
|
|
86
|
+
return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :]
|
|
87
|
+
else:
|
|
88
|
+
return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :]
|
|
89
|
+
|
|
90
|
+
def current_len(self, cache_idx: int) -> int:
|
|
91
|
+
return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0
|
|
92
|
+
|
|
93
|
+
def seq_len(self, layer_idx: int) -> int:
|
|
94
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
95
|
+
return self.current_len(cache_idx)
|
|
96
|
+
|
|
97
|
+
def append(self, layer_idx: int, kv: KV):
|
|
98
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
99
|
+
keys, values = kv
|
|
100
|
+
|
|
101
|
+
# Validate dimensions
|
|
102
|
+
assert keys.shape == values.shape, "keys and values must have the same shape"
|
|
103
|
+
assert keys.shape[0] == self.batch_size * self.num_variates, (
|
|
104
|
+
"keys and values must have batch_size * num_variates as their first dimension"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if self.use_memory_efficient_attention:
|
|
108
|
+
assert keys.shape[2] == self.num_heads, "keys and values must have num_heads as their third dimension"
|
|
109
|
+
else:
|
|
110
|
+
assert keys.shape[1] == self.num_heads, "keys and values must have num_heads as their second dimension"
|
|
111
|
+
assert keys.shape[3] == self.embed_dim // self.num_heads, (
|
|
112
|
+
"keys and values must have head_dim as their fourth dimension"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
start_idx = self._current_idx[cache_idx]
|
|
116
|
+
if self.use_memory_efficient_attention:
|
|
117
|
+
end_idx = start_idx + keys.shape[1]
|
|
118
|
+
else:
|
|
119
|
+
end_idx = start_idx + keys.shape[2]
|
|
120
|
+
assert end_idx <= self.max_seq_len, (
|
|
121
|
+
f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if self.use_memory_efficient_attention:
|
|
125
|
+
self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
|
|
126
|
+
self._values[cache_idx, :, start_idx:end_idx, :, :] = values
|
|
127
|
+
else:
|
|
128
|
+
self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
|
|
129
|
+
self._values[cache_idx, :, :, start_idx:end_idx, :] = values
|
|
130
|
+
|
|
131
|
+
self._current_idx[cache_idx] = end_idx
|
|
132
|
+
|
|
133
|
+
def reset(self):
|
|
134
|
+
self._keys.zero_()
|
|
135
|
+
self._values.zero_()
|
|
136
|
+
self._current_idx.zero_()
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
|
|
10
|
+
from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb, default
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TimeAwareRotaryEmbedding(RotaryEmbedding):
|
|
14
|
+
"""
|
|
15
|
+
A variant of the rotary position embedding that (optionally) uses the time index
|
|
16
|
+
to compute the sinusoidal and cosine embeddings. This is useful for
|
|
17
|
+
time series data, where the time index is the most important positional
|
|
18
|
+
information.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, *args, **kwargs):
|
|
22
|
+
super().__init__(*args, **kwargs)
|
|
23
|
+
# If the parent stored `freqs` as a Parameter, remove it and register as a buffer
|
|
24
|
+
# Register buffer is needed for sharding with FSDP
|
|
25
|
+
if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter):
|
|
26
|
+
# Extract the underlying Tensor
|
|
27
|
+
freqs_data = self.freqs.data
|
|
28
|
+
|
|
29
|
+
# Remove `freqs` from the module's parameters
|
|
30
|
+
self._parameters.pop("freqs")
|
|
31
|
+
|
|
32
|
+
# Register as non-persistent buffer
|
|
33
|
+
self.register_buffer("freqs", freqs_data, persistent=False)
|
|
34
|
+
|
|
35
|
+
def rotate_queries_and_keys(
|
|
36
|
+
self,
|
|
37
|
+
q: torch.Tensor,
|
|
38
|
+
k: torch.Tensor,
|
|
39
|
+
seq_dim: int | None = None,
|
|
40
|
+
seq_pos: torch.Tensor | None = None,
|
|
41
|
+
seq_pos_offset: int = 0,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
This method is the same as the one on the base class, except it allows you to override
|
|
45
|
+
the sequence position tensor with a custom one. It also removes the ability
|
|
46
|
+
to cache the position encodings, since we have to compute them dynamically
|
|
47
|
+
based on the timesteps in the input data.
|
|
48
|
+
"""
|
|
49
|
+
if seq_dim is None:
|
|
50
|
+
seq_dim = self.default_seq_dim
|
|
51
|
+
|
|
52
|
+
assert self.use_xpos
|
|
53
|
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
|
54
|
+
|
|
55
|
+
seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device))
|
|
56
|
+
seq = seq + seq_pos_offset # type: ignore
|
|
57
|
+
|
|
58
|
+
freqs = self.forward(seq)
|
|
59
|
+
|
|
60
|
+
scale = self.get_scale(seq).to(dtype)
|
|
61
|
+
|
|
62
|
+
# used for xformers
|
|
63
|
+
if seq_dim == -3:
|
|
64
|
+
num_heads = q.shape[-2]
|
|
65
|
+
freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
|
|
66
|
+
scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
|
|
67
|
+
|
|
68
|
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) # type: ignore
|
|
69
|
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) # type: ignore
|
|
70
|
+
|
|
71
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
72
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
73
|
+
|
|
74
|
+
return rotated_q, rotated_k
|
|
75
|
+
|
|
76
|
+
def get_scale(self, t: torch.Tensor, seq_len: int | None = None, offset=0):
|
|
77
|
+
"""
|
|
78
|
+
This method is adapted closely from the base class, but it knows how to handle
|
|
79
|
+
when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different
|
|
80
|
+
sequence position vector for each time series).
|
|
81
|
+
"""
|
|
82
|
+
assert self.use_xpos
|
|
83
|
+
|
|
84
|
+
power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
|
|
85
|
+
|
|
86
|
+
scale = self.scale ** rearrange(power, "... n -> ... n 1") # type: ignore
|
|
87
|
+
scale = torch.cat((scale, scale), dim=-1)
|
|
88
|
+
|
|
89
|
+
return scale
|