tirex-mirror 2025.9.9__py3-none-any.whl → 2025.9.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tirex/base.py CHANGED
@@ -15,6 +15,15 @@ def skip_cuda():
15
15
  return os.getenv("TIREX_NO_CUDA", "False").lower() in ("true", "1", "t")
16
16
 
17
17
 
18
+ def xlstm_available():
19
+ try:
20
+ from xlstm.blocks.slstm.cell import sLSTMCellConfig, sLSTMCellFuncGenerator
21
+
22
+ return True
23
+ except ModuleNotFoundError:
24
+ return False
25
+
26
+
18
27
  def parse_hf_repo_id(path):
19
28
  parts = path.split("/")
20
29
  return "/".join(parts[0:2])
@@ -86,8 +95,7 @@ def load_model(
86
95
  """
87
96
 
88
97
  if backend is None:
89
- backend = "torch" if skip_cuda() else "cuda"
90
- assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
98
+ backend = "torch" if skip_cuda() or not xlstm_available() else "cuda"
91
99
 
92
100
  try:
93
101
  _, model_id = parse_hf_repo_id(path).split("/")
@@ -3,6 +3,7 @@
3
3
 
4
4
  import warnings
5
5
  from dataclasses import asdict, dataclass
6
+ from typing import Literal
6
7
 
7
8
  import torch
8
9
  import torch.nn as nn
@@ -15,10 +16,8 @@ from tirex.util import dataclass_from_dict
15
16
  class sLSTMBlockConfig:
16
17
  embedding_dim: int
17
18
  num_heads: int
18
- num_blocks: int
19
19
  ffn_proj_factor: float = 2.6667
20
-
21
- num_states: int = 4 # this is for the sLSTM, a standard LSTM has 2
20
+ num_states: int = 4
22
21
  num_gates: int = 4
23
22
 
24
23
  @property
@@ -27,8 +26,9 @@ class sLSTMBlockConfig:
27
26
 
28
27
 
29
28
  class sLSTMCell(nn.Module):
30
- def __init__(self, config: sLSTMBlockConfig, backend: str):
29
+ def __init__(self, config: sLSTMBlockConfig, backend: Literal["torch", "cuda"]):
31
30
  super().__init__()
31
+ assert backend in ["torch", "cuda"], f"Backend can either be torch or cuda, not {backend}!"
32
32
  self.config = config
33
33
  self.backend = backend
34
34
 
tirex/models/tirex.py CHANGED
@@ -33,10 +33,10 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
33
33
  super().__init__()
34
34
  self.config = TiRexZeroConfig(**model_config, train_ctx_len=train_ctx_len, nan_mask_value=0)
35
35
  assert self.config.input_patch_size == self.config.output_patch_size
36
- self.backend = backend
37
36
 
38
37
  self.tokenizer = PatchedUniTokenizer(patch_size=self.config.input_patch_size)
39
38
 
39
+ num_blocks = self.config.block_kwargs["num_blocks"]
40
40
  block_config = dataclass_from_dict(sLSTMBlockConfig, self.config.block_kwargs)
41
41
  self.input_patch_embedding = ResidualBlock(
42
42
  in_dim=self.config.input_patch_size * 2,
@@ -44,9 +44,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
44
44
  out_dim=block_config.embedding_dim,
45
45
  )
46
46
 
47
- self.blocks = nn.ModuleList(
48
- [sLSTMBlock(block_config, backend=self.backend) for i in range(block_config.num_blocks)]
49
- )
47
+ self.blocks = nn.ModuleList([sLSTMBlock(block_config, backend) for i in range(num_blocks)])
50
48
 
51
49
  self.out_norm = RMSNorm(block_config.embedding_dim)
52
50
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.9.9
3
+ Version: 2025.9.24
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -64,7 +64,6 @@ License-File: LICENSE
64
64
  License-File: LICENSE_MIRROR.txt
65
65
  License-File: NOTICE.txt
66
66
  Requires-Dist: torch
67
- Requires-Dist: torchvision
68
67
  Requires-Dist: einops
69
68
  Requires-Dist: huggingface-hub
70
69
  Requires-Dist: numpy
@@ -81,7 +80,8 @@ Requires-Dist: gluonts; extra == "gluonts"
81
80
  Provides-Extra: hfdataset
82
81
  Requires-Dist: datasets; extra == "hfdataset"
83
82
  Provides-Extra: test
84
- Requires-Dist: fev; extra == "test"
83
+ Requires-Dist: fev>=0.6.0; extra == "test"
84
+ Requires-Dist: pytest; extra == "test"
85
85
  Provides-Extra: all
86
86
  Requires-Dist: xlstm; extra == "all"
87
87
  Requires-Dist: ninja; extra == "all"
@@ -89,7 +89,8 @@ Requires-Dist: ipykernel; extra == "all"
89
89
  Requires-Dist: matplotlib; extra == "all"
90
90
  Requires-Dist: gluonts; extra == "all"
91
91
  Requires-Dist: datasets; extra == "all"
92
- Requires-Dist: fev; extra == "all"
92
+ Requires-Dist: pytest; extra == "all"
93
+ Requires-Dist: fev>=0.6.0; extra == "all"
93
94
  Dynamic: license-file
94
95
 
95
96
  # tirex-mirror
@@ -1,5 +1,5 @@
1
1
  tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
2
- tirex/base.py,sha256=ODUyhYFR33ZYffu7dxDwsb9m2IiZAnGHIXvA81crbjQ,3245
2
+ tirex/base.py,sha256=u_fcwaIKEzq9aAt3UWqH8QvaqXG7qEykLNaP_opY26M,3366
3
3
  tirex/util.py,sha256=7DFVBXwGQA4niT9VhYbt8iKMBINJVW4LfwwpggFS0Us,469
4
4
  tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
5
5
  tirex/api_adapter/forecast.py,sha256=snv0sT1_1WzjkhP1YV-I7CMQmSChl93qFc3b6fwUAS0,8502
@@ -8,14 +8,14 @@ tirex/api_adapter/hf_data.py,sha256=T1eaxqC3OO9yOzIvw4sr55x6iA2AHKJTZd36rROM4fQ,
8
8
  tirex/api_adapter/standard_adapter.py,sha256=bI3XGYlWQu5EDyhDZyYqOJMbwi5h1aovPQvfHuWETJk,2618
9
9
  tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
10
10
  tirex/models/patcher.py,sha256=EOXFkHsPkq0nuxRNLAbnrgJtcYq0IMC3YIg_16WArg4,3213
11
- tirex/models/tirex.py,sha256=dclEckb6CmvESeX_LwT2kaCNTB7deTFovIOQUIFF5J8,9117
11
+ tirex/models/tirex.py,sha256=Kglea86t_f3nXXHSjFgssxxrd1Qbwfr1eB_5gKfWYxM,9098
12
12
  tirex/models/slstm/block.py,sha256=DCOxmLQUb7HRO6wXTZMK4ICUI5LFpo7NC5a28oM-Vsc,2104
13
- tirex/models/slstm/cell.py,sha256=4_pQcXOOT16aEpKIi4A-yEnj4qKK6pFyFADD2nGPzGc,7366
13
+ tirex/models/slstm/cell.py,sha256=XWsn8I7HrUoMrUrfRCpl6Q88xbBz67bKEkdZ8gXE3hY,7444
14
14
  tirex/models/slstm/layer.py,sha256=93CAYuG-HmUpF7mBAQ-z1S1u2__W10EW5jPToR57qqM,2747
15
- tirex_mirror-2025.9.9.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
- tirex_mirror-2025.9.9.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
- tirex_mirror-2025.9.9.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
- tirex_mirror-2025.9.9.dist-info/METADATA,sha256=u9C_cIb8FtaHUep1XrFTeI7UAsVRtNJt2VSQo7420Vo,11200
19
- tirex_mirror-2025.9.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- tirex_mirror-2025.9.9.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
- tirex_mirror-2025.9.9.dist-info/RECORD,,
15
+ tirex_mirror-2025.9.24.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
+ tirex_mirror-2025.9.24.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
+ tirex_mirror-2025.9.24.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
+ tirex_mirror-2025.9.24.dist-info/METADATA,sha256=It43qLOQvebhPBgCcF5gx0GK70vhdj6aoJptqB3oeNQ,11265
19
+ tirex_mirror-2025.9.24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ tirex_mirror-2025.9.24.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
+ tirex_mirror-2025.9.24.dist-info/RECORD,,