tirex-mirror 2025.10.1__py3-none-any.whl → 2025.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tirex/base.py +1 -0
- tirex/models/slstm/block.py +4 -10
- tirex/models/slstm/cell.py +15 -12
- {tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/METADATA +6 -3
- {tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/RECORD +10 -10
- {tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/WHEEL +0 -0
- {tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/licenses/LICENSE +0 -0
- {tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/licenses/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/licenses/NOTICE.txt +0 -0
- {tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/top_level.txt +0 -0
tirex/base.py
CHANGED
|
@@ -84,6 +84,7 @@ def load_model(
|
|
|
84
84
|
Args:
|
|
85
85
|
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
|
|
86
86
|
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
|
|
87
|
+
backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
|
|
87
88
|
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
|
|
88
89
|
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
|
|
89
90
|
|
tirex/models/slstm/block.py
CHANGED
|
@@ -21,14 +21,8 @@ class sLSTMBlock(nn.Module):
|
|
|
21
21
|
self.ffn = FeedForward(config.embedding_dim, up_proj_dim)
|
|
22
22
|
|
|
23
23
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
x_slstm = self.slstm_layer(x_slstm, slstm_state=None)
|
|
27
|
-
x = x + x_slstm
|
|
28
|
-
|
|
29
|
-
x_ffn = self.norm_ffn(x)
|
|
30
|
-
x_ffn = self.ffn(x_ffn)
|
|
31
|
-
x = x + x_ffn
|
|
24
|
+
x = x + self.slstm_layer(self.norm_slstm(x), slstm_state=None)
|
|
25
|
+
x = x + self.ffn(self.norm_ffn(x))
|
|
32
26
|
return x
|
|
33
27
|
|
|
34
28
|
|
|
@@ -41,8 +35,8 @@ class FeedForward(nn.Module):
|
|
|
41
35
|
|
|
42
36
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
43
37
|
x = F.silu(self.proj_up_gate(x)) * self.proj_up(x)
|
|
44
|
-
|
|
45
|
-
return
|
|
38
|
+
x = self.proj_down(x)
|
|
39
|
+
return x
|
|
46
40
|
|
|
47
41
|
|
|
48
42
|
class RMSNorm(nn.Module):
|
tirex/models/slstm/cell.py
CHANGED
|
@@ -168,21 +168,24 @@ def slstm_forward_pointwise(
|
|
|
168
168
|
states: torch.Tensor, # dim [4, B, H]
|
|
169
169
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
170
170
|
raw = Wx + Ry + b
|
|
171
|
-
y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
|
|
172
171
|
|
|
173
172
|
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
|
|
173
|
+
y, c, n, m = torch.unbind(states.view(4, states.shape[1], -1), dim=0)
|
|
174
|
+
|
|
174
175
|
# with torch.no_grad(): # THE difference to maxg aka max_gradient (here max / max_static)
|
|
175
|
-
|
|
176
|
+
# Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
|
|
177
|
+
logfplusm = m + F.logsigmoid(fraw) # eq 15
|
|
176
178
|
if torch.all(n == 0.0):
|
|
177
179
|
mnew = iraw
|
|
178
180
|
else:
|
|
179
|
-
mnew = torch.max(iraw, logfplusm)
|
|
180
|
-
ogate = torch.sigmoid(oraw)
|
|
181
|
-
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw))
|
|
182
|
-
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw))
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
181
|
+
mnew = torch.max(iraw, logfplusm) # eq 15
|
|
182
|
+
ogate = torch.sigmoid(oraw) # eq 14
|
|
183
|
+
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
184
|
+
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
|
|
185
|
+
zgate = torch.tanh(zraw) # eq 11
|
|
186
|
+
cnew = fgate * c + igate * zgate # eq 8
|
|
187
|
+
nnew = fgate * n + igate # eq 9
|
|
188
|
+
hnew = ogate * cnew / nnew # eq 10
|
|
189
|
+
|
|
190
|
+
# y (4, B, H), state (4, B, H)
|
|
191
|
+
return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tirex-mirror
|
|
3
|
-
Version: 2025.10.
|
|
3
|
+
Version: 2025.10.2
|
|
4
4
|
Summary: Unofficial mirror of NX-AI/tirex for packaging
|
|
5
5
|
Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
|
|
6
6
|
License: NXAI COMMUNITY LICENSE AGREEMENT
|
|
@@ -67,16 +67,17 @@ Requires-Dist: torch
|
|
|
67
67
|
Requires-Dist: einops
|
|
68
68
|
Requires-Dist: huggingface-hub
|
|
69
69
|
Requires-Dist: numpy
|
|
70
|
-
Requires-Dist: pandas
|
|
71
|
-
Requires-Dist: tqdm
|
|
72
70
|
Provides-Extra: cuda
|
|
73
71
|
Requires-Dist: xlstm; extra == "cuda"
|
|
74
72
|
Requires-Dist: ninja; extra == "cuda"
|
|
75
73
|
Provides-Extra: notebooks
|
|
76
74
|
Requires-Dist: ipykernel; extra == "notebooks"
|
|
77
75
|
Requires-Dist: matplotlib; extra == "notebooks"
|
|
76
|
+
Requires-Dist: pandas; extra == "notebooks"
|
|
77
|
+
Requires-Dist: python-dotenv; extra == "notebooks"
|
|
78
78
|
Provides-Extra: gluonts
|
|
79
79
|
Requires-Dist: gluonts; extra == "gluonts"
|
|
80
|
+
Requires-Dist: pandas; extra == "gluonts"
|
|
80
81
|
Provides-Extra: hfdataset
|
|
81
82
|
Requires-Dist: datasets; extra == "hfdataset"
|
|
82
83
|
Provides-Extra: test
|
|
@@ -87,6 +88,8 @@ Requires-Dist: xlstm; extra == "all"
|
|
|
87
88
|
Requires-Dist: ninja; extra == "all"
|
|
88
89
|
Requires-Dist: ipykernel; extra == "all"
|
|
89
90
|
Requires-Dist: matplotlib; extra == "all"
|
|
91
|
+
Requires-Dist: pandas; extra == "all"
|
|
92
|
+
Requires-Dist: python-dotenv; extra == "all"
|
|
90
93
|
Requires-Dist: gluonts; extra == "all"
|
|
91
94
|
Requires-Dist: datasets; extra == "all"
|
|
92
95
|
Requires-Dist: pytest; extra == "all"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
-
tirex/base.py,sha256=
|
|
2
|
+
tirex/base.py,sha256=fwyUTGL103kK5jgK5MoLSIHQcZb4lrox_D9fNbY1W1k,3507
|
|
3
3
|
tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
|
|
4
4
|
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
5
|
tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
|
|
@@ -9,13 +9,13 @@ tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQv
|
|
|
9
9
|
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
10
|
tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
|
|
11
11
|
tirex/models/tirex.py,sha256=Kglea86t_f3nXXHSjFgssxxrd1Qbwfr1eB_5gKfWYxM,9098
|
|
12
|
-
tirex/models/slstm/block.py,sha256=
|
|
13
|
-
tirex/models/slstm/cell.py,sha256=
|
|
12
|
+
tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
|
|
13
|
+
tirex/models/slstm/cell.py,sha256=ippaAPKI83j3_1l3pu9ks-iBGO641Elm1W4HsHgVu-c,7601
|
|
14
14
|
tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
|
|
15
|
-
tirex_mirror-2025.10.
|
|
16
|
-
tirex_mirror-2025.10.
|
|
17
|
-
tirex_mirror-2025.10.
|
|
18
|
-
tirex_mirror-2025.10.
|
|
19
|
-
tirex_mirror-2025.10.
|
|
20
|
-
tirex_mirror-2025.10.
|
|
21
|
-
tirex_mirror-2025.10.
|
|
15
|
+
tirex_mirror-2025.10.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
16
|
+
tirex_mirror-2025.10.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
17
|
+
tirex_mirror-2025.10.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
18
|
+
tirex_mirror-2025.10.2.dist-info/METADATA,sha256=Aq9VAU0pojVsrxwbfFLCmwmk1Gfl_Z_49G4yB-Z9eLY,11443
|
|
19
|
+
tirex_mirror-2025.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
20
|
+
tirex_mirror-2025.10.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
21
|
+
tirex_mirror-2025.10.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.1.dist-info → tirex_mirror-2025.10.2.dist-info}/licenses/LICENSE_MIRROR.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|