tirex-mirror 2025.9.2__tar.gz → 2025.9.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {tirex_mirror-2025.9.2/src/tirex_mirror.egg-info → tirex_mirror-2025.9.10}/PKG-INFO +10 -6
  2. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/pyproject.toml +7 -6
  3. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/forecast.py +0 -1
  4. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/base.py +43 -8
  5. tirex_mirror-2025.9.10/src/tirex/models/patcher.py +84 -0
  6. tirex_mirror-2025.9.10/src/tirex/models/slstm/block.py +60 -0
  7. tirex_mirror-2025.9.10/src/tirex/models/slstm/cell.py +188 -0
  8. tirex_mirror-2025.9.10/src/tirex/models/slstm/layer.py +67 -0
  9. tirex_mirror-2025.9.10/src/tirex/models/tirex.py +223 -0
  10. tirex_mirror-2025.9.10/src/tirex/util.py +13 -0
  11. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10/src/tirex_mirror.egg-info}/PKG-INFO +10 -6
  12. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/SOURCES.txt +6 -3
  13. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/requires.txt +10 -5
  14. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/tests/test_chronos_zs.py +13 -6
  15. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/tests/test_forecast.py +10 -8
  16. tirex_mirror-2025.9.10/tests/test_slstm_torch_vs_cuda.py +82 -0
  17. tirex_mirror-2025.9.2/src/tirex/models/components.py +0 -147
  18. tirex_mirror-2025.9.2/src/tirex/models/mixed_stack.py +0 -143
  19. tirex_mirror-2025.9.2/src/tirex/models/predict_utils.py +0 -72
  20. tirex_mirror-2025.9.2/src/tirex/models/tirex.py +0 -231
  21. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/LICENSE +0 -0
  22. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/LICENSE_MIRROR.txt +0 -0
  23. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/MANIFEST.in +0 -0
  24. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/NOTICE.txt +0 -0
  25. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/README.md +0 -0
  26. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/setup.cfg +0 -0
  27. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/__init__.py +0 -0
  28. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/__init__.py +0 -0
  29. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/gluon.py +0 -0
  30. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/hf_data.py +0 -0
  31. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/standard_adapter.py +0 -0
  32. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex/models/__init__.py +0 -0
  33. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  34. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  35. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/tests/test_forecast_adapter.py +0 -0
  36. {tirex_mirror-2025.9.2 → tirex_mirror-2025.9.10}/tests/test_standard_adapter.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.9.2
3
+ Version: 2025.9.10
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -64,28 +64,32 @@ License-File: LICENSE
64
64
  License-File: LICENSE_MIRROR.txt
65
65
  License-File: NOTICE.txt
66
66
  Requires-Dist: torch
67
- Requires-Dist: torchvision
68
- Requires-Dist: xlstm
69
67
  Requires-Dist: einops
70
- Requires-Dist: ninja
71
68
  Requires-Dist: huggingface-hub
72
- Requires-Dist: lightning
73
69
  Requires-Dist: numpy
74
70
  Requires-Dist: pandas
75
- Requires-Dist: dacite
76
71
  Requires-Dist: tqdm
72
+ Provides-Extra: cuda
73
+ Requires-Dist: xlstm; extra == "cuda"
74
+ Requires-Dist: ninja; extra == "cuda"
77
75
  Provides-Extra: notebooks
78
76
  Requires-Dist: ipykernel; extra == "notebooks"
77
+ Requires-Dist: matplotlib; extra == "notebooks"
79
78
  Provides-Extra: gluonts
80
79
  Requires-Dist: gluonts; extra == "gluonts"
81
80
  Provides-Extra: hfdataset
82
81
  Requires-Dist: datasets; extra == "hfdataset"
83
82
  Provides-Extra: test
84
83
  Requires-Dist: fev; extra == "test"
84
+ Requires-Dist: pytest; extra == "test"
85
85
  Provides-Extra: all
86
+ Requires-Dist: xlstm; extra == "all"
87
+ Requires-Dist: ninja; extra == "all"
86
88
  Requires-Dist: ipykernel; extra == "all"
89
+ Requires-Dist: matplotlib; extra == "all"
87
90
  Requires-Dist: gluonts; extra == "all"
88
91
  Requires-Dist: datasets; extra == "all"
92
+ Requires-Dist: pytest; extra == "all"
89
93
  Requires-Dist: fev; extra == "all"
