tirex-mirror 2025.11.28__py3-none-any.whl → 2025.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tirex/models/classification/__init__.py +7 -7
- tirex/models/classification/embedding.py +7 -4
- tirex/models/classification/heads/base_classifier.py +93 -0
- tirex/models/classification/heads/gbm_classifier.py +210 -0
- tirex/models/classification/{linear_classifier.py → heads/linear_classifier.py} +17 -15
- tirex/models/classification/{rf_classifier.py → heads/rf_classifier.py} +17 -42
- tirex/models/classification/trainer.py +3 -0
- tirex/models/classification/utils.py +3 -0
- {tirex_mirror-2025.11.28.dist-info → tirex_mirror-2025.12.2.dist-info}/METADATA +4 -2
- tirex_mirror-2025.12.2.dist-info/RECORD +29 -0
- tirex_mirror-2025.11.28.dist-info/RECORD +0 -27
- {tirex_mirror-2025.11.28.dist-info → tirex_mirror-2025.12.2.dist-info}/WHEEL +0 -0
- {tirex_mirror-2025.11.28.dist-info → tirex_mirror-2025.12.2.dist-info}/licenses/LICENSE +0 -0
- {tirex_mirror-2025.11.28.dist-info → tirex_mirror-2025.12.2.dist-info}/licenses/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.11.28.dist-info → tirex_mirror-2025.12.2.dist-info}/licenses/NOTICE.txt +0 -0
- {tirex_mirror-2025.11.28.dist-info → tirex_mirror-2025.12.2.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
#
|
|
2
|
-
|
|
3
|
-
from .rf_classifier import TirexRFClassifier
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
4
3
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
4
|
+
from .heads.gbm_classifier import TirexGBMClassifier
|
|
5
|
+
from .heads.linear_classifier import TirexClassifierTorch
|
|
6
|
+
from .heads.rf_classifier import TirexRFClassifier
|
|
7
|
+
|
|
8
|
+
__all__ = ["TirexClassifierTorch", "TirexRFClassifier", "TirexGBMClassifier"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
2
3
|
|
|
3
|
-
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn as nn
|
|
6
6
|
import torch.nn.functional as F
|
|
@@ -11,7 +11,9 @@ from .utils import nanmax, nanmin, nanstd
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class TiRexEmbedding(nn.Module):
|
|
14
|
-
def __init__(
|
|
14
|
+
def __init__(
|
|
15
|
+
self, device: str | None = None, data_augmentation: bool = False, batch_size: int = 512, compile: bool = False
|
|
16
|
+
) -> None:
|
|
15
17
|
super().__init__()
|
|
16
18
|
self.data_augmentation = data_augmentation
|
|
17
19
|
self.number_of_patches = 8
|
|
@@ -20,8 +22,9 @@ class TiRexEmbedding(nn.Module):
|
|
|
20
22
|
if device is None:
|
|
21
23
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
22
24
|
self.device = device
|
|
25
|
+
self._compile = compile
|
|
23
26
|
|
|
24
|
-
self.model = load_model(path="NX-AI/TiRex", device=self.device)
|
|
27
|
+
self.model = load_model(path="NX-AI/TiRex", device=self.device, compile=self._compile)
|
|
25
28
|
|
|
26
29
|
def _gen_emb_batched(self, data: torch.Tensor) -> torch.Tensor:
|
|
27
30
|
batches = list(torch.split(data, self.batch_size))
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ..embedding import TiRexEmbedding
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseTirexClassifier(ABC):
|
|
12
|
+
"""Abstract base class for TiRex classification models.
|
|
13
|
+
|
|
14
|
+
This base class provides common functionality for all TiRex classifiers,
|
|
15
|
+
including embedding model initialization and a consistent interface.
|
|
16
|
+
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self, data_augmentation: bool = False, device: str | None = None, compile: bool = False, batch_size: int = 512
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Initializes a TiRex classification model.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
data_augmentation : bool
|
|
26
|
+
Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
|
|
27
|
+
device : str | None
|
|
28
|
+
Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
|
|
29
|
+
compile: bool
|
|
30
|
+
Whether to compile the frozen embedding model. Default: False
|
|
31
|
+
batch_size : int
|
|
32
|
+
Batch size for embedding calculations. Default: 512
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
# Set device
|
|
36
|
+
if device is None:
|
|
37
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
38
|
+
self.device = device
|
|
39
|
+
self._compile = compile
|
|
40
|
+
|
|
41
|
+
self.batch_size = batch_size
|
|
42
|
+
self.data_augmentation = data_augmentation
|
|
43
|
+
self.emb_model = TiRexEmbedding(
|
|
44
|
+
device=self.device,
|
|
45
|
+
data_augmentation=self.data_augmentation,
|
|
46
|
+
batch_size=self.batch_size,
|
|
47
|
+
compile=self._compile,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def fit(self, train_data: tuple[torch.Tensor, torch.Tensor]) -> None:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@torch.inference_mode()
|
|
55
|
+
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
56
|
+
"""Predict class labels for input time series data.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
x: Input time series data as torch.Tensor with shape
|
|
60
|
+
(batch_size, num_variates, seq_len).
|
|
61
|
+
Returns:
|
|
62
|
+
torch.Tensor: Predicted class labels with shape (batch_size,).
|
|
63
|
+
"""
|
|
64
|
+
self.emb_model.eval()
|
|
65
|
+
x = x.to(self.device)
|
|
66
|
+
embeddings = self.emb_model(x).cpu().numpy()
|
|
67
|
+
return torch.from_numpy(self.head.predict(embeddings)).long()
|
|
68
|
+
|
|
69
|
+
@torch.inference_mode()
|
|
70
|
+
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
+
"""Predict class probabilities for input time series data.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
x: Input time series data as torch.Tensor with shape
|
|
75
|
+
(batch_size, num_variates, seq_len).
|
|
76
|
+
Returns:
|
|
77
|
+
torch.Tensor: Class probabilities with shape (batch_size, num_classes).
|
|
78
|
+
"""
|
|
79
|
+
self.emb_model.eval()
|
|
80
|
+
x = x.to(self.device)
|
|
81
|
+
embeddings = self.emb_model(x).cpu().numpy()
|
|
82
|
+
return torch.from_numpy(self.head.predict_proba(embeddings))
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def save_model(self, path: str) -> None:
|
|
86
|
+
"""Saving model abstract method"""
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def load_model(cls, path: str):
|
|
92
|
+
"""Loading model abstract method"""
|
|
93
|
+
pass
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import joblib
|
|
5
|
+
import torch
|
|
6
|
+
from lightgbm import LGBMClassifier, early_stopping
|
|
7
|
+
|
|
8
|
+
from ..utils import train_val_split
|
|
9
|
+
from .base_classifier import BaseTirexClassifier
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TirexGBMClassifier(BaseTirexClassifier):
|
|
13
|
+
"""
|
|
14
|
+
A Gradient Boosting classifier that uses time series embeddings as features.
|
|
15
|
+
|
|
16
|
+
This classifier combines a pre-trained embedding model for feature extraction with a
|
|
17
|
+
Gradient Boosting classifier.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> from tirex.models.classification import TirexGBMClassifier
|
|
21
|
+
>>>
|
|
22
|
+
>>> # Create model with custom LightGBM parameters
|
|
23
|
+
>>> model = TirexGBMClassifier(
|
|
24
|
+
... data_augmentation=True,
|
|
25
|
+
... n_estimators=50,
|
|
26
|
+
... random_state=42
|
|
27
|
+
... )
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Prepare data (can use NumPy arrays or PyTorch tensors)
|
|
30
|
+
>>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
|
|
31
|
+
>>> y_train = torch.randint(0, 3, (100,)) # 3 classes
|
|
32
|
+
>>>
|
|
33
|
+
>>> # Train the model
|
|
34
|
+
>>> model.fit((X_train, y_train))
|
|
35
|
+
>>>
|
|
36
|
+
>>> # Make predictions
|
|
37
|
+
>>> X_test = torch.randn(20, 1, 128)
|
|
38
|
+
>>> predictions = model.predict(X_test)
|
|
39
|
+
>>> probabilities = model.predict_proba(X_test)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
data_augmentation: bool = False,
|
|
45
|
+
device: str | None = None,
|
|
46
|
+
compile: bool = False,
|
|
47
|
+
batch_size: int = 512,
|
|
48
|
+
early_stopping_rounds: int | None = 10,
|
|
49
|
+
min_delta: float = 0.0,
|
|
50
|
+
val_split_ratio: float = 0.2,
|
|
51
|
+
stratify: bool = True,
|
|
52
|
+
# LightGBM kwargs
|
|
53
|
+
**lgbm_kwargs,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""Initializes Embedding Based Gradient Boosting Classification model.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
data_augmentation : bool
|
|
59
|
+
Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
|
|
60
|
+
device : str | None
|
|
61
|
+
Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
|
|
62
|
+
compile: bool
|
|
63
|
+
Whether to compile the frozen embedding model. Default: False
|
|
64
|
+
batch_size : int
|
|
65
|
+
Batch size for embedding calculations. Default: 512
|
|
66
|
+
early_stopping_rounds: int | None
|
|
67
|
+
Number of rounds without improvement of all metrics for Early Stopping. Default: 10
|
|
68
|
+
min_delta: float
|
|
69
|
+
Minimum improvement in score to keep training. Default 0.0
|
|
70
|
+
val_split_ratio : float
|
|
71
|
+
Proportion of training data to use for validation, if validation data are not provided. Default: 0.2
|
|
72
|
+
stratify : bool
|
|
73
|
+
Whether to stratify the train/validation split by class labels. Default: True
|
|
74
|
+
**lgbm_kwargs
|
|
75
|
+
Additional keyword arguments to pass to LightGBM's LGBMClassifier.
|
|
76
|
+
Common options include n_estimators, max_depth, learning_rate, random_state, etc.
|
|
77
|
+
"""
|
|
78
|
+
super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
|
|
79
|
+
|
|
80
|
+
# Early Stopping callback
|
|
81
|
+
self.early_stopping_rounds = early_stopping_rounds
|
|
82
|
+
self.min_delta = min_delta
|
|
83
|
+
|
|
84
|
+
# Data split parameters:
|
|
85
|
+
self.val_split_ratio = val_split_ratio
|
|
86
|
+
self.stratify = stratify
|
|
87
|
+
|
|
88
|
+
# Extract random_state for train_val_split if provided
|
|
89
|
+
self.random_state = lgbm_kwargs.get("random_state", None)
|
|
90
|
+
|
|
91
|
+
self.head = LGBMClassifier(**lgbm_kwargs)
|
|
92
|
+
|
|
93
|
+
@torch.inference_mode()
|
|
94
|
+
def fit(
|
|
95
|
+
self,
|
|
96
|
+
train_data: tuple[torch.Tensor, torch.Tensor],
|
|
97
|
+
val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
98
|
+
) -> None:
|
|
99
|
+
"""Train the LightGBM classifier on embedded time series data.
|
|
100
|
+
|
|
101
|
+
This method generates embeddings for the training data using the embedding
|
|
102
|
+
model, then trains the LightGBM classifier on these embeddings.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
train_data: Tuple of (X_train, y_train) where X_train is the input time
|
|
106
|
+
series data (torch.Tensor) and y_train is a torch.Tensor
|
|
107
|
+
of class labels.
|
|
108
|
+
val_data: Optional tuple of (X_val, y_val) for validation where X_train is the input time
|
|
109
|
+
series data (torch.Tensor) and y_train is a torch.Tensor
|
|
110
|
+
of class labels. If None, validation data will be automatically split from train_data (20% split).
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
(X_train, y_train), (X_val, y_val) = self._create_train_val_datasets(
|
|
114
|
+
train_data=train_data,
|
|
115
|
+
val_data=val_data,
|
|
116
|
+
val_split_ratio=self.val_split_ratio,
|
|
117
|
+
stratify=self.stratify,
|
|
118
|
+
seed=self.random_state,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self.emb_model.eval()
|
|
122
|
+
X_train = X_train.to(self.device)
|
|
123
|
+
X_val = X_val.to(self.device)
|
|
124
|
+
embeddings_train = self.emb_model(X_train).cpu().numpy()
|
|
125
|
+
embeddings_val = self.emb_model(X_val).cpu().numpy()
|
|
126
|
+
|
|
127
|
+
y_train = y_train.detach().cpu().numpy() if isinstance(y_train, torch.Tensor) else y_train
|
|
128
|
+
y_val = y_val.detach().cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val
|
|
129
|
+
|
|
130
|
+
self.head.fit(
|
|
131
|
+
embeddings_train,
|
|
132
|
+
y_train,
|
|
133
|
+
eval_set=[(embeddings_val, y_val)],
|
|
134
|
+
callbacks=[early_stopping(stopping_rounds=self.early_stopping_rounds, min_delta=self.min_delta)]
|
|
135
|
+
if self.early_stopping_rounds is not None
|
|
136
|
+
else None,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def save_model(self, path: str) -> None:
|
|
140
|
+
"""This method saves the trained LightGBM classifier head (joblib format) and embedding information.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
path: File path where the model should be saved (e.g., 'model.joblib').
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
payload = {
|
|
147
|
+
"data_augmentation": self.data_augmentation,
|
|
148
|
+
"compile": self._compile,
|
|
149
|
+
"batch_size": self.batch_size,
|
|
150
|
+
"early_stopping_rounds": self.early_stopping_rounds,
|
|
151
|
+
"min_delta": self.min_delta,
|
|
152
|
+
"val_split_ratio": self.val_split_ratio,
|
|
153
|
+
"stratify": self.stratify,
|
|
154
|
+
"head": self.head,
|
|
155
|
+
}
|
|
156
|
+
joblib.dump(payload, path)
|
|
157
|
+
|
|
158
|
+
@classmethod
|
|
159
|
+
def load_model(cls, path: str) -> "TirexGBMClassifier":
|
|
160
|
+
"""Load a saved model from file.
|
|
161
|
+
|
|
162
|
+
This reconstructs the model with the embedding configuration and loads
|
|
163
|
+
the trained LightGBM classifier from a checkpoint file created by save_model().
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
path: File path to the saved model checkpoint.
|
|
167
|
+
Returns:
|
|
168
|
+
TirexGBMClassifier: The loaded model with trained Gradient Boosting, ready for inference.
|
|
169
|
+
"""
|
|
170
|
+
checkpoint = joblib.load(path)
|
|
171
|
+
|
|
172
|
+
# Create new instance with saved configuration
|
|
173
|
+
model = cls(
|
|
174
|
+
data_augmentation=checkpoint["data_augmentation"],
|
|
175
|
+
compile=checkpoint["compile"],
|
|
176
|
+
batch_size=checkpoint["batch_size"],
|
|
177
|
+
early_stopping_rounds=checkpoint["early_stopping_rounds"],
|
|
178
|
+
min_delta=checkpoint["min_delta"],
|
|
179
|
+
val_split_ratio=checkpoint["val_split_ratio"],
|
|
180
|
+
stratify=checkpoint["stratify"],
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Load the trained LightGBM head
|
|
184
|
+
model.head = checkpoint["head"]
|
|
185
|
+
|
|
186
|
+
# Extract random_state from the loaded head if available
|
|
187
|
+
model.random_state = getattr(model.head, "random_state", None)
|
|
188
|
+
|
|
189
|
+
return model
|
|
190
|
+
|
|
191
|
+
def _create_train_val_datasets(
|
|
192
|
+
self,
|
|
193
|
+
train_data: tuple[torch.Tensor, torch.Tensor],
|
|
194
|
+
val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
195
|
+
val_split_ratio: float = 0.2,
|
|
196
|
+
stratify: bool = True,
|
|
197
|
+
seed: int | None = None,
|
|
198
|
+
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
199
|
+
X_train, y_train = train_data
|
|
200
|
+
|
|
201
|
+
if val_data is None:
|
|
202
|
+
train_data, val_data = train_val_split(
|
|
203
|
+
train_data=train_data, val_split_ratio=val_split_ratio, stratify=stratify, seed=seed
|
|
204
|
+
)
|
|
205
|
+
X_train, y_train = train_data
|
|
206
|
+
X_val, y_val = val_data
|
|
207
|
+
else:
|
|
208
|
+
X_val, y_val = val_data
|
|
209
|
+
|
|
210
|
+
return (X_train, y_train), (X_val, y_val)
|
|
@@ -1,12 +1,15 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
1
4
|
from dataclasses import asdict
|
|
2
5
|
|
|
3
6
|
import torch
|
|
4
7
|
|
|
5
|
-
from
|
|
6
|
-
from .
|
|
8
|
+
from ..trainer import TrainConfig, Trainer
|
|
9
|
+
from .base_classifier import BaseTirexClassifier
|
|
7
10
|
|
|
8
11
|
|
|
9
|
-
class TirexClassifierTorch(torch.nn.Module):
|
|
12
|
+
class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
10
13
|
"""
|
|
11
14
|
A PyTorch classifier that combines time series embeddings with a linear classification head.
|
|
12
15
|
|
|
@@ -18,7 +21,7 @@ class TirexClassifierTorch(torch.nn.Module):
|
|
|
18
21
|
>>> import torch
|
|
19
22
|
>>> from tirex.models.classification import TirexClassifierTorch
|
|
20
23
|
>>>
|
|
21
|
-
>>> # Create model with
|
|
24
|
+
>>> # Create model with TiRex embeddings
|
|
22
25
|
>>> model = TirexClassifierTorch(
|
|
23
26
|
... data_augmentation=True,
|
|
24
27
|
... max_epochs=2,
|
|
@@ -43,8 +46,9 @@ class TirexClassifierTorch(torch.nn.Module):
|
|
|
43
46
|
self,
|
|
44
47
|
data_augmentation: bool = False,
|
|
45
48
|
device: str | None = None,
|
|
49
|
+
compile: bool = False,
|
|
46
50
|
# Training parameters
|
|
47
|
-
max_epochs: int =
|
|
51
|
+
max_epochs: int = 10,
|
|
48
52
|
lr: float = 1e-4,
|
|
49
53
|
weight_decay: float = 0.01,
|
|
50
54
|
batch_size: int = 512,
|
|
@@ -62,9 +66,11 @@ class TirexClassifierTorch(torch.nn.Module):
|
|
|
62
66
|
|
|
63
67
|
Args:
|
|
64
68
|
data_augmentation : bool | None
|
|
65
|
-
Whether to use data_augmentation for embeddings (
|
|
69
|
+
Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
|
|
66
70
|
device : str | None
|
|
67
|
-
Device to run the model on. If None, uses CUDA if available, else CPU. Default: None
|
|
71
|
+
Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
|
|
72
|
+
compile: bool
|
|
73
|
+
Whether to compile the frozen embedding model. Default: False
|
|
68
74
|
max_epochs : int
|
|
69
75
|
Maximum number of training epochs. Default: 50
|
|
70
76
|
lr : float
|
|
@@ -91,15 +97,9 @@ class TirexClassifierTorch(torch.nn.Module):
|
|
|
91
97
|
Dropout probability for the classification head. If None, no dropout is used. Default: None
|
|
92
98
|
"""
|
|
93
99
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
if device is None:
|
|
97
|
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
98
|
-
self.device = device
|
|
100
|
+
torch.nn.Module.__init__(self)
|
|
99
101
|
|
|
100
|
-
|
|
101
|
-
self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
|
|
102
|
-
self.data_augmentation = data_augmentation
|
|
102
|
+
super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
|
|
103
103
|
|
|
104
104
|
# Head parameters
|
|
105
105
|
self.dropout = dropout
|
|
@@ -221,6 +221,7 @@ class TirexClassifierTorch(torch.nn.Module):
|
|
|
221
221
|
{
|
|
222
222
|
"head_state_dict": self.head.state_dict(), # need to save only head, embedding is frozen
|
|
223
223
|
"data_augmentation": self.data_augmentation,
|
|
224
|
+
"compile": self._compile,
|
|
224
225
|
"emb_dim": self.emb_dim,
|
|
225
226
|
"num_classes": self.num_classes,
|
|
226
227
|
"dropout": self.dropout,
|
|
@@ -248,6 +249,7 @@ class TirexClassifierTorch(torch.nn.Module):
|
|
|
248
249
|
|
|
249
250
|
model = cls(
|
|
250
251
|
data_augmentation=checkpoint["data_augmentation"],
|
|
252
|
+
compile=checkpoint["compile"],
|
|
251
253
|
dropout=checkpoint["dropout"],
|
|
252
254
|
max_epochs=train_config_dict.get("max_epochs", 50),
|
|
253
255
|
lr=train_config_dict.get("lr", 1e-4),
|
|
@@ -1,12 +1,14 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
1
4
|
import joblib
|
|
2
|
-
import numpy as np
|
|
3
5
|
import torch
|
|
4
6
|
from sklearn.ensemble import RandomForestClassifier
|
|
5
7
|
|
|
6
|
-
from .
|
|
8
|
+
from .base_classifier import BaseTirexClassifier
|
|
7
9
|
|
|
8
10
|
|
|
9
|
-
class TirexRFClassifier:
|
|
11
|
+
class TirexRFClassifier(BaseTirexClassifier):
|
|
10
12
|
"""
|
|
11
13
|
A Random Forest classifier that uses time series embeddings as features.
|
|
12
14
|
|
|
@@ -15,7 +17,6 @@ class TirexRFClassifier:
|
|
|
15
17
|
time series, which are then used to train the Random Forest.
|
|
16
18
|
|
|
17
19
|
Example:
|
|
18
|
-
>>> import numpy as np
|
|
19
20
|
>>> from tirex.models.classification import TirexRFClassifier
|
|
20
21
|
>>>
|
|
21
22
|
>>> # Create model with custom Random Forest parameters
|
|
@@ -43,6 +44,7 @@ class TirexRFClassifier:
|
|
|
43
44
|
self,
|
|
44
45
|
data_augmentation: bool = False,
|
|
45
46
|
device: str | None = None,
|
|
47
|
+
compile: bool = False,
|
|
46
48
|
batch_size: int = 512,
|
|
47
49
|
# Random Forest parameters
|
|
48
50
|
**rf_kwargs,
|
|
@@ -51,24 +53,18 @@ class TirexRFClassifier:
|
|
|
51
53
|
|
|
52
54
|
Args:
|
|
53
55
|
data_augmentation : bool
|
|
54
|
-
Whether to use data_augmentation for embeddings (
|
|
56
|
+
Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
|
|
55
57
|
device : str | None
|
|
56
58
|
Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
|
|
59
|
+
compile: bool
|
|
60
|
+
Whether to compile the frozen embedding model. Default: False
|
|
57
61
|
batch_size : int
|
|
58
62
|
Batch size for embedding calculations. Default: 512
|
|
59
63
|
**rf_kwargs
|
|
60
64
|
Additional keyword arguments to pass to sklearn's RandomForestClassifier.
|
|
61
65
|
Common options include n_estimators, max_depth, min_samples_split, random_state, etc.
|
|
62
66
|
"""
|
|
63
|
-
|
|
64
|
-
# Set device
|
|
65
|
-
if device is None:
|
|
66
|
-
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
67
|
-
self.device = device
|
|
68
|
-
|
|
69
|
-
self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
|
|
70
|
-
self.data_augmentation = data_augmentation
|
|
71
|
-
|
|
67
|
+
super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
|
|
72
68
|
self.head = RandomForestClassifier(**rf_kwargs)
|
|
73
69
|
|
|
74
70
|
@torch.inference_mode()
|
|
@@ -80,7 +76,7 @@ class TirexRFClassifier:
|
|
|
80
76
|
|
|
81
77
|
Args:
|
|
82
78
|
train_data: Tuple of (X_train, y_train) where X_train is the input time
|
|
83
|
-
series data (torch.Tensor) and y_train is a
|
|
79
|
+
series data (torch.Tensor) and y_train is a torch.Tensor
|
|
84
80
|
of class labels.
|
|
85
81
|
"""
|
|
86
82
|
X_train, y_train = train_data
|
|
@@ -88,36 +84,11 @@ class TirexRFClassifier:
|
|
|
88
84
|
if isinstance(y_train, torch.Tensor):
|
|
89
85
|
y_train = y_train.detach().cpu().numpy()
|
|
90
86
|
|
|
87
|
+
self.emb_model.eval()
|
|
88
|
+
X_train = X_train.to(self.device)
|
|
91
89
|
embeddings = self.emb_model(X_train).cpu().numpy()
|
|
92
90
|
self.head.fit(embeddings, y_train)
|
|
93
91
|
|
|
94
|
-
@torch.inference_mode()
|
|
95
|
-
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
96
|
-
"""Predict class labels for input time series data.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
x: Input time series data as torch.Tensor or np.ndarray with shape
|
|
100
|
-
(batch_size, num_variates, seq_len).
|
|
101
|
-
Returns:
|
|
102
|
-
torch.Tensor: Predicted class labels with shape (batch_size,).
|
|
103
|
-
"""
|
|
104
|
-
|
|
105
|
-
embeddings = self.emb_model(x).cpu().numpy()
|
|
106
|
-
return torch.from_numpy(self.head.predict(embeddings)).long()
|
|
107
|
-
|
|
108
|
-
@torch.inference_mode()
|
|
109
|
-
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
|
|
110
|
-
"""Predict class probabilities for input time series data.
|
|
111
|
-
|
|
112
|
-
Args:
|
|
113
|
-
x: Input time series data as torch.Tensor or np.ndarray with shape
|
|
114
|
-
(batch_size, num_variates, seq_len).
|
|
115
|
-
Returns:
|
|
116
|
-
torch.Tensor: Class probabilities with shape (batch_size, num_classes).
|
|
117
|
-
"""
|
|
118
|
-
embeddings = self.emb_model(x).cpu().numpy()
|
|
119
|
-
return torch.from_numpy(self.head.predict_proba(embeddings))
|
|
120
|
-
|
|
121
92
|
def save_model(self, path: str) -> None:
|
|
122
93
|
"""This method saves the trained Random Forest classifier head and embedding information in joblib format
|
|
123
94
|
|
|
@@ -126,6 +97,8 @@ class TirexRFClassifier:
|
|
|
126
97
|
"""
|
|
127
98
|
payload = {
|
|
128
99
|
"data_augmentation": self.data_augmentation,
|
|
100
|
+
"compile": self._compile,
|
|
101
|
+
"batch_size": self.batch_size,
|
|
129
102
|
"head": self.head,
|
|
130
103
|
}
|
|
131
104
|
joblib.dump(payload, path)
|
|
@@ -147,6 +120,8 @@ class TirexRFClassifier:
|
|
|
147
120
|
# Create new instance with saved configuration
|
|
148
121
|
model = cls(
|
|
149
122
|
data_augmentation=checkpoint["data_augmentation"],
|
|
123
|
+
compile=checkpoint["compile"],
|
|
124
|
+
batch_size=checkpoint["batch_size"],
|
|
150
125
|
)
|
|
151
126
|
|
|
152
127
|
# Load the trained Random Forest head
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tirex-mirror
|
|
3
|
-
Version: 2025.
|
|
3
|
+
Version: 2025.12.2
|
|
4
4
|
Summary: Unofficial mirror of NX-AI/tirex for packaging
|
|
5
5
|
Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
|
|
6
6
|
License: NXAI COMMUNITY LICENSE AGREEMENT
|
|
@@ -55,7 +55,7 @@ License: NXAI COMMUNITY LICENSE AGREEMENT
|
|
|
55
55
|
|
|
56
56
|
Project-URL: Repository, https://github.com/rozsasarpi/tirex-mirror
|
|
57
57
|
Project-URL: Issues, https://github.com/rozsasarpi/tirex-mirror/issues
|
|
58
|
-
Keywords: TiRex,xLSTM,Time Series,Zero-shot,Deep Learning
|
|
58
|
+
Keywords: TiRex,xLSTM,Time Series,Zero-shot,Deep Learning,Classification,Timeseries-Classification
|
|
59
59
|
Classifier: Programming Language :: Python :: 3
|
|
60
60
|
Classifier: Operating System :: OS Independent
|
|
61
61
|
Requires-Python: >=3.11
|
|
@@ -86,6 +86,7 @@ Requires-Dist: fev>=0.6.0; extra == "test"
|
|
|
86
86
|
Requires-Dist: pytest; extra == "test"
|
|
87
87
|
Provides-Extra: classification
|
|
88
88
|
Requires-Dist: scikit-learn; extra == "classification"
|
|
89
|
+
Requires-Dist: lightgbm[scikit-learn]; extra == "classification"
|
|
89
90
|
Provides-Extra: all
|
|
90
91
|
Requires-Dist: xlstm; extra == "all"
|
|
91
92
|
Requires-Dist: ninja; extra == "all"
|
|
@@ -98,6 +99,7 @@ Requires-Dist: datasets; extra == "all"
|
|
|
98
99
|
Requires-Dist: pytest; extra == "all"
|
|
99
100
|
Requires-Dist: fev>=0.6.0; extra == "all"
|
|
100
101
|
Requires-Dist: scikit-learn; extra == "all"
|
|
102
|
+
Requires-Dist: lightgbm[scikit-learn]; extra == "all"
|
|
101
103
|
Dynamic: license-file
|
|
102
104
|
|
|
103
105
|
# tirex-mirror
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
+
tirex/base.py,sha256=pGShyI3LwPAwkl0rIYqik5qb8uBAWqbQrVoVjq7UH_8,5096
|
|
3
|
+
tirex/util.py,sha256=ggfETirJ589Dr5o3QThxTnjhgNnMCk10bNoJghnoeoA,31672
|
|
4
|
+
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
|
+
tirex/api_adapter/forecast.py,sha256=FnRgdI_4vJ7iFTqyWwxPE_C7MBB3nhLTnVMzTr5pihc,17581
|
|
6
|
+
tirex/api_adapter/gluon.py,sha256=0KfKX7dGSooVIUumGtjK6Va2YgI6Isa_8kz9j0oPXUM,1883
|
|
7
|
+
tirex/api_adapter/hf_data.py,sha256=TRyys2xKIGZS0Yhq2Eb61lWCMg5CWWn1yRlLIN1mU7o,1369
|
|
8
|
+
tirex/api_adapter/standard_adapter.py,sha256=vdlxNs8mTUtPgK_5WMqYqNdMj8W44igqWsAgtggt_xk,2809
|
|
9
|
+
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
|
+
tirex/models/patcher.py,sha256=8T4c3PZnOAsEpahhrjtt7S7405WUjN6g3cV33E55PD4,1911
|
|
11
|
+
tirex/models/tirex.py,sha256=wtdjrdE1TuoueIvhlKf-deLH3oKyQuGmd2k15My7SWA,9710
|
|
12
|
+
tirex/models/classification/__init__.py,sha256=0PKH8Cr-VWT4x4SXzGAyj90jMJbxzlIyrBIiV1ZXfm4,377
|
|
13
|
+
tirex/models/classification/embedding.py,sha256=wxIefIpccHe64Cb2JlmL0l0PhqYB_zdBHSh5auPqdlY,5452
|
|
14
|
+
tirex/models/classification/trainer.py,sha256=CbYfdTNdj5iucchxNwgg35EF-MzrdjrMm9auNOLae4U,5895
|
|
15
|
+
tirex/models/classification/utils.py,sha256=JbloQu4KlBHfdlH1NYYg7_VIoH-t2HBo5ovdyAEoCGs,2735
|
|
16
|
+
tirex/models/classification/heads/base_classifier.py,sha256=mBJV7jShBAUQgMpidow8lVfqhufNHvEIwDnidtxA1Hw,3236
|
|
17
|
+
tirex/models/classification/heads/gbm_classifier.py,sha256=V1CzfmJF6Tp2j14F3dAmRZ4wxBsuKoK1jQu4HbjttCc,8365
|
|
18
|
+
tirex/models/classification/heads/linear_classifier.py,sha256=2BI2GqQvBiHjFJ7_dS8rGQdQ-pnpF7-9sZv_YRUL-8w,11351
|
|
19
|
+
tirex/models/classification/heads/rf_classifier.py,sha256=Nsw3mxToa804txYzkvkRQwq34whEb5x-9_6ywskylBk,5063
|
|
20
|
+
tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
|
|
21
|
+
tirex/models/slstm/cell.py,sha256=Otyil_AjpJbUckkINWGHxlqP14J5epm_J_zdWPzvD2g,7290
|
|
22
|
+
tirex/models/slstm/layer.py,sha256=hrDydQJIAHf5W0A0Rt0hXG4yKXrOSY-HPL0UbigR6Q8,2867
|
|
23
|
+
tirex_mirror-2025.12.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
24
|
+
tirex_mirror-2025.12.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
25
|
+
tirex_mirror-2025.12.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
26
|
+
tirex_mirror-2025.12.2.dist-info/METADATA,sha256=XU9yHQPFUYg5qiqMSNoOJRGmSE9dY3GUhGfn8m4LouM,11783
|
|
27
|
+
tirex_mirror-2025.12.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
28
|
+
tirex_mirror-2025.12.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
29
|
+
tirex_mirror-2025.12.2.dist-info/RECORD,,
|
|
@@ -1,27 +0,0 @@
|
|
|
1
|
-
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
-
tirex/base.py,sha256=pGShyI3LwPAwkl0rIYqik5qb8uBAWqbQrVoVjq7UH_8,5096
|
|
3
|
-
tirex/util.py,sha256=ggfETirJ589Dr5o3QThxTnjhgNnMCk10bNoJghnoeoA,31672
|
|
4
|
-
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
|
-
tirex/api_adapter/forecast.py,sha256=FnRgdI_4vJ7iFTqyWwxPE_C7MBB3nhLTnVMzTr5pihc,17581
|
|
6
|
-
tirex/api_adapter/gluon.py,sha256=0KfKX7dGSooVIUumGtjK6Va2YgI6Isa_8kz9j0oPXUM,1883
|
|
7
|
-
tirex/api_adapter/hf_data.py,sha256=TRyys2xKIGZS0Yhq2Eb61lWCMg5CWWn1yRlLIN1mU7o,1369
|
|
8
|
-
tirex/api_adapter/standard_adapter.py,sha256=vdlxNs8mTUtPgK_5WMqYqNdMj8W44igqWsAgtggt_xk,2809
|
|
9
|
-
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
|
-
tirex/models/patcher.py,sha256=8T4c3PZnOAsEpahhrjtt7S7405WUjN6g3cV33E55PD4,1911
|
|
11
|
-
tirex/models/tirex.py,sha256=wtdjrdE1TuoueIvhlKf-deLH3oKyQuGmd2k15My7SWA,9710
|
|
12
|
-
tirex/models/classification/__init__.py,sha256=qVn84uBosWaHm9wr0FoYAXNvmajyyB3_OmpeHNzDH4g,194
|
|
13
|
-
tirex/models/classification/embedding.py,sha256=TJlchUaBhz8Pf1mLpTdDVqmkGVOiaCI55sSoqk2tSXE,5259
|
|
14
|
-
tirex/models/classification/linear_classifier.py,sha256=yE2sekKw2LCLGiITTsTktOJw28RjjOWgsJE2PakAjN8,11142
|
|
15
|
-
tirex/models/classification/rf_classifier.py,sha256=wzMiF5TELGl3V-i94MPOEpNn2yZgrCdPfuzpcl8A18M,5816
|
|
16
|
-
tirex/models/classification/trainer.py,sha256=JzM1XtRoRI3fn6Sbu7V-9IuiKVy454O73uNrMNgCREs,5759
|
|
17
|
-
tirex/models/classification/utils.py,sha256=db9056u6uIVhm0qmHDoOo3K5f-ZiuDytDLcLOg-zFb0,2599
|
|
18
|
-
tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
|
|
19
|
-
tirex/models/slstm/cell.py,sha256=Otyil_AjpJbUckkINWGHxlqP14J5epm_J_zdWPzvD2g,7290
|
|
20
|
-
tirex/models/slstm/layer.py,sha256=hrDydQJIAHf5W0A0Rt0hXG4yKXrOSY-HPL0UbigR6Q8,2867
|
|
21
|
-
tirex_mirror-2025.11.28.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
22
|
-
tirex_mirror-2025.11.28.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
23
|
-
tirex_mirror-2025.11.28.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
24
|
-
tirex_mirror-2025.11.28.dist-info/METADATA,sha256=CQV9y4KSrvo1hxsac_7nBYcvB8SXb88ix2m9nolO-gg,11624
|
|
25
|
-
tirex_mirror-2025.11.28.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
26
|
-
tirex_mirror-2025.11.28.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
27
|
-
tirex_mirror-2025.11.28.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.11.28.dist-info → tirex_mirror-2025.12.2.dist-info}/licenses/LICENSE_MIRROR.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|