tirex-mirror 2025.10.16__tar.gz → 2025.10.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {tirex_mirror-2025.10.16/src/tirex_mirror.egg-info → tirex_mirror-2025.10.18}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/pyproject.toml +1 -1
  3. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/models/slstm/cell.py +3 -3
  4. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/models/slstm/layer.py +5 -3
  5. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/models/tirex.py +38 -35
  6. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
  7. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/LICENSE +0 -0
  8. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/LICENSE_MIRROR.txt +0 -0
  9. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/MANIFEST.in +0 -0
  10. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/NOTICE.txt +0 -0
  11. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/README.md +0 -0
  12. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/setup.cfg +0 -0
  13. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/__init__.py +0 -0
  14. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/api_adapter/__init__.py +0 -0
  15. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/api_adapter/forecast.py +0 -0
  16. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/api_adapter/gluon.py +0 -0
  17. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/api_adapter/hf_data.py +0 -0
  18. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/api_adapter/standard_adapter.py +0 -0
  19. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/base.py +0 -0
  20. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/models/__init__.py +0 -0
  21. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/models/patcher.py +0 -0
  22. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/models/slstm/block.py +0 -0
  23. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex/util.py +0 -0
  24. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex_mirror.egg-info/SOURCES.txt +0 -0
  25. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  26. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex_mirror.egg-info/requires.txt +0 -0
  27. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  28. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/tests/test_chronos_zs.py +0 -0
  29. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/tests/test_forecast.py +0 -0
  30. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/tests/test_forecast_adapter.py +0 -0
  31. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/tests/test_jupyterlab.py +0 -0
  32. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/tests/test_slstm_torch_vs_cuda.py +0 -0
  33. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/tests/test_standard_adapter.py +0 -0
  34. {tirex_mirror-2025.10.16 → tirex_mirror-2025.10.18}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.16
3
+ Version: 2025.10.18
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.10.16"
3
+ version = "2025.10.18"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -100,7 +100,7 @@ class sLSTMCell(nn.Module):
100
100
 
101
101
  def _get_input(self, x: torch.Tensor) -> torch.Tensor:
102
102
  assert x.shape[-1] == self.config.embedding_dim * self.config.num_gates, (
103
- f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {input.size(-1)}."
103
+ f"Input size mismatch: Expected input size {self.config.embedding_dim * self.config.num_gates}, but got {x.size(-1)}."
104
104
  )
105
105
  return x.view(x.shape[0], x.shape[1], self.config.num_gates, self.config.num_heads, -1).permute(1, 0, 2, 3, 4)
106
106
 
@@ -128,7 +128,7 @@ class sLSTMCellTorch:
128
128
  states: torch.Tensor, # [4, B, H] only the first is used for recurrence!
129
129
  R: torch.Tensor, # [K, R*H, H] - K num_heads
130
130
  b: torch.Tensor, # [T*H]
131
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
131
+ ) -> tuple[torch.Tensor, torch.Tensor]:
132
132
  num_gates = 4
133
133
  num_heads = R.shape[0]
134
134
  S, B, _ = x.shape
@@ -167,7 +167,7 @@ class sLSTMCellTorch:
167
167
  iraw, fraw, zraw, oraw = torch.unbind(raw.view(raw.shape[0], 4, -1), dim=1)
168
168
 
169
169
  # Equations reference the xlstm paper on page 4: https://arxiv.org/pdf/2405.04517
170
- logfplusm = m + F.logsigmoid(fraw) # eq 15
170
+ logfplusm = m + F.logsigmoid(torch.clamp(fraw, max=15)) # eq 15 # Clamp to avoid subnomals
171
171
  mnew = torch.where(torch.all(n == 0.0), iraw, torch.max(iraw, logfplusm)) # eq 15
172
172
  ogate = torch.sigmoid(oraw) # eq 14
173
173
  igate = torch.minimum(torch.exp(iraw - mnew), torch.ones_like(iraw)) # eq 16
@@ -20,7 +20,7 @@ class sLSTMLayer(nn.Module):
20
20
  self.ogate = LinearHeadwiseExpand(in_features, num_heads)
21
21
 
22
22
  self.slstm_cell = sLSTMCell(self.config, backend)
23
- self.group_norm = MultiHeadLayerNorm(ndim=in_features)
23
+ self.group_norm = MultiHeadLayerNorm(ndim=in_features, num_heads=num_heads)
24
24
 
25
25
  def forward(self, x: torch.Tensor, slstm_state: torch.Tensor | None = None) -> torch.Tensor:
26
26
  x_g = torch.cat((self.fgate(x), self.igate(x), self.zgate(x), self.ogate(x)), dim=-1)
@@ -50,18 +50,20 @@ class LinearHeadwiseExpand(nn.Module):
50
50
 
51
51
 
52
52
  class MultiHeadLayerNorm(nn.Module):
53
- def __init__(self, ndim: int):
53
+ def __init__(self, ndim: int, num_heads: int):
54
54
  super().__init__()
55
55
  self.weight = nn.Parameter(torch.zeros(ndim))
56
+ self.num_heads = num_heads
56
57
 
57
58
  def forward(self, input: torch.Tensor) -> torch.Tensor:
58
59
  assert input.dim() == 4, "Input must be 4D tensor (B, NH, S, DH)"
59
60
  B, NH, S, DH = input.shape
60
61
 
62
+ assert NH == self.num_heads
61
63
  gn_in_1 = input.transpose(1, 2) # (B, S, NH, DH)
62
64
  gn_in_2 = gn_in_1.reshape(B * S, NH * DH) # (B * S, NH * DH)
63
65
  residual_weight = 1.0 + self.weight
