tirex-mirror 2025.8.28__py3-none-any.whl → 2025.9.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,7 +8,6 @@ import torch
8
8
 
9
9
  from .standard_adapter import ContextType, get_batches
10
10
 
11
-
12
11
  DEF_TARGET_COLUMN = "target"
13
12
  DEF_META_COLUMNS = ("start", "item_id")
14
13
 
tirex/base.py CHANGED
@@ -3,13 +3,18 @@
3
3
 
4
4
  import os
5
5
  from abc import ABC, abstractmethod
6
- from typing import TypeVar
6
+ from typing import Literal, TypeVar
7
7
 
8
+ import torch
8
9
  from huggingface_hub import hf_hub_download
9
10
 
10
11
  T = TypeVar("T", bound="PretrainedModel")
11
12
 
12
13
 
14
+ def skip_cuda():
15
+ return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
16
+
17
+
13
18
  def parse_hf_repo_id(path):
14
19
  parts = path.split("/")
15
20
  return "/".join(parts[0:2])
@@ -23,19 +28,30 @@ class PretrainedModel(ABC):
23
28
  cls.REGISTRY[cls.register_name()] = cls
24
29
 
25
30
  @classmethod
26
- def from_pretrained(cls: type[T], path, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> T:
31
+ def from_pretrained(
32
+ cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
33
+ ) -> T:
27
34
  if hf_kwargs is None:
28
35
  hf_kwargs = {}
29
36
  if ckp_kwargs is None:
30
37
  ckp_kwargs = {}
38
+ if device is None:
39
+ device = "cuda:0" if backend == "cuda" else "cpu"
31
40
  if os.path.exists(path):
32
41
  print("Loading weights from local directory")
33
42
  checkpoint_path = path
34
43
  else:
35
44
  repo_id = parse_hf_repo_id(path)
36
45
  checkpoint_path = hf_hub_download(repo_id=repo_id, filename="model.ckpt", **hf_kwargs)
37
- model = cls.load_from_checkpoint(checkpoint_path, map_location=device, **ckp_kwargs)
38
- model.after_load_from_checkpoint()
46
+
47
+ # load lightning checkpoint
48
+ checkpoint = torch.load(checkpoint_path, map_location=device, **ckp_kwargs, weights_only=True)
49
+ model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
50
+ model.on_load_checkpoint(checkpoint)
51
+ model.load_state_dict(checkpoint["state_dict"])
52
+
53
+ if backend == "cuda":
54
+ model = model.to(device)
39
55
  return model
40
56
 
41
57
  @classmethod
@@ -43,17 +59,22 @@ class PretrainedModel(ABC):
43
59
  def register_name(cls) -> str:
44
60
  pass
45
61
 
46
- def after_load_from_checkpoint(self):
62
+ def on_load_checkpoint(self):
47
63
  pass
48
64
 
49
65
 
50
- def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=None) -> PretrainedModel:
66
+ def load_model(
67
+ path: str,
68
+ device: str | None = None,
69
+ backend: Literal["torch", "cuda"] | None = None,
70
+ hf_kwargs=None,
71
+ ckp_kwargs=None,
72
+ ) -> PretrainedModel:
51
73
  """Loads a TiRex model. This function attempts to load the specified model.
52
74
 
53
75
  Args:
54
76
  path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
55
77
  device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
56
- If you want to use "cpu" you need to deactivate the sLSTM CUDA kernels (check repository FAQ!).
57
78
  hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
58
79
  ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
59
80
 
@@ -63,6 +84,11 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
63
84
  Examples:
64
85
  model: ForecastModel = load_model("NX-AI/TiRex")
65
86
  """
87
+
88
+ if backend is None:
89
+ backend = "torch" if skip_cuda() else "cuda"
90
+ assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
91
+
66
92
  try:
67
93
  _, model_id = parse_hf_repo_id(path).split("/")
68
94
  except:
