tirex-mirror 2025.10.1__py3-none-any.whl → 2025.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tirex/base.py CHANGED
@@ -84,6 +84,7 @@ def load_model(
84
84
  Args:
85
85
  path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
86
86
  device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
87
+ backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
87
88
  hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
88
89
  ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
89
90
 
@@ -21,14 +21,8 @@ class sLSTMBlock(nn.Module):
21
21
  self.ffn = FeedForward(config.embedding_dim, up_proj_dim)
22
22
 
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
- x_slstm = self.norm_slstm(x)
25
-
26
- x_slstm = self.slstm_layer(x_slstm, slstm_state=None)
27
- x = x + x_slstm
28
-
29
- x_ffn = self.norm_ffn(x)
30
- x_ffn = self.ffn(x_ffn)
31
- x = x + x_ffn
24
+ x = x + self.slstm_layer(self.norm_slstm(x), slstm_state=None)
25
+ x = x + self.ffn(self.norm_ffn(x))
32
26
  return x
33
27
 
34
28
 
@@ -41,8 +35,8 @@ class FeedForward(nn.Module):
41
35
 
42
36
  def forward(self, x: torch.Tensor) -> torch.Tensor:
43
37
  x = F.silu(self.proj_up_gate(x)) * self.proj_up(x)
44
- y = self.proj_down(x)
45
- return y
38
+ x = self.proj_down(x)
39
+ return x
46
40
 
47
41
 
48
42
  class RMSNorm(nn.Module):
@@ -168,21 +168,24 @@ def slstm_forward_pointwise(
168
168
  states: torch.Tensor, # dim [4, B, H]
169
169
  ) -> tuple[torch.Tensor, torch.Tensor]:
170
170
  raw = Wx + Ry + b
171
- y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
172
171
 
173
172
  iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
173
+ y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
174
+
174
175
  # with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
175
- logfplusm = m + F.logsigmoid(fraw)
176
+ # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
177
+ logfplusm = m + F.logsigmoid(fraw) # eq 15
176
178
  if torch.all(n == 0.0):
177
179
  mnew = iraw
178
180
  else:
179
- mnew = torch.max(iraw, logfplusm)
180
- ogate = torch.sigmoid(oraw)
181
- igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw))
182
- fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw))
183
- cnew = fgate * c + igate * torch.tanh(zraw)
184
- nnew = fgate * n + igate
185
- ynew = ogate * cnew / nnew
186
-
187
- # shapes ([B,H], [B,H], [B,H]), ([B,H],[B,H],[B,H],[B,H])
188
- return torch.stack((ynew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
181
+ mnew = torch.max(iraw, logfplusm) # eq 15
182
+ ogate = torch.sigmoid(oraw) # eq 14
183
+ igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
184
+ fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
185
+ zgate = torch.tanh(zraw) # eq 11
186
+ cnew = fgate * c + igate * zgate # eq 8
187
+ nnew = fgate * n + igate # eq 9
188
+ hnew = ogate * cnew / nnew # eq 10
189
+
190
+ # y (4, B, H), state (4, B, H)
191
+ return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.1
3
+ Version: 2025.10.2
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -67,16 +67,17 @@ Requires-Dist: torch
67
67
  Requires-Dist: einops
68
68
  Requires-Dist: huggingface-hub
69
69
  Requires-Dist: numpy
70
- Requires-Dist: pandas
71
- Requires-Dist: tqdm
72
70
  Provides-Extra: cuda
73
71
  Requires-Dist: xlstm; extra == "cuda"
74
72
  Requires-Dist: ninja; extra == "cuda"
75
73
  Provides-Extra: notebooks
76
74
  Requires-Dist: ipykernel; extra == "notebooks"
77
75
  Requires-Dist: matplotlib; extra == "notebooks"
76
+ Requires-Dist: pandas; extra == "notebooks"
77
+ Requires-Dist: python-dotenv; extra == "notebooks"
78
78
  Provides-Extra: gluonts
79
79
  Requires-Dist: gluonts; extra == "gluonts"
80
+ Requires-Dist: pandas; extra == "gluonts"
80
81
  Provides-Extra: hfdataset
81
82
  Requires-Dist: datasets; extra == "hfdataset"
82
83
  Provides-Extra: test
@@ -87,6 +88,8 @@ Requires-Dist: xlstm; extra == "all"
87
88
  Requires-Dist: ninja; extra == "all"
88
89
  Requires-Dist: ipykernel; extra == "all"
89
90
  Requires-Dist: matplotlib; extra == "all"
91
+ Requires-Dist: pandas; extra == "all"
92
+ Requires-Dist: python-dotenv; extra == "all"
90
93
  Requires-Dist: gluonts; extra == "all"
91
94
  Requires-Dist: datasets; extra == "all"
92
95
  Requires-Dist: pytest; extra == "all"
@@ -1,5 +1,5 @@
1
1
  tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
2
- tirex/base.py,sha256=u_fcwaIKEzq9aAt3UWqH8QvaqXG7qEykLNaP_opY26M,3366
2
+ tirex/base.py,sha256=fwyUTGL103kK5jgK5MoLSIHQcZb4lrox_D9fNbY1W1k,3507
3
3
  tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
4
4
  tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
5
5
  tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
@@ -9,13 +9,13 @@ tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQv
9
9
  tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
10
10
  tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
11
11
  tirex/models/tirex.py,sha256=Kglea86t_f3nXXHSjFgssxxrd1Qbwfr1eB_5gKfWYxM,9098
12
- tirex/models/slstm/block.py,sha256=DCOxmLQUb7HRO6wXTZMK4ICUI5LFpo7NC5a28oM-Vsc,2104
13
- tirex/models/slstm/cell.py,sha256=XWsn8I7HrUoMrUrfRCpl6Q88xbBz67bKEkdZ8gXE3hY,7444
12
+ tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
13
+ tirex/models/slstm/cell.py,sha256=ippaAPKI83j3_1l3pu9ks-iBGO641Elm1W4HsHgVu-c,7601
14
14
  tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
15
- tirex_mirror-2025.10.1.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
- tirex_mirror-2025.10.1.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
- tirex_mirror-2025.10.1.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
- tirex_mirror-2025.10.1.dist-info/METADATA,sha256=HHaDkSlIQIHE4aIhG1EG3S9vK_tfuWT_HuLxy-6st6s,11265
19
- tirex_mirror-2025.10.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- tirex_mirror-2025.10.1.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
- tirex_mirror-2025.10.1.dist-info/RECORD,,
15
+ tirex_mirror-2025.10.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
+ tirex_mirror-2025.10.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
+ tirex_mirror-2025.10.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
+ tirex_mirror-2025.10.2.dist-info/METADATA,sha256=Aq9VAU0pojVsrxwbfFLCmwmk1Gfl_Z_49G4yB-Z9eLY,11443
19
+ tirex_mirror-2025.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ tirex_mirror-2025.10.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
+ tirex_mirror-2025.10.2.dist-info/RECORD,,