90
94
  Dynamic: license-file
91
95
 
@@ -1,18 +1,18 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.09.02"
3
+ version = "2025.09.10"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
7
7
  classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent",]
8
8
  keywords = [ "TiRex", "xLSTM", "Time Series", "Zero-shot", "Deep Learning",]
9
- dependencies = [ "torch", "torchvision", "xlstm", "einops", "ninja", "huggingface-hub", "lightning", "numpy", "pandas", "dacite", "tqdm",]
9
+ dependencies = [ "torch", "einops", "huggingface-hub", "numpy", "pandas", "tqdm",]
10
10
  [[project.authors]]
11
11
  name = "Arpad Rozsas"
12
12
  email = "rozsasarpi@gmail.com"
13
13
 
14
14
  [build-system]
15
- requires = [ "setuptools>=42", "wheel",]
15
+ requires = [ "setuptools>=77.0.3", "wheel",]
16
16
  build-backend = "setuptools.build_meta"
17
17
 
18
18
  [project.license]
@@ -23,11 +23,12 @@ Repository = "https://github.com/rozsasarpi/tirex-mirror"
23
23
  Issues = "https://github.com/rozsasarpi/tirex-mirror/issues"
24
24
 
25
25
  [project.optional-dependencies]
26
- notebooks = [ "ipykernel",]
26
+ cuda = [ "xlstm", "ninja",]
27
+ notebooks = [ "ipykernel", "matplotlib",]
27
28
  gluonts = [ "gluonts",]
28
29
  hfdataset = [ "datasets",]
29
- test = [ "fev",]
30
- all = [ "ipykernel", "gluonts", "datasets", "fev",]
30
+ test = [ "fev", "pytest",]
31
+ all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "gluonts", "datasets", "pytest", "fev",]
31
32
 
32
33
  [tool.docformatter]
33
34
  diff = false
@@ -8,7 +8,6 @@ import torch
8
8
 
9
9
  from .standard_adapter import ContextType, get_batches
10
10
 
11
-
12
11
  DEF_TARGET_COLUMN = "target"
13
12
  DEF_META_COLUMNS = ("start", "item_id")
14
13
 
@@ -3,13 +3,27 @@
3
3
 
4
4
  import os
5
5
  from abc import ABC, abstractmethod
6
- from typing import TypeVar
6
+ from typing import Literal, TypeVar
7
7
 
8
+ import torch
8
9
  from huggingface_hub import hf_hub_download
9
10
 
10
11
  T = TypeVar("T", bound="PretrainedModel")
11
12
 
12
13
 
14
+ def skip_cuda():
15
+ return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
16
+
17
+
18
+ def xlstm_available():
19
+ try:
20
+ from xlstm.blocks.slstm.cell import sLSTMCellConfig, sLSTMCellFuncGenerator
21
+
22
+ return True
23
+ except ModuleNotFoundError:
24
+ return False
25
+
26
+
13
27
  def parse_hf_repo_id(path):
14
28
  parts = path.split("/")
15
29
  return "/".join(parts[0:2])
@@ -23,19 +37,30 @@ class PretrainedModel(ABC):
23
37
  cls.REGISTRY[cls.register_name()] = cls
24
38
 
25
39
  @classmethod