@@ -70,4 +96,5 @@ def load_model(path: str, device: str = "cuda:0", hf_kwargs=None, ckp_kwargs=Non
70
96
  model_cls = PretrainedModel.REGISTRY.get(model_id, None)
71
97
  if model_cls is None:
72
98
  raise ValueError(f"Invalid model id {model_id}")
73
- return model_cls.from_pretrained(path, device=device, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
99
+
100
+ return model_cls.from_pretrained(path, device=device, backend=backend, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
@@ -0,0 +1,84 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from dataclasses import dataclass
5
+
6
+ import torch
7
+
8
+
9
+ class StandardScaler:
10
+ def __init__(self, eps: float = 1e-5, nan_loc: float = 0.0):
11
+ self.eps = eps
12
+ self.nan_loc = nan_loc
13
+
14
+ def scale(
15
+ self,
16
+ x: torch.Tensor,
17
+ loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
18
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
19
+ if loc_scale is None:
20
+ loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
21
+ scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
22
+ scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
23
+ else:
24
+ loc, scale = loc_scale
25
+
26
+ return ((x - loc) / scale), (loc, scale)
27
+
28
+ def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
29
+ loc, scale = loc_scale
30
+ return x * scale + loc
31
+
32
+
33
+ class Patcher:
34
+ def __init__(self, patch_size: int, patch_stride: int, left_pad: bool):
35
+ self.patch_size = patch_size
36
+ self.patch_stride = patch_stride
37
+ self.left_pad = left_pad
38
+ assert self.patch_size % self.patch_stride == 0
39
+
40
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
41
+ assert x.ndim == 2
42
+ length = x.shape[-1]
43
+
44
+ if length < self.patch_size or (length % self.patch_stride != 0):
45
+ if length < self.patch_size:
46
+ padding_size = (
47
+ *x.shape[:-1],
48
+ self.patch_size - (length % self.patch_size),
49
+ )
50
+ else:
51
+ padding_size = (
52
+ *x.shape[:-1],
53
+ self.patch_stride - (length % self.patch_stride),
54
+ )
55
+ padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
56
+ if self.left_pad:
57
+ x = torch.concat((padding, x), dim=-1)
58
+ else:
59
+ x = torch.concat((x, padding), dim=-1)
60
+
61
+ return x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
62
+
63
+
64
+ @dataclass
65
+ class PatchedUniTokenizerState:
66
+ scale_state: float
67
+
68
+
69
+ class PatchedUniTokenizer:
70
+ def __init__(self, patch_size: int, patch_stride: int | None = None, scaler: StandardScaler | None = None):
71
+ self.patch_size = patch_size
72
+ self.patch_stride = patch_size if patch_stride is None else patch_stride
73
+ self.scaler = StandardScaler() if scaler is None else scaler
74
+ self.patcher = Patcher(self.patch_size, self.patch_stride, left_pad=True)
75
+
76
+ def context_input_transform(self, data: torch.Tensor):
77
+ assert data.ndim == 2
78
+ data, scale_state = self.scaler.scale(data)
79
+ return self.patcher(data), PatchedUniTokenizerState(scale_state)
80
+
81
+ def output_transform(self, data: torch.Tensor, tokenizer_state: PatchedUniTokenizerState):
82
+ data_shape = data.shape
83
+ data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state.scale_state).view(*data_shape)
84
+ return data
@@ -0,0 +1,60 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from tirex.models.slstm.layer import sLSTMBlockConfig, sLSTMLayer
9
+ from tirex.util import round_up_to_next_multiple_of
10
+
11
+
12
+ class sLSTMBlock(nn.Module):
13
+ def __init__(self, config: sLSTMBlockConfig, backend: str):
14
+ super().__init__()
15
+ self.config = config
16
+ self.norm_slstm = RMSNorm(config.embedding_dim)
17
+ self.slstm_layer = sLSTMLayer(config, backend)
18
+ self.norm_ffn = RMSNorm(config.embedding_dim)
19
+
20
+ up_proj_dim = round_up_to_next_multiple_of(config.embedding_dim * config.ffn_proj_factor, 64)
21
+ self.ffn = FeedForward(config.embedding_dim, up_proj_dim)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ x_slstm = self.norm_slstm(x)
25
+
26
+ x_slstm = self.slstm_layer(x_slstm, slstm_state=None)
27
+ x = x + x_slstm
28
+
29
+ x_ffn = self.norm_ffn(x)
30
+ x_ffn = self.ffn(x_ffn)
31
+ x = x + x_ffn
32
+ return x
33
+
34
+
35
+ class FeedForward(nn.Module):
36
+ def __init__(self, embedding_dim: int, up_proj_dim: int):
37
+ super().__init__()
38
+ self.proj_up_gate = nn.Linear(embedding_dim, up_proj_dim, bias=False)
39
+ self.proj_up = nn.Linear(embedding_dim, up_proj_dim, bias=False)
40
+ self.proj_down = nn.Linear(up_proj_dim, embedding_dim, bias=False)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ x = F.silu(self.proj_up_gate(x)) * self.proj_up(x)
44
+ y = self.proj_down(x)
45
+ return y
46
+
47
+
48
+ class RMSNorm(nn.Module):
49
+ def __init__(self, num_features: int, eps: float = 1e-6):
50
+ super().__init__()
51
+ self.eps = eps
52
+ self.weight = nn.Parameter(torch.ones(num_features))
53
+
54
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ x = self._rms_normalize(x.float()).to(x.dtype)
56
+ x = x * self.weight
57
+ return x
58
+
59
+ def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
60
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
@@ -0,0 +1,188 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import warnings
5
+ from dataclasses import asdict, dataclass
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from tirex.util import dataclass_from_dict
12
+
13
+
14
+ @dataclass
15
+ class sLSTMBlockConfig:
16
+ embedding_dim: int
17
+ num_heads: int
18
+ num_blocks: int
19
+ ffn_proj_factor: float = 2.6667
20
+
21
+ num_states: int = 4 # this is for the sLSTM, a standard LSTM has 2
22
+ num_gates: int = 4
23
+
24
+ @property
25
+ def head_dim(self):
26
+ return self.embedding_dim // self.num_heads
27
+
28
+
29
+ class sLSTMCell(nn.Module):
30
+ def __init__(self, config: sLSTMBlockConfig, backend: str):
31
+ super().__init__()
32
+ self.config = config
33
+ self.backend = backend
34
+
35
+ self._recurrent_kernel_ = nn.Parameter(
36
+ torch.empty((config.num_heads, config.head_dim, config.num_gates * config.head_dim), dtype=None)
37
+ )
38
+
39
+ self._bias_ = nn.Parameter(torch.empty((config.num_heads * config.num_gates * config.head_dim), dtype=None))
40
+
41
+ def forward(self, input: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
42
+ input = self._get_input(input)
43
+ state = self._get_state(input, state)
44
+
45
+ if self.backend == "torch":
46
+ all_states = self._impl_torch(input, state)
47
+ elif self.backend == "cuda":
48
+ all_states = self._impl_cuda(input, state)
49
+
50
+ state = all_states[:, -1]
51
+ output = self._permute_output(all_states[0][1:])
52
+ return output.to(input.dtype), state.to(input.dtype)
53
+
54
+ def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
55
+ input = input.to(dtype=torch.bfloat16)
56
+ state = state.to(dtype=torch.bfloat16)
57
+ recurrent_kernel = self._recurrent_kernel_.to(dtype=torch.bfloat16)
58
+ bias = self._bias_.to(dtype=torch.float32)
59
+
60
+ input = input.view(input.shape[0], input.shape[1], -1)
61
+ bias = (
62
+ bias.reshape(self.config.num_heads, self.config.num_gates, self.config.head_dim)
63
+ .permute(1, 0, 2)
64
+ .reshape(-1)
65
+ )
66
+
67
+ return slstm_forward(input, state, recurrent_kernel, bias)[0]
68
+
69
+ def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
70
+ if input.device.type != "cuda":
71
+ warnings.warn(
72
+ f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {input.device.type})!"
73
+ "This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
74
+ "If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
75
+ )
76
+
77
+ if not hasattr(self, "func"):
78
+ try:
79
+ from xlstm.blocks.slstm.cell import sLSTMCellConfig as sLSTMCellConfigCuda, sLSTMCellFuncGenerator
80
+ except ModuleNotFoundError:
81
+ raise ValueError(
82
+ 'xlstm package not found! To use the custom cuda backend, install the additional dependencies with: pip install -e ".[cuda]"'
83
+ )
84
+ cuda_config = dataclass_from_dict(
85
+ sLSTMCellConfigCuda, {**asdict(self.config), "hidden_size": self.config.embedding_dim}
86
+ )
87
+ self.func = sLSTMCellFuncGenerator(False, cuda_config)
88
+
89
+ input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
90
+
91
+ return self.func.apply(
92
+ False,
93
+ input.contiguous(),
94
+ state.contiguous(),
95
+ self._recurrent_kernel_.contiguous(),
96
+ self._bias_.contiguous(),
97
+ )
98
+
99
+ def _get_input(self, x: torch.Tensor) -> torch.Tensor:
100
+ assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
101
+ f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
102
+ )
103
+ return x.view(x.shape[0], x.shape[1], self.config.num_gates, self.config.num_heads, -1).permute(1, 0, 2, 3, 4)
104
+
105
+ def _get_state(self, input: torch.Tensor, state: torch.Tensor | None) -> torch.Tensor:
106
+ B = input.shape[1]
107
+ if state is None:
108
+ state = torch.zeros(
109
+ (self.config.num_states, B, self.config.embedding_dim),
110
+ dtype=input.dtype,
111
+ device=input.device,
112
+ )
113
+
114
+ assert state.shape == (self.config.num_states, B, self.config.embedding_dim)
115
+ return state
116
+
117
+ def _permute_output(self, output: torch.Tensor) -> torch.Tensor:
118
+ output = output.view(output.shape[0], output.shape[1], self.config.num_heads, self.config.head_dim)
119
+ return output.permute(1, 2, 0, 3)
120
+
121
+
122
+ def slstm_forward(
123
+ x: torch.Tensor, # [S, B, G*I]
124
+ states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
125
+ R: torch.Tensor, # [K, R*H, H] - K num_heads
126
+ b: torch.Tensor, # [T*H]
127
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ num_states = states.shape[0]
129
+ sequence_dim = x.shape[0]
130
+ # this only works for a fully-connected RNN, for a hin change this
131
+ num_gates_r = R.shape[2] // R.shape[1]
132
+ hidden_dim = R.shape[1] * R.shape[0]
133
+ batch_dim = x.shape[1]
134
+ num_heads = R.shape[0]
135
+
136
+ assert batch_dim == states.shape[1]
137
+ assert hidden_dim == states.shape[2]
138
+
139
+ states_all = torch.zeros(
140
+ [num_states, sequence_dim + 1, batch_dim, hidden_dim],
141
+ device=x.device,
142
+ dtype=x.dtype,
143
+ )
144
+ states_all[:, 0] = states
145
+ for i, Wx_t in enumerate(x.unbind(dim=0)):
146
+ Ry = (
147
+ states[0]
148
+ .reshape(batch_dim, num_heads, 1, -1)
149
+ .matmul(R.unsqueeze(0))
150
+ .reshape(batch_dim, num_heads, num_gates_r, -1)
151
+ .transpose(1, 2)
152
+ .reshape(batch_dim, -1)
153
+ )
154
+ sdtype = states.dtype
155
+ Wx_t, Ry, b, states = Wx_t.float(), Ry.float(), b.float(), states.float()
156
+ states, gates = slstm_forward_pointwise(Wx_t, Ry, b, states)
157
+ states = states.to(dtype=sdtype)
158
+ states_all[:, i + 1] = states
159
+
160
+ # shapes ([S, B, H], ([B,H], [B,H], [B,H])
161
+ return states_all, states
162
+
163
+
164
+ def slstm_forward_pointwise(
165
+ Wx: torch.Tensor, # dim [B, 4*H]
166
+ Ry: torch.Tensor, # dim [B, 4*H]
167
+ b: torch.Tensor, # dim [1, 4*H]
168
+ states: torch.Tensor, # dim [4, B, H]
169
+ ) -> tuple[torch.Tensor, torch.Tensor]:
170
+ raw = Wx + Ry + b
171
+ y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
172
+
173
+ iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
174
+ # with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
175
+ logfplusm = m + F.logsigmoid(fraw)
176
+ if torch.all(n == 0.0):
177
+ mnew = iraw
178
+ else:
179
+ mnew = torch.max(iraw, logfplusm)
180
+ ogate = torch.sigmoid(oraw)
181
+ igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw))
182
+ fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw))
183
+ cnew = fgate * c + igate * torch.tanh(zraw)
184
+ nnew = fgate * n + igate
185
+ ynew = ogate * cnew / nnew
186
+
187
+ # shapes ([B,H], [B,H], [B,H]), ([B,H],[B,H],[B,H],[B,H])
188
+ return torch.stack((ynew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
@@ -0,0 +1,67 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from .cell import sLSTMBlockConfig, sLSTMCell
9
+
10
+
11
+ class sLSTMLayer(nn.Module):
12
+ def __init__(self, config: sLSTMBlockConfig, backend: str):
13
+ super().__init__()
14
+ self.config = config
15
+
16
+ in_features, num_heads = self.config.embedding_dim, self.config.num_heads
17
+ self.fgate = LinearHeadwiseExpand(in_features, num_heads)
18
+ self.igate = LinearHeadwiseExpand(in_features, num_heads)
19
+ self.zgate = LinearHeadwiseExpand(in_features, num_heads)
20
+ self.ogate = LinearHeadwiseExpand(in_features, num_heads)
21
+
22
+ self.slstm_cell = sLSTMCell(self.config, backend)
23
+ self.group_norm = MultiHeadLayerNorm(ndim=in_features)
24
+
25
+ def forward(self, x: torch.Tensor, slstm_state: torch.Tensor | None = None) -> torch.Tensor:
26
+ x_g = torch.cat((self.fgate(x), self.igate(x), self.zgate(x), self.ogate(x)), dim=-1)
27
+
28
+ y, slstm_state = self.slstm_cell(x_g, state=slstm_state)
29
+
30
+ return self.group_norm(y).transpose(1, 2).view(x.shape[0], x.shape[1], -1)
31
+
32
+
33
+ class LinearHeadwiseExpand(nn.Module):
34
+ def __init__(self, in_features, num_heads, expand_factor_up: float = 1):
35
+ super().__init__()
36
+ assert num_heads <= in_features, "num_heads must be <= in_features"
37
+ assert in_features % num_heads == 0, "in_features must be a multiple of num_heads"
38
+ self.num_heads = num_heads
39
+
40
+ out_features = round(expand_factor_up * in_features)
41
+ out_features_per_head = out_features // num_heads
42
+ self.weight = nn.Parameter(torch.empty(num_heads, out_features_per_head, in_features // num_heads))
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ shape = x.shape
46
+ x = x.view(*shape[:-1], self.num_heads, -1)
47
+ x = torch.einsum("...hd,hod->...ho", x, self.weight)
48
+ x = x.reshape(*shape[:-1], -1)
49
+ return x
50
+
51
+
52
+ class MultiHeadLayerNorm(nn.Module):
53
+ def __init__(self, ndim: int):
54
+ super().__init__()
55
+ self.weight = nn.Parameter(torch.zeros(ndim))
56
+
57
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
58
+ assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
59
+ B, NH, S, DH = input.shape
60
+
61
+ gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
62
+ gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
63
+ residual_weight = 1.0 + self.weight
64
+ out = F.group_norm(gn_in_2, num_groups=NH, weight=residual_weight)
65
+ # (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
66
+ out = out.view(B, S, NH, DH).transpose(1, 2)
67
+ return out
tirex/models/tirex.py CHANGED
@@ -2,18 +2,17 @@
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
4
  import logging
5
- import warnings
6
- from contextlib import redirect_stdout
7
5
  from dataclasses import dataclass
8
6
 
9
- import lightning as L
10
7
  import torch
11
- from dacite import Config, from_dict
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
12
10
 
11
+ from ..api_adapter.forecast import ForecastModel
13
12
  from ..base import PretrainedModel
14
- from .components import PatchedUniTokenizer, ResidualBlock, StreamToLogger
15
- from .mixed_stack import skip_cuda, xLSTMMixedLargeBlockStack, xLSTMMixedLargeConfig
16
- from .predict_utils import TensorQuantileUniPredictMixin
13
+ from ..util import dataclass_from_dict
14
+ from .patcher import PatchedUniTokenizer
15
+ from .slstm.block import RMSNorm, sLSTMBlock, sLSTMBlockConfig
17
16
 
18
17
  LOGGER = logging.getLogger()
19
18
 
@@ -25,113 +24,70 @@ class TiRexZeroConfig:
25
24
  quantiles: list[float]
26
25
  block_kwargs: dict
27
26
  input_ff_dim: int
27
+ train_ctx_len: int
28
+ nan_mask_value: int = 0
28
29
 
29
30
 
30
- class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixin):
31
- def __init__(self, model_config: dict, train_ctx_len=None):
31
+ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
32
+ def __init__(self, backend, model_config: TiRexZeroConfig, train_ctx_len=None):
32
33
  super().__init__()
33
- self.model_config: TiRexZeroConfig = from_dict(TiRexZeroConfig, model_config, config=Config(strict=True))
34
- assert self.model_config.input_patch_size == self.model_config.output_patch_size
35
- self.train_ctx_len = train_ctx_len
34
+ self.config = TiRexZeroConfig(**model_config, train_ctx_len=train_ctx_len, nan_mask_value=0)
35
+ assert self.config.input_patch_size == self.config.output_patch_size
36
+ self.backend = backend
36
37
 
37
- # Block Stack
38
- self.nan_mask_value = 0
39
- self.block_stack, resolved_config = self.init_block(self.model_config.block_kwargs)
40
- self.model_config.block_kwargs = resolved_config
38
+ self.tokenizer = PatchedUniTokenizer(patch_size=self.config.input_patch_size)
41
39
 
42
- # Input Layer
40
+ block_config = dataclass_from_dict(sLSTMBlockConfig, self.config.block_kwargs)
43
41
  self.input_patch_embedding = ResidualBlock(
44
- in_dim=self.model_config.input_patch_size * 2,
45
- h_dim=self.model_config.input_ff_dim,
46
- out_dim=self.model_config.block_kwargs.embedding_dim,
42
+ in_dim=self.config.input_patch_size * 2,
43
+ h_dim=self.config.input_ff_dim,
44
+ out_dim=block_config.embedding_dim,
47
45
  )
48
- self.tokenizer = PatchedUniTokenizer(
49
- patch_size=self.model_config.input_patch_size,
46
+
47
+ self.blocks = nn.ModuleList(
48
+ [sLSTMBlock(block_config, backend=self.backend) for i in range(block_config.num_blocks)]
50
49
  )
51
50
 
52
- # Output Layer
53
- self.num_quantiles = len(self.model_config.quantiles)
54
- quantiles = torch.tensor(self.model_config.quantiles)
55
- self.register_buffer("quantiles", quantiles, persistent=False)
51
+ self.out_norm = RMSNorm(block_config.embedding_dim)
56
52
 
57
53
  self.output_patch_embedding = ResidualBlock(
58
- in_dim=self.model_config.block_kwargs.embedding_dim,
59
- h_dim=self.model_config.input_ff_dim,
60
- out_dim=self.num_quantiles * self.model_config.output_patch_size,
54
+ in_dim=block_config.embedding_dim,
55
+ h_dim=self.config.input_ff_dim,
56
+ out_dim=len(self.config.quantiles) * self.config.output_patch_size,
61
57
  )
62
58
 
63
- self.save_hyperparameters()
64
-
65
59
  @classmethod
66
60
  def register_name(cls):
67
61
  return "TiRex"
68
62
 
69
- def init_block(self, block_kwargs):
70
- config = from_dict(xLSTMMixedLargeConfig, block_kwargs)
71
- log_redirect = StreamToLogger(LOGGER, logging.INFO)
72
- with redirect_stdout(log_redirect): # avoid excessive print statements of sLSTM compile
73
- model = xLSTMMixedLargeBlockStack(config)
74
- return model, config
75
-
76
- @property
77
- def quantiles(self):
78
- return self.model.quantiles
79
-
80
- def _forward_model_tokenized(
63
+ def _forecast_quantiles(
81
64
  self,
82
- input_token,
83
- input_mask=None,
84
- rollouts=1,
85
- ):
86
- input_mask = (
87
- input_mask.to(input_token.dtype)
88
- if input_mask is not None
89
- else torch.isnan(input_token).logical_not().to(input_token.dtype)
90
- )
91
- assert rollouts >= 1
92
- bs, numb_ctx_token, token_dim = input_token.shape
93
- if rollouts > 1:
94
- input_token = torch.cat(
95
- (
96
- input_token,
97
- torch.full(
98
- (bs, rollouts - 1, token_dim),
99
- fill_value=torch.nan,
100
- device=input_token.device,
101
- dtype=input_token.dtype,
102
- ),
103
- ),
104
- dim=1,
105
- )
106
- input_mask = torch.cat(
107
- (
108
- input_mask,
109
- torch.full(
110
- (bs, rollouts - 1, token_dim),
111
- fill_value=False,
112
- device=input_mask.device,
113
- dtype=input_mask.dtype,
114
- ),
115
- ),
116
- dim=1,
117
- )
118
- input_token = torch.nan_to_num(input_token, nan=self.nan_mask_value)
119
- input_embeds = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
120
-
121
- # hidden_states = []
122
- # for rollout in range(rollout):
123
- x = self.block_stack(input_embeds)
124
- if isinstance(x, tuple):
125
- hidden_states = x[0]
65
+ context: torch.Tensor,
66
+ prediction_length: int | None = None,
67
+ quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
68
+ output_device: str = "cpu",
69
+ auto_cast: bool = False,
70
+ **predict_kwargs,
71
+ ) -> tuple[torch.Tensor, torch.Tensor]:
72
+ device = self.input_patch_embedding.hidden_layer.weight.device
73
+ context = context.to(device)
74
+
75
+ with torch.autocast(device_type=device.type, enabled=auto_cast):
76
+ predictions = self._forecast_tensor(
77
+ context=context, prediction_length=prediction_length, **predict_kwargs
78
+ ).detach()
79
+ predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
80
+
81
+ training_quantile_levels = self.config.quantiles
82
+
83
+ if set(quantile_levels).issubset(set(training_quantile_levels)):
84
+ quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
126
85
  else:
127
- hidden_states = x
86
+ quantiles = self._interpolate_quantiles(predictions, quantile_levels)
128
87
 
129
- quantile_preds = self.output_patch_embedding(hidden_states)
130
- quantile_preds = torch.unflatten(quantile_preds, -1, (self.num_quantiles, self.model_config.output_patch_size))
131
- quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
132
- # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
133
-
134
- return quantile_preds, hidden_states
88
+ # median as mean
89
+ mean = predictions[:, :, training_quantile_levels.index(0.5)]
90
+ return quantiles, mean
135
91
 
136
92
  @torch.inference_mode()
137
93
  def _forecast_tensor(
@@ -146,13 +102,10 @@ class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixi
146
102
  prediction_length = self.tokenizer.patch_size
147
103
  remaining = -(prediction_length // -self.tokenizer.patch_size)
148
104
  if max_context is None:
149
- max_context = self.train_ctx_len
150
- min_context = max(self.train_ctx_len, max_context)
105
+ max_context = self.config.train_ctx_len
106
+ min_context = max(self.config.train_ctx_len, max_context)
151
107
 
152
- context = context.to(
153
- device=self.device,
154
- dtype=torch.float32,
155
- )
108
+ context = context.to(dtype=torch.float32)
156
109
  while remaining > 0:
157
110
  if context.shape[-1] > max_context:
158
111
  context = context[..., -max_context:]
@@ -181,51 +134,92 @@ class TiRexZero(L.LightningModule, PretrainedModel, TensorQuantileUniPredictMixi
181
134
 
182
135
  context = torch.cat([context, torch.full_like(prediction[:, 0, :], fill_value=torch.nan)], dim=-1)
183
136
 
184
- return torch.cat(predictions, dim=-1)[..., :prediction_length].to(
185
- dtype=torch.float32,
186
- )
137
+ return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32)
187
138
 
188
- def on_load_checkpoint(self, checkpoint: dict) -> None:
189
- state_dict = checkpoint["state_dict"]
190
- load_vanilla_kernel = skip_cuda()
191
- if load_vanilla_kernel:
192
- warnings.warn(
193
- "You use TiRex without sLSTM CUDA kernels! This might slow down the model considerably and might degrade forecasting results!"
194
- "Set the environment variable TIREX_NO_CUDA to 0 to avoid this!"
139
+ def _forward_model_tokenized(
140
+ self,
141
+ input_token: torch.Tensor,
142
+ input_mask=None,
143
+ rollouts=1,
144
+ ):
145
+ input_mask = (
146
+ input_mask.to(input_token.dtype)
147
+ if input_mask is not None
148
+ else torch.isnan(input_token).logical_not().to(input_token.dtype)
149
+ )
150
+ assert rollouts >= 1
151
+ bs, numb_ctx_token, token_dim = input_token.shape
152
+ if rollouts > 1:
153
+ input_token_rollout_pad = torch.full(
154
+ (bs, rollouts - 1, token_dim),
155
+ fill_value=torch.nan,
156
+ device=input_token.device,
157
+ dtype=input_token.dtype,
195
158
  )
196
- block_kwargs = self.model_config.block_kwargs
197
- head_dim = block_kwargs.embedding_dim // block_kwargs.num_heads
198
- num_gates = 4
199
- new_state_dict = {}
200
- for k, v in state_dict.items():
201
- if "slstm_layer.slstm_cell._recurrent_kernel_" in k:
202
- new_state_dict[k] = (
203
- v.reshape(
204
- block_kwargs.num_heads,
205
- head_dim,
206
- num_gates,
207
- head_dim,
208
- )
209
- .permute(0, 2, 3, 1)
210
- .reshape(
211
- block_kwargs.num_heads,
212
- num_gates * head_dim,
213
- head_dim,
214
- )
215
- )
216
- # new_state_dict[k] = v.permute(0, 2, 1)
217
- elif "slstm_layer.slstm_cell._bias_" in k:
218
- new_state_dict[k] = (
219
- v.reshape(block_kwargs.num_heads, num_gates, head_dim).permute(1, 0, 2).reshape(-1)
220
- )
221
- else:
222
- new_state_dict[k] = v
223
- checkpoint["state_dict"] = new_state_dict
224
-
225
- def after_load_from_checkpoint(self):
226
- if not skip_cuda() and self.device.type != "cuda":
227
- warnings.warn(
228
- f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {self.device.type})!"
229
- "This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
230
- "If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
159
+ input_token = torch.cat((input_token, input_token_rollout_pad), dim=1)
160
+ input_mask_rollout_pad = torch.full(
161
+ (bs, rollouts - 1, token_dim),
162
+ fill_value=False,
163
+ device=input_mask.device,
164
+ dtype=input_mask.dtype,
165
+ )
166
+ input_mask = torch.cat((input_mask, input_mask_rollout_pad), dim=1)
167
+
168
+ input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
169
+
170
+ hidden_states = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
171
+
172
+ for block in self.blocks:
173
+ hidden_states = block(hidden_states)
174
+
175
+ hidden_states = self.out_norm(hidden_states)
176
+
177
+ quantile_preds = self.output_patch_embedding(hidden_states)
178
+ quantile_preds = torch.unflatten(
179
+ quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
180
+ )
181
+ quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
182
+ # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
183
+
184
+ return quantile_preds, hidden_states
185
+
186
+ def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
187
+ training_quantile_levels = self.config.quantiles
188
+ if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
189
+ logging.warning(
190
+ f"Requested quantile levels ({quantile_levels}) fall outside the range of "
191
+ f"quantiles the model was trained on ({training_quantile_levels}). "
192
+ "Predictions for out-of-range quantiles will be clamped to the nearest "
193
+ "boundary of the trained quantiles (i.e., minimum or maximum trained level). "
194
+ "This can significantly impact prediction accuracy, especially for extreme quantiles. "
231
195
  )
196
+
197
+ augmented_predictions = torch.cat(
198
+ [predictions[..., [0]], predictions, predictions[..., [-1]]],
199
+ dim=-1,
200
+ )
201
+ quantiles = torch.quantile(
202
+ augmented_predictions,
203
+ q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
204
+ dim=-1,
205
+ ).permute(1, 2, 0)
206
+ return quantiles
207
+
208
+ def on_load_checkpoint(self, checkpoint: dict) -> None:
209
+ # rename keys of state_dict, because the block_stack was moved directly into the tirex model
210
+ checkpoint["state_dict"] = {k.replace("block_stack.", ""): v for k, v in checkpoint["state_dict"].items()}
211
+
212
+
213
+ class ResidualBlock(nn.Module):
214
+ def __init__(self, in_dim: int, h_dim: int, out_dim: int) -> None:
215
+ super().__init__()
216
+ self.hidden_layer = nn.Linear(in_dim, h_dim)
217
+ self.output_layer = nn.Linear(h_dim, out_dim)
218
+ self.residual_layer = nn.Linear(in_dim, out_dim)
219
+
220
+ def forward(self, x: torch.Tensor):
221
+ hid = F.relu(self.hidden_layer(x))
222
+ out = self.output_layer(hid)
223
+ res = self.residual_layer(x)
224
+ out = out + res
225
+ return out
tirex/util.py ADDED
@@ -0,0 +1,13 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from dataclasses import fields
5
+
6
+
7
+ def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
8
+ return int(((x + multiple_of - 1) // multiple_of) * multiple_of)
9
+
10
+
11
+ def dataclass_from_dict(cls, dict: dict):
12
+ class_fields = {f.name for f in fields(cls)}
13
+ return cls(**{k: v for k, v in dict.items() if k in class_fields})
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.8.28
3
+ Version: 2025.9.9
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -65,17 +65,17 @@ License-File: LICENSE_MIRROR.txt
65
65
  License-File: NOTICE.txt
66
66
  Requires-Dist: torch
67
67
  Requires-Dist: torchvision
68
- Requires-Dist: xlstm
69
68
  Requires-Dist: einops
70
- Requires-Dist: ninja
71
69
  Requires-Dist: huggingface-hub
72
- Requires-Dist: lightning
73
70
  Requires-Dist: numpy
74
71
  Requires-Dist: pandas
75
- Requires-Dist: dacite
76
72
  Requires-Dist: tqdm
73
+ Provides-Extra: cuda
74
+ Requires-Dist: xlstm; extra == "cuda"
75
+ Requires-Dist: ninja; extra == "cuda"
77
76
  Provides-Extra: notebooks
78
77
  Requires-Dist: ipykernel; extra == "notebooks"
78
+ Requires-Dist: matplotlib; extra == "notebooks"
79
79
  Provides-Extra: gluonts
80
80
  Requires-Dist: gluonts; extra == "gluonts"
81
81
  Provides-Extra: hfdataset
@@ -83,7 +83,10 @@ Requires-Dist: datasets; extra == "hfdataset"
83
83
  Provides-Extra: test
84
84
  Requires-Dist: fev; extra == "test"
85
85
  Provides-Extra: all
86
+ Requires-Dist: xlstm; extra == "all"
87
+ Requires-Dist: ninja; extra == "all"
86
88
  Requires-Dist: ipykernel; extra == "all"
89
+ Requires-Dist: matplotlib; extra == "all"
87
90
  Requires-Dist: gluonts; extra == "all"
88
91
  Requires-Dist: datasets; extra == "all"
89
92
  Requires-Dist: fev; extra == "all"
@@ -0,0 +1,21 @@
1
+ tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
2
+ tirex/base.py,sha256=ODUyhYFR33ZYffu7dxDwsb9m2IiZAnGHIXvA81crbjQ,3245
3
+ tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
4
+ tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
5
+ tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
6
+ tirex/api_adapter/gluon.py,sha256=faiYyn0kBBVQKbpWqrVoyylxZUrmr-qce66twpguVds,1827
7
+ tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,1377
8
+ tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
9
+ tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
10
+ tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
11
+ tirex/models/tirex.py,sha256=dclEckb6CmvESeX_LwT2kaCNTB7deTFovIOQUIFF5J8,9117
12
+ tirex/models/slstm/block.py,sha256=DCOxmLQUb7HRO6wXTZMK4ICUI5LFpo7NC5a28oM-Vsc,2104
13
+ tirex/models/slstm/cell.py,sha256=4_pQcXOOT16aEpKIi4A-yEnj4qKK6pFyFADD2nGPzGc,7366
14
+ tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
15
+ tirex_mirror-2025.9.9.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
+ tirex_mirror-2025.9.9.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
+ tirex_mirror-2025.9.9.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
+ tirex_mirror-2025.9.9.dist-info/METADATA,sha256=u9C_cIb8FtaHUep1XrFTeI7UAsVRtNJt2VSQo7420Vo,11200
19
+ tirex_mirror-2025.9.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ tirex_mirror-2025.9.9.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
+ tirex_mirror-2025.9.9.dist-info/RECORD,,
@@ -1,147 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
-
5
- from dataclasses import dataclass, field
6
- from typing import Any
7
-
8
- import torch
9
-
10
- SCALER_STATE = "scaler_state"
11
-
12
-
13
- class ResidualBlock(torch.nn.Module):
14
- def __init__(
15
- self,
16
- in_dim: int,
17
- h_dim: int,
18
- out_dim: int,
19
- dropout: float = 0,
20
- ) -> None:
21
- super().__init__()
22
- self.dropout = torch.nn.Dropout(dropout)
23
- self.hidden_layer = torch.nn.Linear(in_dim, h_dim)
24
- self.output_layer = torch.nn.Linear(h_dim, out_dim)
25
- self.residual_layer = torch.nn.Linear(in_dim, out_dim)
26
- self.act = torch.nn.ReLU()
27
-
28
- def forward(self, x: torch.Tensor):
29
- hid = self.act(self.hidden_layer(x))
30
- out = self.output_layer(hid)
31
- res = self.residual_layer(x)
32
- out = out + res
33
- return out
34
-
35
-
36
- @dataclass
37
- class StandardScaler:
38
- eps: float = 1e-5
39
- nan_loc: float = 0.0
40
-
41
- def scale(
42
- self,
43
- x: torch.Tensor,
44
- loc_scale: tuple[torch.Tensor, torch.Tensor] | None = None,
45
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
46
- if loc_scale is None:
47
- loc = torch.nan_to_num(torch.nanmean(x, dim=-1, keepdim=True), nan=self.nan_loc)
48
- scale = torch.nan_to_num(torch.nanmean((x - loc).square(), dim=-1, keepdim=True).sqrt(), nan=1.0)
49
- scale = torch.where(scale == 0, torch.abs(loc) + self.eps, scale)
50
- else:
51
- loc, scale = loc_scale
52
-
53
- return ((x - loc) / scale), (loc, scale)
54
-
55
- def re_scale(self, x: torch.Tensor, loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
56
- loc, scale = loc_scale
57
- return x * scale + loc
58
-
59
-
60
- @dataclass
61
- class _Patcher:
62
- patch_size: int
63
- patch_stride: int
64
- left_pad: bool
65
-
66
- def __post_init__(self):
67
- assert self.patch_size % self.patch_stride == 0
68
-
69
- def __call__(self, x: torch.Tensor) -> torch.Tensor:
70
- assert x.ndim == 2
71
- length = x.shape[-1]
72
-
73
- if length < self.patch_size or (length % self.patch_stride != 0):
74
- if length < self.patch_size:
75
- padding_size = (
76
- *x.shape[:-1],
77
- self.patch_size - (length % self.patch_size),
78
- )
79
- else:
80
- padding_size = (
81
- *x.shape[:-1],
82
- self.patch_stride - (length % self.patch_stride),
83
- )
84
- padding = torch.full(size=padding_size, fill_value=torch.nan, dtype=x.dtype, device=x.device)
85
- if self.left_pad:
86
- x = torch.concat((padding, x), dim=-1)
87
- else:
88
- x = torch.concat((x, padding), dim=-1)
89
-
90
- x = x.unfold(dimension=-1, size=self.patch_size, step=self.patch_stride)
91
- return x
92
-
93
-
94
- @dataclass
95
- class PatchedUniTokenizer:
96
- patch_size: int
97
- scaler: Any = field(default_factory=StandardScaler)
98
- patch_stride: int | None = None
99
-
100
- def __post_init__(self):
101
- if self.patch_stride is None:
102
- self.patch_stride = self.patch_size
103
- self.patcher = _Patcher(self.patch_size, self.patch_stride, left_pad=True)
104
-
105
- def context_input_transform(self, data: torch.Tensor):
106
- assert data.ndim == 2
107
- data, scale_state = self.scaler.scale(data)
108
- return self.patcher(data), {SCALER_STATE: scale_state}
109
-
110
- def output_transform(self, data: torch.Tensor, tokenizer_state: dict):
111
- data_shape = data.shape
112
- data = self.scaler.re_scale(data.reshape(data_shape[0], -1), tokenizer_state[SCALER_STATE]).view(*data_shape)
113
- return data
114
-
115
-
116
- class StreamToLogger:
117
- """Fake file-like stream object that redirects writes to a logger
118
- instance."""
119
-
120
- def __init__(self, logger, log_level):
121
- self.logger = logger
122
- self.log_level = log_level
123
- self.linebuf = "" # Buffer for partial lines
124
-
125
- def write(self, message):
126
- # Filter out empty messages (often from just a newline)
127
- if message.strip():
128
- self.linebuf += message
129
- # If the message contains a newline, process the full line
130
- if "\n" in self.linebuf:
131
- lines = self.linebuf.splitlines(keepends=True)
132
- for line in lines:
133
- if line.endswith("\n"):
134
- # Log full lines without the trailing newline (logger adds its own)
135
- self.logger.log(self.log_level, line.rstrip("\n"))
136
- else:
137
- # Keep partial lines in buffer
138
- self.linebuf = line
139
- return
140
- self.linebuf = "" # All lines processed
141
- # If no newline, keep buffering
142
-
143
- def flush(self):
144
- # Log any remaining buffered content when flush is called
145
- if self.linebuf.strip():
146
- self.logger.log(self.log_level, self.linebuf.rstrip("\n"))
147
- self.linebuf = ""
@@ -1,143 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
-
5
- import os
6
- from dataclasses import dataclass, field
7
-
8
- import torch
9
- from torch import nn
10
- from xlstm.blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
11
- from xlstm.xlstm_large import xLSTMLargeConfig
12
- from xlstm.xlstm_large.components import RMSNorm
13
- from xlstm.xlstm_large.model import FeedForward, mLSTMBlock, mLSTMStateType
14
-
15
-
16
- def skip_cuda():
17
- return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
18
-
19
-
20
- def init_cell(config: xLSTMLargeConfig, block_idx, num_blocks):
21
- return sLSTMLayer(
22
- sLSTMLayerConfig(
23
- embedding_dim=config.embedding_dim,
24
- num_heads=config.num_heads,
25
- conv1d_kernel_size=0, # 0 means no convolution included
26
- group_norm_weight=True,
27
- dropout=0,
28
- # CellConfig
29
- backend="vanilla" if skip_cuda() else "cuda",
30
- bias_init="powerlaw_blockdependent",
31
- recurrent_weight_init="zeros",
32
- num_gates=4,
33
- gradient_recurrent_cut=False,
34
- gradient_recurrent_clipval=None,
35
- forward_clipval=None,
36
- batch_size=8, # needed?
37
- _block_idx=block_idx,
38
- _num_blocks=num_blocks,
39
- )
40
- )
41
-
42
-
43
- sLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor]
44
- sLSTMStateType = dict[int, sLSTMLayerStateType]
45
-
46
-
47
- class sLSTMBlock(nn.Module):
48
- def __init__(self, config: xLSTMLargeConfig, block_idx: int, num_blocks: int):
49
- super().__init__()
50
- self.config = config
51
- self.norm_slstm = RMSNorm(
52
- num_features=config.embedding_dim,
53
- eps=config.norm_eps,
54
- use_weight=True,
55
- use_bias=config.use_bias,
56
- force_float32_reductions=config.norm_reduction_force_float32,
57
- )
58
- self.slstm_layer = init_cell(config, block_idx, num_blocks)
59
-
60
- self.norm_ffn = RMSNorm(
61
- num_features=config.embedding_dim,
62
- eps=config.norm_eps,
63
- use_weight=True,
64
- use_bias=config.use_bias,
65
- force_float32_reductions=config.norm_reduction_force_float32,
66
- )
67
- self.ffn = FeedForward(config)
68
-
69
- def forward(
70
- self, x: torch.Tensor, state: sLSTMLayerStateType | None = None
71
- ) -> tuple[torch.Tensor, sLSTMLayerStateType]:
72
- x_slstm = self.norm_slstm(x)
73
- if state is None:
74
- conv_state, slstm_state = None, None
75
- else:
76
- conv_state, slstm_state = state
77
- x_slstm, state = self.slstm_layer(x_slstm, conv_state, slstm_state, return_last_state=True)
78
- x = x + x_slstm
79
-
80
- x_ffn = self.norm_ffn(x)
81
- x_ffn = self.ffn(x_ffn)
82
- x = x + x_ffn
83
-
84
- return x, (state["conv_state"], state["slstm_state"])
85
-
86
-
87
- @dataclass
88
- class xLSTMMixedLargeConfig(xLSTMLargeConfig):
89
- slstm_at: list[int] = field(default_factory=list)
90
- all_slstm: bool = True
91
-
92
- @property
93
- def block_types(self):
94
- return ["s" if i in self.slstm_at or self.all_slstm else "m" for i in range(self.num_blocks)]
95
-
96
-
97
- class xLSTMMixedLargeBlockStack(nn.Module):
98
- config_class = xLSTMMixedLargeConfig
99
-
100
- def __init__(self, config: xLSTMMixedLargeConfig):
101
- super().__init__()
102
- self.config = config
103
-
104
- self.blocks = nn.ModuleList(
105
- [
106
- sLSTMBlock(config, block_idx=i, num_blocks=config.num_blocks) if t == "s" else mLSTMBlock(config)
107
- for i, t in enumerate(config.block_types)
108
- ]
109
- )
110
-
111
- if self.config.add_out_norm:
112
- self.out_norm = RMSNorm(
113
- num_features=config.embedding_dim,
114
- eps=config.norm_eps,
115
- use_weight=True,
116
- use_bias=config.use_bias,
117
- force_float32_reductions=config.norm_reduction_force_float32,
118
- )
119
- else:
120
- self.out_norm = nn.Identity()
121
-
122
- def forward(
123
- self, x: torch.Tensor, state: mLSTMStateType | sLSTMStateType | None = None
124
- ) -> tuple[torch.Tensor, mLSTMStateType]:
125
- if state is None:
126
- state = {i: None for i in range(len(self.blocks))}
127
-
128
- for i, block in enumerate(self.blocks):
129
- block_state = state[i]
130
- x, block_state_new = block(x, block_state)
131
-
132
- if block_state is None:
133
- state[i] = block_state_new
134
- else:
135
- pass
136
- ## layer state is a tuple of three tensors: c, n, m
137
- ## we update the state in place in order to avoid creating new tensors
138
- # for state_idx in range(len(block_state)):
139
- # state[i][state_idx].copy_(block_state_new[state_idx])
140
-
141
- x = self.out_norm(x)
142
-
143
- return x, state
@@ -1,72 +0,0 @@
1
- # Copyright (c) NXAI GmbH.
2
- # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
-
4
-
5
- import logging
6
- from abc import abstractmethod
7
-
8
- import torch
9
-
10
- from ..api_adapter.forecast import ForecastModel
11
-
12
- LOGGER = logging.getLogger()
13
-
14
-
15
- class TensorQuantileUniPredictMixin(ForecastModel):
16
- @abstractmethod
17
- def _forecast_tensor(
18
- self,
19
- context: torch.Tensor,
20
- prediction_length: int | None = None,
21
- **predict_kwargs,
22
- ) -> torch.Tensor:
23
- pass
24
-
25
- @property
26
- @abstractmethod
27
- def quantiles(self):
28
- pass
29
-
30
- def _forecast_quantiles(
31
- self,
32
- context: torch.Tensor,
33
- prediction_length: int | None = None,
34
- quantile_levels: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
35
- output_device: str = "cpu",
36
- auto_cast: bool = False,
37
- **predict_kwargs,
38
- ) -> tuple[torch.Tensor, torch.Tensor]:
39
- with torch.autocast(device_type=self.device.type, enabled=auto_cast):
40
- predictions = self._forecast_tensor(
41
- context=context, prediction_length=prediction_length, **predict_kwargs
42
- ).detach()
43
- predictions = predictions.to(torch.device(output_device)).swapaxes(1, 2)
44
-
45
- training_quantile_levels = list(self.quantiles)
46
-
47
- if set(quantile_levels).issubset(set(training_quantile_levels)):
48
- quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
49
- else:
50
- if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(
51
- training_quantile_levels
52
- ):
53
- logging.warning(
54
- f"Requested quantile levels ({quantile_levels}) fall outside the range of "
55
- f"quantiles the model was trained on ({training_quantile_levels}). "
56
- "Predictions for out-of-range quantiles will be clamped to the nearest "
57
- "boundary of the trained quantiles (i.e., minimum or maximum trained level). "
58
- "This can significantly impact prediction accuracy, especially for extreme quantiles. "
59
- )
60
- # Interpolate quantiles
61
- augmented_predictions = torch.cat(
62
- [predictions[..., [0]], predictions, predictions[..., [-1]]],
63
- dim=-1,
64
- )
65
- quantiles = torch.quantile(
66
- augmented_predictions,
67
- q=torch.tensor(quantile_levels, dtype=augmented_predictions.dtype),
68
- dim=-1,
69
- ).permute(1, 2, 0)
70
- # median as mean
71
- mean = predictions[:, :, training_quantile_levels.index(0.5)]
72
- return quantiles, mean
@@ -1,19 +0,0 @@
1
- tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
2
- tirex/base.py,sha256=F18v9tTbLH0-nX-PC6kBAkYQHkS1T_7OQD6_aN6EjMw,2623
3
- tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
4
- tirex/api_adapter/forecast.py,sha256=iOVP_L7fYlp1ZjyrQe2b8fwuEcxTYOszfZ5f9VDqKHU,8503
5
- tirex/api_adapter/gluon.py,sha256=faiYyn0kBBVQKbpWqrVoyylxZUrmr-qce66twpguVds,1827
6
- tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,1377
7
- tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
8
- tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
9
- tirex/models/components.py,sha256=sluhMbV6KL3W1ESoC5Nyoxdge9WSNx98alc8NG85dv0,4991
10
- tirex/models/mixed_stack.py,sha256=ffpdhwCrPAbpp4_s1q8Z0Ei7iZ2TsqzVzOPe3BQPW9w,4790
11
- tirex/models/predict_utils.py,sha256=QUMZZ4_Sxa09UaHs1DG-MbfP8j_XwYt0x1zemdSEcFI,2749
12
- tirex/models/tirex.py,sha256=bFxtcpQB9-Hnayy_4bqif-o75DwO3-W0wJxelS8F_6c,9243
13
- tirex_mirror-2025.8.28.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
14
- tirex_mirror-2025.8.28.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
15
- tirex_mirror-2025.8.28.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
16
- tirex_mirror-2025.8.28.dist-info/METADATA,sha256=c9lhDLJfwEsYyxjEJ732XMn1FgCowXZ5yIyWK2NBy8o,11029
17
- tirex_mirror-2025.8.28.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
- tirex_mirror-2025.8.28.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
19
- tirex_mirror-2025.8.28.dist-info/RECORD,,