tirex-mirror 2025.10.2__py3-none-any.whl → 2025.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tirex/base.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright (c) NXAI GmbH.
2
2
  # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
3
 
4
+ import logging
4
5
  import os
5
6
  from abc import ABC, abstractmethod
6
7
  from typing import Literal, TypeVar
@@ -8,6 +9,8 @@ from typing import Literal, TypeVar
8
9
  import torch
9
10
  from huggingface_hub import hf_hub_download
10
11
 
12
+ from tirex.models.slstm.cell import sLSTMCellTorch
13
+
11
14
  T = TypeVar("T", bound="PretrainedModel")
12
15
 
13
16
 
@@ -38,7 +41,7 @@ class PretrainedModel(ABC):
38
41
 
39
42
  @classmethod
40
43
  def from_pretrained(
41
- cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
44
+ cls: type[T], path: str, backend: str, device: str | None = None, compile=False, hf_kwargs=None, ckp_kwargs=None
42
45
  ) -> T:
43
46
  if hf_kwargs is None:
44
47
  hf_kwargs = {}
@@ -58,9 +61,10 @@ class PretrainedModel(ABC):
58
61
  model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
59
62
  model.on_load_checkpoint(checkpoint)
60
63
  model.load_state_dict(checkpoint["state_dict"])
64
+ model = model.to(device)
61
65
 
62
- if backend == "cuda":
63
- model = model.to(device)
66
+ if compile and backend == "torch":
67
+ sLSTMCellTorch.slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward, mode="max-autotune")
64
68
  return model
65
69
 
66
70
  @classmethod
@@ -76,6 +80,7 @@ def load_model(
76
80
  path: str,
77
81
  device: str | None = None,
78
82
  backend: Literal["torch", "cuda"] | None = None,
83
+ compile: bool = False,
79
84
  hf_kwargs=None,
80
85
  ckp_kwargs=None,
81
86
  ) -> PretrainedModel:
@@ -85,6 +90,7 @@ def load_model(
85
90
  path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
86
91
  device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
87
92
  backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
93
+ compile (bool, optional): toch.compile the sLSTM cells, only works with the torch backend
88
94
  hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
89
95
  ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
90
96
 
@@ -106,4 +112,6 @@ def load_model(
106
112
  if model_cls is None:
107
113
  raise ValueError(f"Invalid model id {model_id}")
108
114
 
109
- return model_cls.from_pretrained(path, device=device, backend=backend, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs)
115
+ return model_cls.from_pretrained(
116
+ path, device=device, backend=backend, compile=compile, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs
117
+ )
@@ -43,13 +43,11 @@ class sLSTMCell(nn.Module):
43
43
  state = self._get_state(input, state)
44
44
 
45
45
  if self.backend == "torch":
46
- all_states = self._impl_torch(input, state)
46
+ output, state = self._impl_torch(input, state)
47
47
  elif self.backend == "cuda":
48
- all_states = self._impl_cuda(input, state)
48
+ output, state = self._impl_cuda(input, state)
49
49
 
50
- state = all_states[:, -1]
51
- output = self._permute_output(all_states[0][1:])
52
- return output.to(input.dtype), state.to(input.dtype)
50
+ return self._permute_output(output).to(input.dtype), state.to(input.dtype)
53
51
 
54
52
  def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
55
53
  input = input.to(dtype=torch.bfloat16)
@@ -64,7 +62,7 @@ class sLSTMCell(nn.Module):
64
62
  .reshape(-1)
65
63
  )
66
64
 
67
- return slstm_forward(input, state, recurrent_kernel, bias)[0]
65
+ return sLSTMCellTorch.slstm_forward(input, state, recurrent_kernel, bias)
68
66
 
69
67
  def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
70
68
  if input.device.type != "cuda":
@@ -88,7 +86,7 @@ class sLSTMCell(nn.Module):
88
86
 
89
87
  input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
90
88
 
91
- return self.func.apply(
89
+ all_states = self.func.apply(
92
90
  False,
93
91
  input.contiguous(),
94
92
  state.contiguous(),
@@ -96,6 +94,10 @@ class sLSTMCell(nn.Module):
96
94
  self._bias_.contiguous(),
97
95
  )
98
96
 
97
+ state = all_states[:, -1]
98
+ output = all_states[0][1:]
99
+ return output, state
100
+
99
101
  def _get_input(self, x: torch.Tensor) -> torch.Tensor:
100
102
  assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
101
103
  f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
@@ -119,73 +121,60 @@ class sLSTMCell(nn.Module):
119
121
  return output.permute(1, 2, 0, 3)
120
122
 
121
123
 
122
- def slstm_forward(
123
- x: torch.Tensor, # [S, B, G*I]
124
- states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
125
- R: torch.Tensor, # [K, R*H, H] - K num_heads
126
- b: torch.Tensor, # [T*H]
127
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
- num_states = states.shape[0]
129
- sequence_dim = x.shape[0]
130
- # this only works for a fully-connected RNN, for a hin change this
131
- num_gates_r = R.shape[2] // R.shape[1]
132
- hidden_dim = R.shape[1] * R.shape[0]
133
- batch_dim = x.shape[1]
134
- num_heads = R.shape[0]
135
-
136
- assert batch_dim == states.shape[1]
137
- assert hidden_dim == states.shape[2]
138
-
139
- states_all = torch.zeros(
140
- [num_states, sequence_dim + 1, batch_dim, hidden_dim],
141
- device=x.device,
142
- dtype=x.dtype,
143
- )
144
- states_all[:, 0] = states
145
- for i, Wx_t in enumerate(x.unbind(dim=0)):
146
- Ry = (
147
- states[0]
148
- .reshape(batch_dim, num_heads, 1, -1)
149
- .matmul(R.unsqueeze(0))
150
- .reshape(batch_dim, num_heads, num_gates_r, -1)
151
- .transpose(1, 2)
152
- .reshape(batch_dim, -1)
153
- )
154
- sdtype = states.dtype
155
- Wx_t, Ry, b, states = Wx_t.float(), Ry.float(), b.float(), states.float()
156
- states, gates = slstm_forward_pointwise(Wx_t, Ry, b, states)
157
- states = states.to(dtype=sdtype)
158
- states_all[:, i + 1] = states
159
-
160
- # shapes ([S, B, H], ([B,H], [B,H], [B,H])
161
- return states_all, states
162
-
163
-
164
- def slstm_forward_pointwise(
165
- Wx: torch.Tensor, # dim [B, 4*H]
166
- Ry: torch.Tensor, # dim [B, 4*H]
167
- b: torch.Tensor, # dim [1, 4*H]
168
- states: torch.Tensor, # dim [4, B, H]
169
- ) -> tuple[torch.Tensor, torch.Tensor]:
170
- raw = Wx + Ry + b
171
-
172
- iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
173
- y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
174
-
175
- # with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
176
- # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
177
- logfplusm = m + F.logsigmoid(fraw) # eq 15
178
- if torch.all(n == 0.0):
179
- mnew = iraw
180
- else:
181
- mnew = torch.max(iraw, logfplusm) # eq 15
182
- ogate = torch.sigmoid(oraw) # eq 14
183
- igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
184
- fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
185
- zgate = torch.tanh(zraw) # eq 11
186
- cnew = fgate * c + igate * zgate # eq 8
187
- nnew = fgate * n + igate # eq 9
188
- hnew = ogate * cnew / nnew # eq 10
189
-
190
- # y (4, B, H), state (4, B, H)
191
- return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
124
+ class sLSTMCellTorch:
125
+ @staticmethod
126
+ def slstm_forward(
127
+ x: torch.Tensor, # [S, B, G*I]
128
+ states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
129
+ R: torch.Tensor, # [K, R*H, H] - K num_heads
130
+ b: torch.Tensor, # [T*H]
131
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
132
+ num_gates = 4
133
+ num_heads = R.shape[0]
134
+ S, B, _ = x.shape
135
+ H = R.shape[1] * num_heads
136
+ assert states.shape == (num_gates, B, H)
137
+
138
+ states = states.to(R.dtype).unbind(dim=0)
139
+ output = []
140
+ for i in range(S):
141
+ Ry = (
142
+ states[0]
143
+ .reshape(B, num_heads, 1, -1)
144
+ .matmul(R.unsqueeze(0))
145
+ .reshape(B, num_heads, num_gates, -1)
146
+ .transpose(1, 2)
147
+ .reshape(B, -1)
148
+ )
149
+ states = sLSTMCellTorch.slstm_forward_pointwise(
150
+ x[i].float(), Ry.float(), b.float(), [s.float() for s in states]
151
+ )
152
+ states = [s.to(dtype=R.dtype) for s in states]
153
+ output.append(states[0])
154
+
155
+ return torch.stack(output), torch.stack(states) # (S, B, H), 4 x (B, H)
156
+
157
+ @staticmethod
158
+ def slstm_forward_pointwise(
159
+ Wx: torch.Tensor, # dim [B, 4*H]
160
+ Ry: torch.Tensor, # dim [B, 4*H]
161
+ b: torch.Tensor, # dim [1, 4*H]
162
+ states: torch.Tensor, # dim 4 x [B, H]
163
+ ) -> list[torch.Tensor]:
164
+ y, c, n, m = states
165
+
166
+ raw = Wx + Ry + b
167
+ iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
168
+
169
+ # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
170
+ logfplusm = m + F.logsigmoid(fraw) # eq 15
171
+ mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
172
+ ogate = torch.sigmoid(oraw) # eq 14
173
+ igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
174
+ fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
175
+ zgate = torch.tanh(zraw) # eq 11
176
+ cnew = fgate * c + igate * zgate # eq 8
177
+ nnew = fgate * n + igate # eq 9
178
+ hnew = ogate * cnew / nnew # eq 10
179
+
180
+ return [hnew, cnew, nnew, mnew] # 4 x (B, H)
tirex/models/tirex.py CHANGED
@@ -179,8 +179,25 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
179
179
  quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
180
180
  # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
181
181
 
182
+ quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
183
+
184
+ quantile_preds = torch.unflatten(
185
+ quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
186
+ )
187
+ quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
188
+ # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
182
189
  return quantile_preds, hidden_states
183
190
 
191
+ def _forward_model(self, input: torch.Tensor):
192
+ hidden_states = self.input_patch_embedding(input)
193
+
194
+ for block in self.blocks:
195
+ hidden_states = block(hidden_states)
196
+
197
+ hidden_states = self.out_norm(hidden_states)
198
+
199
+ return self.output_patch_embedding(hidden_states)
200
+
184
201
  def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
185
202
  training_quantile_levels = self.config.quantiles
186
203
  if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.2
3
+ Version: 2025.10.3
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,5 +1,5 @@
1
1
  tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
2
- tirex/base.py,sha256=fwyUTGL103kK5jgK5MoLSIHQcZb4lrox_D9fNbY1W1k,3507
2
+ tirex/base.py,sha256=P1RXKcDekG_v9fRgmrSVFRX9koaHQHL4fHCkDqLusgA,3862
3
3
  tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
4
4
  tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
5
5
  tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
@@ -8,14 +8,14 @@ tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,
8
8
  tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
9
9
  tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
10
10
  tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
11
- tirex/models/tirex.py,sha256=Kglea86t_f3nXXHSjFgssxxrd1Qbwfr1eB_5gKfWYxM,9098
11
+ tirex/models/tirex.py,sha256=JKNuCzTI6B9_yCbcmTf2UFjAQXulLNEmloqtAhKJKjQ,9830
12
12
  tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
13
- tirex/models/slstm/cell.py,sha256=ippaAPKI83j3_1l3pu9ks-iBGO641Elm1W4HsHgVu-c,7601
13
+ tirex/models/slstm/cell.py,sha256=JfCs1aUy9IHuz9RwExhUwiUtbg8WmbEg4upcO7hA5Rg,7229
14
14
  tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
15
- tirex_mirror-2025.10.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
- tirex_mirror-2025.10.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
- tirex_mirror-2025.10.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
- tirex_mirror-2025.10.2.dist-info/METADATA,sha256=Aq9VAU0pojVsrxwbfFLCmwmk1Gfl_Z_49G4yB-Z9eLY,11443
19
- tirex_mirror-2025.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- tirex_mirror-2025.10.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
- tirex_mirror-2025.10.2.dist-info/RECORD,,
15
+ tirex_mirror-2025.10.3.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
+ tirex_mirror-2025.10.3.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
+ tirex_mirror-2025.10.3.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
+ tirex_mirror-2025.10.3.dist-info/METADATA,sha256=ncNo7JrdyV4x4IMir2RUrWVmfJROSBQnQasP6gefa9g,11443
19
+ tirex_mirror-2025.10.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ tirex_mirror-2025.10.3.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
+ tirex_mirror-2025.10.3.dist-info/RECORD,,