tirex-mirror 2025.11.29__tar.gz → 2025.12.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/PKG-INFO +8 -3
  2. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/pyproject.toml +7 -6
  3. tirex_mirror-2025.12.16/src/tirex/models/base/base_classifier.py +36 -0
  4. tirex_mirror-2025.12.16/src/tirex/models/base/base_regressor.py +23 -0
  5. tirex_mirror-2025.12.16/src/tirex/models/base/base_tirex.py +109 -0
  6. tirex_mirror-2025.12.16/src/tirex/models/classification/__init__.py +8 -0
  7. tirex_mirror-2025.12.16/src/tirex/models/classification/gbm_classifier.py +188 -0
  8. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/classification/linear_classifier.py +25 -24
  9. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/classification/rf_classifier.py +18 -43
  10. {tirex_mirror-2025.11.29/src/tirex/models/classification → tirex_mirror-2025.12.16/src/tirex/models}/embedding.py +9 -7
  11. tirex_mirror-2025.12.16/src/tirex/models/regression/__init__.py +8 -0
  12. tirex_mirror-2025.12.16/src/tirex/models/regression/gbm_regressor.py +181 -0
  13. tirex_mirror-2025.12.16/src/tirex/models/regression/linear_regressor.py +250 -0
  14. tirex_mirror-2025.12.16/src/tirex/models/regression/rf_regressor.py +130 -0
  15. {tirex_mirror-2025.11.29/src/tirex/models/classification → tirex_mirror-2025.12.16/src/tirex/models}/trainer.py +32 -10
  16. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/util.py +82 -0
  17. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex_mirror.egg-info/PKG-INFO +8 -3
  18. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex_mirror.egg-info/SOURCES.txt +14 -3
  19. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex_mirror.egg-info/requires.txt +7 -1
  20. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_embedding.py +1 -3
  21. tirex_mirror-2025.12.16/tests/test_gbm_classifier.py +247 -0
  22. tirex_mirror-2025.12.16/tests/test_gbm_regressor.py +175 -0
  23. tirex_mirror-2025.12.16/tests/test_linear_classifier.py +290 -0
  24. tirex_mirror-2025.12.16/tests/test_linear_regressor.py +217 -0
  25. tirex_mirror-2025.12.16/tests/test_rf_classifier.py +228 -0
  26. tirex_mirror-2025.12.16/tests/test_rf_regressor.py +163 -0
  27. tirex_mirror-2025.11.29/src/tirex/models/classification/__init__.py +0 -8
  28. tirex_mirror-2025.11.29/src/tirex/models/classification/utils.py +0 -81
  29. tirex_mirror-2025.11.29/tests/test_linear_classifier.py +0 -170
  30. tirex_mirror-2025.11.29/tests/test_rf_classifier.py +0 -123
  31. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/LICENSE +0 -0
  32. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/LICENSE_MIRROR.txt +0 -0
  33. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/MANIFEST.in +0 -0
  34. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/NOTICE.txt +0 -0
  35. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/README.md +0 -0
  36. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/setup.cfg +0 -0
  37. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/__init__.py +0 -0
  38. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/api_adapter/__init__.py +0 -0
  39. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/api_adapter/forecast.py +0 -0
  40. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/api_adapter/gluon.py +0 -0
  41. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/api_adapter/hf_data.py +0 -0
  42. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/api_adapter/standard_adapter.py +0 -0
  43. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/base.py +0 -0
  44. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/__init__.py +0 -0
  45. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/patcher.py +0 -0
  46. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/slstm/block.py +0 -0
  47. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/slstm/cell.py +0 -0
  48. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/slstm/layer.py +0 -0
  49. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex/models/tirex.py +0 -0
  50. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  51. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  52. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_chronos_zs.py +0 -0
  53. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_compile.py +0 -0
  54. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_forecast.py +0 -0
  55. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_forecast_adapter.py +0 -0
  56. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_load_model.py +0 -0
  57. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_patcher.py +0 -0
  58. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_slstm_torch_vs_cuda.py +0 -0
  59. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_standard_adapter.py +0 -0
  60. {tirex_mirror-2025.11.29 → tirex_mirror-2025.12.16}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.29
