tirex-mirror 2025.11.28__tar.gz → 2025.12.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/PKG-INFO +4 -2
  2. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/pyproject.toml +4 -4
  3. tirex_mirror-2025.12.2/src/tirex/models/classification/__init__.py +8 -0
  4. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/classification/embedding.py +7 -4
  5. tirex_mirror-2025.12.2/src/tirex/models/classification/heads/base_classifier.py +93 -0
  6. tirex_mirror-2025.12.2/src/tirex/models/classification/heads/gbm_classifier.py +210 -0
  7. {tirex_mirror-2025.11.28/src/tirex/models/classification → tirex_mirror-2025.12.2/src/tirex/models/classification/heads}/linear_classifier.py +17 -15
  8. {tirex_mirror-2025.11.28/src/tirex/models/classification → tirex_mirror-2025.12.2/src/tirex/models/classification/heads}/rf_classifier.py +17 -42
  9. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/classification/trainer.py +3 -0
  10. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/classification/utils.py +3 -0
  11. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex_mirror.egg-info/PKG-INFO +4 -2
  12. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex_mirror.egg-info/SOURCES.txt +5 -2
  13. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex_mirror.egg-info/requires.txt +2 -0
  14. tirex_mirror-2025.12.2/tests/test_gbm_classifier.py +238 -0
  15. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_linear_classifier.py +101 -0
  16. tirex_mirror-2025.12.2/tests/test_rf_classifier.py +220 -0
  17. tirex_mirror-2025.11.28/src/tirex/models/classification/__init__.py +0 -8
  18. tirex_mirror-2025.11.28/tests/test_rf_classifier.py +0 -123
  19. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/LICENSE +0 -0
  20. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/LICENSE_MIRROR.txt +0 -0
  21. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/MANIFEST.in +0 -0
  22. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/NOTICE.txt +0 -0
  23. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/README.md +0 -0
  24. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/setup.cfg +0 -0
  25. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/__init__.py +0 -0
  26. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/api_adapter/__init__.py +0 -0
  27. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/api_adapter/forecast.py +0 -0
  28. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/api_adapter/gluon.py +0 -0
  29. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/api_adapter/hf_data.py +0 -0
  30. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/api_adapter/standard_adapter.py +0 -0
  31. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/base.py +0 -0
  32. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/__init__.py +0 -0
  33. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/patcher.py +0 -0
  34. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/slstm/block.py +0 -0
  35. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/slstm/cell.py +0 -0
  36. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/slstm/layer.py +0 -0
  37. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/models/tirex.py +0 -0
  38. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex/util.py +0 -0
  39. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  40. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  41. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_chronos_zs.py +0 -0
  42. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_compile.py +0 -0
  43. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_embedding.py +0 -0
  44. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_forecast.py +0 -0
  45. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_forecast_adapter.py +0 -0
  46. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_load_model.py +0 -0
  47. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_patcher.py +0 -0
  48. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_slstm_torch_vs_cuda.py +0 -0
  49. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_standard_adapter.py +0 -0
  50. {tirex_mirror-2025.11.28 → tirex_mirror-2025.12.2}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.28
3
+ Version: 2025.12.2
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -55,7 +55,7 @@ License: NXAI COMMUNITY LICENSE AGREEMENT
55
55
 
56
56
  Project-URL: Repository, https://github.com/rozsasarpi/tirex-mirror
57
57
  Project-URL: Issues, https://github.com/rozsasarpi/tirex-mirror/issues
58
- Keywords: TiRex,xLSTM,Time Series,Zero-shot,Deep Learning
58
+ Keywords: TiRex,xLSTM,Time Series,Zero-shot,Deep Learning,Classification,Timeseries-Classification
59
59
  Classifier: Programming Language :: Python :: 3
60
60
  Classifier: Operating System :: OS Independent
61
61
  Requires-Python: >=3.11
@@ -86,6 +86,7 @@ Requires-Dist: fev>=0.6.0; extra == "test"
86
86
  Requires-Dist: pytest; extra == "test"
87
87
  Provides-Extra: classification
88
88
  Requires-Dist: scikit-learn; extra == "classification"
89
+ Requires-Dist: lightgbm[scikit-learn]; extra == "classification"
89
90
  Provides-Extra: all
