tirex-mirror 2025.9.9__tar.gz → 2025.9.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {tirex_mirror-2025.9.9/src/tirex_mirror.egg-info → tirex_mirror-2025.9.10}/PKG-INFO +3 -2
  2. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/pyproject.toml +5 -5
  3. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/base.py +10 -2
  4. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/models/slstm/cell.py +4 -4
  5. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/models/tirex.py +2 -4
  6. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10/src/tirex_mirror.egg-info}/PKG-INFO +3 -2
  7. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/SOURCES.txt +1 -0
  8. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/requires.txt +2 -1
  9. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/tests/test_chronos_zs.py +4 -1
  10. tirex_mirror-2025.9.10/tests/test_slstm_torch_vs_cuda.py +82 -0
  11. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/LICENSE +0 -0
  12. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/LICENSE_MIRROR.txt +0 -0
  13. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/MANIFEST.in +0 -0
  14. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/NOTICE.txt +0 -0
  15. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/README.md +0 -0
  16. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/setup.cfg +0 -0
  17. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/__init__.py +0 -0
  18. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/__init__.py +0 -0
  19. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/forecast.py +0 -0
  20. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/gluon.py +0 -0
  21. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/hf_data.py +0 -0
  22. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/api_adapter/standard_adapter.py +0 -0
  23. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/models/__init__.py +0 -0
  24. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/models/patcher.py +0 -0
  25. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/models/slstm/block.py +0 -0
  26. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/models/slstm/layer.py +0 -0
  27. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex/util.py +0 -0
  28. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  29. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  30. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/tests/test_forecast.py +0 -0
  31. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/tests/test_forecast_adapter.py +0 -0
  32. {tirex_mirror-2025.9.9 → tirex_mirror-2025.9.10}/tests/test_standard_adapter.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.9.9
3
+ Version: 2025.9.10
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -64,7 +64,6 @@ License-File: LICENSE
64
64
  License-File: LICENSE_MIRROR.txt
65
65
  License-File: NOTICE.txt
66
66
  Requires-Dist: torch
67
- Requires-Dist: torchvision
68
67
  Requires-Dist: einops
69
68
  Requires-Dist: huggingface-hub
70
69
  Requires-Dist: numpy
@@ -82,6 +81,7 @@ Provides-Extra: hfdataset
82
81
  Requires-Dist: datasets; extra == "hfdataset"
83
82
  Provides-Extra: test
84
83
  Requires-Dist: fev; extra == "test"
84
+ Requires-Dist: pytest; extra == "test"
85
85
  Provides-Extra: all
86
86
  Requires-Dist: xlstm; extra == "all"
87
87
  Requires-Dist: ninja; extra == "all"
@@ -89,6 +89,7 @@ Requires-Dist: ipykernel; extra == "all"
89
89
  Requires-Dist: matplotlib; extra == "all"
90
90
  Requires-Dist: gluonts; extra == "all"
91
91
  Requires-Dist: datasets; extra == "all"
92
+ Requires-Dist: pytest; extra == "all"
92
93
  Requires-Dist: fev; extra == "all"
93
94
  Dynamic: license-file
94
95
 
@@ -1,18 +1,18 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.09.09"
3
+ version = "2025.09.10"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
7
7
  classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent",]
8
8
  keywords = [ "TiRex", "xLSTM", "Time Series", "Zero-shot", "Deep Learning",]
9
- dependencies = [ "torch", "torchvision", "einops", "huggingface-hub", "numpy", "pandas", "tqdm",]
9
+ dependencies = [ "torch", "einops", "huggingface-hub", "numpy", "pandas", "tqdm",]
10
10
  [[project.authors]]
11
11
  name = "Arpad Rozsas"
12
12
  email = "rozsasarpi@gmail.com"
13
13
 
14
14
  [build-system]
15
- requires = [ "setuptools>=42", "wheel",]
15
+ requires = [ "setuptools>=77.0.3", "wheel",]
16
16
  build-backend = "setuptools.build_meta"
17
17
 
18
18
  [project.license]
@@ -27,8 +27,8 @@ cuda = [ "xlstm", "ninja",]
27
27
  notebooks = [ "ipykernel", "matplotlib",]
28
28
  gluonts = [ "gluonts",]
29
29
  hfdataset = [ "datasets",]
30
- test = [ "fev",]
31
- all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "gluonts", "datasets", "fev",]
30
+ test = [ "fev", "pytest",]
31
+ all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "gluonts", "datasets", "pytest", "fev",]
32
32
 
33
33
  [tool.docformatter]
34
34
  diff = false
@@ -15,6 +15,15 @@ def skip_cuda():
15
15
  return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
16
16
 
17
17
 
18
+ def xlstm_available():
19
+ try:
20
+ from xlstm.blocks.slstm.cell import sLSTMCellConfig, sLSTMCellFuncGenerator
21
+
22
+ return True
23
+ except ModuleNotFoundError:
24
+ return False
25
+
26
+
18
27
  def parse_hf_repo_id(path):
19
28
  parts = path.split("/")
20
29
  return "/".join(parts[0:2])
@@ -86,8 +95,7 @@ def load_model(
86
95
  """
87
96
 
88
97
  if backend is None:
89
- backend = "torch" if skip_cuda() else "cuda"
90
- assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
98
+ backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
91
99
 
92
100
  try:
93
101
  _, model_id = parse_hf_repo_id(path).split("/")
@@ -3,6 +3,7 @@
3
3
 
4
4
  import warnings
5
5
  from dataclasses import asdict, dataclass
6
+ from typing import Literal
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
@@ -15,10 +16,8 @@ from tirex.util import dataclass_from_dict
15
16
  class sLSTMBlockConfig:
16
17
  embedding_dim: int
17
18
  num_heads: int
18
- num_blocks: int
19
19
  ffn_proj_factor: float = 2.6667
20
-
21
- num_states: int = 4 # this is for the sLSTM, a standard LSTM has 2
20
+ num_states: int = 4
22
21
  num_gates: int = 4
23
22
 
24
23
  @property
@@ -27,8 +26,9 @@ class sLSTMBlockConfig:
27
26
 
28
27
 
29
28
  class sLSTMCell(nn.Module):
30
- def __init__(self, config: sLSTMBlockConfig, backend: str):
29
+ def __init__(self, config: sLSTMBlockConfig, backend: Literal["torch", "cuda"]):
31
30
  super().__init__()
31
+ assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
32
32
  self.config = config
33
33
  self.backend = backend
34
34
 
@@ -33,10 +33,10 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
33
33
  super().__init__()
34
34
  self.config = TiRexZeroConfig(**model_config, train_ctx_len=train_ctx_len, nan_mask_value=0)
35
35
  assert self.config.input_patch_size == self.config.output_patch_size
36
- self.backend = backend
37
36
 
38
37
  self.tokenizer = PatchedUniTokenizer(patch_size=self.config.input_patch_size)
39
38
 
39
+ num_blocks = self.config.block_kwargs["num_blocks"]
40
40
  block_config = dataclass_from_dict(sLSTMBlockConfig, self.config.block_kwargs)
41
41
  self.input_patch_embedding = ResidualBlock(
42
42
  in_dim=self.config.input_patch_size * 2,
@@ -44,9 +44,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
44
44
  out_dim=block_config.embedding_dim,
45
45
  )
46
46
 
47
- self.blocks = nn.ModuleList(
48
- [sLSTMBlock(block_config, backend=self.backend) for i in range(block_config.num_blocks)]
49
- )
47
+ self.blocks = nn.ModuleList([sLSTMBlock(block_config, backend) for i in range(num_blocks)])
50
48
 
51
49
  self.out_norm = RMSNorm(block_config.embedding_dim)
52
50
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.9.9
3
+ Version: 2025.9.10
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -64,7 +64,6 @@ License-File: LICENSE
64
64
  License-File: LICENSE_MIRROR.txt
65
65
  License-File: NOTICE.txt
66
66
  Requires-Dist: torch
67
- Requires-Dist: torchvision
68
67
  Requires-Dist: einops
69
68
  Requires-Dist: huggingface-hub
70
69
  Requires-Dist: numpy
@@ -82,6 +81,7 @@ Provides-Extra: hfdataset
82
81
  Requires-Dist: datasets; extra == "hfdataset"
83
82
  Provides-Extra: test
84
83
  Requires-Dist: fev; extra == "test"
84
+ Requires-Dist: pytest; extra == "test"
85
85
  Provides-Extra: all
86
86
  Requires-Dist: xlstm; extra == "all"
87
87
  Requires-Dist: ninja; extra == "all"
@@ -89,6 +89,7 @@ Requires-Dist: ipykernel; extra == "all"
89
89
  Requires-Dist: matplotlib; extra == "all"
90
90
  Requires-Dist: gluonts; extra == "all"
91
91
  Requires-Dist: datasets; extra == "all"
92
+ Requires-Dist: pytest; extra == "all"
92
93
  Requires-Dist: fev; extra == "all"
93
94
  Dynamic: license-file
94
95
 
@@ -26,4 +26,5 @@ src/tirex_mirror.egg-info/top_level.txt
26
26
  tests/test_chronos_zs.py
27
27
  tests/test_forecast.py
28
28
  tests/test_forecast_adapter.py
29
+ tests/test_slstm_torch_vs_cuda.py
29
30
  tests/test_standard_adapter.py
@@ -1,5 +1,4 @@
1
1
  torch
2
- torchvision
3
2
  einops
4
3
  huggingface-hub
5
4
  numpy
@@ -13,6 +12,7 @@ ipykernel
13
12
  matplotlib
14
13
  gluonts
15
14
  datasets
15
+ pytest
16
16
  fev
17
17
 
18
18
  [cuda]
@@ -31,3 +31,4 @@ matplotlib
31
31
 
32
32
  [test]
33
33
  fev
34
+ pytest
@@ -57,7 +57,10 @@ def test_chronos_single(tirex_model, benchmark):
57
57
  assert evaluation_summary["MASE"] < 0.99, "MASE on the electricity task needs to be less than 0.99"
58
58
 
59
59
 
60
- @pytest.mark.skipif(os.getenv("CI"), reason="Skip full chromos benchmarking in the CI")
60
+ @pytest.mark.skipif(
61
+ os.getenv("CI") is not None and os.getenv("CI_RUN_BENCHMARKS") is None,
62
+ reason="Skip Chronos benchmarks in CI",
63
+ )
61
64
  def test_chronos_all(tirex_model, benchmark):
62
65
  tasks_wql = []
63
66
  tasks_mase = []
@@ -0,0 +1,82 @@
1
+ import copy
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from tirex.models.slstm.cell import sLSTMBlockConfig, sLSTMCell
7
+
8
+ pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="This test needs CUDA.")
9
+
10
+
11
+ @pytest.mark.parametrize("with_in_state", [True, False])
12
+ def test_with_in_state(with_in_state):
13
+ run_slstm_torch_vs_cuda(with_in_state=with_in_state)
14
+
15
+
16
+ @pytest.mark.parametrize("sequence_length", [1, 2, 4])
17
+ def test_sequence_length(sequence_length):
18
+ run_slstm_torch_vs_cuda(sequence_length=sequence_length)
19
+
20
+
21
+ @pytest.mark.parametrize("batch_size", [1, 2, 4])
22
+ def test_batch_size(batch_size):
23
+ run_slstm_torch_vs_cuda(batch_size=batch_size)
24
+
25
+
26
+ @pytest.mark.parametrize("num_heads", [4, 1])
27
+ def test_num_heads(num_heads):
28
+ run_slstm_torch_vs_cuda(num_heads=num_heads, with_in_state=True, atol=5e-5)
29
+
30
+
31
+ @pytest.mark.parametrize("hidden_size", [64, 8])
32
+ def test_hidden_size(hidden_size):
33
+ run_slstm_torch_vs_cuda(hidden_size=hidden_size, with_in_state=True)
34
+
35
+
36
+ def test_complex():
37
+ run_slstm_torch_vs_cuda(
38
+ hidden_size=128, batch_size=2, sequence_length=8, num_heads=4, with_in_state=True, atol=1e-5
39
+ )
40
+
41
+
42
+ def test_long_sequence():
43
+ run_slstm_torch_vs_cuda(sequence_length=128, atol=1e-5)
44
+
45
+
46
+ def set_seed(seed):
47
+ torch.use_deterministic_algorithms(True)
48
+ torch.manual_seed(seed)
49
+ torch.cuda.manual_seed_all(seed)
50
+
51
+
52
+ def run_slstm_torch_vs_cuda(
53
+ batch_size=1, sequence_length=1, with_in_state=False, num_heads=4, hidden_size=64, rtol=1.3e-6, atol=1e-6
54
+ ):
55
+ device_cuda = "cuda"
56
+ config = sLSTMBlockConfig(embedding_dim=hidden_size, num_heads=num_heads)
57
+
58
+ set_seed(42)
59
+ recurrent_kernel_weight = torch.randn(
60
+ (config.num_heads, config.head_dim, config.num_gates * config.head_dim), dtype=torch.bfloat16
61
+ )
62
+ bias_weight = torch.randn((config.num_heads * config.num_gates * config.head_dim), dtype=torch.bfloat16)
63
+
64
+ cell_torch = sLSTMCell(copy.deepcopy(config), backend="torch")
65
+ cell_torch._recurrent_kernel_.data = recurrent_kernel_weight
66
+ cell_torch._bias_.data = bias_weight
67
+
68
+ cell_cuda = sLSTMCell(copy.deepcopy(config), backend="cuda").to(device_cuda)
69
+ cell_cuda._recurrent_kernel_.data = recurrent_kernel_weight.to(device_cuda)
70
+ cell_cuda._bias_.data = bias_weight.to(device_cuda)
71
+
72
+ set_seed(42)
73
+ current_input = torch.randn((batch_size, sequence_length, 4 * config.embedding_dim))
74
+ state = torch.randn((4, batch_size, hidden_size)) if with_in_state else None
75
+
76
+ output_torch, state_torch = cell_torch.forward(current_input, state)
77
+ output_cuda, state_cuda = cell_cuda.forward(
78
+ current_input.to(device_cuda), state.to(device_cuda) if state is not None else state
79
+ )
80
+
81
+ torch.testing.assert_close(output_torch, output_cuda.cpu(), rtol=rtol, atol=atol)
82
+ torch.testing.assert_close(state_torch, state_cuda.cpu(), rtol=rtol, atol=atol)