tirex-mirror 2025.10.18__tar.gz → 2025.10.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {tirex_mirror-2025.10.18/src/tirex_mirror.egg-info → tirex_mirror-2025.10.21}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/pyproject.toml +1 -1
  3. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/base.py +3 -1
  4. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/models/slstm/cell.py +3 -1
  5. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21/src/tirex_mirror.egg-info}/PKG-INFO +1 -1
  6. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex_mirror.egg-info/SOURCES.txt +1 -0
  7. tirex_mirror-2025.10.21/tests/test_compile.py +39 -0
  8. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/LICENSE +0 -0
  9. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/LICENSE_MIRROR.txt +0 -0
  10. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/MANIFEST.in +0 -0
  11. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/NOTICE.txt +0 -0
  12. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/README.md +0 -0
  13. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/setup.cfg +0 -0
  14. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/__init__.py +0 -0
  15. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/api_adapter/__init__.py +0 -0
  16. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/api_adapter/forecast.py +0 -0
  17. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/api_adapter/gluon.py +0 -0
  18. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/api_adapter/hf_data.py +0 -0
  19. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/api_adapter/standard_adapter.py +0 -0
  20. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/models/__init__.py +0 -0
  21. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/models/patcher.py +0 -0
  22. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/models/slstm/block.py +0 -0
  23. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/models/slstm/layer.py +0 -0
  24. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/models/tirex.py +0 -0
  25. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex/util.py +0 -0
  26. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  27. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex_mirror.egg-info/requires.txt +0 -0
  28. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  29. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/tests/test_chronos_zs.py +0 -0
  30. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/tests/test_forecast.py +0 -0
  31. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/tests/test_forecast_adapter.py +0 -0
  32. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/tests/test_jupyterlab.py +0 -0
  33. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/tests/test_slstm_torch_vs_cuda.py +0 -0
  34. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/tests/test_standard_adapter.py +0 -0
  35. {tirex_mirror-2025.10.18 → tirex_mirror-2025.10.21}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.18
3
+ Version: 2025.10.21
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.10.18"
3
+ version = "2025.10.21"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -76,7 +76,9 @@ class PretrainedModel(ABC):
76
76
  model = model.to(device)
77
77
 
78
78
  if compile and backend == "torch":
79
- sLSTMCellTorch.slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward, mode="max-autotune")
79
+ compiled_slstm_forward = torch.compile(sLSTMCellTorch.slstm_forward)
80
+ for block in model.blocks:
81
+ block.slstm_layer.slstm_cell._impl_forward_torch = compiled_slstm_forward
80
82
  return model
81
83
 
82
84
  @classmethod
@@ -38,6 +38,8 @@ class sLSTMCell(nn.Module):
38
38
 
39
39
  self._bias_ = nn.Parameter(torch.empty((config.num_heads * config.num_gates * config.head_dim), dtype=None))
40
40
 
41
+ self._impl_forward_torch = sLSTMCellTorch.slstm_forward
42
+
41
43
  def forward(self, input: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
42
44
  input = self._get_input(input)
43
45
  state = self._get_state(input, state)
@@ -62,7 +64,7 @@ class sLSTMCell(nn.Module):
62
64
  .reshape(-1)
63
65
  )
64
66
 
65
- return sLSTMCellTorch.slstm_forward(input, state, recurrent_kernel, bias)
67
+ return self._impl_forward_torch(input, state, recurrent_kernel, bias)
66
68
 
67
69
  def _impl_cuda(self, input: torch.Tensor, state: torch.Tensor) -> torch.Tensor:
68
70
  if input.device.type != "cuda":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.10.18
3
+ Version: 2025.10.21
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -24,6 +24,7 @@ src/tirex_mirror.egg-info/dependency_links.txt
24
24
  src/tirex_mirror.egg-info/requires.txt
25
25
  src/tirex_mirror.egg-info/top_level.txt
26
26
  tests/test_chronos_zs.py
27
+ tests/test_compile.py
27
28
  tests/test_forecast.py
28
29
  tests/test_forecast_adapter.py
29
30
  tests/test_jupyterlab.py
@@ -0,0 +1,39 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import time
5
+
6
+ import torch
7
+
8
+ from tirex import load_model
9
+
10
+
11
+ def measure_model_execution_time(model):
12
+ context = torch.randn((1, 128))
13
+
14
+ _, __ = model.forecast(context, prediction_length=32) # warmup
15
+
16
+ start = time.time()
17
+ _, __ = model.forecast(context, prediction_length=32)
18
+ end = time.time()
19
+
20
+ return end - start
21
+
22
+
23
+ def test_compileable():
24
+ model = load_model("NX-AI/TiRex", backend="torch", compile=True)
25
+
26
+ context = torch.randn((1, 128))
27
+ _, mean = model.forecast(context, prediction_length=32)
28
+
29
+ assert mean.shape == (1, 32)
30
+
31
+
32
+ def test_compiled_faster():
33
+ model_uncompiled = load_model("NX-AI/TiRex", backend="torch", compile=False)
34
+ model_compiled = load_model("NX-AI/TiRex", backend="torch", compile=True)
35
+
36
+ time_uncompiled = measure_model_execution_time(model_uncompiled)
37
+ time_compiled = measure_model_execution_time(model_compiled)
38
+
39
+ assert time_compiled < time_uncompiled, "Compiled model has to be faster than uncompiled one!"