90
91
  Requires-Dist: xlstm; extra == "all"
91
92
  Requires-Dist: ninja; extra == "all"
@@ -98,6 +99,7 @@ Requires-Dist: datasets; extra == "all"
98
99
  Requires-Dist: pytest; extra == "all"
99
100
  Requires-Dist: fev>=0.6.0; extra == "all"
100
101
  Requires-Dist: scikit-learn; extra == "all"
102
+ Requires-Dist: lightgbm[scikit-learn]; extra == "all"
101
103
  Dynamic: license-file
102
104
 
103
105
  # tirex-mirror
@@ -1,11 +1,11 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.11.28"
3
+ version = "2025.12.02"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
7
7
  classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent",]
8
- keywords = [ "TiRex", "xLSTM", "Time Series", "Zero-shot", "Deep Learning",]
8
+ keywords = [ "TiRex", "xLSTM", "Time Series", "Zero-shot", "Deep Learning", "Classification", "Timeseries-Classification",]
9
9
  dependencies = [ "torch", "huggingface-hub", "numpy",]
10
10
  [[project.authors]]
11
11
  name = "Arpad Rozsas"
@@ -29,8 +29,8 @@ plotting = [ "matplotlib",]
29
29
  gluonts = [ "gluonts", "pandas",]
30
30
  hfdataset = [ "datasets",]
31
31
  test = [ "fev>=0.6.0", "pytest",]
32
- classification = [ "scikit-learn",]
33
- all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "pandas", "python-dotenv", "gluonts", "datasets", "pytest", "fev>=0.6.0", "scikit-learn",]
32
+ classification = [ "scikit-learn", "lightgbm[scikit-learn]",]
33
+ all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "pandas", "python-dotenv", "gluonts", "datasets", "pytest", "fev>=0.6.0", "scikit-learn", "lightgbm[scikit-learn]",]
34
34
 
35
35
  [tool.docformatter]
36
36
  diff = false
@@ -0,0 +1,8 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from .heads.gbm_classifier import TirexGBMClassifier
5
+ from .heads.linear_classifier import TirexClassifierTorch
6
+ from .heads.rf_classifier import TirexRFClassifier
7
+
8
+ __all__ = ["TirexClassifierTorch", "TirexRFClassifier", "TirexGBMClassifier"]
@@ -1,6 +1,6 @@
1
- import inspect
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
2
3
 
3
- import numpy as np
4
4
  import torch
5
5
  import torch.nn as nn
6
6
  import torch.nn.functional as F
@@ -11,7 +11,9 @@ from .utils import nanmax, nanmin, nanstd
11
11
 
12
12
 
13
13
  class TiRexEmbedding(nn.Module):
14
- def __init__(self, device: str | None = None, data_augmentation: bool = False, batch_size: int = 512) -> None:
14
+ def __init__(
15
+ self, device: str | None = None, data_augmentation: bool = False, batch_size: int = 512, compile: bool = False
16
+ ) -> None:
15
17
  super().__init__()
16
18
  self.data_augmentation = data_augmentation
17
19
  self.number_of_patches = 8
@@ -20,8 +22,9 @@ class TiRexEmbedding(nn.Module):
20
22
  if device is None:
21
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
24
  self.device = device
25
+ self._compile = compile
23
26
 
24
- self.model = load_model(path="NX-AI/TiRex", device=self.device)
27
+ self.model = load_model(path="NX-AI/TiRex", device=self.device, compile=self._compile)
25
28
 
26
29
  def _gen_emb_batched(self, data: torch.Tensor) -> torch.Tensor:
27
30
  batches = list(torch.split(data, self.batch_size))
@@ -0,0 +1,93 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+
8
+ from ..embedding import TiRexEmbedding
9
+
10
+
11
+ class BaseTirexClassifier(ABC):
12
+ """Abstract base class for TiRex classification models.
13
+
14
+ This base class provides common functionality for all TiRex classifiers,
15
+ including embedding model initialization and a consistent interface.
16
+
17
+ """
18
+
19
+ def __init__(
20
+ self, data_augmentation: bool = False, device: str | None = None, compile: bool = False, batch_size: int = 512
21
+ ) -> None:
22
+ """Initializes a TiRex classification model.
23
+
24
+ Args:
25
+ data_augmentation : bool
26
+ Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
27
+ device : str | None
28
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
29
+ compile: bool
30
+ Whether to compile the frozen embedding model. Default: False
31
+ batch_size : int
32
+ Batch size for embedding calculations. Default: 512
33
+ """
34
+
35
+ # Set device
36
+ if device is None:
37
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
38
+ self.device = device
39
+ self._compile = compile
40
+
41
+ self.batch_size = batch_size
42
+ self.data_augmentation = data_augmentation
43
+ self.emb_model = TiRexEmbedding(
44
+ device=self.device,
45
+ data_augmentation=self.data_augmentation,
46
+ batch_size=self.batch_size,
47
+ compile=self._compile,
48
+ )
49
+
50
+ @abstractmethod
51
+ def fit(self, train_data: tuple[torch.Tensor, torch.Tensor]) -> None:
52
+ pass
53
+
54
+ @torch.inference_mode()
55
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
56
+ """Predict class labels for input time series data.
57
+
58
+ Args:
59
+ x: Input time series data as torch.Tensor with shape
60
+ (batch_size, num_variates, seq_len).
61
+ Returns:
62
+ torch.Tensor: Predicted class labels with shape (batch_size,).
63
+ """
64
+ self.emb_model.eval()
65
+ x = x.to(self.device)
66
+ embeddings = self.emb_model(x).cpu().numpy()
67
+ return torch.from_numpy(self.head.predict(embeddings)).long()
68
+
69
+ @torch.inference_mode()
70
+ def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
71
+ """Predict class probabilities for input time series data.
72
+
73
+ Args:
74
+ x: Input time series data as torch.Tensor with shape
75
+ (batch_size, num_variates, seq_len).
76
+ Returns:
77
+ torch.Tensor: Class probabilities with shape (batch_size, num_classes).
78
+ """
79
+ self.emb_model.eval()
80
+ x = x.to(self.device)
81
+ embeddings = self.emb_model(x).cpu().numpy()
82
+ return torch.from_numpy(self.head.predict_proba(embeddings))
83
+
84
+ @abstractmethod
85
+ def save_model(self, path: str) -> None:
86
+ """Saving model abstract method"""
87
+ pass
88
+
89
+ @classmethod
90
+ @abstractmethod
91
+ def load_model(cls, path: str):
92
+ """Loading model abstract method"""
93
+ pass
@@ -0,0 +1,210 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import joblib
5
+ import torch
6
+ from lightgbm import LGBMClassifier, early_stopping
7
+
8
+ from ..utils import train_val_split
9
+ from .base_classifier import BaseTirexClassifier
10
+
11
+
12
+ class TirexGBMClassifier(BaseTirexClassifier):
13
+ """
14
+ A Gradient Boosting classifier that uses time series embeddings as features.
15
+
16
+ This classifier combines a pre-trained embedding model for feature extraction with a
17
+ Gradient Boosting classifier.
18
+
19
+ Example:
20
+ >>> from tirex.models.classification import TirexGBMClassifier
21
+ >>>
22
+ >>> # Create model with custom LightGBM parameters
23
+ >>> model = TirexGBMClassifier(
24
+ ... data_augmentation=True,
25
+ ... n_estimators=50,
26
+ ... random_state=42
27
+ ... )
28
+ >>>
29
+ >>> # Prepare data (can use NumPy arrays or PyTorch tensors)
30
+ >>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
31
+ >>> y_train = torch.randint(0, 3, (100,)) # 3 classes
32
+ >>>
33
+ >>> # Train the model
34
+ >>> model.fit((X_train, y_train))
35
+ >>>
36
+ >>> # Make predictions
37
+ >>> X_test = torch.randn(20, 1, 128)
38
+ >>> predictions = model.predict(X_test)
39
+ >>> probabilities = model.predict_proba(X_test)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ data_augmentation: bool = False,
45
+ device: str | None = None,
46
+ compile: bool = False,
47
+ batch_size: int = 512,
48
+ early_stopping_rounds: int | None = 10,
49
+ min_delta: float = 0.0,
50
+ val_split_ratio: float = 0.2,
51
+ stratify: bool = True,
52
+ # LightGBM kwargs
53
+ **lgbm_kwargs,
54
+ ) -> None:
55
+ """Initializes Embedding Based Gradient Boosting Classification model.
56
+
57
+ Args:
58
+ data_augmentation : bool
59
+ Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
60
+ device : str | None
61
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
62
+ compile: bool
63
+ Whether to compile the frozen embedding model. Default: False
64
+ batch_size : int
65
+ Batch size for embedding calculations. Default: 512
66
+ early_stopping_rounds: int | None
67
+ Number of rounds without improvement of all metrics for Early Stopping. Default: 10
68
+ min_delta: float
69
+ Minimum improvement in score to keep training. Default 0.0
70
+ val_split_ratio : float
71
+ Proportion of training data to use for validation, if validation data are not provided. Default: 0.2
72
+ stratify : bool
73
+ Whether to stratify the train/validation split by class labels. Default: True
74
+ **lgbm_kwargs
75
+ Additional keyword arguments to pass to LightGBM's LGBMClassifier.
76
+ Common options include n_estimators, max_depth, learning_rate, random_state, etc.
77
+ """
78
+ super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
79
+
80
+ # Early Stopping callback
81
+ self.early_stopping_rounds = early_stopping_rounds
82
+ self.min_delta = min_delta
83
+
84
+ # Data split parameters:
85
+ self.val_split_ratio = val_split_ratio
86
+ self.stratify = stratify
87
+
88
+ # Extract random_state for train_val_split if provided
89
+ self.random_state = lgbm_kwargs.get("random_state", None)
90
+
91
+ self.head = LGBMClassifier(**lgbm_kwargs)
92
+
93
+ @torch.inference_mode()
94
+ def fit(
95
+ self,
96
+ train_data: tuple[torch.Tensor, torch.Tensor],
97
+ val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
98
+ ) -> None:
99
+ """Train the LightGBM classifier on embedded time series data.
100
+
101
+ This method generates embeddings for the training data using the embedding
102
+ model, then trains the LightGBM classifier on these embeddings.
103
+
104
+ Args:
105
+ train_data: Tuple of (X_train, y_train) where X_train is the input time
106
+ series data (torch.Tensor) and y_train is a torch.Tensor
107
+ of class labels.
108
+ val_data: Optional tuple of (X_val, y_val) for validation where X_train is the input time
109
+ series data (torch.Tensor) and y_train is a torch.Tensor
110
+ of class labels. If None, validation data will be automatically split from train_data (20% split).
111
+ """
112
+
113
+ (X_train, y_train), (X_val, y_val) = self._create_train_val_datasets(
114
+ train_data=train_data,
115
+ val_data=val_data,
116
+ val_split_ratio=self.val_split_ratio,
117
+ stratify=self.stratify,
118
+ seed=self.random_state,
119
+ )
120
+
121
+ self.emb_model.eval()
122
+ X_train = X_train.to(self.device)
123
+ X_val = X_val.to(self.device)
124
+ embeddings_train = self.emb_model(X_train).cpu().numpy()
125
+ embeddings_val = self.emb_model(X_val).cpu().numpy()
126
+
127
+ y_train = y_train.detach().cpu().numpy() if isinstance(y_train, torch.Tensor) else y_train
128
+ y_val = y_val.detach().cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val
129
+
130
+ self.head.fit(
131
+ embeddings_train,
132
+ y_train,
133
+ eval_set=[(embeddings_val, y_val)],
134
+ callbacks=[early_stopping(stopping_rounds=self.early_stopping_rounds, min_delta=self.min_delta)]
135
+ if self.early_stopping_rounds is not None
136
+ else None,
137
+ )
138
+
139
+ def save_model(self, path: str) -> None:
140
+ """This method saves the trained LightGBM classifier head (joblib format) and embedding information.
141
+
142
+ Args:
143
+ path: File path where the model should be saved (e.g., 'model.joblib').
144
+ """
145
+
146
+ payload = {
147
+ "data_augmentation": self.data_augmentation,
148
+ "compile": self._compile,
149
+ "batch_size": self.batch_size,
150
+ "early_stopping_rounds": self.early_stopping_rounds,
151
+ "min_delta": self.min_delta,
152
+ "val_split_ratio": self.val_split_ratio,
153
+ "stratify": self.stratify,
154
+ "head": self.head,
155
+ }
156
+ joblib.dump(payload, path)
157
+
158
+ @classmethod
159
+ def load_model(cls, path: str) -> "TirexGBMClassifier":
160
+ """Load a saved model from file.
161
+
162
+ This reconstructs the model with the embedding configuration and loads
163
+ the trained LightGBM classifier from a checkpoint file created by save_model().
164
+
165
+ Args:
166
+ path: File path to the saved model checkpoint.
167
+ Returns:
168
+ TirexGBMClassifier: The loaded model with trained Gradient Boosting, ready for inference.
169
+ """
170
+ checkpoint = joblib.load(path)
171
+
172
+ # Create new instance with saved configuration
173
+ model = cls(
174
+ data_augmentation=checkpoint["data_augmentation"],
175
+ compile=checkpoint["compile"],
176
+ batch_size=checkpoint["batch_size"],
177
+ early_stopping_rounds=checkpoint["early_stopping_rounds"],
178
+ min_delta=checkpoint["min_delta"],
179
+ val_split_ratio=checkpoint["val_split_ratio"],
180
+ stratify=checkpoint["stratify"],
181
+ )
182
+
183
+ # Load the trained LightGBM head
184
+ model.head = checkpoint["head"]
185
+
186
+ # Extract random_state from the loaded head if available
187
+ model.random_state = getattr(model.head, "random_state", None)
188
+
189
+ return model
190
+
191
+ def _create_train_val_datasets(
192
+ self,
193
+ train_data: tuple[torch.Tensor, torch.Tensor],
194
+ val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
195
+ val_split_ratio: float = 0.2,
196
+ stratify: bool = True,
197
+ seed: int | None = None,
198
+ ) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
199
+ X_train, y_train = train_data
200
+
201
+ if val_data is None:
202
+ train_data, val_data = train_val_split(
203
+ train_data=train_data, val_split_ratio=val_split_ratio, stratify=stratify, seed=seed
204
+ )
205
+ X_train, y_train = train_data
206
+ X_val, y_val = val_data
207
+ else:
208
+ X_val, y_val = val_data
209
+
210
+ return (X_train, y_train), (X_val, y_val)
@@ -1,12 +1,15 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
1
4
  from dataclasses import asdict
2
5
 
3
6
  import torch
4
7
 
5
- from .embedding import TiRexEmbedding
6
- from .trainer import TrainConfig, Trainer
8
+ from ..trainer import TrainConfig, Trainer
9
+ from .base_classifier import BaseTirexClassifier
7
10
 
8
11
 
9
- class TirexClassifierTorch(torch.nn.Module):
12
+ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
10
13
  """
