tirex-mirror 2025.8.28__tar.gz → 2025.9.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tirex_mirror-2025.8.28/src/tirex_mirror.egg-info → tirex_mirror-2025.9.9}/PKG-INFO +8 -5
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/pyproject.toml +5 -4
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/api_adapter/forecast.py +0 -1
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/base.py +35 -8
- tirex_mirror-2025.9.9/src/tirex/models/patcher.py +84 -0
- tirex_mirror-2025.9.9/src/tirex/models/slstm/block.py +60 -0
- tirex_mirror-2025.9.9/src/tirex/models/slstm/cell.py +188 -0
- tirex_mirror-2025.9.9/src/tirex/models/slstm/layer.py +67 -0
- tirex_mirror-2025.9.9/src/tirex/models/tirex.py +225 -0
- tirex_mirror-2025.9.9/src/tirex/util.py +13 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9/src/tirex_mirror.egg-info}/PKG-INFO +8 -5
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex_mirror.egg-info/SOURCES.txt +5 -3
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex_mirror.egg-info/requires.txt +8 -4
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/tests/test_chronos_zs.py +9 -5
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/tests/test_forecast.py +10 -8
- tirex_mirror-2025.8.28/src/tirex/models/components.py +0 -147
- tirex_mirror-2025.8.28/src/tirex/models/mixed_stack.py +0 -143
- tirex_mirror-2025.8.28/src/tirex/models/predict_utils.py +0 -72
- tirex_mirror-2025.8.28/src/tirex/models/tirex.py +0 -231
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/LICENSE +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/MANIFEST.in +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/NOTICE.txt +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/README.md +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/setup.cfg +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/__init__.py +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/api_adapter/__init__.py +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/api_adapter/gluon.py +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/api_adapter/hf_data.py +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/api_adapter/standard_adapter.py +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex/models/__init__.py +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/src/tirex_mirror.egg-info/top_level.txt +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/tests/test_forecast_adapter.py +0 -0
- {tirex_mirror-2025.8.28 → tirex_mirror-2025.9.9}/tests/test_standard_adapter.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tirex-mirror
|
|
3
|
-
Version: 2025.
|
|
3
|
+
Version: 2025.9.9
|
|
4
4
|
Summary: Unofficial mirror of NX-AI/tirex for packaging
|
|
5
5
|
Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
|
|
6
6
|
License: NXAI COMMUNITY LICENSE AGREEMENT
|
|
@@ -65,17 +65,17 @@ License-File: LICENSE_MIRROR.txt
|
|
|
65
65
|
License-File: NOTICE.txt
|
|
66
66
|
Requires-Dist: torch
|
|
67
67
|
Requires-Dist: torchvision
|
|
68
|
-
Requires-Dist: xlstm
|
|
69
68
|
Requires-Dist: einops
|
|
70
|
-
Requires-Dist: ninja
|
|
71
69
|
Requires-Dist: huggingface-hub
|
|
72
|
-
Requires-Dist: lightning
|
|
73
70
|
Requires-Dist: numpy
|
|
74
71
|
Requires-Dist: pandas
|
|
75
|
-
Requires-Dist: dacite
|
|
76
72
|
Requires-Dist: tqdm
|
|
73
|
+
Provides-Extra: cuda
|
|
74
|
+
Requires-Dist: xlstm; extra == "cuda"
|
|
75
|
+
Requires-Dist: ninja; extra == "cuda"
|
|
77
76
|
Provides-Extra: notebooks
|
|
78
77
|
Requires-Dist: ipykernel; extra == "notebooks"
|
|
78
|
+
Requires-Dist: matplotlib; extra == "notebooks"
|
|
79
79
|
Provides-Extra: gluonts
|
|
80
80
|
Requires-Dist: gluonts; extra == "gluonts"
|
|
81
81
|
Provides-Extra: hfdataset
|
|
@@ -83,7 +83,10 @@ Requires-Dist: datasets; extra == "hfdataset"
|
|
|
83
83
|
Provides-Extra: test
|
|
84
84
|
Requires-Dist: fev; extra == "test"
|
|
85
85
|
Provides-Extra: all
|
|
86
|
+
Requires-Dist: xlstm; extra == "all"
|
|
87
|
+
Requires-Dist: ninja; extra == "all"
|
|
86
88
|
Requires-Dist: ipykernel; extra == "all"
|
|
89
|
+
Requires-Dist: matplotlib; extra == "all"
|
|
87
90
|
Requires-Dist: gluonts; extra == "all"
|
|
88
91
|
Requires-Dist: datasets; extra == "all"
|
|
89
92
|
Requires-Dist: fev; extra == "all"
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "tirex-mirror"
|
|
3
|
-
version = "2025.
|
|
3
|
+
version = "2025.09.09"
|
|
4
4
|
description = "Unofficial mirror of NX-AI/tirex for packaging"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.11"
|
|
7
7
|
classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent",]
|
|
8
8
|
keywords = [ "TiRex", "xLSTM", "Time Series", "Zero-shot", "Deep Learning",]
|
|
9
|
-
dependencies = [ "torch", "torchvision", "
|
|
9
|
+
dependencies = [ "torch", "torchvision", "einops", "huggingface-hub", "numpy", "pandas", "tqdm",]
|
|
10
10
|
[[project.authors]]
|
|
11
11
|
name = "Arpad Rozsas"
|
|
12
12
|
email = "rozsasarpi@gmail.com"
|
|
@@ -23,11 +23,12 @@ Repository = "https://github.com/rozsasarpi/tirex-mirror"
|
|
|
23
23
|
Issues = "https://github.com/rozsasarpi/tirex-mirror/issues"
|
|
24
24
|
|
|
25
25
|
[project.optional-dependencies]
|
|
26
|
-
|
|
26
|
+
cuda = [ "xlstm", "ninja",]
|
|
27
|
+
notebooks = [ "ipykernel", "matplotlib",]
|
|
27
28
|
gluonts = [ "gluonts",]
|
|
28
29
|
hfdataset = [ "datasets",]
|
|
29
30
|
test = [ "fev",]
|
|
30
|
-
all = [ "ipykernel", "gluonts", "datasets", "fev",]
|
|
31
|
+
all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "gluonts", "datasets", "fev",]
|
|
31
32
|
|
|
32
33
|
[tool.docformatter]
|
|
33
34
|
diff = false
|
|
@@ -3,13 +3,18 @@
|
|
|
3
3
|
|
|
4
4
|
import os
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
-
from typing import TypeVar
|
|
6
|
+
from typing import Literal, TypeVar
|
|
7
7
|
|
|
8
|
+
import torch
|
|
8
9
|
from huggingface_hub import hf_hub_download
|
|
9
10
|
|
|
10
11
|
T = TypeVar("T", bound="PretrainedModel")
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
def skip_cuda():
|
|
15
|
+
return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
|
|
16
|
+
|
|
17
|
+
|
|
13
18
|
def parse_hf_repo_id(path):
|
|
14
19
|
parts = path.split("/")
|
|
15
20
|
return "/".join(parts[0:2])
|
|
@@ -23,19 +28,30 @@ class PretrainedModel(ABC):
|
|
|
23
28
|
cls.REGISTRY[cls.register_name()] = cls
|
|
24
29
|
|
|
25
30
|
@classmethod
|
|
26
|
-
def from_pretrained(
|
|
31
|
+
def from_pretrained(
|
|
32
|
+
cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
|
|
33
|
+
) -> T:
|
|
27
34
|
if hf_kwargs is None:
|
|
28
35
|
hf_kwargs = {}
|
|
29
36
|
if ckp_kwargs is None:
|
|
30
37
|
ckp_kwargs = {}
|
|
38
|
+
if device is None:
|
|
39
|
+
device = "cuda:0" if backend == "cuda" else "cpu"
|
|
31
40
|
if os.path.exists(path):
|
|
32
41
|
print("Loading weights from local directory")
|
|
33
42
|
checkpoint_path = path
|
|
34
43
|
else:
|
|
35
44
|
repo_id = parse_hf_repo_id(path)
|
|
36
45
|
checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
|
|
37
|
-
|
|
38
|
-
|
|
46
|
+
|
|
47
|
+
# load lightning checkpoint
|
|
48
|
+
checkpoint = torch.load(checkpoint_path, map_location=device, **ckp_kwargs, weights_only=True)
|
|
49
|
+
model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
|
|
50
|
+
model.on_load_checkpoint(checkpoint)
|
|
51
|
+
model.load_state_dict(checkpoint["state_dict"])
|
|
52
|
+
|
|
53
|
+
if backend == "cuda":
|
|
54
|
+
model = model.to(device)
|
|
39
55
|
return model
|
|
40
56
|
|
|
41
57
|
@classmethod
|
|
@@ -43,17 +59,22 @@ class PretrainedModel(ABC):
|
|
|
43
59
|
def register_name(cls) -> str:
|
|
44
60
|
pass
|
|
45
61
|
|
|
46
|
-
def
|
|
62
|
+
def on_load_checkpoint(self):
|
|
47
63
|
pass
|
|
48
64
|
|
|
49
65
|
|
|
50
|
-
def load_model(
|
|
66
|
+
def load_model(
|
|
67
|
+
path: str,
|
|
68
|
+
device: str | None = None,
|
|
69
|
+
backend: Literal["torch", "cuda"] | None = None,
|
|
70
|
+
hf_kwargs=None,
|
|
71
|
+
ckp_kwargs=None,
|
|
72
|
+
) -> PretrainedModel:
|
|
51
73
|
"""Loads a TiRex model. This function attempts to load the specified model.
|
|
52
74
|
|
|
53
75
|
Args:
|
|
54
76
|
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
|
|
55
77
|
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
|
|
56
|
-
If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
|
|
57
78
|
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
|
|
58
79
|
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
|
|
59
80
|
|
|
@@ -63,6 +84,11 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
|
|
|
63
84
|
Examples:
|
|
64
85
|
model: ForecastModel = load_model("NX-AI/TiRex")
|
|
65
86
|
"""
|
|
87
|
+
|
|
88
|
+
if backend is None:
|
|
89
|
+
backend = "torch" if skip_cuda() else "cuda"
|
|
90
|
+
assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
|
|
91
|
+
|
|
66
92
|
try:
|
|
67
93
|
_, model_id = parse_hf_repo_id(path).split("/")
|
|
68
94
|
except:
|
|
@@ -70,4 +96,5 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
|
|
|
70
96
|
model_cls = PretrainedModel.REGISTRY.get(model_id, None)
|
|
71
97
|
if model_cls is None:
|
|
72
98
|
raise ValueError(f"Invalid model id {model_id}")
|
|
73
|
-
|
|
99
|
+
|
|
100
|
+
return model_cls.from_pretrained(path, device=device, backend=backend, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StandardScaler:
|
|
10
|
+
def __init__(self, eps: float = 1e-5, nan_loc: float = 0.0):
|
|
11
|
+
self.eps = eps
|
|
12
|
+
self.nan_loc = nan_loc
|
|
13
|
+
|
|
14
|
+
def scale(
|
|
15
|
+
self,
|
|
16
|
+
x: torch.Tensor,
|
|
17
|
+
loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
18
|
+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
19
|
+
if loc_scale is None:
|
|
20
|
+
loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
|
|
21
|
+
scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
|
|
22
|
+
scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
|
|
23
|
+
else:
|
|
24
|
+
loc, scale = loc_scale
|
|
25
|
+
|
|
26
|
+
return ((x - loc) / scale), (loc, scale)
|
|
27
|
+
|
|
28
|
+
def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
|
29
|
+
loc, scale = loc_scale
|
|
30
|
+
return x * scale + loc
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Patcher:
|
|
34
|
+
def __init__(self, patch_size: int, patch_stride: int, left_pad: bool):
|
|
35
|
+
self.patch_size = patch_size
|
|
36
|
+
self.patch_stride = patch_stride
|
|
37
|
+
self.left_pad = left_pad
|
|
38
|
+
assert self.patch_size % self.patch_stride == 0
|
|
39
|
+
|
|
40
|
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
assert x.ndim == 2
|
|
42
|
+
length = x.shape[-1]
|
|
43
|
+
|
|
44
|
+
if length < self.patch_size or (length % self.patch_stride != 0):
|
|
45
|
+
if length < self.patch_size:
|
|
46
|
+
padding_size = (
|
|
47
|
+
*x.shape[:-1],
|
|
48
|
+
self.patch_size - (length % self.patch_size),
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
padding_size = (
|
|
52
|
+
*x.shape[:-1],
|
|
53
|
+
self.patch_stride - (length % self.patch_stride),
|
|
54
|
+
)
|
|
55
|
+
padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
|
|
56
|
+
if self.left_pad:
|
|
57
|
+
x = torch.concat((padding, x), dim=-1)
|
|
58
|
+
else:
|
|
59
|
+
x = torch.concat((x, padding), dim=-1)
|
|
60
|
+
|
|
61
|
+
return x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class PatchedUniTokenizerState:
|
|
66
|
+
scale_state: float
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class PatchedUniTokenizer:
|
|
70
|
+
def __init__(self, patch_size: int, patch_stride: int | None = None, scaler: StandardScaler | None = None):
|
|
71
|
+
self.patch_size = patch_size
|
|
72
|
+
self.patch_stride = patch_size if patch_stride is None else patch_stride
|
|
73
|
+
self.scaler = StandardScaler() if scaler is None else scaler
|
|
74
|
+
self.patcher = Patcher(self.patch_size, self.patch_stride, left_pad=True)
|
|
75
|
+
|
|
76
|
+
def context_input_transform(self, data: torch.Tensor):
|
|
77
|
+
assert data.ndim == 2
|
|
78
|
+
data, scale_state = self.scaler.scale(data)
|
|
79
|
+
return self.patcher(data), PatchedUniTokenizerState(scale_state)
|
|
80
|
+
|
|
81
|
+
def output_transform(self, data: torch.Tensor, tokenizer_state: PatchedUniTokenizerState):
|
|
82
|
+
data_shape = data.shape
|
|
83
|
+
data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state.scale_state).view(*data_shape)
|
|
84
|
+
return data
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from tirex.models.slstm.layer import sLSTMBlockConfig, sLSTMLayer
|
|
9
|
+
from tirex.util import round_up_to_next_multiple_of
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class sLSTMBlock(nn.Module):
|
|
13
|
+
def __init__(self, config: sLSTMBlockConfig, backend: str):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.config = config
|
|
16
|
+
self.norm_slstm = RMSNorm(config.embedding_dim)
|
|
17
|
+
self.slstm_layer = sLSTMLayer(config, backend)
|
|
18
|
+
self.norm_ffn = RMSNorm(config.embedding_dim)
|
|
19
|
+
|
|
20
|
+
up_proj_dim = round_up_to_next_multiple_of(config.embedding_dim * config.ffn_proj_factor, 64)
|
|
21
|
+
self.ffn = FeedForward(config.embedding_dim, up_proj_dim)
|
|
22
|
+
|
|
23
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
x_slstm = self.norm_slstm(x)
|
|
25
|
+
|
|
26
|
+
x_slstm = self.slstm_layer(x_slstm, slstm_state=None)
|
|
27
|
+
x = x + x_slstm
|
|
28
|
+
|
|
29
|
+
x_ffn = self.norm_ffn(x)
|
|
30
|
+
x_ffn = self.ffn(x_ffn)
|
|
31
|
+
x = x + x_ffn
|
|
32
|
+
return x
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FeedForward(nn.Module):
|
|
36
|
+
def __init__(self, embedding_dim: int, up_proj_dim: int):
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.proj_up_gate = nn.Linear(embedding_dim, up_proj_dim, bias=False)
|
|
39
|
+
self.proj_up = nn.Linear(embedding_dim, up_proj_dim, bias=False)
|
|
40
|
+
self.proj_down = nn.Linear(up_proj_dim, embedding_dim, bias=False)
|
|
41
|
+
|
|
42
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
|
+
x = F.silu(self.proj_up_gate(x)) * self.proj_up(x)
|
|
44
|
+
y = self.proj_down(x)
|
|
45
|
+
return y
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class RMSNorm(nn.Module):
|
|
49
|
+
def __init__(self, num_features: int, eps: float = 1e-6):
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.eps = eps
|
|
52
|
+
self.weight = nn.Parameter(torch.ones(num_features))
|
|
53
|
+
|
|
54
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
x = self._rms_normalize(x.float()).to(x.dtype)
|
|
56
|
+
x = x * self.weight
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
|
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import warnings
|
|
5
|
+
from dataclasses import asdict, dataclass
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
from tirex.util import dataclass_from_dict
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class sLSTMBlockConfig:
|
|
16
|
+
embedding_dim: int
|
|
17
|
+
num_heads: int
|
|
18
|
+
num_blocks: int
|
|
19
|
+
ffn_proj_factor: float = 2.6667
|
|
20
|
+
|
|
21
|
+
num_states: int = 4 # this is for the sLSTM, a standard LSTM has 2
|
|
22
|
+
num_gates: int = 4
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def head_dim(self):
|
|
26
|
+
return self.embedding_dim // self.num_heads
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class sLSTMCell(nn.Module):
|
|
30
|
+
def __init__(self, config: sLSTMBlockConfig, backend: str):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.config = config
|
|
33
|
+
self.backend = backend
|
|
34
|
+
|
|
35
|
+
self._recurrent_kernel_ = nn.Parameter(
|
|
36
|
+
torch.empty((config.num_heads, config.head_dim, config.num_gates * config.head_dim), dtype=None)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
self._bias_ = nn.Parameter(torch.empty((config.num_heads * config.num_gates * config.head_dim), dtype=None))
|
|
40
|
+
|
|
41
|
+
def forward(self, input: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
42
|
+
input = self._get_input(input)
|
|
43
|
+
state = self._get_state(input, state)
|
|
44
|
+
|
|
45
|
+
if self.backend == "torch":
|
|
46
|
+
all_states = self._impl_torch(input, state)
|
|
47
|
+
elif self.backend == "cuda":
|
|
48
|
+
all_states = self._impl_cuda(input, state)
|
|
49
|
+
|
|
50
|
+
state = all_states[:, -1]
|
|
51
|
+
output = self._permute_output(all_states[0][1:])
|
|
52
|
+
return output.to(input.dtype), state.to(input.dtype)
|
|
53
|
+
|
|
54
|
+
def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
input = input.to(dtype=torch.bfloat16)
|
|
56
|
+
state = state.to(dtype=torch.bfloat16)
|
|
57
|
+
recurrent_kernel = self._recurrent_kernel_.to(dtype=torch.bfloat16)
|
|
58
|
+
bias = self._bias_.to(dtype=torch.float32)
|
|
59
|
+
|
|
60
|
+
input = input.view(input.shape[0], input.shape[1], -1)
|
|
61
|
+
bias = (
|
|
62
|
+
bias.reshape(self.config.num_heads, self.config.num_gates, self.config.head_dim)
|
|
63
|
+
.permute(1, 0, 2)
|
|
64
|
+
.reshape(-1)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
return slstm_forward(input, state, recurrent_kernel, bias)[0]
|
|
68
|
+
|
|
69
|
+
def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
if input.device.type != "cuda":
|
|
71
|
+
warnings.warn(
|
|
72
|
+
f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {input.device.type})!"
|
|
73
|
+
"This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
|
|
74
|
+
"If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if not hasattr(self, "func"):
|
|
78
|
+
try:
|
|
79
|
+
from xlstm.blocks.slstm.cell import sLSTMCellConfig as sLSTMCellConfigCuda, sLSTMCellFuncGenerator
|
|
80
|
+
except ModuleNotFoundError:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
'xlstm package not found! To use the custom cuda backend, install the additional dependencies with: pip install -e ".[cuda]"'
|
|
83
|
+
)
|
|
84
|
+
cuda_config = dataclass_from_dict(
|
|
85
|
+
sLSTMCellConfigCuda, {**asdict(self.config), "hidden_size": self.config.embedding_dim}
|
|
86
|
+
)
|
|
87
|
+
self.func = sLSTMCellFuncGenerator(False, cuda_config)
|
|
88
|
+
|
|
89
|
+
input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
|
|
90
|
+
|
|
91
|
+
return self.func.apply(
|
|
92
|
+
False,
|
|
93
|
+
input.contiguous(),
|
|
94
|
+
state.contiguous(),
|
|
95
|
+
self._recurrent_kernel_.contiguous(),
|
|
96
|
+
self._bias_.contiguous(),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _get_input(self, x: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
|
|
101
|
+
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
|
|
102
|
+
)
|
|
103
|
+
return x.view(x.shape[0], x.shape[1], self.config.num_gates, self.config.num_heads, -1).permute(1, 0, 2, 3, 4)
|
|
104
|
+
|
|
105
|
+
def _get_state(self, input: torch.Tensor, state: torch.Tensor | None) -> torch.Tensor:
|
|
106
|
+
B = input.shape[1]
|
|
107
|
+
if state is None:
|
|
108
|
+
state = torch.zeros(
|
|
109
|
+
(self.config.num_states, B, self.config.embedding_dim),
|
|
110
|
+
dtype=input.dtype,
|
|
111
|
+
device=input.device,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
assert state.shape == (self.config.num_states, B, self.config.embedding_dim)
|
|
115
|
+
return state
|
|
116
|
+
|
|
117
|
+
def _permute_output(self, output: torch.Tensor) -> torch.Tensor:
|
|
118
|
+
output = output.view(output.shape[0], output.shape[1], self.config.num_heads, self.config.head_dim)
|
|
119
|
+
return output.permute(1, 2, 0, 3)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def slstm_forward(
|
|
123
|
+
x: torch.Tensor, # [S, B, G*I]
|
|
124
|
+
states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
|
|
125
|
+
R: torch.Tensor, # [K, R*H, H] - K num_heads
|
|
126
|
+
b: torch.Tensor, # [T*H]
|
|
127
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
128
|
+
num_states = states.shape[0]
|
|
129
|
+
sequence_dim = x.shape[0]
|
|
130
|
+
# this only works for a fully-connected RNN, for a hin change this
|
|
131
|
+
num_gates_r = R.shape[2] // R.shape[1]
|
|
132
|
+
hidden_dim = R.shape[1] * R.shape[0]
|
|
133
|
+
batch_dim = x.shape[1]
|
|
134
|
+
num_heads = R.shape[0]
|
|
135
|
+
|
|
136
|
+
assert batch_dim == states.shape[1]
|
|
137
|
+
assert hidden_dim == states.shape[2]
|
|
138
|
+
|
|
139
|
+
states_all = torch.zeros(
|
|
140
|
+
[num_states, sequence_dim + 1, batch_dim, hidden_dim],
|
|
141
|
+
device=x.device,
|
|
142
|
+
dtype=x.dtype,
|
|
143
|
+
)
|
|
144
|
+
states_all[:, 0] = states
|
|
145
|
+
for i, Wx_t in enumerate(x.unbind(dim=0)):
|
|
146
|
+
Ry = (
|
|
147
|
+
states[0]
|
|
148
|
+
.reshape(batch_dim, num_heads, 1, -1)
|
|
149
|
+
.matmul(R.unsqueeze(0))
|
|
150
|
+
.reshape(batch_dim, num_heads, num_gates_r, -1)
|
|
151
|
+
.transpose(1, 2)
|
|
152
|
+
.reshape(batch_dim, -1)
|
|
153
|
+
)
|
|
154
|
+
sdtype = states.dtype
|
|
155
|
+
Wx_t, Ry, b, states = Wx_t.float(), Ry.float(), b.float(), states.float()
|
|
156
|
+
states, gates = slstm_forward_pointwise(Wx_t, Ry, b, states)
|
|
157
|
+
states = states.to(dtype=sdtype)
|
|
158
|
+
states_all[:, i + 1] = states
|
|
159
|
+
|
|
160
|
+
# shapes ([S, B, H], ([B,H], [B,H], [B,H])
|
|
161
|
+
return states_all, states
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def slstm_forward_pointwise(
|
|
165
|
+
Wx: torch.Tensor, # dim [B, 4*H]
|
|
166
|
+
Ry: torch.Tensor, # dim [B, 4*H]
|
|
167
|
+
b: torch.Tensor, # dim [1, 4*H]
|
|
168
|
+
states: torch.Tensor, # dim [4, B, H]
|
|
169
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
170
|
+
raw = Wx + Ry + b
|
|
171
|
+
y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
|
|
172
|
+
|
|
173
|
+
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
|
|
174
|
+
# with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
|
|
175
|
+
logfplusm = m + F.logsigmoid(fraw)
|
|
176
|
+
if torch.all(n == 0.0):
|
|
177
|
+
mnew = iraw
|
|
178
|
+
else:
|
|
179
|
+
mnew = torch.max(iraw, logfplusm)
|
|
180
|
+
ogate = torch.sigmoid(oraw)
|
|
181
|
+
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw))
|
|
182
|
+
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw))
|
|
183
|
+
cnew = fgate * c + igate * torch.tanh(zraw)
|
|
184
|
+
nnew = fgate * n + igate
|
|
185
|
+
ynew = ogate * cnew / nnew
|
|
186
|
+
|
|
187
|
+
# shapes ([B,H], [B,H], [B,H]), ([B,H],[B,H],[B,H],[B,H])
|
|
188
|
+
return torch.stack((ynew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from .cell import sLSTMBlockConfig, sLSTMCell
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class sLSTMLayer(nn.Module):
|
|
12
|
+
def __init__(self, config: sLSTMBlockConfig, backend: str):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.config = config
|
|
15
|
+
|
|
16
|
+
in_features, num_heads = self.config.embedding_dim, self.config.num_heads
|
|
17
|
+
self.fgate = LinearHeadwiseExpand(in_features, num_heads)
|
|
18
|
+
self.igate = LinearHeadwiseExpand(in_features, num_heads)
|
|
19
|
+
self.zgate = LinearHeadwiseExpand(in_features, num_heads)
|
|
20
|
+
self.ogate = LinearHeadwiseExpand(in_features, num_heads)
|
|
21
|
+
|
|
22
|
+
self.slstm_cell = sLSTMCell(self.config, backend)
|
|
23
|
+
self.group_norm = MultiHeadLayerNorm(ndim=in_features)
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor, slstm_state: torch.Tensor | None = None) -> torch.Tensor:
|
|
26
|
+
x_g = torch.cat((self.fgate(x), self.igate(x), self.zgate(x), self.ogate(x)), dim=-1)
|
|
27
|
+
|
|
28
|
+
y, slstm_state = self.slstm_cell(x_g, state=slstm_state)
|
|
29
|
+
|
|
30
|
+
return self.group_norm(y).transpose(1, 2).view(x.shape[0], x.shape[1], -1)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class LinearHeadwiseExpand(nn.Module):
|
|
34
|
+
def __init__(self, in_features, num_heads, expand_factor_up: float = 1):
|
|
35
|
+
super().__init__()
|
|
36
|
+
assert num_heads <= in_features, "num_heads must be <= in_features"
|
|
37
|
+
assert in_features % num_heads == 0, "in_features must be a multiple of num_heads"
|
|
38
|
+
self.num_heads = num_heads
|
|
39
|
+
|
|
40
|
+
out_features = round(expand_factor_up * in_features)
|
|
41
|
+
out_features_per_head = out_features // num_heads
|
|
42
|
+
self.weight = nn.Parameter(torch.empty(num_heads, out_features_per_head, in_features // num_heads))
|
|
43
|
+
|
|
44
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
45
|
+
shape = x.shape
|
|
46
|
+
x = x.view(*shape[:-1], self.num_heads, -1)
|
|
47
|
+
x = torch.einsum("...hd,hod->...ho", x, self.weight)
|
|
48
|
+
x = x.reshape(*shape[:-1], -1)
|
|
49
|
+
return x
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MultiHeadLayerNorm(nn.Module):
|
|
53
|
+
def __init__(self, ndim: int):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.weight = nn.Parameter(torch.zeros(ndim))
|
|
56
|
+
|
|
57
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
58
|
+
assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
|
|
59
|
+
B, NH, S, DH = input.shape
|
|
60
|
+
|
|
61
|
+
gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
|
|
62
|
+
gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
|
|
63
|
+
residual_weight = 1.0 + self.weight
|
|
64
|
+
out = F.group_norm(gn_in_2, num_groups=NH, weight=residual_weight)
|
|
65
|
+
# (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
|
|
66
|
+
out = out.view(B, S, NH, DH).transpose(1, 2)
|
|
67
|
+
return out
|