tirex-mirror 2025.10.2__tar.gz → 2025.10.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tirex_mirror-2025.10.2/src/tirex_mirror.egg-info → tirex_mirror-2025.10.3}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/pyproject.toml +1 -1
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/base.py +12 -4
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/models/slstm/cell.py +66 -77
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/models/tirex.py +17 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/LICENSE +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/MANIFEST.in +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/NOTICE.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/README.md +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/setup.cfg +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/__init__.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/api_adapter/__init__.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/api_adapter/forecast.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/api_adapter/gluon.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/api_adapter/hf_data.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/api_adapter/standard_adapter.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/models/__init__.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/models/patcher.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/models/slstm/block.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/models/slstm/layer.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex/util.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex_mirror.egg-info/SOURCES.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex_mirror.egg-info/requires.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex_mirror.egg-info/top_level.txt +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/tests/test_chronos_zs.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/tests/test_forecast.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/tests/test_forecast_adapter.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/tests/test_slstm_torch_vs_cuda.py +0 -0
- {tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/tests/test_standard_adapter.py +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# Copyright (c) NXAI GmbH.
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
|
+
import logging
|
|
4
5
|
import os
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
6
7
|
from typing import Literal, TypeVar
|
|
@@ -8,6 +9,8 @@ from typing import Literal, TypeVar
|
|
|
8
9
|
import torch
|
|
9
10
|
from huggingface_hub import hf_hub_download
|
|
10
11
|
|
|
12
|
+
from tirex.models.slstm.cell import sLSTMCellTorch
|
|
13
|
+
|
|
11
14
|
T = TypeVar("T", bound="PretrainedModel")
|
|
12
15
|
|
|
13
16
|
|
|
@@ -38,7 +41,7 @@ class PretrainedModel(ABC):
|
|
|
38
41
|
|
|
39
42
|
@classmethod
|
|
40
43
|
def from_pretrained(
|
|
41
|
-
cls: type[T], path: str, backend: str, device: str | None = None, hf_kwargs=None, ckp_kwargs=None
|
|
44
|
+
cls: type[T], path: str, backend: str, device: str | None = None, compile=False, hf_kwargs=None, ckp_kwargs=None
|
|
42
45
|
) -> T:
|
|
43
46
|
if hf_kwargs is None:
|
|
44
47
|
hf_kwargs = {}
|
|
@@ -58,9 +61,10 @@ class PretrainedModel(ABC):
|
|
|
58
61
|
model: T = cls(backend=backend, **checkpoint["hyper_parameters"])
|
|
59
62
|
model.on_load_checkpoint(checkpoint)
|
|
60
63
|
model.load_state_dict(checkpoint["state_dict"])
|
|
64
|
+
model = model.to(device)
|
|
61
65
|
|
|
62
|
-
if backend == "
|
|
63
|
-
|
|
66
|
+
if compile and backend == "torch":
|
|
67
|
+
sLSTMCellTorch.slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward, mode="max-autotune")
|
|
64
68
|
return model
|
|
65
69
|
|
|
66
70
|
@classmethod
|
|
@@ -76,6 +80,7 @@ def load_model(
|
|
|
76
80
|
path: str,
|
|
77
81
|
device: str | None = None,
|
|
78
82
|
backend: Literal["torch", "cuda"] | None = None,
|
|
83
|
+
compile: bool = False,
|
|
79
84
|
hf_kwargs=None,
|
|
80
85
|
ckp_kwargs=None,
|
|
81
86
|
) -> PretrainedModel:
|
|
@@ -85,6 +90,7 @@ def load_model(
|
|
|
85
90
|
path (str): Hugging Face path to the model (e.g. NX-AI/TiRex)
|
|
86
91
|
device (str, optional): The device on which to load the model (e.g., "cuda:0", "cpu").
|
|
87
92
|
backend (torch | cuda): What backend to use, torch or the custom CUDA kernels. Defaults to cuda when xlstm is installed, else torch.
|
|
93
|
+
compile (bool, optional): toch.compile the sLSTM cells, only works with the torch backend
|
|
88
94
|
hf_kwargs (dict, optional): Keyword arguments to pass to the Hugging Face Hub download method.
|
|
89
95
|
ckp_kwargs (dict, optional): Keyword arguments to pass when loading the checkpoint.
|
|
90
96
|
|
|
@@ -106,4 +112,6 @@ def load_model(
|
|
|
106
112
|
if model_cls is None:
|
|
107
113
|
raise ValueError(f"Invalid model id {model_id}")
|
|
108
114
|
|
|
109
|
-
return model_cls.from_pretrained(
|
|
115
|
+
return model_cls.from_pretrained(
|
|
116
|
+
path, device=device, backend=backend, compile=compile, hf_kwargs=hf_kwargs, ckp_kwargs=ckp_kwargs
|
|
117
|
+
)
|
|
@@ -43,13 +43,11 @@ class sLSTMCell(nn.Module):
|
|
|
43
43
|
state = self._get_state(input, state)
|
|
44
44
|
|
|
45
45
|
if self.backend == "torch":
|
|
46
|
-
|
|
46
|
+
output, state = self._impl_torch(input, state)
|
|
47
47
|
elif self.backend == "cuda":
|
|
48
|
-
|
|
48
|
+
output, state = self._impl_cuda(input, state)
|
|
49
49
|
|
|
50
|
-
|
|
51
|
-
output = self._permute_output(all_states[0][1:])
|
|
52
|
-
return output.to(input.dtype), state.to(input.dtype)
|
|
50
|
+
return self._permute_output(output).to(input.dtype), state.to(input.dtype)
|
|
53
51
|
|
|
54
52
|
def _impl_torch(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
55
53
|
input = input.to(dtype=torch.bfloat16)
|
|
@@ -64,7 +62,7 @@ class sLSTMCell(nn.Module):
|
|
|
64
62
|
.reshape(-1)
|
|
65
63
|
)
|
|
66
64
|
|
|
67
|
-
return slstm_forward(input, state, recurrent_kernel, bias)
|
|
65
|
+
return sLSTMCellTorch.slstm_forward(input, state, recurrent_kernel, bias)
|
|
68
66
|
|
|
69
67
|
def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
|
|
70
68
|
if input.device.type != "cuda":
|
|
@@ -88,7 +86,7 @@ class sLSTMCell(nn.Module):
|
|
|
88
86
|
|
|
89
87
|
input = input.permute(0, 1, 3, 2, 4).reshape(input.shape[0], input.shape[1], -1)
|
|
90
88
|
|
|
91
|
-
|
|
89
|
+
all_states = self.func.apply(
|
|
92
90
|
False,
|
|
93
91
|
input.contiguous(),
|
|
94
92
|
state.contiguous(),
|
|
@@ -96,6 +94,10 @@ class sLSTMCell(nn.Module):
|
|
|
96
94
|
self._bias_.contiguous(),
|
|
97
95
|
)
|
|
98
96
|
|
|
97
|
+
state = all_states[:, -1]
|
|
98
|
+
output = all_states[0][1:]
|
|
99
|
+
return output, state
|
|
100
|
+
|
|
99
101
|
def _get_input(self, x: torch.Tensor) -> torch.Tensor:
|
|
100
102
|
assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
|
|
101
103
|
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
|
|
@@ -119,73 +121,60 @@ class sLSTMCell(nn.Module):
|
|
|
119
121
|
return output.permute(1, 2, 0, 3)
|
|
120
122
|
|
|
121
123
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
states
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
.
|
|
151
|
-
.
|
|
152
|
-
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
mnew = iraw
|
|
180
|
-
else:
|
|
181
|
-
mnew = torch.max(iraw, logfplusm) # eq 15
|
|
182
|
-
ogate = torch.sigmoid(oraw) # eq 14
|
|
183
|
-
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
184
|
-
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
|
|
185
|
-
zgate = torch.tanh(zraw) # eq 11
|
|
186
|
-
cnew = fgate * c + igate * zgate # eq 8
|
|
187
|
-
nnew = fgate * n + igate # eq 9
|
|
188
|
-
hnew = ogate * cnew / nnew # eq 10
|
|
189
|
-
|
|
190
|
-
# y (4, B, H), state (4, B, H)
|
|
191
|
-
return torch.stack((hnew, cnew, nnew, mnew), dim=0), torch.stack((igate, fgate, zraw, ogate), dim=0)
|
|
124
|
+
class sLSTMCellTorch:
|
|
125
|
+
@staticmethod
|
|
126
|
+
def slstm_forward(
|
|
127
|
+
x: torch.Tensor, # [S, B, G*I]
|
|
128
|
+
states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
|
|
129
|
+
R: torch.Tensor, # [K, R*H, H] - K num_heads
|
|
130
|
+
b: torch.Tensor, # [T*H]
|
|
131
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
132
|
+
num_gates = 4
|
|
133
|
+
num_heads = R.shape[0]
|
|
134
|
+
S, B, _ = x.shape
|
|
135
|
+
H = R.shape[1] * num_heads
|
|
136
|
+
assert states.shape == (num_gates, B, H)
|
|
137
|
+
|
|
138
|
+
states = states.to(R.dtype).unbind(dim=0)
|
|
139
|
+
output = []
|
|
140
|
+
for i in range(S):
|
|
141
|
+
Ry = (
|
|
142
|
+
states[0]
|
|
143
|
+
.reshape(B, num_heads, 1, -1)
|
|
144
|
+
.matmul(R.unsqueeze(0))
|
|
145
|
+
.reshape(B, num_heads, num_gates, -1)
|
|
146
|
+
.transpose(1, 2)
|
|
147
|
+
.reshape(B, -1)
|
|
148
|
+
)
|
|
149
|
+
states = sLSTMCellTorch.slstm_forward_pointwise(
|
|
150
|
+
x[i].float(), Ry.float(), b.float(), [s.float() for s in states]
|
|
151
|
+
)
|
|
152
|
+
states = [s.to(dtype=R.dtype) for s in states]
|
|
153
|
+
output.append(states[0])
|
|
154
|
+
|
|
155
|
+
return torch.stack(output), torch.stack(states) # (S, B, H), 4 x (B, H)
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def slstm_forward_pointwise(
|
|
159
|
+
Wx: torch.Tensor, # dim [B, 4*H]
|
|
160
|
+
Ry: torch.Tensor, # dim [B, 4*H]
|
|
161
|
+
b: torch.Tensor, # dim [1, 4*H]
|
|
162
|
+
states: torch.Tensor, # dim 4 x [B, H]
|
|
163
|
+
) -> list[torch.Tensor]:
|
|
164
|
+
y, c, n, m = states
|
|
165
|
+
|
|
166
|
+
raw = Wx + Ry + b
|
|
167
|
+
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
|
|
168
|
+
|
|
169
|
+
# Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
|
|
170
|
+
logfplusm = m + F.logsigmoid(fraw) # eq 15
|
|
171
|
+
mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
|
|
172
|
+
ogate = torch.sigmoid(oraw) # eq 14
|
|
173
|
+
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
174
|
+
fgate = torch.minimum(torch.exp(logfplusm - mnew), torch.ones_like(iraw)) # eq 17
|
|
175
|
+
zgate = torch.tanh(zraw) # eq 11
|
|
176
|
+
cnew = fgate * c + igate * zgate # eq 8
|
|
177
|
+
nnew = fgate * n + igate # eq 9
|
|
178
|
+
hnew = ogate * cnew / nnew # eq 10
|
|
179
|
+
|
|
180
|
+
return [hnew, cnew, nnew, mnew] # 4 x (B, H)
|
|
@@ -179,8 +179,25 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
179
179
|
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
|
180
180
|
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
181
181
|
|
|
182
|
+
quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
|
|
183
|
+
|
|
184
|
+
quantile_preds = torch.unflatten(
|
|
185
|
+
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
|
|
186
|
+
)
|
|
187
|
+
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
|
188
|
+
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
182
189
|
return quantile_preds, hidden_states
|
|
183
190
|
|
|
191
|
+
def _forward_model(self, input: torch.Tensor):
|
|
192
|
+
hidden_states = self.input_patch_embedding(input)
|
|
193
|
+
|
|
194
|
+
for block in self.blocks:
|
|
195
|
+
hidden_states = block(hidden_states)
|
|
196
|
+
|
|
197
|
+
hidden_states = self.out_norm(hidden_states)
|
|
198
|
+
|
|
199
|
+
return self.output_patch_embedding(hidden_states)
|
|
200
|
+
|
|
184
201
|
def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
|
|
185
202
|
training_quantile_levels = self.config.quantiles
|
|
186
203
|
if min(quantile_levels) < min(training_quantile_levels) or max(quantile_levels) > max(training_quantile_levels):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.2 → tirex_mirror-2025.10.3}/src/tirex_mirror.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|