3
+ Version: 2025.12.16
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -55,7 +55,7 @@ License: NXAI COMMUNITY LICENSE AGREEMENT
55
55
 
56
56
  Project-URL: Repository, https://github.com/rozsasarpi/tirex-mirror
57
57
  Project-URL: Issues, https://github.com/rozsasarpi/tirex-mirror/issues
58
- Keywords: TiRex,xLSTM,Time Series,Zero-shot,Deep Learning
58
+ Keywords: TiRex,xLSTM,Timeseries,Zero-shot,Deep Learning,Timeseries-Forecasting,Timeseries-Classification,Timeseries-Regression
59
59
  Classifier: Programming Language :: Python :: 3
60
60
  Classifier: Operating System :: OS Independent
61
61
  Requires-Python: >=3.11
@@ -66,6 +66,7 @@ License-File: NOTICE.txt
66
66
  Requires-Dist: torch
67
67
  Requires-Dist: huggingface-hub
68
68
  Requires-Dist: numpy
69
+ Requires-Dist: scikit-learn
69
70
  Provides-Extra: cuda
70
71
  Requires-Dist: xlstm; extra == "cuda"
71
72
  Requires-Dist: ninja; extra == "cuda"
@@ -84,8 +85,11 @@ Requires-Dist: datasets; extra == "hfdataset"
84
85
  Provides-Extra: test
85
86
  Requires-Dist: fev>=0.6.0; extra == "test"
86
87
  Requires-Dist: pytest; extra == "test"
88
+ Requires-Dist: aeon; extra == "test"
87
89
  Provides-Extra: classification
88
- Requires-Dist: scikit-learn; extra == "classification"
90
+ Requires-Dist: lightgbm[scikit-learn]; extra == "classification"
91
+ Provides-Extra: regression
92
+ Requires-Dist: lightgbm[scikit-learn]; extra == "regression"
89
93
  Provides-Extra: all
90
94
  Requires-Dist: xlstm; extra == "all"
91
95
  Requires-Dist: ninja; extra == "all"
@@ -98,6 +102,7 @@ Requires-Dist: datasets; extra == "all"
98
102
  Requires-Dist: pytest; extra == "all"
99
103
  Requires-Dist: fev>=0.6.0; extra == "all"
100
104
  Requires-Dist: scikit-learn; extra == "all"
105
+ Requires-Dist: lightgbm[scikit-learn]; extra == "all"
101
106
  Dynamic: license-file
102
107
 
103
108
  # tirex-mirror
@@ -1,12 +1,12 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.11.29"
3
+ version = "2025.12.16"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
7
7
  classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent",]
8
- keywords = [ "TiRex", "xLSTM", "Time Series", "Zero-shot", "Deep Learning",]
9
- dependencies = [ "torch", "huggingface-hub", "numpy",]
8
+ keywords = [ "TiRex", "xLSTM", "Timeseries", "Zero-shot", "Deep Learning", "Timeseries-Forecasting", "Timeseries-Classification", "Timeseries-Regression",]
9
+ dependencies = [ "torch", "huggingface-hub", "numpy", "scikit-learn",]
10
10
  [[project.authors]]
11
11
  name = "Arpad Rozsas"
12
12
  email = "rozsasarpi@gmail.com"
@@ -28,9 +28,10 @@ notebooks = [ "ipykernel", "matplotlib", "pandas", "python-dotenv",]
28
28
  plotting = [ "matplotlib",]
29
29
  gluonts = [ "gluonts", "pandas",]
30
30
  hfdataset = [ "datasets",]
