autogluon.timeseries 1.4.1b20250923__py3-none-any.whl → 1.4.1b20250929__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/timeseries/models/__init__.py +2 -0
- autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
- autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
- autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
- autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
- autogluon/timeseries/models/toto/model.py +234 -0
- autogluon/timeseries/version.py +1 -1
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/METADATA +10 -5
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/RECORD +26 -11
- /autogluon.timeseries-1.4.1b20250923-py3.9-nspkg.pth → /autogluon.timeseries-1.4.1b20250929-py3.9-nspkg.pth +0 -0
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/LICENSE +0 -0
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/NOTICE +0 -0
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/WHEEL +0 -0
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/top_level.txt +0 -0
- {autogluon.timeseries-1.4.1b20250923.dist-info → autogluon.timeseries-1.4.1b20250929.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from .attention import TimeWiseMultiheadAttention
|
|
11
|
+
|
|
12
|
+
K = torch.Tensor
|
|
13
|
+
V = torch.Tensor
|
|
14
|
+
KV = tuple[torch.Tensor, torch.Tensor]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class KVCache:
|
|
19
|
+
"""
|
|
20
|
+
Key/Value cache for storing intermediate attention values
|
|
21
|
+
during multistep inference. Only stores KV cache for timewise layers, skipping spacewise layers.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
batch_size: int
|
|
25
|
+
num_variates: int
|
|
26
|
+
transformer_layers: list
|
|
27
|
+
num_layers: int
|
|
28
|
+
embed_dim: int
|
|
29
|
+
num_heads: int
|
|
30
|
+
max_seq_len: int
|
|
31
|
+
device: torch.device = torch.device("cpu")
|
|
32
|
+
dtype: torch.dtype = torch.float32
|
|
33
|
+
use_memory_efficient_attention: bool = True
|
|
34
|
+
|
|
35
|
+
_keys: torch.Tensor = field(init=False)
|
|
36
|
+
_values: torch.Tensor = field(init=False)
|
|
37
|
+
_current_idx: torch.Tensor = field(init=False)
|
|
38
|
+
_layer_cache_map: torch.Tensor = field(init=False)
|
|
39
|
+
|
|
40
|
+
def __post_init__(self):
|
|
41
|
+
"""
|
|
42
|
+
- Determine timewise vs. spacewise layers and allocate KV only for timewise.
|
|
43
|
+
- Create a fast tensor-based mapping from global layer_idx -> timewise layer_idx.
|
|
44
|
+
"""
|
|
45
|
+
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
46
|
+
head_dim = self.embed_dim // self.num_heads
|
|
47
|
+
|
|
48
|
+
# Compute which layers are timewise
|
|
49
|
+
time_layer_indices = [
|
|
50
|
+
i
|
|
51
|
+
for i in range(self.num_layers)
|
|
52
|
+
if isinstance(self.transformer_layers[i].attention, TimeWiseMultiheadAttention)
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
time_layer_count = max(1, len(time_layer_indices)) # handle edge case for no timewise layers
|
|
56
|
+
# Allocate for only the timewise layers
|
|
57
|
+
if self.use_memory_efficient_attention:
|
|
58
|
+
shape = (
|
|
59
|
+
time_layer_count,
|
|
60
|
+
self.batch_size * self.num_variates,
|
|
61
|
+
self.max_seq_len,
|
|
62
|
+
self.num_heads,
|
|
63
|
+
head_dim,
|
|
64
|
+
)
|
|
65
|
+
else:
|
|
66
|
+
shape = (
|
|
67
|
+
time_layer_count,
|
|
68
|
+
self.batch_size * self.num_variates,
|
|
69
|
+
self.num_heads,
|
|
70
|
+
self.max_seq_len,
|
|
71
|
+
head_dim,
|
|
72
|
+
)
|
|
73
|
+
self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
|
74
|
+
self._values = torch.zeros_like(self._keys)
|
|
75
|
+
self._current_idx = torch.zeros(time_layer_count, device=self.device, dtype=torch.int)
|
|
76
|
+
# Build a tensor lookup for global -> timewise layer index (default to 0)
|
|
77
|
+
self._layer_cache_map = torch.zeros((self.num_layers,), dtype=torch.int, device=self.device)
|
|
78
|
+
for cache_idx, layer_idx in enumerate(time_layer_indices):
|
|
79
|
+
self._layer_cache_map[layer_idx] = int(cache_idx) # Assign correct indices
|
|
80
|
+
|
|
81
|
+
def __getitem__(self, layer_idx: int) -> KV:
|
|
82
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
83
|
+
end_idx = int(self._current_idx[cache_idx].item())
|
|
84
|
+
|
|
85
|
+
if self.use_memory_efficient_attention:
|
|
86
|
+
return self._keys[cache_idx, :, :end_idx, :, :], self._values[cache_idx, :, :end_idx, :, :]
|
|
87
|
+
else:
|
|
88
|
+
return self._keys[cache_idx, :, :, :end_idx, :], self._values[cache_idx, :, :, :end_idx, :]
|
|
89
|
+
|
|
90
|
+
def current_len(self, cache_idx: int) -> int:
|
|
91
|
+
return int(self._current_idx[cache_idx].item()) if self._current_idx.numel() > 0 else 0
|
|
92
|
+
|
|
93
|
+
def seq_len(self, layer_idx: int) -> int:
|
|
94
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
95
|
+
return self.current_len(cache_idx)
|
|
96
|
+
|
|
97
|
+
def append(self, layer_idx: int, kv: KV):
|
|
98
|
+
cache_idx = int(self._layer_cache_map[layer_idx].item())
|
|
99
|
+
keys, values = kv
|
|
100
|
+
|
|
101
|
+
# Validate dimensions
|
|
102
|
+
assert keys.shape == values.shape, "keys and values must have the same shape"
|
|
103
|
+
assert keys.shape[0] == self.batch_size * self.num_variates, (
|
|
104
|
+
"keys and values must have batch_size * num_variates as their first dimension"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if self.use_memory_efficient_attention:
|
|
108
|
+
assert keys.shape[2] == self.num_heads, "keys and values must have num_heads as their third dimension"
|
|
109
|
+
else:
|
|
110
|
+
assert keys.shape[1] == self.num_heads, "keys and values must have num_heads as their second dimension"
|
|
111
|
+
assert keys.shape[3] == self.embed_dim // self.num_heads, (
|
|
112
|
+
"keys and values must have head_dim as their fourth dimension"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
start_idx = self._current_idx[cache_idx]
|
|
116
|
+
if self.use_memory_efficient_attention:
|
|
117
|
+
end_idx = start_idx + keys.shape[1]
|
|
118
|
+
else:
|
|
119
|
+
end_idx = start_idx + keys.shape[2]
|
|
120
|
+
assert end_idx <= self.max_seq_len, (
|
|
121
|
+
f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape: {keys.shape}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if self.use_memory_efficient_attention:
|
|
125
|
+
self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
|
|
126
|
+
self._values[cache_idx, :, start_idx:end_idx, :, :] = values
|
|
127
|
+
else:
|
|
128
|
+
self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
|
|
129
|
+
self._values[cache_idx, :, :, start_idx:end_idx, :] = values
|
|
130
|
+
|
|
131
|
+
self._current_idx[cache_idx] = end_idx
|
|
132
|
+
|
|
133
|
+
def reset(self):
|
|
134
|
+
self._keys.zero_()
|
|
135
|
+
self._values.zero_()
|
|
136
|
+
self._current_idx.zero_()
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from einops import rearrange
|
|
10
|
+
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
|
11
|
+
from rotary_embedding_torch.rotary_embedding_torch import default
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def exists(val):
|
|
15
|
+
return val is not None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TimeAwareRotaryEmbedding(RotaryEmbedding):
|
|
19
|
+
"""
|
|
20
|
+
A variant of the rotary position embedding that (optionally) uses the time index
|
|
21
|
+
to compute the sinusoidal and cosine embeddings. This is useful for
|
|
22
|
+
time series data, where the time index is the most important positional
|
|
23
|
+
information.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, *args, **kwargs):
|
|
27
|
+
super().__init__(*args, **kwargs)
|
|
28
|
+
# If the parent stored `freqs` as a Parameter, remove it and register as a buffer
|
|
29
|
+
# Register buffer is needed for sharding with FSDP
|
|
30
|
+
if hasattr(self, "freqs") and isinstance(self.freqs, torch.nn.Parameter):
|
|
31
|
+
# Extract the underlying Tensor
|
|
32
|
+
freqs_data = self.freqs.data
|
|
33
|
+
|
|
34
|
+
# Remove `freqs` from the module's parameters
|
|
35
|
+
self._parameters.pop("freqs")
|
|
36
|
+
|
|
37
|
+
# Register as non-persistent buffer
|
|
38
|
+
self.register_buffer("freqs", freqs_data, persistent=False)
|
|
39
|
+
|
|
40
|
+
def rotate_queries_and_keys(
|
|
41
|
+
self,
|
|
42
|
+
q: torch.Tensor,
|
|
43
|
+
k: torch.Tensor,
|
|
44
|
+
seq_dim: Optional[int] = None,
|
|
45
|
+
seq_pos: Optional[torch.Tensor] = None,
|
|
46
|
+
seq_pos_offset: int = 0,
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
This method is the same as the one on the base class, except it allows you to override
|
|
50
|
+
the sequence position tensor with a custom one. It also removes the ability
|
|
51
|
+
to cache the position encodings, since we have to compute them dynamically
|
|
52
|
+
based on the timesteps in the input data.
|
|
53
|
+
"""
|
|
54
|
+
if seq_dim is None:
|
|
55
|
+
seq_dim = self.default_seq_dim
|
|
56
|
+
|
|
57
|
+
assert self.use_xpos
|
|
58
|
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
|
59
|
+
|
|
60
|
+
seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype, device=device))
|
|
61
|
+
seq = seq + seq_pos_offset # type: ignore
|
|
62
|
+
|
|
63
|
+
freqs = self.forward(seq)
|
|
64
|
+
|
|
65
|
+
scale = self.get_scale(seq).to(dtype)
|
|
66
|
+
|
|
67
|
+
# used for xformers
|
|
68
|
+
if seq_dim == -3:
|
|
69
|
+
num_heads = q.shape[-2]
|
|
70
|
+
freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
|
|
71
|
+
scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
|
|
72
|
+
|
|
73
|
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) # type: ignore
|
|
74
|
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) # type: ignore
|
|
75
|
+
|
|
76
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
77
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
78
|
+
|
|
79
|
+
return rotated_q, rotated_k
|
|
80
|
+
|
|
81
|
+
def get_scale(self, t: torch.Tensor, seq_len: Optional[int] = None, offset=0):
|
|
82
|
+
"""
|
|
83
|
+
This method is adapted closely from the base class, but it knows how to handle
|
|
84
|
+
when `t` has more than 1 dim (as is the case when we're using time-aware RoPE, and have a different
|
|
85
|
+
sequence position vector for each time series).
|
|
86
|
+
"""
|
|
87
|
+
assert self.use_xpos
|
|
88
|
+
|
|
89
|
+
power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
|
|
90
|
+
|
|
91
|
+
scale = self.scale ** rearrange(power, "... n -> ... n 1") # type: ignore
|
|
92
|
+
scale = torch.cat((scale, scale), dim=-1)
|
|
93
|
+
|
|
94
|
+
return scale
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from einops import repeat
|
|
11
|
+
from gluonts.core.component import validated
|
|
12
|
+
from gluonts.torch.scaler import Scaler
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compute_causal_statistics(
|
|
16
|
+
data: torch.Tensor,
|
|
17
|
+
weights: torch.Tensor,
|
|
18
|
+
padding_mask: torch.Tensor,
|
|
19
|
+
dim: int,
|
|
20
|
+
minimum_scale: float,
|
|
21
|
+
use_bessel_correction: bool = True,
|
|
22
|
+
stabilize_with_global: bool = False,
|
|
23
|
+
scale_factor_exponent: float = 10.0,
|
|
24
|
+
prefix_length: Optional[int] = None,
|
|
25
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
26
|
+
"""
|
|
27
|
+
Compute causal mean and scale statistics along a specified dimension using
|
|
28
|
+
a vectorized implementation of Welford's algorithm for numerical stability.
|
|
29
|
+
|
|
30
|
+
This implementation avoids explicit loops while maintaining the numerical stability
|
|
31
|
+
of Welford's algorithm, achieving better performance with the same robustness
|
|
32
|
+
against overflow issues.
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
Can optionally use global statistics to stabilize causal statistics by clamping
|
|
36
|
+
extreme values, preventing instability while preserving a relaxed version of the
|
|
37
|
+
causal property. This allows a controlled amount of future information leakage,
|
|
38
|
+
introducing an explicit tradeoff between causality and stability.
|
|
39
|
+
extreme values, preventing instability while preserving the causal property.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
data
|
|
44
|
+
The input data tensor
|
|
45
|
+
weights
|
|
46
|
+
The weight tensor (same shape as data)
|
|
47
|
+
padding_mask
|
|
48
|
+
The padding mask tensor (same shape as data)
|
|
49
|
+
dim
|
|
50
|
+
The dimension along which to compute statistics (must be -1, the time dimension)
|
|
51
|
+
minimum_scale
|
|
52
|
+
Minimum scale value to use
|
|
53
|
+
use_bessel_correction
|
|
54
|
+
Whether to use Bessel's correction to get an unbiased estimator
|
|
55
|
+
stabilize_with_global
|
|
56
|
+
Whether to use global statistics to stabilize the causal statistics by clamping
|
|
57
|
+
extreme values
|
|
58
|
+
scale_factor_exponent
|
|
59
|
+
Exponent that controls the allowed range of deviation from global scale.
|
|
60
|
+
For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
|
|
61
|
+
With exponent=2.0, the range would be 0.01x to 100x.
|
|
62
|
+
prefix_length
|
|
63
|
+
If specified, the global statistics will be computed using only the prefix length
|
|
64
|
+
requested. This is used for multistep decoding, where we only want to use the
|
|
65
|
+
initial historical data to compute the global statistics. If stabilize_with_global
|
|
66
|
+
is False, this parameter is ignored.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
tuple[torch.Tensor, torch.Tensor]
|
|
71
|
+
Causal mean and scale tensors, potentially stabilized with global statistics
|
|
72
|
+
"""
|
|
73
|
+
# Assert that dim is -1 (last dimension)
|
|
74
|
+
assert dim == -1, "compute_causal_statistics only supports dim=-1 (last dimension)"
|
|
75
|
+
|
|
76
|
+
with torch.no_grad():
|
|
77
|
+
# Apply padding mask to weights
|
|
78
|
+
weights = weights * padding_mask
|
|
79
|
+
|
|
80
|
+
# Try to use higher precision for numerical stability
|
|
81
|
+
try:
|
|
82
|
+
high_precision_data = data.to(torch.float64)
|
|
83
|
+
high_precision_weights = weights.to(torch.float64)
|
|
84
|
+
except TypeError:
|
|
85
|
+
# Fallback for devices that don't support float64
|
|
86
|
+
warnings.warn(
|
|
87
|
+
f"Float64 is not supported by device {data.device}. "
|
|
88
|
+
"Using float32 instead for causal scaler calculations. "
|
|
89
|
+
"This may lead to numerical issues if the data contains extreme values.",
|
|
90
|
+
RuntimeWarning,
|
|
91
|
+
)
|
|
92
|
+
high_precision_data = data.to(torch.float32)
|
|
93
|
+
high_precision_weights = weights.to(torch.float32)
|
|
94
|
+
|
|
95
|
+
# Check if deterministic algorithms are enabled and we're using CUDA.
|
|
96
|
+
# Cumsum operations do not support deterministic mode in CUDA,
|
|
97
|
+
# so we need to disable it for just this section.
|
|
98
|
+
prev_deterministic = torch.are_deterministic_algorithms_enabled()
|
|
99
|
+
if prev_deterministic and data.device.type == "cuda":
|
|
100
|
+
# Disable deterministic algorithms for operations
|
|
101
|
+
torch.use_deterministic_algorithms(False)
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
# Create weighted data
|
|
105
|
+
weighted_data = high_precision_weights * high_precision_data
|
|
106
|
+
|
|
107
|
+
# Compute cumulative sum of weights and weighted data along time dimension
|
|
108
|
+
cum_weights = torch.cumsum(high_precision_weights, dim=dim)
|
|
109
|
+
cum_values = torch.cumsum(weighted_data, dim=dim)
|
|
110
|
+
|
|
111
|
+
# Avoid division by zero for the first time step or when no valid values
|
|
112
|
+
denominator = cum_weights.clamp_min(1.0)
|
|
113
|
+
|
|
114
|
+
# Compute causal means at each time step
|
|
115
|
+
causal_means = cum_values / denominator
|
|
116
|
+
|
|
117
|
+
# For Welford's algorithm, we need to compute the correction term
|
|
118
|
+
# using the difference between the current value and the current mean
|
|
119
|
+
|
|
120
|
+
# Create shifted version of causal means to compute delta efficiently
|
|
121
|
+
# First item in shifted_means will be zero
|
|
122
|
+
shifted_means = torch.zeros_like(causal_means)
|
|
123
|
+
shifted_means[..., 1:] = causal_means[..., :-1]
|
|
124
|
+
|
|
125
|
+
# Compute delta between current data point and previous mean
|
|
126
|
+
# For t=0, this is just the first data point
|
|
127
|
+
delta = high_precision_data - shifted_means
|
|
128
|
+
|
|
129
|
+
# Compute the increment term for Welford's algorithm.
|
|
130
|
+
# This is defined as the product of the delta and the difference between the current data point and the causal mean.
|
|
131
|
+
# This is where we avoid the traditional E[X²] - E[X]² computation
|
|
132
|
+
increment = delta * (high_precision_data - causal_means) * high_precision_weights
|
|
133
|
+
|
|
134
|
+
# The Welford algorithm uses the term m_2, which is the cumulative sum of the increment term.
|
|
135
|
+
# This is an accumulator that helps us compute the second moment (hence m_2) of the distribution.
|
|
136
|
+
# Compute cumulative sum of the increment term
|
|
137
|
+
m_2 = torch.cumsum(increment, dim=dim)
|
|
138
|
+
|
|
139
|
+
# Compute variance according to Welford's algorithm
|
|
140
|
+
if use_bessel_correction:
|
|
141
|
+
causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0)
|
|
142
|
+
else:
|
|
143
|
+
causal_variance = m_2 / denominator
|
|
144
|
+
|
|
145
|
+
# Add minimum scale but keep in high precision for now
|
|
146
|
+
causal_scale = torch.sqrt(causal_variance + minimum_scale)
|
|
147
|
+
|
|
148
|
+
# Apply stabilization with global statistics if requested
|
|
149
|
+
if stabilize_with_global:
|
|
150
|
+
if prefix_length is not None:
|
|
151
|
+
# Create a prefix mask for global statistics computation
|
|
152
|
+
prefix_mask = torch.zeros_like(weights)
|
|
153
|
+
prefix_mask[..., :prefix_length] = 1.0
|
|
154
|
+
|
|
155
|
+
# Apply prefix mask to restrict computation to prefix
|
|
156
|
+
weighted_data = weighted_data * prefix_mask
|
|
157
|
+
weights = weights * prefix_mask
|
|
158
|
+
padding_mask = padding_mask * prefix_mask
|
|
159
|
+
|
|
160
|
+
# Calculate scale factors from the exponent
|
|
161
|
+
scale_factor_min = 10.0 ** (-scale_factor_exponent)
|
|
162
|
+
scale_factor_max = 10.0**scale_factor_exponent
|
|
163
|
+
|
|
164
|
+
global_denominator = (weights * padding_mask).sum(dim, keepdim=True).clamp_min(1.0)
|
|
165
|
+
global_means = (weighted_data).sum(dim, keepdim=True) / global_denominator
|
|
166
|
+
global_means = torch.nan_to_num(global_means)
|
|
167
|
+
|
|
168
|
+
global_variance = (((high_precision_data - global_means) * weights * padding_mask) ** 2).sum(
|
|
169
|
+
dim, keepdim=True
|
|
170
|
+
) / global_denominator
|
|
171
|
+
global_scale = torch.sqrt(global_variance + minimum_scale)
|
|
172
|
+
|
|
173
|
+
# Expand global statistics to match the time dimension
|
|
174
|
+
expanded_global_scale = global_scale.expand_as(causal_scale)
|
|
175
|
+
|
|
176
|
+
# Define bounds using scale factors
|
|
177
|
+
min_allowed_scale = expanded_global_scale * scale_factor_min
|
|
178
|
+
max_allowed_scale = expanded_global_scale * scale_factor_max
|
|
179
|
+
|
|
180
|
+
# Clamp the causal scale between min_allowed_scale and max_allowed_scale
|
|
181
|
+
causal_scale = torch.clamp(
|
|
182
|
+
causal_scale,
|
|
183
|
+
min=torch.max(torch.tensor(minimum_scale, device=causal_scale.device), min_allowed_scale),
|
|
184
|
+
max=max_allowed_scale,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Now convert means and scale to original dtype after all numerical operations
|
|
188
|
+
causal_means = causal_means.to(data.dtype)
|
|
189
|
+
causal_scale = causal_scale.to(data.dtype)
|
|
190
|
+
|
|
191
|
+
finally:
|
|
192
|
+
# Restore original deterministic setting if it was changed
|
|
193
|
+
if prev_deterministic and data.device.type == "cuda":
|
|
194
|
+
torch.use_deterministic_algorithms(True)
|
|
195
|
+
|
|
196
|
+
return causal_means, causal_scale
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class CausalPatchStdMeanScaler(Scaler):
|
|
200
|
+
"""
|
|
201
|
+
Causally scales data in patches, where each patch uses statistics computed
|
|
202
|
+
from all data up to and including that patch. Within each patch, all timesteps
|
|
203
|
+
use the same scaling values.
|
|
204
|
+
|
|
205
|
+
This approach provides more stability than per-timestep causal scaling while
|
|
206
|
+
still maintaining the causal property (not using future data).
|
|
207
|
+
|
|
208
|
+
Can optionally stabilize causal statistics using global statistics to prevent
|
|
209
|
+
extreme values, while preserving the causal property.
|
|
210
|
+
|
|
211
|
+
The statistics are computed using Welford's algorithm, which provides better
|
|
212
|
+
numerical stability compared to the direct computation of variance, especially
|
|
213
|
+
when dealing with large values or a large number of data points.
|
|
214
|
+
|
|
215
|
+
Note: This scaler only works with the following constraints:
|
|
216
|
+
- The input must have shape [batch, variates, time_steps]
|
|
217
|
+
- It only operates on the last dimension (-1)
|
|
218
|
+
- The time_steps must be divisible by patch_size
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
dim
|
|
223
|
+
dimension along which to compute the causal scale. Must be -1 (the last dimension).
|
|
224
|
+
patch_size
|
|
225
|
+
number of timesteps in each patch
|
|
226
|
+
minimum_scale
|
|
227
|
+
default scale that is used for elements that are constantly zero
|
|
228
|
+
along dimension `dim` or for the first patch.
|
|
229
|
+
use_bessel_correction
|
|
230
|
+
whether to use Bessel's correction to get an unbiased estimator
|
|
231
|
+
stabilize_with_global
|
|
232
|
+
whether to use global statistics to stabilize extreme causal statistics
|
|
233
|
+
scale_factor_exponent
|
|
234
|
+
exponent that controls the allowed range of deviation from global scale.
|
|
235
|
+
For example, with exponent=1.0, causal scale must be between 0.1x and 10x the global scale.
|
|
236
|
+
With exponent=2.0, the range would be 0.01x to 100x.
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
@validated()
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
dim: int = -1,
|
|
243
|
+
patch_size: int = 32,
|
|
244
|
+
minimum_scale: float = 0.1,
|
|
245
|
+
use_bessel_correction: bool = True,
|
|
246
|
+
stabilize_with_global: bool = False,
|
|
247
|
+
scale_factor_exponent: float = 10.0,
|
|
248
|
+
) -> None:
|
|
249
|
+
super().__init__()
|
|
250
|
+
assert dim == -1, "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)"
|
|
251
|
+
self.dim = dim
|
|
252
|
+
self.patch_size = patch_size
|
|
253
|
+
self.minimum_scale = minimum_scale
|
|
254
|
+
self.use_bessel_correction = use_bessel_correction
|
|
255
|
+
self.stabilize_with_global = stabilize_with_global
|
|
256
|
+
self.scale_factor_exponent = scale_factor_exponent
|
|
257
|
+
|
|
258
|
+
def __call__( # type: ignore[override]
|
|
259
|
+
self,
|
|
260
|
+
data: torch.Tensor,
|
|
261
|
+
padding_mask: torch.Tensor,
|
|
262
|
+
weights: torch.Tensor,
|
|
263
|
+
prefix_length: Optional[int] = None,
|
|
264
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
265
|
+
assert data.shape == weights.shape, "data and weights must have same shape"
|
|
266
|
+
assert len(data.shape) == 3, "Input data must have shape [batch, variates, time_steps]"
|
|
267
|
+
|
|
268
|
+
with torch.no_grad():
|
|
269
|
+
# Get the number of time steps (last dimension)
|
|
270
|
+
time_steps = data.shape[-1]
|
|
271
|
+
|
|
272
|
+
# Assert that time_steps is divisible by patch_size
|
|
273
|
+
assert time_steps % self.patch_size == 0, (
|
|
274
|
+
f"Time steps ({time_steps}) must be divisible by patch size ({self.patch_size})"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# First compute causal statistics with optional stabilization
|
|
278
|
+
causal_means, causal_scale = compute_causal_statistics(
|
|
279
|
+
data,
|
|
280
|
+
weights,
|
|
281
|
+
padding_mask,
|
|
282
|
+
-1,
|
|
283
|
+
self.minimum_scale,
|
|
284
|
+
self.use_bessel_correction,
|
|
285
|
+
self.stabilize_with_global,
|
|
286
|
+
self.scale_factor_exponent,
|
|
287
|
+
prefix_length,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Unfold the causal means and scales to get the patches
|
|
291
|
+
means_unfolded = causal_means.unfold(-1, self.patch_size, self.patch_size)
|
|
292
|
+
scales_unfolded = causal_scale.unfold(-1, self.patch_size, self.patch_size)
|
|
293
|
+
|
|
294
|
+
# Get the last element of each patch (the most recent statistic)
|
|
295
|
+
patch_stats_means = means_unfolded[..., -1]
|
|
296
|
+
patch_stats_scales = scales_unfolded[..., -1]
|
|
297
|
+
|
|
298
|
+
# Tile the patch statistics across time dimension using einops.repeat
|
|
299
|
+
# With our fixed [batch, variates, num_patches] shape this is much simpler
|
|
300
|
+
patch_means = repeat(patch_stats_means, "b v p -> b v (p s)", s=self.patch_size)
|
|
301
|
+
patch_scales = repeat(patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size)
|
|
302
|
+
|
|
303
|
+
# Apply normalization
|
|
304
|
+
scaled_data = (data - patch_means) / patch_scales
|
|
305
|
+
|
|
306
|
+
return scaled_data, patch_means, patch_scales
|