autogluon.timeseries 1.4.1b20250924__tar.gz → 1.4.1b20251001__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.timeseries might be problematic. Click here for more details.
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/PKG-INFO +2 -1
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/setup.py +5 -1
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/__init__.py +2 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/__init__.py +3 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/__init__.py +9 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/__init__.py +3 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/attention.py +197 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/backbone.py +262 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/distribution.py +70 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/kvcache.py +136 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/rope.py +94 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/scaler.py +306 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/backbone/transformer.py +333 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/dataset.py +165 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/_internal/forecaster.py +423 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/dataloader.py +108 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/hf_pretrained_model.py +119 -0
- autogluon.timeseries-1.4.1b20251001/src/autogluon/timeseries/models/toto/model.py +234 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/version.py +1 -1
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon.timeseries.egg-info/PKG-INFO +2 -1
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon.timeseries.egg-info/SOURCES.txt +15 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon.timeseries.egg-info/requires.txt +10 -4
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/setup.cfg +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/configs/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/configs/hyperparameter_presets.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/configs/predictor_presets.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/dataset/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/dataset/ts_dataframe.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/evaluator.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/learner.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/metrics/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/metrics/abstract.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/metrics/point.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/metrics/quantile.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/metrics/utils.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/abstract/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/abstract/abstract_timeseries_model.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/abstract/model_trial.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/abstract/tunable.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/autogluon_tabular/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/autogluon_tabular/mlforecast.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/autogluon_tabular/per_step.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/autogluon_tabular/transforms.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/autogluon_tabular/utils.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/chronos/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/chronos/model.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/chronos/pipeline/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/chronos/pipeline/base.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/chronos/pipeline/chronos.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/chronos/pipeline/chronos_bolt.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/chronos/pipeline/utils.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/ensemble/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/ensemble/abstract.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/ensemble/basic.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/ensemble/greedy.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/gluonts/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/gluonts/abstract.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/gluonts/dataset.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/gluonts/models.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/local/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/local/abstract_local_model.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/local/naive.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/local/npts.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/local/statsforecast.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/multi_window/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/multi_window/multi_window_model.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/models/registry.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/predictor.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/regressor.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/splitter.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/trainer/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/trainer/model_set_builder.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/trainer/prediction_cache.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/trainer/trainer.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/transforms/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/transforms/covariate_scaler.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/transforms/target_scaler.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/datetime/__init__.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/datetime/base.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/datetime/lags.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/datetime/seasonality.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/datetime/time_features.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/features.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/forecast.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon/timeseries/utils/warning_filters.py +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon.timeseries.egg-info/dependency_links.txt +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon.timeseries.egg-info/namespace_packages.txt +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon.timeseries.egg-info/top_level.txt +0 -0
- {autogluon.timeseries-1.4.1b20250924 → autogluon.timeseries-1.4.1b20251001}/src/autogluon.timeseries.egg-info/zip-safe +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: autogluon.timeseries
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.1b20251001
|
|
4
4
|
Summary: Fast and Accurate ML in 3 Lines of Code
|
|
5
5
|
Home-page: https://github.com/autogluon/autogluon
|
|
6
6
|
Author: AutoGluon Community
|
|
@@ -35,6 +35,7 @@ Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
|
35
35
|
Requires-Python: >=3.9, <3.13
|
|
36
36
|
Description-Content-Type: text/markdown
|
|
37
37
|
Provides-Extra: tests
|
|
38
|
+
Provides-Extra: toto
|
|
38
39
|
Provides-Extra: all
|
|
39
40
|
License-File: ../LICENSE
|
|
40
41
|
License-File: ../NOTICE
|
|
@@ -54,10 +54,14 @@ extras_require = {
|
|
|
54
54
|
"flaky>=3.7,<4",
|
|
55
55
|
"pytest-timeout>=2.1,<3",
|
|
56
56
|
],
|
|
57
|
+
"toto": [
|
|
58
|
+
"einops>=0.7,<1",
|
|
59
|
+
"rotary-embedding-torch>=0.8,<1",
|
|
60
|
+
],
|
|
57
61
|
}
|
|
58
62
|
|
|
59
63
|
# chronos-openvino and chronos-onnx are deprecated, and will be removed in a future version
|
|
60
|
-
extras_require["all"] = []
|
|
64
|
+
extras_require["all"] = list(set.union(*(set(extras_require[extra]) for extra in ["toto"])))
|
|
61
65
|
install_requires = ag.get_dependency_version_ranges(install_requires)
|
|
62
66
|
|
|
63
67
|
if __name__ == "__main__":
|
|
@@ -28,6 +28,7 @@ from .local import (
|
|
|
28
28
|
ZeroModel,
|
|
29
29
|
)
|
|
30
30
|
from .registry import ModelRegistry
|
|
31
|
+
from .toto import TotoModel
|
|
31
32
|
|
|
32
33
|
__all__ = [
|
|
33
34
|
"ADIDAModel",
|
|
@@ -56,6 +57,7 @@ __all__ = [
|
|
|
56
57
|
"TemporalFusionTransformerModel",
|
|
57
58
|
"ThetaModel",
|
|
58
59
|
"TiDEModel",
|
|
60
|
+
"TotoModel",
|
|
59
61
|
"WaveNetModel",
|
|
60
62
|
"ZeroModel",
|
|
61
63
|
]
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Optional, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from einops import rearrange
|
|
12
|
+
from torch.nn.functional import scaled_dot_product_attention
|
|
13
|
+
|
|
14
|
+
from .rope import TimeAwareRotaryEmbedding
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AttentionAxis(Enum):
|
|
20
|
+
TIME = 1
|
|
21
|
+
SPACE = 2
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BaseMultiheadAttention(torch.nn.Module):
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
embed_dim: int,
|
|
28
|
+
num_heads: int,
|
|
29
|
+
dropout: float,
|
|
30
|
+
rotary_emb: Optional[TimeAwareRotaryEmbedding],
|
|
31
|
+
use_memory_efficient_attention: bool,
|
|
32
|
+
):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.embed_dim = embed_dim
|
|
35
|
+
self.num_heads = num_heads
|
|
36
|
+
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads."
|
|
37
|
+
self.head_dim = embed_dim // num_heads
|
|
38
|
+
self.rotary_emb = rotary_emb
|
|
39
|
+
|
|
40
|
+
# We allocate a single tensor for the q, k, and v projection matrices,
|
|
41
|
+
# multiply them with the inputs, and then split the projected tensors into q, k, and v using unbind.
|
|
42
|
+
# This reduces overhead a bit vs. having multiple separate Linear layers,
|
|
43
|
+
# which need to be initialized, tracked by the optimizer, etc.
|
|
44
|
+
self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
|
|
45
|
+
self.dropout = dropout
|
|
46
|
+
self.use_memory_efficient_attention = use_memory_efficient_attention
|
|
47
|
+
self.wO = torch.nn.Linear(embed_dim, embed_dim)
|
|
48
|
+
|
|
49
|
+
assert not self.use_memory_efficient_attention, (
|
|
50
|
+
"xformers is not available, so use_memory_efficient_attention must be False"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE):
|
|
54
|
+
raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.")
|
|
55
|
+
|
|
56
|
+
def rearrange_inputs(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
57
|
+
pattern = (
|
|
58
|
+
"batch variate seq_len embed_dim -> (batch variate) seq_len embed_dim"
|
|
59
|
+
if self.attention_axis == AttentionAxis.TIME
|
|
60
|
+
else "batch variate seq_len embed_dim -> (batch seq_len) variate embed_dim"
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
return rearrange(inputs, pattern)
|
|
64
|
+
|
|
65
|
+
def get_qkv(
|
|
66
|
+
self,
|
|
67
|
+
inputs: torch.Tensor,
|
|
68
|
+
) -> tuple[torch.Tensor, ...]:
|
|
69
|
+
pattern: str = ""
|
|
70
|
+
if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
|
|
71
|
+
pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate seq_len n_heads head_dim"
|
|
72
|
+
elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
|
|
73
|
+
pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv batch_X_variate n_heads seq_len head_dim"
|
|
74
|
+
elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
|
|
75
|
+
pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len variate n_heads head_dim"
|
|
76
|
+
elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
|
|
77
|
+
pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv batch_X_seq_len n_heads variate head_dim"
|
|
78
|
+
|
|
79
|
+
assert pattern
|
|
80
|
+
qkv = self.wQKV(inputs.contiguous())
|
|
81
|
+
return rearrange(qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads).unbind(dim=0)
|
|
82
|
+
|
|
83
|
+
def positional_embedding(self, q, k, v, kv_cache, layer_idx):
|
|
84
|
+
# Apply the rotary embeddings
|
|
85
|
+
seq_pos_offset = 0
|
|
86
|
+
if self.rotary_emb is not None and self.attention_axis == AttentionAxis.TIME:
|
|
87
|
+
if kv_cache is not None:
|
|
88
|
+
seq_pos_offset = kv_cache.seq_len(layer_idx)
|
|
89
|
+
|
|
90
|
+
# We need to permute because rotary embeddings expect the sequence dimension to be the second-to-last dimension
|
|
91
|
+
q, k = self.rotary_emb.rotate_queries_and_keys(q, k, seq_pos_offset=seq_pos_offset)
|
|
92
|
+
|
|
93
|
+
if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
|
|
94
|
+
# First, we append the current input key and value tensors to the cache.
|
|
95
|
+
# This concatenates the current key and value tensors to the existing key and value tensors
|
|
96
|
+
kv_cache.append(layer_idx, (k, v))
|
|
97
|
+
# Then, we retrieve the key and value tensors from the cache.
|
|
98
|
+
# This includes all the key and value tensors from previous time steps
|
|
99
|
+
# as well as the current time step.
|
|
100
|
+
k, v = kv_cache[layer_idx]
|
|
101
|
+
|
|
102
|
+
q = q.contiguous()
|
|
103
|
+
k = k.contiguous().to(q.dtype) # Ensure k is the same dtype as q; this is necessary when using mixed precision
|
|
104
|
+
v = v.contiguous().to(q.dtype) # Ensure v is the same dtype as q; this is necessary when using mixed precision
|
|
105
|
+
|
|
106
|
+
return q, k, v, seq_pos_offset
|
|
107
|
+
|
|
108
|
+
def rearrange_output(self, output: torch.Tensor, batch: int, variate: int, seq_len: int) -> torch.Tensor:
|
|
109
|
+
if self.attention_axis == AttentionAxis.TIME and self.use_memory_efficient_attention:
|
|
110
|
+
pattern = "(batch variate) seq_len n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
111
|
+
elif self.attention_axis == AttentionAxis.TIME and not self.use_memory_efficient_attention:
|
|
112
|
+
pattern = "(batch variate) n_heads seq_len head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
113
|
+
elif self.attention_axis == AttentionAxis.SPACE and self.use_memory_efficient_attention:
|
|
114
|
+
pattern = "(batch seq_len) variate n_heads head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
115
|
+
elif self.attention_axis == AttentionAxis.SPACE and not self.use_memory_efficient_attention:
|
|
116
|
+
pattern = "(batch seq_len) n_heads variate head_dim -> batch variate seq_len (n_heads head_dim)"
|
|
117
|
+
|
|
118
|
+
return rearrange(output, pattern, batch=batch, variate=variate, seq_len=seq_len) # type: ignore
|
|
119
|
+
|
|
120
|
+
def run_attention(self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate):
|
|
121
|
+
# Determine dimension ranges for attention
|
|
122
|
+
# Ensure the last query vector index is used from the cache
|
|
123
|
+
q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
|
|
124
|
+
kv_dim_start, kv_dim_end = 0, v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
|
|
125
|
+
if self.attention_axis == AttentionAxis.TIME:
|
|
126
|
+
attention_mask = (
|
|
127
|
+
attention_mask[..., q_dim_start:q_dim_end, kv_dim_start:kv_dim_end]
|
|
128
|
+
if torch.is_tensor(attention_mask)
|
|
129
|
+
else None
|
|
130
|
+
)
|
|
131
|
+
return scaled_dot_product_attention(
|
|
132
|
+
q,
|
|
133
|
+
k,
|
|
134
|
+
v,
|
|
135
|
+
attn_mask=attention_mask,
|
|
136
|
+
dropout_p=dropout,
|
|
137
|
+
is_causal=(attention_mask is None and seq_pos_offset == 0),
|
|
138
|
+
)
|
|
139
|
+
elif self.attention_axis == AttentionAxis.SPACE:
|
|
140
|
+
# We don't use causal masking for space-wise attention
|
|
141
|
+
attention_mask = (
|
|
142
|
+
attention_mask[..., kv_dim_start:kv_dim_end, kv_dim_start:kv_dim_end]
|
|
143
|
+
if torch.is_tensor(attention_mask)
|
|
144
|
+
else None
|
|
145
|
+
)
|
|
146
|
+
return scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p=dropout, is_causal=False)
|
|
147
|
+
else:
|
|
148
|
+
raise ValueError("Invalid attention axis")
|
|
149
|
+
|
|
150
|
+
def forward(
|
|
151
|
+
self,
|
|
152
|
+
layer_idx: int,
|
|
153
|
+
inputs: torch.Tensor,
|
|
154
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
155
|
+
kv_cache=None,
|
|
156
|
+
) -> torch.Tensor:
|
|
157
|
+
batch_size, variate, seq_len, _ = inputs.shape
|
|
158
|
+
dropout = self.dropout if self.training else 0.0
|
|
159
|
+
|
|
160
|
+
rearranged_inputs = self.rearrange_inputs(inputs)
|
|
161
|
+
q, k, v = self.get_qkv(rearranged_inputs)
|
|
162
|
+
|
|
163
|
+
q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx)
|
|
164
|
+
|
|
165
|
+
output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate)
|
|
166
|
+
|
|
167
|
+
output = self.rearrange_output(output, batch_size, variate, seq_len)
|
|
168
|
+
return self.wO(output)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class TimeWiseMultiheadAttention(BaseMultiheadAttention):
|
|
172
|
+
"""
|
|
173
|
+
Computes standard multihead causal attention over the time axis.
|
|
174
|
+
It does this by flattening out the variates along the batch dimension.
|
|
175
|
+
It also applies rotary position embeddings to the query and key matrices
|
|
176
|
+
in order to incorporate relative positional information.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
attention_axis = AttentionAxis.TIME
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
|
|
183
|
+
"""
|
|
184
|
+
Computes bidirectional multihead attention over the space axis (i.e. across variates within
|
|
185
|
+
a multi-variate time series). This is done by flattening out the time axis along the batch dimension.
|
|
186
|
+
This allows the model to attend to different variates at the same time point. By alternating
|
|
187
|
+
between time-wise and space-wise attention, the model can learn both temporal and cross-variate
|
|
188
|
+
dependencies in the data.
|
|
189
|
+
|
|
190
|
+
Unlike with time-wise attention, don't apply rotary embeddings here
|
|
191
|
+
because we want cross-variate attention to be invariant to the order of the variates.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
attention_axis = AttentionAxis.SPACE
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
MultiHeadAttention = Union[TimeWiseMultiheadAttention, SpaceWiseMultiheadAttention]
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from typing import NamedTuple, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from .distribution import MixtureOfStudentTsOutput
|
|
12
|
+
from .kvcache import KVCache
|
|
13
|
+
from .scaler import CausalPatchStdMeanScaler
|
|
14
|
+
from .transformer import Transformer
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TotoOutput(NamedTuple):
|
|
18
|
+
"""
|
|
19
|
+
Output of the Toto model. Contains the output distribution, the location parameters,
|
|
20
|
+
and the scale parameters.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
distribution: torch.distributions.Distribution
|
|
24
|
+
loc: torch.Tensor
|
|
25
|
+
scale: torch.Tensor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def patchify_id_mask(id_mask: torch.Tensor, patch_size: int) -> torch.Tensor:
|
|
29
|
+
patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size, step=patch_size)
|
|
30
|
+
patched_id_mask_min = patched_id_mask.min(-1).values
|
|
31
|
+
patched_id_mask_max = patched_id_mask.max(-1).values
|
|
32
|
+
assert torch.eq(patched_id_mask_min, patched_id_mask_max).all(), "Patches cannot span multiple datasets"
|
|
33
|
+
return patched_id_mask_min
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PatchEmbedding(torch.nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Multivariate time series patch embedding.
|
|
39
|
+
Patchifies each variate separately.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, patch_size: int, stride: int, embed_dim: int):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.patch_size = patch_size
|
|
45
|
+
self.embed_dim = embed_dim
|
|
46
|
+
self.stride = stride
|
|
47
|
+
self.projection = torch.nn.Linear(self.patch_size, self.embed_dim)
|
|
48
|
+
|
|
49
|
+
def _patchify(self, x: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
return x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
|
|
51
|
+
|
|
52
|
+
def forward(
|
|
53
|
+
self,
|
|
54
|
+
x: torch.Tensor,
|
|
55
|
+
id_mask: torch.Tensor,
|
|
56
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
57
|
+
assert x.shape[-1] % self.patch_size == 0, (
|
|
58
|
+
f"Series length ({x.shape=}) must be divisible by ({self.patch_size=})"
|
|
59
|
+
)
|
|
60
|
+
x_patched: torch.Tensor = self._patchify(x)
|
|
61
|
+
id_mask_patched: torch.Tensor = self._patchify(id_mask)
|
|
62
|
+
|
|
63
|
+
assert torch.eq(id_mask_patched.min(-1).values, id_mask_patched.max(-1).values).all(), (
|
|
64
|
+
"Patches cannot span multiple datasets"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return (
|
|
68
|
+
self.projection(x_patched),
|
|
69
|
+
id_mask_patched.min(-1).values,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TotoBackbone(torch.nn.Module):
|
|
74
|
+
"""
|
|
75
|
+
Toto (Timeseries-Optimized Transformer for Observability) is a transformer-based model for multivariate
|
|
76
|
+
time series forecasting. It applies a patch embedding to the input data, followed by a transformer
|
|
77
|
+
that alternates between time-wise and space-wise attention. The transformer is followed by a linear projection
|
|
78
|
+
that maps the transformer output to the output distribution.
|
|
79
|
+
|
|
80
|
+
The output distribution can be a single distribution (e.g. Gaussian) or a mixture of distributions.
|
|
81
|
+
If a mixture of distributions is used, the model will learn to predict the mixture weights
|
|
82
|
+
as well as the parameters of the individual distributions.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
patch_size
|
|
87
|
+
Size of the patch to use for the patch embedding.
|
|
88
|
+
stride
|
|
89
|
+
Stride to use for the patch embedding.
|
|
90
|
+
embed_dim
|
|
91
|
+
Dimension of the model's latent space.
|
|
92
|
+
num_layers
|
|
93
|
+
Number of transformer layers to use.
|
|
94
|
+
num_heads
|
|
95
|
+
Number of attention heads to use in each self-attention layer.
|
|
96
|
+
mlp_hidden_dim
|
|
97
|
+
Dimension of the hidden layer in the feedforward network.
|
|
98
|
+
dropout
|
|
99
|
+
Dropout rate to use in the model.
|
|
100
|
+
spacewise_every_n_layers
|
|
101
|
+
How many time-wise transformer layers to apply between each space-wise transformer layer.
|
|
102
|
+
spacewise_first
|
|
103
|
+
Whether to apply space-wise attention before time-wise attention.
|
|
104
|
+
scaler_cls
|
|
105
|
+
Class to use for scaling the input data.
|
|
106
|
+
output_distribution_classes
|
|
107
|
+
List of classes to use for the output distribution. If a single class is provided, the model
|
|
108
|
+
will output a single distribution. If multiple classes are provided, the model will output a
|
|
109
|
+
learned mixture of distributions.
|
|
110
|
+
output_distribution_kwargs
|
|
111
|
+
Keyword arguments to pass to the output distribution class. Note: this currently only works
|
|
112
|
+
with a single output distribution class.
|
|
113
|
+
use_memory_efficient_attention:
|
|
114
|
+
Whether to use memory-efficient attention. If True, the model will use the memory-efficient from xFormers.
|
|
115
|
+
stabilize_with_global:
|
|
116
|
+
Whether to use global statistics to stabilize causal statistics by clamping extreme values. Only applies to causal scalers.
|
|
117
|
+
scale_factor_exponent:
|
|
118
|
+
Exponent that controls the allowed range of deviation from global scale for causal scalers.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
patch_size: int,
|
|
124
|
+
stride: int,
|
|
125
|
+
embed_dim: int,
|
|
126
|
+
num_layers: int,
|
|
127
|
+
num_heads: int,
|
|
128
|
+
mlp_hidden_dim: int,
|
|
129
|
+
dropout: float,
|
|
130
|
+
spacewise_every_n_layers: int,
|
|
131
|
+
scaler_cls: str,
|
|
132
|
+
output_distribution_classes: list[str],
|
|
133
|
+
spacewise_first: bool = True,
|
|
134
|
+
output_distribution_kwargs: Optional[dict] = None,
|
|
135
|
+
use_memory_efficient_attention: bool = True,
|
|
136
|
+
stabilize_with_global: bool = True,
|
|
137
|
+
scale_factor_exponent: float = 10.0,
|
|
138
|
+
):
|
|
139
|
+
super().__init__()
|
|
140
|
+
self.embed_dim = embed_dim
|
|
141
|
+
# strings are used when loading a safetensors checkpoint
|
|
142
|
+
# Initialize patch-based scalers with the correct patch_size
|
|
143
|
+
|
|
144
|
+
self.scaler = CausalPatchStdMeanScaler(
|
|
145
|
+
patch_size=patch_size,
|
|
146
|
+
stabilize_with_global=stabilize_with_global,
|
|
147
|
+
scale_factor_exponent=scale_factor_exponent,
|
|
148
|
+
)
|
|
149
|
+
self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim)
|
|
150
|
+
self.dropout = dropout
|
|
151
|
+
self.num_layers = num_layers
|
|
152
|
+
self.use_memory_efficient_attention = use_memory_efficient_attention
|
|
153
|
+
self.transformer = Transformer(
|
|
154
|
+
embed_dim=embed_dim,
|
|
155
|
+
num_heads=num_heads,
|
|
156
|
+
num_layers=self.num_layers,
|
|
157
|
+
mlp_hidden_dim=mlp_hidden_dim,
|
|
158
|
+
dropout=dropout,
|
|
159
|
+
spacewise_every_n_layers=spacewise_every_n_layers,
|
|
160
|
+
spacewise_first=spacewise_first,
|
|
161
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
|
162
|
+
)
|
|
163
|
+
self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size)
|
|
164
|
+
|
|
165
|
+
# TODO[BEN] this doesn't need to be a list
|
|
166
|
+
output_distribution_classes_ = [MixtureOfStudentTsOutput]
|
|
167
|
+
self.output_distribution = output_distribution_classes_[0](embed_dim, **(output_distribution_kwargs or {}))
|
|
168
|
+
|
|
169
|
+
def allocate_kv_cache(
|
|
170
|
+
self,
|
|
171
|
+
batch_size: int,
|
|
172
|
+
num_variates: int,
|
|
173
|
+
max_time_steps: int,
|
|
174
|
+
device: torch.device,
|
|
175
|
+
dtype: torch.dtype,
|
|
176
|
+
) -> KVCache:
|
|
177
|
+
return KVCache(
|
|
178
|
+
batch_size=batch_size,
|
|
179
|
+
num_variates=num_variates,
|
|
180
|
+
transformer_layers=list(self.transformer.layers),
|
|
181
|
+
num_layers=self.num_layers,
|
|
182
|
+
embed_dim=self.embed_dim,
|
|
183
|
+
num_heads=self.transformer.layers[0].num_heads, # type: ignore
|
|
184
|
+
max_seq_len=math.ceil(max_time_steps / self.patch_embed.stride),
|
|
185
|
+
device=device,
|
|
186
|
+
dtype=dtype,
|
|
187
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def backbone(
|
|
191
|
+
self,
|
|
192
|
+
inputs: torch.Tensor,
|
|
193
|
+
input_padding_mask: torch.Tensor,
|
|
194
|
+
id_mask: torch.Tensor,
|
|
195
|
+
kv_cache: Optional[KVCache] = None,
|
|
196
|
+
scaling_prefix_length: Optional[int] = None,
|
|
197
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
198
|
+
scaled_inputs: torch.Tensor
|
|
199
|
+
loc: torch.Tensor
|
|
200
|
+
scale: torch.Tensor
|
|
201
|
+
|
|
202
|
+
# Standard scaling operation, same API but without ID mask.
|
|
203
|
+
scaled_inputs, loc, scale = self.scaler(
|
|
204
|
+
inputs,
|
|
205
|
+
weights=torch.ones_like(inputs, device=inputs.device),
|
|
206
|
+
padding_mask=input_padding_mask,
|
|
207
|
+
prefix_length=scaling_prefix_length,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
if kv_cache is not None:
|
|
211
|
+
prefix_len = self.patch_embed.stride * kv_cache.current_len(0)
|
|
212
|
+
|
|
213
|
+
# Truncate inputs so that the transformer only processes
|
|
214
|
+
# the last patch in the sequence. We'll use the KVCache
|
|
215
|
+
# for the earlier patches.
|
|
216
|
+
scaled_inputs = scaled_inputs[:, :, prefix_len:]
|
|
217
|
+
|
|
218
|
+
# As a simplification, when using kv cache we only allow decoding
|
|
219
|
+
# one step at a time after the initial forward pass.
|
|
220
|
+
assert (prefix_len == 0) or (scaled_inputs.shape[-1] == self.patch_embed.stride), (
|
|
221
|
+
"Must decode one step at a time."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
input_padding_mask = input_padding_mask[:, :, prefix_len:]
|
|
225
|
+
id_mask = id_mask[:, :, prefix_len:]
|
|
226
|
+
|
|
227
|
+
embeddings: torch.Tensor
|
|
228
|
+
reduced_id_mask: torch.Tensor
|
|
229
|
+
|
|
230
|
+
embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask)
|
|
231
|
+
|
|
232
|
+
# Apply the transformer on the embeddings
|
|
233
|
+
transformed: torch.Tensor = self.transformer(embeddings, reduced_id_mask, kv_cache)
|
|
234
|
+
|
|
235
|
+
# Unembed and flatten the sequence
|
|
236
|
+
unembedded = self.unembed(transformed)
|
|
237
|
+
batch_size, num_variates, seq_len = unembedded.shape[:3]
|
|
238
|
+
patch_size = unembedded.shape[-1] // self.embed_dim
|
|
239
|
+
flattened = unembedded.view(batch_size, num_variates, seq_len * patch_size, self.embed_dim)
|
|
240
|
+
return flattened, loc, scale
|
|
241
|
+
|
|
242
|
+
def forward(
|
|
243
|
+
self,
|
|
244
|
+
inputs: torch.Tensor,
|
|
245
|
+
input_padding_mask: torch.Tensor,
|
|
246
|
+
id_mask: torch.Tensor,
|
|
247
|
+
kv_cache: Optional[KVCache] = None,
|
|
248
|
+
scaling_prefix_length: Optional[int] = None,
|
|
249
|
+
) -> TotoOutput:
|
|
250
|
+
flattened, loc, scale = self.backbone(
|
|
251
|
+
inputs,
|
|
252
|
+
input_padding_mask,
|
|
253
|
+
id_mask,
|
|
254
|
+
kv_cache,
|
|
255
|
+
scaling_prefix_length,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return TotoOutput(self.output_distribution(flattened), loc, scale)
|
|
259
|
+
|
|
260
|
+
@property
|
|
261
|
+
def device(self):
|
|
262
|
+
return next(self.parameters()).device
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License.
|
|
2
|
+
#
|
|
3
|
+
# This product includes software developed at Datadog (https://www.datadoghq.com/)
|
|
4
|
+
# Copyright 2025 Datadog, Inc.
|
|
5
|
+
|
|
6
|
+
from abc import ABC
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from gluonts.torch.distributions import AffineTransformed
|
|
11
|
+
from gluonts.torch.distributions.studentT import StudentT
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DistributionOutput(ABC, torch.nn.Module):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StudentTOutput(DistributionOutput):
|
|
19
|
+
def __init__(self, embed_dim):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.embed_dim = embed_dim
|
|
22
|
+
self.df = torch.nn.Linear(embed_dim, 1)
|
|
23
|
+
self.loc_proj = torch.nn.Linear(embed_dim, 1)
|
|
24
|
+
self.scale_proj = torch.nn.Linear(embed_dim, 1)
|
|
25
|
+
|
|
26
|
+
def forward(self, inputs, loc=None, scale=None):
|
|
27
|
+
eps = torch.finfo(inputs.dtype).eps
|
|
28
|
+
df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
|
|
29
|
+
base_loc = self.loc_proj(inputs).squeeze(-1)
|
|
30
|
+
base_scale = F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
|
|
31
|
+
|
|
32
|
+
base_dist = torch.distributions.StudentT(df, base_loc, base_scale, validate_args=False) # type: ignore
|
|
33
|
+
|
|
34
|
+
if loc is not None and scale is not None:
|
|
35
|
+
return AffineTransformed(
|
|
36
|
+
base_dist,
|
|
37
|
+
loc=loc,
|
|
38
|
+
scale=scale,
|
|
39
|
+
)
|
|
40
|
+
return base_dist
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MixtureOfStudentTsOutput(DistributionOutput):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
embed_dim,
|
|
47
|
+
k_components,
|
|
48
|
+
):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.embed_dim = embed_dim
|
|
51
|
+
self.k_components = k_components
|
|
52
|
+
|
|
53
|
+
self.df = torch.nn.Linear(embed_dim, k_components)
|
|
54
|
+
self.loc_proj = torch.nn.Linear(embed_dim, k_components)
|
|
55
|
+
self.scale_proj = torch.nn.Linear(embed_dim, k_components)
|
|
56
|
+
self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
|
|
57
|
+
|
|
58
|
+
def forward(self, inputs, loc=None, scale=None):
|
|
59
|
+
df = 2.0 + F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
|
|
60
|
+
loc = self.loc_proj(inputs)
|
|
61
|
+
scale = F.softplus(self.scale_proj(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
|
|
62
|
+
logits = self.mixture_weights(inputs)
|
|
63
|
+
probs = F.softmax(logits, dim=-1)
|
|
64
|
+
components = StudentT(df, loc, scale)
|
|
65
|
+
mixture_distribution = torch.distributions.Categorical(probs=probs)
|
|
66
|
+
|
|
67
|
+
return torch.distributions.MixtureSameFamily(
|
|
68
|
+
mixture_distribution,
|
|
69
|
+
components,
|
|
70
|
+
)
|