tirex-mirror 2025.10.16__tar.gz → 2025.10.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tirex_mirror-2025.10.16/src/tirex_mirror.egg-info → tirex_mirror-2025.10.17}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/pyproject.toml +1 -1
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/models/slstm/cell.py +3 -3
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/models/slstm/layer.py +5 -3
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/models/tirex.py +38 -35
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/LICENSE +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/MANIFEST.in +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/NOTICE.txt +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/README.md +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/setup.cfg +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/__init__.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/api_adapter/__init__.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/api_adapter/forecast.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/api_adapter/gluon.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/api_adapter/hf_data.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/api_adapter/standard_adapter.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/base.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/models/__init__.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/models/patcher.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/models/slstm/block.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/util.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex_mirror.egg-info/SOURCES.txt +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex_mirror.egg-info/requires.txt +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex_mirror.egg-info/top_level.txt +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/tests/test_chronos_zs.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/tests/test_forecast.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/tests/test_forecast_adapter.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/tests/test_jupyterlab.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/tests/test_slstm_torch_vs_cuda.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/tests/test_standard_adapter.py +0 -0
- {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/tests/test_util_freq.py +0 -0
|
@@ -100,7 +100,7 @@ class sLSTMCell(nn.Module):
|
|
|
100
100
|
|
|
101
101
|
def _get_input(self, x: torch.Tensor) -> torch.Tensor:
|
|
102
102
|
assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
|
|
103
|
-
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {
|
|
103
|
+
f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {x.size(-1)}."
|
|
104
104
|
)
|
|
105
105
|
return x.view(x.shape[0], x.shape[1], self.config.num_gates, self.config.num_heads, -1).permute(1, 0, 2, 3, 4)
|
|
106
106
|
|
|
@@ -128,7 +128,7 @@ class sLSTMCellTorch:
|
|
|
128
128
|
states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
|
|
129
129
|
R: torch.Tensor, # [K, R*H, H] - K num_heads
|
|
130
130
|
b: torch.Tensor, # [T*H]
|
|
131
|
-
) -> tuple[torch.Tensor, torch.Tensor
|
|
131
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
132
132
|
num_gates = 4
|
|
133
133
|
num_heads = R.shape[0]
|
|
134
134
|
S, B, _ = x.shape
|
|
@@ -167,7 +167,7 @@ class sLSTMCellTorch:
|
|
|
167
167
|
iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
|
|
168
168
|
|
|
169
169
|
# Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
|
|
170
|
-
logfplusm = m + F.logsigmoid(fraw) # eq 15
|
|
170
|
+
logfplusm = m + F.logsigmoid(torch.clamp(fraw, max=15)) # eq 15 # Clamp to avoid subnomals
|
|
171
171
|
mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
|
|
172
172
|
ogate = torch.sigmoid(oraw) # eq 14
|
|
173
173
|
igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
|
|
@@ -20,7 +20,7 @@ class sLSTMLayer(nn.Module):
|
|
|
20
20
|
self.ogate = LinearHeadwiseExpand(in_features, num_heads)
|
|
21
21
|
|
|
22
22
|
self.slstm_cell = sLSTMCell(self.config, backend)
|
|
23
|
-
self.group_norm = MultiHeadLayerNorm(ndim=in_features)
|
|
23
|
+
self.group_norm = MultiHeadLayerNorm(ndim=in_features, num_heads=num_heads)
|
|
24
24
|
|
|
25
25
|
def forward(self, x: torch.Tensor, slstm_state: torch.Tensor | None = None) -> torch.Tensor:
|
|
26
26
|
x_g = torch.cat((self.fgate(x), self.igate(x), self.zgate(x), self.ogate(x)), dim=-1)
|
|
@@ -50,18 +50,20 @@ class LinearHeadwiseExpand(nn.Module):
|
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class MultiHeadLayerNorm(nn.Module):
|
|
53
|
-
def __init__(self, ndim: int):
|
|
53
|
+
def __init__(self, ndim: int, num_heads: int):
|
|
54
54
|
super().__init__()
|
|
55
55
|
self.weight = nn.Parameter(torch.zeros(ndim))
|
|
56
|
+
self.num_heads = num_heads
|
|
56
57
|
|
|
57
58
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
58
59
|
assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
|
|
59
60
|
B, NH, S, DH = input.shape
|
|
60
61
|
|
|
62
|
+
assert NH == self.num_heads
|
|
61
63
|
gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
|
|
62
64
|
gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
|
|
63
65
|
residual_weight = 1.0 + self.weight
|
|
64
|
-
out = F.group_norm(gn_in_2, num_groups=
|
|
66
|
+
out = F.group_norm(gn_in_2, num_groups=self.num_heads, weight=residual_weight)
|
|
65
67
|
# (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
|
|
66
68
|
out = out.view(B, S, NH, DH).transpose(1, 2)
|
|
67
69
|
return out
|
|
@@ -79,12 +79,18 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
79
79
|
training_quantile_levels = self.config.quantiles
|
|
80
80
|
|
|
81
81
|
if set(quantile_levels).issubset(set(training_quantile_levels)):
|
|
82
|
-
|
|
82
|
+
quantile_indices = torch.tensor(
|
|
83
|
+
[training_quantile_levels.index(q) for q in quantile_levels],
|
|
84
|
+
dtype=torch.long,
|
|
85
|
+
device=predictions.device,
|
|
86
|
+
)
|
|
87
|
+
quantiles = torch.index_select(predictions, dim=-1, index=quantile_indices)
|
|
83
88
|
else:
|
|
84
89
|
quantiles = self._interpolate_quantiles(predictions, quantile_levels)
|
|
85
90
|
|
|
86
91
|
# median as mean
|
|
87
|
-
|
|
92
|
+
median_idx = torch.tensor([training_quantile_levels.index(0.5)], dtype=torch.long, device=predictions.device)
|
|
93
|
+
mean = torch.index_select(predictions, dim=-1, index=median_idx).squeeze(-1)
|
|
88
94
|
return quantiles, mean
|
|
89
95
|
|
|
90
96
|
@torch.inference_mode()
|
|
@@ -105,24 +111,8 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
105
111
|
|
|
106
112
|
context = context.to(dtype=torch.float32)
|
|
107
113
|
while remaining > 0:
|
|
108
|
-
if context.shape[-1] > max_context:
|
|
109
|
-
context = context[..., -max_context:]
|
|
110
|
-
if context.shape[-1] < min_context:
|
|
111
|
-
pad = torch.full(
|
|
112
|
-
(context.shape[0], min_context - context.shape[-1]),
|
|
113
|
-
fill_value=torch.nan,
|
|
114
|
-
device=context.device,
|
|
115
|
-
dtype=context.dtype,
|
|
116
|
-
)
|
|
117
|
-
context = torch.concat((pad, context), dim=1)
|
|
118
|
-
tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
|
|
119
114
|
fut_rollouts = min(remaining, max_accelerated_rollout_steps)
|
|
120
|
-
|
|
121
|
-
prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=fut_rollouts)
|
|
122
|
-
prediction = prediction[:, :, -fut_rollouts:, :].to(tokenized_tensor) # predicted token
|
|
123
|
-
# [bs, num_quantiles, num_predicted_token, output_patch_size]
|
|
124
|
-
prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
|
|
125
|
-
prediction = prediction.flatten(start_dim=2)
|
|
115
|
+
prediction, fut_rollouts = self._forecast_single_step(context, max_context, min_context, fut_rollouts)
|
|
126
116
|
|
|
127
117
|
predictions.append(prediction)
|
|
128
118
|
remaining -= fut_rollouts
|
|
@@ -134,6 +124,33 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
134
124
|
|
|
135
125
|
return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32)
|
|
136
126
|
|
|
127
|
+
def _forecast_single_step(
|
|
128
|
+
self,
|
|
129
|
+
context: torch.Tensor,
|
|
130
|
+
max_context: int,
|
|
131
|
+
min_context: int,
|
|
132
|
+
new_patch_count: int = 1,
|
|
133
|
+
) -> tuple[torch.Tensor, int]:
|
|
134
|
+
if context.shape[-1] > max_context:
|
|
135
|
+
context = context[..., -max_context:]
|
|
136
|
+
if context.shape[-1] < min_context:
|
|
137
|
+
pad = torch.full(
|
|
138
|
+
(context.shape[0], min_context - context.shape[-1]),
|
|
139
|
+
fill_value=torch.nan,
|
|
140
|
+
device=context.device,
|
|
141
|
+
dtype=context.dtype,
|
|
142
|
+
)
|
|
143
|
+
context = torch.concat((pad, context), dim=1)
|
|
144
|
+
|
|
145
|
+
tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
|
|
146
|
+
prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=new_patch_count)
|
|
147
|
+
prediction = prediction[:, :, -new_patch_count:, :].to(tokenized_tensor) # predicted token
|
|
148
|
+
# Shape: [bs, num_quantiles, num_predicted_token, output_patch_size]
|
|
149
|
+
prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
|
|
150
|
+
prediction = prediction.flatten(start_dim=2)
|
|
151
|
+
|
|
152
|
+
return prediction, new_patch_count
|
|
153
|
+
|
|
137
154
|
def _forward_model_tokenized(
|
|
138
155
|
self,
|
|
139
156
|
input_token: torch.Tensor,
|
|
@@ -165,21 +182,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
165
182
|
|
|
166
183
|
input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
|
|
167
184
|
|
|
168
|
-
hidden_states = self.
|
|
169
|
-
|
|
170
|
-
for block in self.blocks:
|
|
171
|
-
hidden_states = block(hidden_states)
|
|
172
|
-
|
|
173
|
-
hidden_states = self.out_norm(hidden_states)
|
|
174
|
-
|
|
175
|
-
quantile_preds = self.output_patch_embedding(hidden_states)
|
|
176
|
-
quantile_preds = torch.unflatten(
|
|
177
|
-
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
|
|
178
|
-
)
|
|
179
|
-
quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
|
|
180
|
-
# quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
|
|
181
|
-
|
|
182
|
-
quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
|
|
185
|
+
quantile_preds, hidden_states = self._forward_model(torch.cat((input_token, input_mask), dim=2))
|
|
183
186
|
|
|
184
187
|
quantile_preds = torch.unflatten(
|
|
185
188
|
quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
|
|
@@ -196,7 +199,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
|
|
|
196
199
|
|
|
197
200
|
hidden_states = self.out_norm(hidden_states)
|
|
198
201
|
|
|
199
|
-
return self.output_patch_embedding(hidden_states)
|
|
202
|
+
return self.output_patch_embedding(hidden_states), hidden_states
|
|
200
203
|
|
|
201
204
|
def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
|
|
202
205
|
training_quantile_levels = self.config.quantiles
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex/api_adapter/standard_adapter.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.10.16 → tirex_mirror-2025.10.17}/src/tirex_mirror.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|