11
14
  A PyTorch classifier that combines time series embeddings with a linear classification head.
12
15
 
@@ -18,7 +21,7 @@ class TirexClassifierTorch(torch.nn.Module):
18
21
  >>> import torch
19
22
  >>> from tirex.models.classification import TirexClassifierTorch
20
23
  >>>
21
- >>> # Create model with TIREX embeddings
24
+ >>> # Create model with TiRex embeddings
22
25
  >>> model = TirexClassifierTorch(
23
26
  ... data_augmentation=True,
24
27
  ... max_epochs=2,
@@ -43,8 +46,9 @@ class TirexClassifierTorch(torch.nn.Module):
43
46
  self,
44
47
  data_augmentation: bool = False,
45
48
  device: str | None = None,
49
+ compile: bool = False,
46
50
  # Training parameters
47
- max_epochs: int = 50,
51
+ max_epochs: int = 10,
48
52
  lr: float = 1e-4,
49
53
  weight_decay: float = 0.01,
50
54
  batch_size: int = 512,
@@ -62,9 +66,11 @@ class TirexClassifierTorch(torch.nn.Module):
62
66
 
63
67
  Args:
64
68
  data_augmentation : bool | None
65
- Whether to use data_augmentation for embeddings (stats and first-order differences of the original data). Default: False
69
+ Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
66
70
  device : str | None
67
- Device to run the model on. If None, uses CUDA if available, else CPU. Default: None
71
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
72
+ compile: bool
73
+ Whether to compile the frozen embedding model. Default: False
68
74
  max_epochs : int
69
75
  Maximum number of training epochs. Default: 50
70
76
  lr : float
@@ -91,15 +97,9 @@ class TirexClassifierTorch(torch.nn.Module):
91
97
  Dropout probability for the classification head. If None, no dropout is used. Default: None
92
98
  """
93
99
 
94
- super().__init__()
95
-
96
- if device is None:
97
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
98
- self.device = device
100
+ torch.nn.Module.__init__(self)
99
101
 
100
- # Create embedding model
101
- self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
102
- self.data_augmentation = data_augmentation
102
+ super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
103
103
 
104
104
  # Head parameters
105
105
  self.dropout = dropout
@@ -221,6 +221,7 @@ class TirexClassifierTorch(torch.nn.Module):
221
221
  {
222
222
  "head_state_dict": self.head.state_dict(), # need to save only head, embedding is frozen
223
223
  "data_augmentation": self.data_augmentation,
224
+ "compile": self._compile,
224
225
  "emb_dim": self.emb_dim,
225
226
  "num_classes": self.num_classes,
226
227
  "dropout": self.dropout,
@@ -248,6 +249,7 @@ class TirexClassifierTorch(torch.nn.Module):
248
249
 
249
250
  model = cls(
250
251
  data_augmentation=checkpoint["data_augmentation"],
252
+ compile=checkpoint["compile"],
251
253
  dropout=checkpoint["dropout"],
252
254
  max_epochs=train_config_dict.get("max_epochs", 50),
253
255
  lr=train_config_dict.get("lr", 1e-4),
@@ -1,12 +1,14 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
1
4
  import joblib
2
- import numpy as np
3
5
  import torch
4
6
  from sklearn.ensemble import RandomForestClassifier
5
7
 
6
- from .embedding import TiRexEmbedding
8
+ from .base_classifier import BaseTirexClassifier
7
9
 
8
10
 
9
- class TirexRFClassifier:
11
+ class TirexRFClassifier(BaseTirexClassifier):
10
12
  """
11
13
  A Random Forest classifier that uses time series embeddings as features.
12
14
 
@@ -15,7 +17,6 @@ class TirexRFClassifier:
15
17
  time series, which are then used to train the Random Forest.
16
18
 
17
19
  Example:
18
- >>> import numpy as np
19
20
  >>> from tirex.models.classification import TirexRFClassifier
20
21
  >>>
21
22
  >>> # Create model with custom Random Forest parameters
@@ -43,6 +44,7 @@ class TirexRFClassifier:
43
44
  self,
44
45
  data_augmentation: bool = False,
45
46
  device: str | None = None,
47
+ compile: bool = False,
46
48
  batch_size: int = 512,
47
49
  # Random Forest parameters
48
50
  **rf_kwargs,
@@ -51,24 +53,18 @@ class TirexRFClassifier:
51
53
 
52
54
  Args:
53
55
  data_augmentation : bool
54
- Whether to use data_augmentation for embeddings (stats and first-order differences of the original data). Default: False
56
+ Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
55
57
  device : str | None
56
58
  Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
59
+ compile: bool
60
+ Whether to compile the frozen embedding model. Default: False
57
61
  batch_size : int
58
62
  Batch size for embedding calculations. Default: 512
59
63
  **rf_kwargs
60
64
  Additional keyword arguments to pass to sklearn's RandomForestClassifier.
61
65
  Common options include n_estimators, max_depth, min_samples_split, random_state, etc.
62
66
  """
63
-
64
- # Set device
65
- if device is None:
66
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
67
- self.device = device
68
-
69
- self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
70
- self.data_augmentation = data_augmentation
71
-
67
+ super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
72
68
  self.head = RandomForestClassifier(**rf_kwargs)
73
69
 
74
70
  @torch.inference_mode()
@@ -80,7 +76,7 @@ class TirexRFClassifier:
80
76
 
81
77
  Args:
82
78
  train_data: Tuple of (X_train, y_train) where X_train is the input time
83
- series data (torch.Tensor) and y_train is a numpy array
79
+ series data (torch.Tensor) and y_train is a torch.Tensor
84
80
  of class labels.
85
81
  """
86
82
  X_train, y_train = train_data
@@ -88,36 +84,11 @@ class TirexRFClassifier:
88
84
  if isinstance(y_train, torch.Tensor):
89
85
  y_train = y_train.detach().cpu().numpy()
90
86
 
87
+ self.emb_model.eval()
88
+ X_train = X_train.to(self.device)
91
89
  embeddings = self.emb_model(X_train).cpu().numpy()
92
90
  self.head.fit(embeddings, y_train)
93
91
 
94
- @torch.inference_mode()
95
- def predict(self, x: torch.Tensor) -> torch.Tensor:
96
- """Predict class labels for input time series data.
97
-
98
- Args:
99
- x: Input time series data as torch.Tensor or np.ndarray with shape
100
- (batch_size, num_variates, seq_len).
101
- Returns:
102
- torch.Tensor: Predicted class labels with shape (batch_size,).
103
- """
104
-
105
- embeddings = self.emb_model(x).cpu().numpy()
106
- return torch.from_numpy(self.head.predict(embeddings)).long()
107
-
108
- @torch.inference_mode()
109
- def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
110
- """Predict class probabilities for input time series data.
111
-
112
- Args:
113
- x: Input time series data as torch.Tensor or np.ndarray with shape
114
- (batch_size, num_variates, seq_len).
115
- Returns:
116
- torch.Tensor: Class probabilities with shape (batch_size, num_classes).
117
- """
118
- embeddings = self.emb_model(x).cpu().numpy()
119
- return torch.from_numpy(self.head.predict_proba(embeddings))
120
-
121
92
  def save_model(self, path: str) -> None:
122
93
  """This method saves the trained Random Forest classifier head and embedding information in joblib format
123
94
 
@@ -126,6 +97,8 @@ class TirexRFClassifier:
126
97
  """
127
98
  payload = {
128
99
  "data_augmentation": self.data_augmentation,
100
+ "compile": self._compile,
101
+ "batch_size": self.batch_size,
129
102
  "head": self.head,
130
103
  }
131
104
  joblib.dump(payload, path)
@@ -147,6 +120,8 @@ class TirexRFClassifier:
147
120
  # Create new instance with saved configuration
148
121
  model = cls(
149
122
  data_augmentation=checkpoint["data_augmentation"],
123
+ compile=checkpoint["compile"],
124
+ batch_size=checkpoint["batch_size"],
150
125
  )
151
126
 
152
127
  # Load the trained Random Forest head
@@ -1,3 +1,6 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
1
4
  from dataclasses import dataclass
2
5
 
3
6
  import torch
@@ -1,3 +1,6 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
1
4
  import numpy as np
2
5
  import torch
3
6
  from sklearn.model_selection import train_test_split
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.28
3
+ Version: 2025.12.2
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -55,7 +55,7 @@ License: NXAI COMMUNITY LICENSE AGREEMENT
55
55
 
56
56
  Project-URL: Repository, https://github.com/rozsasarpi/tirex-mirror
57
57
  Project-URL: Issues, https://github.com/rozsasarpi/tirex-mirror/issues
58
- Keywords: TiRex,xLSTM,Time Series,Zero-shot,Deep Learning
58
+ Keywords: TiRex,xLSTM,Time Series,Zero-shot,Deep Learning,Classification,Timeseries-Classification
59
59
  Classifier: Programming Language :: Python :: 3
60
60
  Classifier: Operating System :: OS Independent
61
61
  Requires-Python: >=3.11
@@ -86,6 +86,7 @@ Requires-Dist: fev>=0.6.0; extra == "test"
86
86
  Requires-Dist: pytest; extra == "test"
87
87
  Provides-Extra: classification
88
88
  Requires-Dist: scikit-learn; extra == "classification"
89
+ Requires-Dist: lightgbm[scikit-learn]; extra == "classification"
89
90
  Provides-Extra: all
90
91
  Requires-Dist: xlstm; extra == "all"
91
92
  Requires-Dist: ninja; extra == "all"
@@ -98,6 +99,7 @@ Requires-Dist: datasets; extra == "all"
98
99
  Requires-Dist: pytest; extra == "all"
99
100
  Requires-Dist: fev>=0.6.0; extra == "all"
100
101
  Requires-Dist: scikit-learn; extra == "all"
102
+ Requires-Dist: lightgbm[scikit-learn]; extra == "all"
101
103
  Dynamic: license-file
102
104
 
103
105
  # tirex-mirror