31
- test = [ "fev>=0.6.0", "pytest",]
32
- classification = [ "scikit-learn",]
33
- all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "pandas", "python-dotenv", "gluonts", "datasets", "pytest", "fev>=0.6.0", "scikit-learn",]
31
+ test = [ "fev>=0.6.0", "pytest", "aeon",]
32
+ classification = [ "lightgbm[scikit-learn]",]
33
+ regression = [ "lightgbm[scikit-learn]",]
34
+ all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "pandas", "python-dotenv", "gluonts", "datasets", "pytest", "fev>=0.6.0", "scikit-learn", "lightgbm[scikit-learn]",]
34
35
 
35
36
  [tool.docformatter]
36
37
  diff = false
@@ -0,0 +1,36 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import torch
5
+
6
+ from .base_tirex import BaseTirexEmbeddingModel
7
+
8
+
9
+ class BaseTirexClassifier(BaseTirexEmbeddingModel):
10
+ """Abstract base class for TiRex classification models."""
11
+
12
+ @torch.inference_mode()
13
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
14
+ """Predict class labels for input time series data.
15
+
16
+ Args:
17
+ x: Input time series data as torch.Tensor with shape
18
+ (batch_size, num_variates, seq_len).
19
+ Returns:
20
+ torch.Tensor: Predicted class labels with shape (batch_size,).
21
+ """
22
+ emb = self._compute_embeddings(x)
23
+ return torch.from_numpy(self.head.predict(emb)).long()
24
+
25
+ @torch.inference_mode()
26
+ def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
27
+ """Predict class probabilities for input time series data.
28
+
29
+ Args:
30
+ x: Input time series data as torch.Tensor with shape
31
+ (batch_size, num_variates, seq_len).
32
+ Returns:
33
+ torch.Tensor: Class probabilities with shape (batch_size, num_classes).
34
+ """
35
+ emb = self._compute_embeddings(x)
36
+ return torch.from_numpy(self.head.predict_proba(emb))
@@ -0,0 +1,23 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import torch
5
+
6
+ from .base_tirex import BaseTirexEmbeddingModel
7
+
8
+
9
+ class BaseTirexRegressor(BaseTirexEmbeddingModel):
10
+ """Abstract base class for TiRex regression models."""
11
+
12
+ @torch.inference_mode()
13
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
14
+ """Predict values for input time series data.
15
+
16
+ Args:
17
+ x: Input time series data as torch.Tensor with shape
18
+ (batch_size, num_variates, seq_len).
19
+ Returns:
20
+ torch.Tensor: Predicted values with shape (batch_size,).
21
+ """
22
+ emb = self._compute_embeddings(x)
23
+ return torch.from_numpy(self.head.predict(emb)).float()
@@ -0,0 +1,109 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from abc import ABC, abstractmethod
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from tirex.util import train_val_split
10
+
11
+ from ..embedding import TiRexEmbedding
12
+
13
+
14
+ class BaseTirexEmbeddingModel(ABC):
15
+ """Abstract base class for TiRex models.
16
+
17
+ This base class provides common functionality for all TiRex classifier and regression models,
18
+ including embedding model initialization and a consistent interface.
19
+
20
+ """
21
+
22
+ def __init__(
23
+ self, data_augmentation: bool = False, device: str | None = None, compile: bool = False, batch_size: int = 512
24
+ ) -> None:
25
+ """Initializes a base TiRex model.
26
+
27
+ This base class initializes the embedding model and common configuration
28
+ used by both classification and regression models.
29
+
30
+ Args:
31
+ data_augmentation : bool
32
+ Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
33
+ device : str | None
34
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
35
+ compile: bool
36
+ Whether to compile the frozen embedding model. Default: False
37
+ batch_size : int
38
+ Batch size for embedding calculations. Default: 512
39
+ """
40
+
41
+ # Set device
42
+ if device is None:
43
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
44
+ self.device = device
45
+ self._compile = compile
46
+
47
+ self.batch_size = batch_size
48
+ self.data_augmentation = data_augmentation
49
+ self.emb_model = TiRexEmbedding(
50
+ device=self.device,
51
+ data_augmentation=self.data_augmentation,
52
+ batch_size=self.batch_size,
53
+ compile=self._compile,
54
+ )
55
+
56
+ @abstractmethod
57
+ def fit(self, train_data: tuple[torch.Tensor, torch.Tensor]) -> None:
58
+ """Abstract method for model training"""
59
+ pass
60
+
61
+ def _compute_embeddings(self, x: torch.Tensor) -> np.ndarray:
62
+ """Compute embeddings for input time series data.
63
+
64
+ Args:
65
+ x: Input time series data as torch.Tensor with shape
66
+ (batch_size, num_variates, seq_len).
67
+
68
+ Returns:
69
+ np.ndarray: Embeddings with shape (batch_size, embedding_dim).
70
+ """
71
+ self.emb_model.eval()
72
+ x = x.to(self.device)
73
+ return self.emb_model(x).cpu().numpy()
74
+
75
+ def _create_train_val_datasets(
76
+ self,
77
+ train_data: tuple[torch.Tensor, torch.Tensor],
78
+ val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
79
+ val_split_ratio: float = 0.2,
80
+ stratify: bool = False,
81
+ seed: int | None = None,
82
+ ) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
83
+ if val_data is None:
84
+ train_data, val_data = train_val_split(
85
+ train_data=train_data, val_split_ratio=val_split_ratio, stratify=stratify, seed=seed
86
+ )
87
+ return train_data, val_data
88
+
89
+ @abstractmethod
90
+ def save_model(self, path: str) -> None:
91
+ """Save model to file.
92
+
93
+ Args:
94
+ path: File path where the model should be saved.
95
+ """
96
+ pass
97
+
98
+ @classmethod
99
+ @abstractmethod
100
+ def load_model(cls, path: str):
101
+ """Load model from file.
102
+
103
+ Args:
104
+ path: File path to the saved model checkpoint.
105
+
106
+ Returns:
107
+ Instance of the model class with loaded weights and configuration.
108
+ """
109
+ pass
@@ -0,0 +1,8 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ from .gbm_classifier import TirexGBMClassifier
5
+ from .linear_classifier import TirexLinearClassifier
6
+ from .rf_classifier import TirexRFClassifier
7
+
8
+ __all__ = ["TirexLinearClassifier", "TirexRFClassifier", "TirexGBMClassifier"]
@@ -0,0 +1,188 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
4
+ import joblib
5
+ import torch
6
+ from lightgbm import LGBMClassifier, early_stopping
7
+
8
+ from ..base.base_classifier import BaseTirexClassifier
9
+
10
+
11
+ class TirexGBMClassifier(BaseTirexClassifier):
12
+ """
13
+ A Gradient Boosting classifier that uses time series embeddings as features.
14
+
15
+ This classifier combines a pre-trained embedding model for feature extraction with a
16
+ Gradient Boosting classifier.
17
+
18
+ Example:
19
+ >>> from tirex.models.classification import TirexGBMClassifier
20
+ >>>
21
+ >>> # Create model with custom LightGBM parameters
22
+ >>> model = TirexGBMClassifier(
23
+ ... data_augmentation=True,
24
+ ... n_estimators=50,
25
+ ... random_state=42
26
+ ... )
27
+ >>>
28
+ >>> # Prepare data (can use NumPy arrays or PyTorch tensors)
29
+ >>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
30
+ >>> y_train = torch.randint(0, 3, (100,)) # 3 classes
31
+ >>>
32
+ >>> # Train the model
33
+ >>> model.fit((X_train, y_train))
34
+ >>>
35
+ >>> # Make predictions
36
+ >>> X_test = torch.randn(20, 1, 128)
37
+ >>> predictions = model.predict(X_test)
38
+ >>> probabilities = model.predict_proba(X_test)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ data_augmentation: bool = False,
44
+ device: str | None = None,
45
+ compile: bool = False,
46
+ batch_size: int = 512,
47
+ early_stopping_rounds: int | None = 10,
48
+ min_delta: float = 0.0,
49
+ val_split_ratio: float = 0.2,
50
+ stratify: bool = True,
51
+ # LightGBM kwargs
52
+ **lgbm_kwargs,
53
+ ) -> None:
54
+ """Initializes Embedding Based Gradient Boosting Classification model.
55
+
56
+ Args:
57
+ data_augmentation : bool
58
+ Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
59
+ device : str | None
60
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
61
+ compile: bool
62
+ Whether to compile the frozen embedding model. Default: False
63
+ batch_size : int
64
+ Batch size for embedding calculations. Default: 512
65
+ early_stopping_rounds: int | None
66
+ Number of rounds without improvement of all metrics for Early Stopping. Default: 10
67
+ min_delta: float
68
+ Minimum improvement in score to keep training. Default 0.0
69
+ val_split_ratio : float
70
+ Proportion of training data to use for validation, if validation data are not provided. Default: 0.2
71
+ stratify : bool
72
+ Whether to stratify the train/validation split by class labels. Default: True
73
+ **lgbm_kwargs
74
+ Additional keyword arguments to pass to LightGBM's LGBMClassifier.
75
+ Common options include n_estimators, max_depth, learning_rate, random_state, etc.
76
+ """
77
+ super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
78
+
79
+ # Early Stopping callback
80
+ self.early_stopping_rounds = early_stopping_rounds
81
+ self.min_delta = min_delta
82
+
83
+ # Data split parameters:
84
+ self.val_split_ratio = val_split_ratio
85
+ self.stratify = stratify
86
+
87
+ # Extract random_state for train_val_split if provided
88
+ self.random_state = lgbm_kwargs.get("random_state", None)
89
+
90
+ self.head = LGBMClassifier(**lgbm_kwargs)
91
+
92
+ @torch.inference_mode()
93
+ def fit(
94
+ self,
95
+ train_data: tuple[torch.Tensor, torch.Tensor],
96
+ val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
97
+ ) -> None:
98
+ """Train the LightGBM classifier on embedded time series data.
99
+
100
+ This method generates embeddings for the training data using the embedding
101
+ model, then trains the LightGBM classifier on these embeddings.
102
+
103
+ Args:
104
+ train_data: Tuple of (X_train, y_train) where X_train is the input time
105
+ series data (torch.Tensor) and y_train is a torch.Tensor
106
+ of class labels.
107
+ val_data: Optional tuple of (X_val, y_val) for validation where X_train is the input time
108
+ series data (torch.Tensor) and y_train is a torch.Tensor
109
+ of class labels. If None, validation data will be automatically split from train_data (20% split).
110
+ """
111
+
112
+ (X_train, y_train), (X_val, y_val) = self._create_train_val_datasets(
113
+ train_data=train_data,
114
+ val_data=val_data,
115
+ val_split_ratio=self.val_split_ratio,
116
+ stratify=self.stratify,
117
+ seed=self.random_state,
118
+ )
119
+
120
+ X_train = X_train.to(self.device)
121
+ X_val = X_val.to(self.device)
122
+
123
+ embeddings_train = self._compute_embeddings(X_train)
124
+ embeddings_val = self._compute_embeddings(X_val)
125
+
126
+ y_train = y_train.detach().cpu().numpy() if isinstance(y_train, torch.Tensor) else y_train
127
+ y_val = y_val.detach().cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val
128
+
129
+ self.head.fit(
130
+ embeddings_train,
131
+ y_train,
132
+ eval_set=[(embeddings_val, y_val)],
133
+ callbacks=[early_stopping(stopping_rounds=self.early_stopping_rounds, min_delta=self.min_delta)]
134
+ if self.early_stopping_rounds is not None
135
+ else None,
136
+ )
137
+
138
+ def save_model(self, path: str) -> None:
139
+ """This method saves the trained LightGBM classifier head (joblib format) and embedding information.
140
+
141
+ Args:
142
+ path: File path where the model should be saved (e.g., 'model.joblib').
143
+ """
144
+
145
+ payload = {
146
+ "data_augmentation": self.data_augmentation,
147
+ "compile": self._compile,
148
+ "batch_size": self.batch_size,
149
+ "early_stopping_rounds": self.early_stopping_rounds,
150
+ "min_delta": self.min_delta,
151
+ "val_split_ratio": self.val_split_ratio,
152
+ "stratify": self.stratify,
153
+ "head": self.head,
154
+ }
155
+ joblib.dump(payload, path)
156
+
157
+ @classmethod
158
+ def load_model(cls, path: str) -> "TirexGBMClassifier":
159
+ """Load a saved model from file.
160
+
161
+ This reconstructs the model with the embedding configuration and loads
162
+ the trained LightGBM classifier from a checkpoint file created by save_model().
163
+
164
+ Args:
165
+ path: File path to the saved model checkpoint.
166
+ Returns:
167
+ TirexGBMClassifier: The loaded model with trained Gradient Boosting, ready for inference.
168
+ """
169
+ checkpoint = joblib.load(path)
170
+
171
+ # Create new instance with saved configuration
172
+ model = cls(
173
+ data_augmentation=checkpoint["data_augmentation"],
174
+ compile=checkpoint["compile"],
175
+ batch_size=checkpoint["batch_size"],
176
+ early_stopping_rounds=checkpoint["early_stopping_rounds"],
177
+ min_delta=checkpoint["min_delta"],
178
+ val_split_ratio=checkpoint["val_split_ratio"],
179
+ stratify=checkpoint["stratify"],
180
+ )
181
+
182
+ # Load the trained LightGBM head
183
+ model.head = checkpoint["head"]
184
+
185
+ # Extract random_state from the loaded head if available
186
+ model.random_state = getattr(model.head, "random_state", None)
187
+
188
+ return model
@@ -1,12 +1,15 @@
1
+ # Copyright (c) NXAI GmbH.
2
+ # This software may be used and distributed according to the terms of the NXAI Community License Agreement.
3
+
1
4
  from dataclasses import asdict
2
5
 
3
6
  import torch
4
7
 
5
- from .embedding import TiRexEmbedding
6
- from .trainer import TrainConfig, Trainer
8
+ from ..base.base_classifier import BaseTirexClassifier
9
+ from ..trainer import TrainConfig, Trainer, TrainingMetrics
7
10
 
8
11
 
9
- class TirexClassifierTorch(torch.nn.Module):
12
+ class TirexLinearClassifier(BaseTirexClassifier, torch.nn.Module):
10
13
  """
