tirex-mirror 2025.11.14__tar.gz → 2025.11.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/PKG-INFO +1 -1
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/pyproject.toml +1 -1
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/base.py +19 -4
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/models/slstm/cell.py +2 -2
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex_mirror.egg-info/PKG-INFO +1 -1
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex_mirror.egg-info/SOURCES.txt +1 -0
- tirex_mirror-2025.11.18/tests/test_load_model.py +67 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/LICENSE +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/MANIFEST.in +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/NOTICE.txt +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/README.md +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/setup.cfg +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/__init__.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/api_adapter/__init__.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/api_adapter/forecast.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/api_adapter/gluon.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/api_adapter/hf_data.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/api_adapter/standard_adapter.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/models/__init__.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/models/patcher.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/models/slstm/block.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/models/slstm/layer.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/models/tirex.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/util.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex_mirror.egg-info/requires.txt +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex_mirror.egg-info/top_level.txt +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_chronos_zs.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_compile.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_forecast.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_forecast_adapter.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_patcher.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_slstm_torch_vs_cuda.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_standard_adapter.py +0 -0
- {tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/tests/test_util_freq.py +0 -0
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
import os
|
|
6
|
+
import warnings
|
|
6
7
|
from abc import ABC, abstractmethod
|
|
7
8
|
from typing import Literal, TypeVar
|
|
8
9
|
|
|
@@ -60,7 +61,7 @@ class PretrainedModel(ABC):
|
|
|
60
61
|
if ckp_kwargs is None:
|
|
61
62
|
ckp_kwargs = {}
|
|
62
63
|
if device is None:
|
|
63
|
-
device = "cuda:0" if
|
|
64
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
64
65
|
if os.path.exists(path):
|
|
65
66
|
print("Loading weights from local directory")
|
|
66
67
|
checkpoint_path = path
|
|
@@ -93,7 +94,7 @@ class PretrainedModel(ABC):
|
|
|
93
94
|
def load_model(
|
|
94
95
|
path: str,
|
|
95
96
|
device: str | None = None,
|
|
96
|
-
backend: Literal["torch", "cuda"]
|
|
97
|
+
backend: Literal["torch", "cuda"] = "torch",
|
|
97
98
|
compile: bool = False,
|
|
98
99
|
hf_kwargs=None,
|
|
99
100
|
ckp_kwargs=None,
|
|
@@ -115,8 +116,22 @@ def load_model(
|
|
|
115
116
|
model: ForecastModel = load_model("NX-AI/TiRex")
|
|
116
117
|
"""
|
|
117
118
|
|
|
118
|
-
if backend
|
|
119
|
-
backend = "torch"
|
|
119
|
+
if backend == "cuda" and not xlstm_available() and torch.cuda.is_available():
|
|
120
|
+
backend = "torch"
|
|
121
|
+
warnings.warn(
|
|
122
|
+
"Switching to 'torch' backend with device='gpu'. In oder to use a CUDA backend, please make sure that xlstm package is installed.",
|
|
123
|
+
UserWarning,
|
|
124
|
+
stacklevel=2,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if device is not None and device.startswith("cuda") and not torch.cuda.is_available():
|
|
128
|
+
raise ValueError(
|
|
129
|
+
"CUDA is not available! This could be because:\n"
|
|
130
|
+
" - No GPU is present on this machine\n"
|
|
131
|
+
" - GPU drivers are not installed or not functioning\n"
|
|
132
|
+
" - PyTorch is installed without CUDA support (CPU-only version)\n"
|
|
133
|
+
"To resolve: use device='cpu' for CPU inference, or install CUDA-enabled PyTorch if a GPU is available."
|
|
134
|
+
)
|
|
120
135
|
|
|
121
136
|
try:
|
|
122
137
|
_, model_string = parse_hf_repo_id(path).split("/")
|
|
@@ -70,8 +70,8 @@ class sLSTMCell(nn.Module):
|
|
|
70
70
|
if input.device.type != "cuda":
|
|
71
71
|
warnings.warn(
|
|
72
72
|
f"You use TiRex with sLSTM CUDA kernels BUT DO NOT LOAD THE DEVICE ON A CUDA DEVICE (device type is {input.device.type})!"
|
|
73
|
-
"This is not supported and calls to the model will likely lead to an error if you
|
|
74
|
-
"If you want to run TiRex on CPU
|
|
73
|
+
"This is not supported and calls to the model will likely lead to an error if you don't move your model to a CUDA device!"
|
|
74
|
+
"If you want to run TiRex on CPU, please select backend='torch' and device='cpu'"
|
|
75
75
|
)
|
|
76
76
|
|
|
77
77
|
if not hasattr(self, "func"):
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from tirex import load_model
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def test_load_model_with_default_parameters():
|
|
13
|
+
with warnings.catch_warnings(record=True) as w:
|
|
14
|
+
warnings.simplefilter("always")
|
|
15
|
+
model = load_model("NX-AI/TiRex")
|
|
16
|
+
|
|
17
|
+
assert model is not None
|
|
18
|
+
assert model.blocks[0].slstm_layer.slstm_cell.backend == "torch"
|
|
19
|
+
|
|
20
|
+
context = torch.randn(1, 64)
|
|
21
|
+
_, _ = model.forecast(context, prediction_length=32)
|
|
22
|
+
assert len(w) == 0 # no warnings check
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_load_model_with_cpu_device_no_warning():
|
|
26
|
+
with warnings.catch_warnings(record=True) as w:
|
|
27
|
+
warnings.simplefilter("always")
|
|
28
|
+
model = load_model("NX-AI/TiRex", device="cpu")
|
|
29
|
+
|
|
30
|
+
assert model is not None
|
|
31
|
+
assert model.blocks[0].slstm_layer.slstm_cell.backend == "torch"
|
|
32
|
+
context = torch.randn(1, 64)
|
|
33
|
+
_, _ = model.forecast(context, prediction_length=32)
|
|
34
|
+
assert len(w) == 0 # no warnings check
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def test_load_model_cuda_device_without_cuda_raises_error():
|
|
38
|
+
if torch.cuda.is_available():
|
|
39
|
+
pytest.skip("CUDA is available, skipping no-CUDA test")
|
|
40
|
+
|
|
41
|
+
with pytest.raises(ValueError) as exc_info:
|
|
42
|
+
model = load_model("NX-AI/TiRex", device="cuda:0")
|
|
43
|
+
|
|
44
|
+
error_msg = str(exc_info.value)
|
|
45
|
+
assert "CUDA is not available" in error_msg
|
|
46
|
+
assert "No GPU is present" in error_msg
|
|
47
|
+
assert "PyTorch is installed without CUDA support (CPU-only version)" in error_msg
|
|
48
|
+
assert "To resolve: use device='cpu'" in error_msg
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_load_model_with_mps_device():
|
|
52
|
+
if not torch.backends.mps.is_available():
|
|
53
|
+
pytest.skip("MPS is not available, skipping MPS test")
|
|
54
|
+
|
|
55
|
+
with warnings.catch_warnings(record=True) as w:
|
|
56
|
+
warnings.simplefilter("always")
|
|
57
|
+
model = load_model("NX-AI/TiRex", device="mps")
|
|
58
|
+
|
|
59
|
+
assert model is not None
|
|
60
|
+
assert model.blocks[0].slstm_layer.slstm_cell.backend == "torch"
|
|
61
|
+
|
|
62
|
+
device = next(model.parameters()).device
|
|
63
|
+
assert device.type == "mps"
|
|
64
|
+
|
|
65
|
+
context = torch.randn(1, 64).to("mps")
|
|
66
|
+
_, _ = model.forecast(context, prediction_length=32)
|
|
67
|
+
assert len(w) == 0 # no warnings check
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex/api_adapter/standard_adapter.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.11.14 → tirex_mirror-2025.11.18}/src/tirex_mirror.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|