26
- def from_pretrained(cls: type[T], path, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> T:
40
+ def from_pretrained(
41
+ cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
42
+ ) -> T:
27
43
  if hf_kwargs is None:
28
44
  hf_kwargs = {}
29
45
  if ckp_kwargs is None:
30
46
  ckp_kwargs = {}
47
+ if device is None:
48
+ device = "cuda:0" if backend == "cuda" else "cpu"
31
49
  if os.path.exists(path):
32
50
  print("Loading weights from local directory")
33
51
  checkpoint_path = path
34
52
  else:
35
53
  repo_id = parse_hf_repo_id(path)
36
54
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
37
- model = cls.load_from_checkpoint(checkpoint_path, map_location=device, **ckp_kwargs)
38
- model.after_load_from_checkpoint()
55
+
56
+ # load lightning checkpoint
57
+ checkpoint = torch.load(checkpoint_path, map_location=device, **ckp_kwargs, weights_only=True)
58
+ model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
59
+ model.on_load_checkpoint(checkpoint)
60
+ model.load_state_dict(checkpoint["state_dict"])
61
+
62
+ if backend == "cuda":
63
+ model = model.to(device)
39
64
  return model
40
65
 
41
66
  @classmethod
@@ -43,17 +68,22 @@ class PretrainedModel(ABC):
43
68
  def register_name(cls) -> str:
44
69
  pass
45
70
 
46
- def after_load_from_checkpoint(self):
71
+ def on_load_checkpoint(self):
47
72
  pass
48
73
 
49
74
 
50
- def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> PretrainedModel:
75
+ def load_model(
76
+ path: str,
77
+ device: str | None = None,
78
+ backend: Literal["torch", "cuda"] | None = None,
79
+ hf_kwargs=None,
80
+ ckp_kwargs=None,
81
+ ) -> PretrainedModel:
51
82
  """Loads a TiRex model. This function attempts to load the specified model.
52
83
 
53
84
  Args:
54
85
  path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
55
86
  device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
56
- If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
57
87
  hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
58
88
  ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
59
89
 
@@ -63,6 +93,10 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
63
93
  Examples:
64
94
  model: ForecastModel = load_model("NX-AI/TiRex")
65
95
  """
96
+
97
+ if backend is None:
98
+ backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
99
+
66
100
  try:
67
101
  _, model_id = parse_hf_repo_id(path).split("/")
68
102
  except:
@@ -70,4 +104,5 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
70
104
  model_cls = PretrainedModel.REGISTRY.get(model_id, None)
71
105
  if model_cls is None:
72
106
  raise ValueError(f"Invalid model id {model_id}")
73
- return model_cls.from_pretrained(path, device=device, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
107
+
108
+ return model_cls.from_pretrained(path, device=device, backend=backend, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
@@ -0,0 +1,84 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+
8
+
9
+ class StandardScaler:
10
+ def __init__(self, eps: float = 1e-5, nan_loc: float = 0.0):
11
+ self.eps = eps
12
+ self.nan_loc = nan_loc
13
+
14
+ def scale(
15
+ self,
16
+ x: torch.Tensor,
17
+ loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
18
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
19
+ if loc_scale is None:
20
+ loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
21
+ scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
22
+ scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
23
+ else:
24
+ loc, scale = loc_scale
25
+
26
+ return ((x - loc) / scale), (loc, scale)
27
+
28
+ def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
29
+ loc, scale = loc_scale
30
+ return x * scale + loc
31
+
32
+
33
+ class Patcher:
34
+ def __init__(self, patch_size: int, patch_stride: int, left_pad: bool):
35
+ self.patch_size = patch_size
36
+ self.patch_stride = patch_stride
37
+ self.left_pad = left_pad
38
+ assert self.patch_size % self.patch_stride == 0
39
+
40
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
41
+ assert x.ndim == 2
42
+ length = x.shape[-1]
43
+
44
+ if length < self.patch_size or (length % self.patch_stride != 0):
45
+ if length < self.patch_size:
46
+ padding_size = (
47
+ *x.shape[:-1],
48
+ self.patch_size - (length % self.patch_size),
49
+ )
50
+ else:
51
+ padding_size = (
52
+ *x.shape[:-1],
53
+ self.patch_stride - (length % self.patch_stride),
54
+ )
55
+ padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
56
+ if self.left_pad:
57
+ x = torch.concat((padding, x), dim=-1)
58
+ else:
59
+ x = torch.concat((x, padding), dim=-1)
60
+
61
+ return x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
62
+
63
+
64
+ @dataclass
65
+ class PatchedUniTokenizerState:
66
+ scale_state: float
67
+
68
+
69
+ class PatchedUniTokenizer:
70
+ def __init__(self, patch_size: int, patch_stride: int | None = None, scaler: StandardScaler | None = None):
71
+ self.patch_size = patch_size
72
+ self.patch_stride = patch_size if patch_stride is None else patch_stride
73
+ self.scaler = StandardScaler() if scaler is None else scaler
74
+ self.patcher = Patcher(self.patch_size, self.patch_stride, left_pad=True)
75
+
76
+ def context_input_transform(self, data: torch.Tensor):
77
+ assert data.ndim == 2
78
+ data, scale_state = self.scaler.scale(data)
79
+ return self.patcher(data), PatchedUniTokenizerState(scale_state)
80
+
81
+ def output_transform(self, data: torch.Tensor, tokenizer_state: PatchedUniTokenizerState):
82
+ data_shape = data.shape
83
+ data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state.scale_state).view(*data_shape)
84
+ return data
@@ -0,0 +1,60 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from tirex.models.slstm.layer import sLSTMBlockConfig, sLSTMLayer
9
+ from tirex.util import round_up_to_next_multiple_of
10
+
11
+
12
+ class sLSTMBlock(nn.Module):
13
+ def __init__(self, config: sLSTMBlockConfig, backend: str):
14
+ super().__init__()
15
+ self.config = config
16
+ self.norm_slstm = RMSNorm(config.embedding_dim)
17
+ self.slstm_layer = sLSTMLayer(config, backend)
18
+ self.norm_ffn = RMSNorm(config.embedding_dim)
19
+
20
+ up_proj_dim = round_up_to_next_multiple_of(config.embedding_dim * config.ffn_proj_factor, 64)
21
+ self.ffn = FeedForward(config.embedding_dim, up_proj_dim)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ x_slstm = self.norm_slstm(x)
25
+
26
+ x_slstm = self.slstm_layer(x_slstm, slstm_state=None)
27
+ x = x + x_slstm
28
+
29
+ x_ffn = self.norm_ffn(x)
30
+ x_ffn = self.ffn(x_ffn)
31
+ x = x + x_ffn
32
+ return x
33
+
34
+
35
+ class FeedForward(nn.Module):
36
+ def __init__(self, embedding_dim: int, up_proj_dim: int):
37
+ super().__init__()
38
+ self.proj_up_gate = nn.Linear(embedding_dim, up_proj_dim, bias=False)
39
+ self.proj_up = nn.Linear(embedding_dim, up_proj_dim, bias=False)
40
+ self.proj_down = nn.Linear(up_proj_dim, embedding_dim, bias=False)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ x = F.silu(self.proj_up_gate(x)) * self.proj_up(x)
44
+ y = self.proj_down(x)
45
+ return y
46
+
47
+
48
+ class RMSNorm(nn.Module):
49
+ def __init__(self, num_features: int, eps: float = 1e-6):
50
+ super().__init__()
51
+ self.eps = eps
52
+ self.weight = nn.Parameter(torch.ones(num_features))
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ x = self._rms_normalize(x.float()).to(x.dtype)
56
+ x = x * self.weight
57
+ return x
58
+
59
+ def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
60
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
@@ -0,0 +1,188 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import warnings
5
+ from dataclasses import asdict, dataclass
6
+ from typing import Literal
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from tirex.util import dataclass_from_dict
13
+
14
+
15
+ @dataclass
16
+ class sLSTMBlockConfig:
17
+ embedding_dim: int
18
+ num_heads: int
19
+ ffn_proj_factor: float = 2.6667
20
+ num_states: int = 4
21
+ num_gates: int = 4
22
+
23
+ @property
24
+ def head_dim(self):
25
+ return self.embedding_dim // self.num_heads
26
+
27
+
28
+ class sLSTMCell(nn.Module):
29
+ def __init__(self, config: sLSTMBlockConfig, backend: Literal["torch", "cuda"]):
30
+ super().__init__()
31
+ assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
32
+ self.config = config
33
+ self.backend = backend
34
+
35
+ self._recurrent_kernel_ = nn.Parameter(
36
+ torch.empty((config.num_heads, config.head_dim, config.num_gates * config.head_dim), dtype=None)
37
+ )
38
+
39
+ self._bias_ = nn.Parameter(torch.empty((config.num_heads * config.num_gates * config.head_dim), dtype=None))
40
+
41
+ def forward(self, input: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
42
+ input = self._get_input(input)
43
+ state = self._get_state(input, state)
44
+
45
+ if self.backend == "torch":
46
+ all_states = self._impl_torch(input, state)
47
+ elif self.backend == "cuda":
48
+ all_states = self._impl_cuda(input, state)
49
+
50
+ state = all_states[:, -1]
51
+ output = self._permute_output(all_states[0][1:])
52
+ return output.to(input.dtype), state.to(input.dtype)
53
+
54
+ def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
55
+ input = input.to(dtype=torch.bfloat16)
56
+ state = state.to(dtype=torch.bfloat16)
57
+ recurrent_kernel = self._recurrent_kernel_.to(dtype=torch.bfloat16)
58
+ bias = self._bias_.to(dtype=torch.float32)
59
+
60
+ input = input.view(input.shape[0], input.shape[1], -1)
61
+ bias = (
62
+ bias.reshape(self.config.num_heads, self.config.num_gates, self.config.head_dim)
63
+ .permute(1, 0, 2)
64
+ .reshape(-1)
65
+ )
66
+
67
+ return slstm_forward(input, state, recurrent_kernel, bias)[0]
68
+
69
+ def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
70
+ if input.device.type != "cuda":
71
+ warnings.warn(
72
+ f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {input.device.type})!"
73
+ "This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
74
+ "If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
75
+ )
76
+
77
+ if not hasattr(self, "func"):
78
+ try:
79
+ from xlstm.blocks.slstm.cell import sLSTMCellConfig as sLSTMCellConfigCuda, sLSTMCellFuncGenerator
80
+ except ModuleNotFoundError:
81
+ raise ValueError(
82
+ 'xlstm package not found! To use the custom cuda backend, install the additional dependencies with: pip install -e ".[cuda]"'
83
+ )
84
+ cuda_config = dataclass_from_dict(
85
+ sLSTMCellConfigCuda, {**asdict(self.config), "hidden_size": self.config.embedding_dim}
86
+ )
87
+ self.func = sLSTMCellFuncGenerator(False, cuda_config)
88
+
89
+ input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
90
+
91
+ return self.func.apply(
92
+ False,
93
+ input.contiguous(),
94
+ state.contiguous(),
95
+ self._recurrent_kernel_.contiguous(),
96
+ self._bias_.contiguous(),
97
+ )
98
+
99
+ def _get_input(self, x: torch.Tensor) -> torch.Tensor:
100
+ assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
101
+ f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
102
+ )
103
+ return x.view(x.shape[0], x.shape[1], self.config.num_gates, self.config.num_heads, -1).permute(1, 0, 2, 3, 4)
104
+
105
+ def _get_state(self, input: torch.Tensor, state: torch.Tensor | None) -> torch.Tensor:
106
+ B = input.shape[1]
107
+ if state is None:
108
+ state = torch.zeros(
109
+ (self.config.num_states, B, self.config.embedding_dim),
110
+ dtype=input.dtype,
111
+ device=input.device,
112
+ )
113
+
114
+ assert state.shape == (self.config.num_states, B, self.config.embedding_dim)
115
+ return state
116
+
117
+ def _permute_output(self, output: torch.Tensor) -> torch.Tensor:
118
+ output = output.view(output.shape[0], output.shape[1], self.config.num_heads, self.config.head_dim)
119
+ return output.permute(1, 2, 0, 3)
120
+
121
+
122
+ def slstm_forward(
123
+ x: torch.Tensor, # [S, B, G*I]
124
+ states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
125
+ R: torch.Tensor, # [K, R*H, H] - K num_heads
126
+ b: torch.Tensor, # [T*H]
127
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ num_states = states.shape[0]
129
+ sequence_dim = x.shape[0]
130
+ # this only works for a fully-connected RNN, for a hin change this
131
+ num_gates_r = R.shape[2] // R.shape[1]
132
+ hidden_dim = R.shape[1] * R.shape[0]
133
+ batch_dim = x.shape[1]
134
+ num_heads = R.shape[0]
135
+
136
+ assert batch_dim == states.shape[1]
137
+ assert hidden_dim == states.shape[2]
138
+
139
+ states_all = torch.zeros(
140
+ [num_states, sequence_dim + 1, batch_dim, hidden_dim],
141
+ device=x.device,
142
+ dtype=x.dtype,
143
+ )
144
+ states_all[:, 0] = states
145
+ for i, Wx_t in enumerate(x.unbind(dim=0)):
146
+ Ry = (
147
+ states[0]
148
+ .reshape(batch_dim, num_heads, 1, -1)
149
+ .matmul(R.unsqueeze(0))
150
+ .reshape(batch_dim, num_heads, num_gates_r, -1)
151
+ .transpose(1, 2)
152
+ .reshape(batch_dim, -1)
153
+ )
154
+ sdtype = states.dtype
155
+ Wx_t, Ry, b, states = Wx_t.float(), Ry.float(), b.float(), states.float()
156
+ states, gates = slstm_forward_pointwise(Wx_t, Ry, b, states)
157
+ states = states.to(dtype=sdtype)
158
+ states_all[:, i + 1] = states
159
+
160
+ # shapes ([S, B, H], ([B,H], [B,H], [B,H])
161
+ return states_all, states
162
+
163
+
164
+ def slstm_forward_pointwise(
165
+ Wx: torch.Tensor, # dim [B, 4*H]
166
+ Ry: torch.Tensor, # dim [B, 4*H]
167
+ b: torch.Tensor, # dim [1, 4*H]
168
+ states: torch.Tensor, # dim [4, B, H]
169
+ ) -> tuple[torch.Tensor, torch.Tensor]:
170
+ raw = Wx + Ry + b
171
+ y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
172
+
173
+ iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
174
+ # with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
175
+ logfplusm = m + F.logsigmoid(fraw)
176
+ if torch.all(n == 0.0):
177
+ mnew = iraw
178
+ else:
179
+ mnew = torch.max(iraw, logfplusm)
180
+ ogate = torch.sigmoid(oraw)
181
+ igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw))
182
+ fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw))
183
+ cnew = fgate * c + igate * torch.tanh(zraw)
184
+ nnew = fgate * n + igate
185
+ ynew = ogate * cnew / nnew
186
+
187
+ # shapes ([B,H], [B,H], [B,H]), ([B,H],[B,H],[B,H],[B,H])
188
+ return torch.stack((ynew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
@@ -0,0 +1,67 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .cell import sLSTMBlockConfig, sLSTMCell
9
+
10
+
11
+ class sLSTMLayer(nn.Module):
12
+ def __init__(self, config: sLSTMBlockConfig, backend: str):
13
+ super().__init__()
14
+ self.config = config
15
+
16
+ in_features, num_heads = self.config.embedding_dim, self.config.num_heads
17
+ self.fgate = LinearHeadwiseExpand(in_features, num_heads)
18
+ self.igate = LinearHeadwiseExpand(in_features, num_heads)
19
+ self.zgate = LinearHeadwiseExpand(in_features, num_heads)
20
+ self.ogate = LinearHeadwiseExpand(in_features, num_heads)
21
+
22
+ self.slstm_cell = sLSTMCell(self.config, backend)
23
+ self.group_norm = MultiHeadLayerNorm(ndim=in_features)
24
+
25
+ def forward(self, x: torch.Tensor, slstm_state: torch.Tensor | None = None) -> torch.Tensor:
26
+ x_g = torch.cat((self.fgate(x), self.igate(x), self.zgate(x), self.ogate(x)), dim=-1)
27
+
28
+ y, slstm_state = self.slstm_cell(x_g, state=slstm_state)
29
+
30
+ return self.group_norm(y).transpose(1, 2).view(x.shape[0], x.shape[1], -1)
31
+
32
+
33
+ class LinearHeadwiseExpand(nn.Module):
34
+ def __init__(self, in_features, num_heads, expand_factor_up: float = 1):
35
+ super().__init__()
36
+ assert num_heads <= in_features, "num_heads must be <= in_features"
37
+ assert in_features % num_heads == 0, "in_features must be a multiple of num_heads"
38
+ self.num_heads = num_heads
39
+
40
+ out_features = round(expand_factor_up * in_features)
41
+ out_features_per_head = out_features // num_heads
42
+ self.weight = nn.Parameter(torch.empty(num_heads, out_features_per_head, in_features // num_heads))
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ shape = x.shape
46
+ x = x.view(*shape[:-1], self.num_heads, -1)
47
+ x = torch.einsum("...hd,hod->...ho", x, self.weight)
48
+ x = x.reshape(*shape[:-1], -1)
49
+ return x
50
+
51
+
52
+ class MultiHeadLayerNorm(nn.Module):
53
+ def __init__(self, ndim: int):
54
+ super().__init__()
55
+ self.weight = nn.Parameter(torch.zeros(ndim))
56
+
57
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
58
+ assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
59
+ B, NH, S, DH = input.shape
60
+
61
+ gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
62
+ gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
63
+ residual_weight = 1.0 + self.weight
64
+ out = F.group_norm(gn_in_2, num_groups=NH, weight=residual_weight)
65
+ # (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
66
+ out = out.view(B, S, NH, DH).transpose(1, 2)
67
+ return out