11
14
  A PyTorch classifier that combines time series embeddings with a linear classification head.
12
15
 
@@ -16,10 +19,10 @@ class TirexClassifierTorch(torch.nn.Module):
16
19
 
17
20
  Example:
18
21
  >>> import torch
19
- >>> from tirex.models.classification import TirexClassifierTorch
22
+ >>> from tirex.models.classification import TirexLinearClassifier
20
23
  >>>
21
- >>> # Create model with TIREX embeddings
22
- >>> model = TirexClassifierTorch(
24
+ >>> # Create model with TiRex embeddings
25
+ >>> model = TirexLinearClassifier(
23
26
  ... data_augmentation=True,
24
27
  ... max_epochs=2,
25
28
  ... lr=1e-4,
@@ -43,8 +46,9 @@ class TirexClassifierTorch(torch.nn.Module):
43
46
  self,
44
47
  data_augmentation: bool = False,
45
48
  device: str | None = None,
49
+ compile: bool = False,
46
50
  # Training parameters
47
- max_epochs: int = 50,
51
+ max_epochs: int = 10,
48
52
  lr: float = 1e-4,
49
53
  weight_decay: float = 0.01,
50
54
  batch_size: int = 512,
@@ -62,11 +66,13 @@ class TirexClassifierTorch(torch.nn.Module):
62
66
 
63
67
  Args:
64
68
  data_augmentation : bool | None
65
- Whether to use data_augmentation for embeddings (stats and first-order differences of the original data). Default: False
69
+ Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
66
70
  device : str | None
67
- Device to run the model on. If None, uses CUDA if available, else CPU. Default: None
71
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
72
+ compile: bool
73
+ Whether to compile the frozen embedding model. Default: False
68
74
  max_epochs : int
69
- Maximum number of training epochs. Default: 50
75
+ Maximum number of training epochs. Default: 10
70
76
  lr : float
71
77
  Learning rate for the optimizer. Default: 1e-4
72
78
  weight_decay : float
@@ -91,15 +97,9 @@ class TirexClassifierTorch(torch.nn.Module):
91
97
  Dropout probability for the classification head. If None, no dropout is used. Default: None
92
98
  """
93
99
 
94
- super().__init__()
95
-
96
- if device is None:
97
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
98
- self.device = device
100
+ torch.nn.Module.__init__(self)
99
101
 
100
- # Create embedding model
101
- self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
102
- self.data_augmentation = data_augmentation
102
+ super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
103
103
 
104
104
  # Head parameters
105
105
  self.dropout = dropout
@@ -115,6 +115,7 @@ class TirexClassifierTorch(torch.nn.Module):
115
115
  lr=lr,
116
116
  weight_decay=weight_decay,
117
117
  class_weights=class_weights,
118
+ task_type="classification",
118
119
  batch_size=batch_size,
119
120
  val_split_ratio=val_split_ratio,
120
121
  stratify=stratify,
@@ -132,9 +133,7 @@ class TirexClassifierTorch(torch.nn.Module):
132
133
 
133
134
  @torch.inference_mode()
134
135
  def _identify_head_dims(self, x: torch.Tensor, y: torch.Tensor) -> None:
135
- self.emb_model.eval()
136
- sample_emb = self.emb_model(x[:1])
137
- self.emb_dim = sample_emb.shape[-1]
136
+ self.emb_dim = self._compute_embeddings(x[:1]).shape[-1]
138
137
  self.num_classes = len(torch.unique(y))
139
138
 
140
139
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -155,7 +154,7 @@ class TirexClassifierTorch(torch.nn.Module):
155
154
 
156
155
  def fit(
157
156
  self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None = None
158
- ) -> dict[str, float]:
157
+ ) -> TrainingMetrics:
159
158
  """Train the classification head on the provided data.
160
159
 
161
160
  This method initializes the classification head based on the data dimensions,
@@ -221,6 +220,7 @@ class TirexClassifierTorch(torch.nn.Module):
221
220
  {
222
221
  "head_state_dict": self.head.state_dict(), # need to save only head, embedding is frozen
223
222
  "data_augmentation": self.data_augmentation,
223
+ "compile": self._compile,
224
224
  "emb_dim": self.emb_dim,
225
225
  "num_classes": self.num_classes,
226
226
  "dropout": self.dropout,
@@ -230,7 +230,7 @@ class TirexClassifierTorch(torch.nn.Module):
230
230
  )
231
231
 
232
232
  @classmethod
233
- def load_model(cls, path: str) -> "TirexClassifierTorch":
233
+ def load_model(cls, path: str) -> "TirexLinearClassifier":
234
234
  """Load a saved model from file.
235
235
 
236
236
  This reconstructs the model architecture and loads the trained weights from
@@ -239,7 +239,7 @@ class TirexClassifierTorch(torch.nn.Module):
239
239
  Args:
240
240
  path: File path to the saved model checkpoint.
241
241
  Returns:
242
- TirexClassifierTorch: The loaded model with trained weights, ready for inference.
242
+ TirexLinearClassifier: The loaded model with trained weights, ready for inference.
243
243
  """
244
244
  checkpoint = torch.load(path)
245
245
 
@@ -248,6 +248,7 @@ class TirexClassifierTorch(torch.nn.Module):
248
248
 
249
249
  model = cls(
250
250
  data_augmentation=checkpoint["data_augmentation"],
251
+ compile=checkpoint["compile"],
251
252
  dropout=checkpoint["dropout"],
252
253
  max_epochs=train_config_dict.get("max_epochs", 50),
253
254
  lr=train_config_dict.get("lr", 1e-4),