64
- out = F.group_norm(gn_in_2, num_groups=NH, weight=residual_weight)
66
+ out = F.group_norm(gn_in_2, num_groups=self.num_heads, weight=residual_weight)
65
67
  # (B * S), (NH * DH) -> (B, S, NH, DH) -> (B, NH, S, DH)
66
68
  out = out.view(B, S, NH, DH).transpose(1, 2)
67
69
  return out
@@ -79,12 +79,18 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
79
79
  training_quantile_levels = self.config.quantiles
80
80
 
81
81
  if set(quantile_levels).issubset(set(training_quantile_levels)):
82
- quantiles = predictions[..., [training_quantile_levels.index(q) for q in quantile_levels]]
82
+ quantile_indices = torch.tensor(
83
+ [training_quantile_levels.index(q) for q in quantile_levels],
84
+ dtype=torch.long,
85
+ device=predictions.device,
86
+ )
87
+ quantiles = torch.index_select(predictions, dim=-1, index=quantile_indices)
83
88
  else:
84
89
  quantiles = self._interpolate_quantiles(predictions, quantile_levels)
85
90
 
86
91
  # median as mean
87
- mean = predictions[:, :, training_quantile_levels.index(0.5)]
92
+ median_idx = torch.tensor([training_quantile_levels.index(0.5)], dtype=torch.long, device=predictions.device)
93
+ mean = torch.index_select(predictions, dim=-1, index=median_idx).squeeze(-1)
88
94
  return quantiles, mean
89
95
 
90
96
  @torch.inference_mode()
@@ -105,24 +111,8 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
105
111
 
106
112
  context = context.to(dtype=torch.float32)
107
113
  while remaining > 0:
108
- if context.shape[-1] > max_context:
109
- context = context[..., -max_context:]
110
- if context.shape[-1] < min_context:
111
- pad = torch.full(
112
- (context.shape[0], min_context - context.shape[-1]),
113
- fill_value=torch.nan,
114
- device=context.device,
115
- dtype=context.dtype,
116
- )
117
- context = torch.concat((pad, context), dim=1)
118
- tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
119
114
  fut_rollouts = min(remaining, max_accelerated_rollout_steps)
120
- with torch.no_grad():
121
- prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=fut_rollouts)
122
- prediction = prediction[:, :, -fut_rollouts:, :].to(tokenized_tensor) # predicted token
123
- # [bs, num_quantiles, num_predicted_token, output_patch_size]
124
- prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
125
- prediction = prediction.flatten(start_dim=2)
115
+ prediction, fut_rollouts = self._forecast_single_step(context, max_context, min_context, fut_rollouts)
126
116
 
127
117
  predictions.append(prediction)
128
118
  remaining -= fut_rollouts
@@ -134,6 +124,33 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
134
124
 
135
125
  return torch.cat(predictions, dim=-1)[..., :prediction_length].to(dtype=torch.float32)
136
126
 
127
+ def _forecast_single_step(
128
+ self,
129
+ context: torch.Tensor,
130
+ max_context: int,
131
+ min_context: int,
132
+ new_patch_count: int = 1,
133
+ ) -> tuple[torch.Tensor, int]:
134
+ if context.shape[-1] > max_context:
135
+ context = context[..., -max_context:]
136
+ if context.shape[-1] < min_context:
137
+ pad = torch.full(
138
+ (context.shape[0], min_context - context.shape[-1]),
139
+ fill_value=torch.nan,
140
+ device=context.device,
141
+ dtype=context.dtype,
142
+ )
143
+ context = torch.concat((pad, context), dim=1)
144
+
145
+ tokenized_tensor, tokenizer_state = self.tokenizer.context_input_transform(context)
146
+ prediction, _ = self._forward_model_tokenized(input_token=tokenized_tensor, rollouts=new_patch_count)
147
+ prediction = prediction[:, :, -new_patch_count:, :].to(tokenized_tensor) # predicted token
148
+ # Shape: [bs, num_quantiles, num_predicted_token, output_patch_size]
149
+ prediction = self.tokenizer.output_transform(prediction, tokenizer_state)
150
+ prediction = prediction.flatten(start_dim=2)
151
+
152
+ return prediction, new_patch_count
153
+
137
154
  def _forward_model_tokenized(
138
155
  self,
139
156
  input_token: torch.Tensor,
@@ -165,21 +182,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
165
182
 
166
183
  input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
167
184
 
168
- hidden_states = self.input_patch_embedding(torch.cat((input_token, input_mask), dim=2))
169
-
170
- for block in self.blocks:
171
- hidden_states = block(hidden_states)
172
-
173
- hidden_states = self.out_norm(hidden_states)
174
-
175
- quantile_preds = self.output_patch_embedding(hidden_states)
176
- quantile_preds = torch.unflatten(
177
- quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
178
- )
179
- quantile_preds = torch.transpose(quantile_preds, 1, 2) # switch quantile and num_token_dimension
180
- # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
181
-
182
- quantile_preds = self._forward_model(torch.cat((input_token, input_mask), dim=2))
185
+ quantile_preds, hidden_states = self._forward_model(torch.cat((input_token, input_mask), dim=2))
183
186
 
184
187
  quantile_preds = torch.unflatten(
185
188
  quantile_preds, -1, (len(self.config.quantiles), self.config.output_patch_size)
@@ -196,7 +199,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
196
199
 
197
200
  hidden_states = self.out_norm(hidden_states)
198
201
 
199
- return self.output_patch_embedding(hidden_states)
202
+ return self.output_patch_embedding(hidden_states), hidden_states
200
203
 
201
204
  def _interpolate_quantiles(self, predictions: torch.Tensor, quantile_levels: list[float]):
202
205
  training_quantile_levels = self.config.quantiles
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.16
3
+ Version: 2025.10.18
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT