tirex-mirror 2025.11.13__tar.gz → 2025.11.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/PKG-INFO +1 -1
  2. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/pyproject.toml +1 -1
  3. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/base.py +19 -4
  4. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/models/slstm/cell.py +2 -2
  5. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex_mirror.egg-info/PKG-INFO +1 -1
  6. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex_mirror.egg-info/SOURCES.txt +1 -0
  7. tirex_mirror-2025.11.15/tests/test_load_model.py +67 -0
  8. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/LICENSE +0 -0
  9. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/LICENSE_MIRROR.txt +0 -0
  10. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/MANIFEST.in +0 -0
  11. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/NOTICE.txt +0 -0
  12. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/README.md +0 -0
  13. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/setup.cfg +0 -0
  14. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/__init__.py +0 -0
  15. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/api_adapter/__init__.py +0 -0
  16. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/api_adapter/forecast.py +0 -0
  17. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/api_adapter/gluon.py +0 -0
  18. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/api_adapter/hf_data.py +0 -0
  19. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/api_adapter/standard_adapter.py +0 -0
  20. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/models/__init__.py +0 -0
  21. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/models/patcher.py +0 -0
  22. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/models/slstm/block.py +0 -0
  23. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/models/slstm/layer.py +0 -0
  24. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/models/tirex.py +0 -0
  25. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex/util.py +0 -0
  26. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  27. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex_mirror.egg-info/requires.txt +0 -0
  28. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  29. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_chronos_zs.py +0 -0
  30. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_compile.py +0 -0
  31. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_forecast.py +0 -0
  32. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_forecast_adapter.py +0 -0
  33. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_patcher.py +0 -0
  34. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_slstm_torch_vs_cuda.py +0 -0
  35. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_standard_adapter.py +0 -0
  36. {tirex_mirror-2025.11.13 → tirex_mirror-2025.11.15}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.13
3
+ Version: 2025.11.15
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.11.13"
3
+ version = "2025.11.15"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -3,6 +3,7 @@
3
3
 
4
4
  import logging
5
5
  import os
6
+ import warnings
6
7
  from abc import ABC, abstractmethod
7
8
  from typing import Literal, TypeVar
8
9
 
@@ -60,7 +61,7 @@ class PretrainedModel(ABC):
60
61
  if ckp_kwargs is None:
61
62
  ckp_kwargs = {}
62
63
  if device is None:
63
- device = "cuda:0" if backend == "cuda" else "cpu"
64
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
64
65
  if os.path.exists(path):
65
66
  print("Loading weights from local directory")
66
67
  checkpoint_path = path
@@ -93,7 +94,7 @@ class PretrainedModel(ABC):
93
94
  def load_model(
94
95
  path: str,
95
96
  device: str | None = None,
96
- backend: Literal["torch", "cuda"] | None = None,
97
+ backend: Literal["torch", "cuda"] = "torch",
97
98
  compile: bool = False,
98
99
  hf_kwargs=None,
99
100
  ckp_kwargs=None,
@@ -115,8 +116,22 @@ def load_model(
115
116
  model: ForecastModel = load_model("NX-AI/TiRex")
116
117
  """
117
118
 
118
- if backend is None:
119
- backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
119
+ if backend == "cuda" and not xlstm_available() and torch.cuda.is_available():
120
+ backend = "torch"
121
+ warnings.warn(
122
+ "Switching to 'torch' backend with device='gpu'. In oder to use a CUDA backend, please make sure that xlstm package is installed.",
123
+ UserWarning,
124
+ stacklevel=2,
125
+ )
126
+
127
+ if device is not None and device.startswith("cuda") and not torch.cuda.is_available():
128
+ raise ValueError(
129
+ "CUDA is not available! This could be because:\n"
130
+ " - No GPU is present on this machine\n"
131
+ " - GPU drivers are not installed or not functioning\n"
132
+ " - PyTorch is installed without CUDA support (CPU-only version)\n"
133
+ "To resolve: use device='cpu' for CPU inference, or install CUDA-enabled PyTorch if a GPU is available."
134
+ )
120
135
 
121
136
  try:
122
137
  _, model_string = parse_hf_repo_id(path).split("/")
@@ -70,8 +70,8 @@ class sLSTMCell(nn.Module):
70
70
  if input.device.type != "cuda":
71
71
  warnings.warn(
72
72
  f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {input.device.type})!"
73
- "This is not supported and calls to the model will likely lead to an error if you dont move your model to a CUDA device!"
74
- "If you want to run TiRex on CPU you need to disable sLSTM CUDA kernels but be aware of the downsides (see FAQ)"
73
+ "This is not supported and calls to the model will likely lead to an error if you don't move your model to a CUDA device!"
74
+ "If you want to run TiRex on CPU, please select backend='torch' and device='cpu'"
75
75
  )
76
76
 
77
77
  if not hasattr(self, "func"):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.13
3
+ Version: 2025.11.15
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -27,6 +27,7 @@ tests/test_chronos_zs.py
27
27
  tests/test_compile.py
28
28
  tests/test_forecast.py
29
29
  tests/test_forecast_adapter.py
30
+ tests/test_load_model.py
30
31
  tests/test_patcher.py
31
32
  tests/test_slstm_torch_vs_cuda.py
32
33
  tests/test_standard_adapter.py
@@ -0,0 +1,67 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import warnings
5
+
6
+ import pytest
7
+ import torch
8
+
9
+ from tirex import load_model
10
+
11
+
12
+ def test_load_model_with_default_parameters():
13
+ with warnings.catch_warnings(record=True) as w:
14
+ warnings.simplefilter("always")
15
+ model = load_model("NX-AI/TiRex")
16
+
17
+ assert model is not None
18
+ assert model.blocks[0].slstm_layer.slstm_cell.backend == "torch"
19
+
20
+ context = torch.randn(1, 64)
21
+ _, _ = model.forecast(context, prediction_length=32)
22
+ assert len(w) == 0 # no warnings check
23
+
24
+
25
+ def test_load_model_with_cpu_device_no_warning():
26
+ with warnings.catch_warnings(record=True) as w:
27
+ warnings.simplefilter("always")
28
+ model = load_model("NX-AI/TiRex", device="cpu")
29
+
30
+ assert model is not None
31
+ assert model.blocks[0].slstm_layer.slstm_cell.backend == "torch"
32
+ context = torch.randn(1, 64)
33
+ _, _ = model.forecast(context, prediction_length=32)
34
+ assert len(w) == 0 # no warnings check
35
+
36
+
37
+ def test_load_model_cuda_device_without_cuda_raises_error():
38
+ if torch.cuda.is_available():
39
+ pytest.skip("CUDA is available, skipping no-CUDA test")
40
+
41
+ with pytest.raises(ValueError) as exc_info:
42
+ model = load_model("NX-AI/TiRex", device="cuda:0")
43
+
44
+ error_msg = str(exc_info.value)
45
+ assert "CUDA is not available" in error_msg
46
+ assert "No GPU is present" in error_msg
47
+ assert "PyTorch is installed without CUDA support (CPU-only version)" in error_msg
48
+ assert "To resolve: use device='cpu'" in error_msg
49
+
50
+
51
+ def test_load_model_with_mps_device():
52
+ if not torch.backends.mps.is_available():
53
+ pytest.skip("MPS is not available, skipping MPS test")
54
+
55
+ with warnings.catch_warnings(record=True) as w:
56
+ warnings.simplefilter("always")
57
+ model = load_model("NX-AI/TiRex", device="mps")
58
+
59
+ assert model is not None
60
+ assert model.blocks[0].slstm_layer.slstm_cell.backend == "torch"
61
+
62
+ device = next(model.parameters()).device
63
+ assert device.type == "mps"
64
+
65
+ context = torch.randn(1, 64).to("mps")
66
+ _, _ = model.forecast(context, prediction_length=32)
67
+ assert len(w) == 0 # no warnings check