tirex-mirror 2025.11.26__py3-none-any.whl → 2025.11.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ # classification/__init__.py
2
+ from .linear_classifier import TirexClassifierTorch
3
+ from .rf_classifier import TirexRFClassifier
4
+
5
+ __all__ = [
6
+ "TirexClassifierTorch",
7
+ "TirexRFClassifier",
8
+ ]
@@ -0,0 +1,125 @@
1
+ import inspect
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from tirex import load_model
9
+
10
+ from .utils import nanmax, nanmin, nanstd
11
+
12
+
13
+ class TiRexEmbedding(nn.Module):
14
+ def __init__(self, device: str | None = None, data_augmentation: bool = False, batch_size: int = 512) -> None:
15
+ super().__init__()
16
+ self.data_augmentation = data_augmentation
17
+ self.number_of_patches = 8
18
+ self.batch_size = batch_size
19
+
20
+ if device is None:
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ self.device = device
23
+
24
+ self.model = load_model(path="NX-AI/TiRex", device=self.device)
25
+
26
+ def _gen_emb_batched(self, data: torch.Tensor) -> torch.Tensor:
27
+ batches = list(torch.split(data, self.batch_size))
28
+ embedding_list = []
29
+ for batch in batches:
30
+ embedding = self.model._embed_context(batch)
31
+ embedding_list.append(embedding.cpu())
32
+ return torch.cat(embedding_list, dim=0)
33
+
34
+ def _calculate_n_patches(self, data: torch.Tensor) -> int:
35
+ _, _, n_steps = data.shape
36
+ n_patches = -(-n_steps // self.model.config.input_patch_size)
37
+ return n_patches
38
+
39
+ def forward(self, data: torch.Tensor) -> torch.Tensor:
40
+ n_patches = self._calculate_n_patches(data)
41
+
42
+ embedding = torch.stack(
43
+ [self._gen_emb_batched(var_slice) for var_slice in torch.unbind(data, dim=1)], dim=1
44
+ ) # Stack in case of multivar
45
+ embedding = self.process_embedding(embedding, n_patches)
46
+
47
+ if self.data_augmentation:
48
+ # Difference Embedding
49
+ diff_data = torch.diff(data, dim=-1, prepend=data[..., :0])
50
+ n_patches = self._calculate_n_patches(diff_data)
51
+
52
+ diff_embedding = torch.stack(
53
+ [self._gen_emb_batched(var_slice) for var_slice in torch.unbind(diff_data, dim=1)], dim=1
54
+ )
55
+ diff_embedding = self.process_embedding(diff_embedding, n_patches)
56
+ embedding = torch.cat((diff_embedding, embedding), dim=-1)
57
+
58
+ # Stats Embedding
59
+ stat_features = self._generate_stats_features(data)
60
+ normalized_stats = self._normalize_stats(stat_features)
61
+ normalized_stats = normalized_stats.to(embedding.device)
62
+
63
+ # Concat all together
64
+ embedding = torch.cat((embedding, normalized_stats), dim=-1)
65
+
66
+ return embedding
67
+
68
+ def process_embedding(self, embedding: torch.Tensor, n_patches: int) -> torch.Tensor:
69
+ # embedding shape: (bs, var_dim, n_patches, n_layer, emb_dim)
70
+ embedding = embedding[:, :, -n_patches:, :, :]
71
+ embedding = torch.mean(embedding, dim=2) # sequence
72
+ embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
73
+ embedding = torch.transpose(embedding, 1, -2).flatten(start_dim=-2) # var
74
+ embedding = torch.transpose(embedding, 1, -2).flatten(start_dim=-2) # layer
75
+ embedding = F.layer_norm(embedding, (embedding.shape[-1],))
76
+ return embedding
77
+
78
+ def _normalize_stats(self, stat_features: torch.Tensor) -> torch.Tensor:
79
+ dataset_mean = torch.nanmean(stat_features, dim=0, keepdim=True)
80
+ dataset_std = nanstd(stat_features, dim=0, keepdim=True)
81
+ stat_features = (stat_features - dataset_mean) / (dataset_std + 1e-8)
82
+ stat_features = torch.nan_to_num(stat_features, nan=0.0)
83
+
84
+ stat_features = (stat_features - stat_features.nanmean(dim=-1, keepdim=True)) / (
85
+ stat_features.std(dim=-1, keepdim=True) + 1e-8
86
+ )
87
+ return stat_features
88
+
89
+ def _generate_stats_features(self, data: torch.Tensor) -> torch.Tensor:
90
+ bs, variates, n_steps = data.shape
91
+
92
+ patch_size = max(1, n_steps // self.number_of_patches)
93
+ n_full_patches = n_steps // patch_size
94
+ n_remain = n_steps % patch_size
95
+
96
+ # [batch, variates, n_patches, patch_size]
97
+ patches = data[..., : n_full_patches * patch_size].unfold(-1, patch_size, patch_size)
98
+
99
+ # Stats for full patches
100
+ patch_means = torch.nanmean(patches, dim=-1)
101
+ patch_stds = nanstd(patches, dim=-1)
102
+ patch_maxes = nanmax(patches, dim=-1)
103
+ patch_mins = nanmin(patches, dim=-1)
104
+
105
+ stats = [patch_means, patch_stds, patch_maxes, patch_mins]
106
+
107
+ # Handle last smaller patch if needed
108
+ if n_remain > 0:
109
+ self._handle_remaining_patch(data, stats, n_full_patches * patch_size)
110
+
111
+ stats = torch.stack(stats, dim=-1) # [batch, variates, n_patches(+1), 4]
112
+ return stats.flatten(start_dim=1) # [batch, variates * n_patches * 4]
113
+
114
+ def _handle_remaining_patch(self, data: torch.Tensor, stats: list[torch.Tensor], full_patch_length: int) -> None:
115
+ last_patch = data[..., full_patch_length:]
116
+
117
+ mean_last = last_patch.mean(dim=-1, keepdim=True)
118
+ std_last = last_patch.std(dim=-1, keepdim=True)
119
+ max_last = last_patch.max(dim=-1, keepdim=True)
120
+ min_last = last_patch.min(dim=-1, keepdim=True)
121
+
122
+ stats[0] = torch.cat([stats[0], mean_last], dim=-1)
123
+ stats[1] = torch.cat([stats[1], std_last], dim=-1)
124
+ stats[2] = torch.cat([stats[2], max_last], dim=-1)
125
+ stats[3] = torch.cat([stats[3], min_last], dim=-1)
@@ -0,0 +1,274 @@
1
+ from dataclasses import asdict
2
+
3
+ import torch
4
+
5
+ from .embedding import TiRexEmbedding
6
+ from .trainer import TrainConfig, Trainer
7
+
8
+
9
+ class TirexClassifierTorch(torch.nn.Module):
10
+ """
11
+ A PyTorch classifier that combines time series embeddings with a linear classification head.
12
+
13
+ This model uses a pre-trained TiRex embedding model to generate feature representations from time series
14
+ data, followed by a linear layer (with optional dropout) for classification. The embedding backbone
15
+ is frozen during training, and only the classification head is trained.
16
+
17
+ Example:
18
+ >>> import torch
19
+ >>> from tirex.models.classification import TirexClassifierTorch
20
+ >>>
21
+ >>> # Create model with TIREX embeddings
22
+ >>> model = TirexClassifierTorch(
23
+ ... data_augmentation=True,
24
+ ... max_epochs=2,
25
+ ... lr=1e-4,
26
+ ... batch_size=32
27
+ ... )
28
+ >>>
29
+ >>> # Prepare data
30
+ >>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
31
+ >>> y_train = torch.randint(0, 3, (100,)) # 3 classes
32
+ >>>
33
+ >>> # Train the model
34
+ >>> metrics = model.fit((X_train, y_train)) # doctest: +ELLIPSIS
35
+ Epoch 1, Train Loss: ...
36
+ >>> # Make predictions
37
+ >>> X_test = torch.randn(20, 1, 128)
38
+ >>> predictions = model.predict(X_test)
39
+ >>> probabilities = model.predict_proba(X_test)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ data_augmentation: bool = False,
45
+ device: str | None = None,
46
+ # Training parameters
47
+ max_epochs: int = 50,
48
+ lr: float = 1e-4,
49
+ weight_decay: float = 0.01,
50
+ batch_size: int = 512,
51
+ val_split_ratio: float = 0.2,
52
+ stratify: bool = True,
53
+ patience: int = 7,
54
+ delta: float = 0.001,
55
+ log_every_n_steps: int = 5,
56
+ seed: int | None = None,
57
+ class_weights: torch.Tensor | None = None,
58
+ # Head parameters
59
+ dropout: float | None = None,
60
+ ) -> None:
61
+ """Initializes Embedding Based Linear Classification model.
62
+
63
+ Args:
64
+ data_augmentation : bool | None
65
+ Whether to use data_augmentation for embeddings (stats and first-order differences of the original data). Default: False
66
+ device : str | None
67
+ Device to run the model on. If None, uses CUDA if available, else CPU. Default: None
68
+ max_epochs : int
69
+ Maximum number of training epochs. Default: 50
70
+ lr : float
71
+ Learning rate for the optimizer. Default: 1e-4
72
+ weight_decay : float
73
+ Weight decay coefficient. Default: 0.01
74
+ batch_size : int
75
+ Batch size for training and embedding calculations. Default: 512
76
+ val_split_ratio : float
77
+ Proportion of training data to use for validation, if validation data are not provided. Default: 0.2
78
+ stratify : bool
79
+ Whether to stratify the train/validation split by class labels. Default: True
80
+ patience : int
81
+ Number of epochs to wait for improvement before early stopping. Default: 7
82
+ delta : float
83
+ Minimum change in validation loss to qualify as an improvement. Default: 0.001
84
+ log_every_n_steps : int
85
+ Frequency of logging during training. Default: 5
86
+ seed : int | None
87
+ Random seed for reproducibility. If None, no seed is set. Default: None
88
+ class_weights: torch.Tensor | None
89
+ Weight classes according to given values, has to be a Tensor of size number of categories. Default: None
90
+ dropout : float | None
91
+ Dropout probability for the classification head. If None, no dropout is used. Default: None
92
+ """
93
+
94
+ super().__init__()
95
+
96
+ if device is None:
97
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
98
+ self.device = device
99
+
100
+ # Create embedding model
101
+ self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
102
+ self.data_augmentation = data_augmentation
103
+
104
+ # Head parameters
105
+ self.dropout = dropout
106
+ self.head = None
107
+ self.emb_dim = None
108
+ self.num_classes = None
109
+
110
+ # Train config
111
+ train_config = TrainConfig(
112
+ max_epochs=max_epochs,
113
+ log_every_n_steps=log_every_n_steps,
114
+ device=self.device,
115
+ lr=lr,
116
+ weight_decay=weight_decay,
117
+ class_weights=class_weights,
118
+ batch_size=batch_size,
119
+ val_split_ratio=val_split_ratio,
120
+ stratify=stratify,
121
+ patience=patience,
122
+ delta=delta,
123
+ seed=seed,
124
+ )
125
+ self.trainer = Trainer(self, train_config=train_config)
126
+
127
+ def _init_classifier(self, emb_dim: int, num_classes: int, dropout: float | None) -> torch.nn.Module:
128
+ if dropout:
129
+ return torch.nn.Sequential(torch.nn.Dropout(p=dropout), torch.nn.Linear(emb_dim, num_classes))
130
+ else:
131
+ return torch.nn.Linear(emb_dim, num_classes)
132
+
133
+ @torch.inference_mode()
134
+ def _identify_head_dims(self, x: torch.Tensor, y: torch.Tensor) -> None:
135
+ self.emb_model.eval()
136
+ sample_emb = self.emb_model(x[:1])
137
+ self.emb_dim = sample_emb.shape[-1]
138
+ self.num_classes = len(torch.unique(y))
139
+
140
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
141
+ """Forward pass through the embedding model and classification head.
142
+
143
+ Args:
144
+ x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
145
+ Returns:
146
+ torch.Tensor: Logits for each class with shape (batch_size, num_classes).
147
+ Raises:
148
+ RuntimeError: If the classification head has not been initialized via fit().
149
+ """
150
+ if self.head is None:
151
+ raise RuntimeError("Head not initialized. Call fit() first to automatically build the head.")
152
+
153
+ embedding = self.emb_model(x).to(self.device)
154
+ return self.head(embedding)
155
+
156
+ def fit(
157
+ self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None = None
158
+ ) -> dict[str, float]:
159
+ """Train the classification head on the provided data.
160
+
161
+ This method initializes the classification head based on the data dimensions,
162
+ then trains it on provided data. The embedding model remains frozen.
163
+
164
+ Args:
165
+ train_data: Tuple of (X_train, y_train) where X_train is the input time series
166
+ data and y_train are the corresponding class labels.
167
+ val_data: Optional tuple of (X_val, y_val) for validation. If None and
168
+ val_split_ratio > 0, validation data will be split from train_data.
169
+
170
+ Returns:
171
+ dict[str, float]: Dictionary containing final training and validation losses.
172
+ """
173
+ X_train, y_train = train_data
174
+
175
+ self._identify_head_dims(X_train, y_train)
176
+ self.head = self._init_classifier(self.emb_dim, self.num_classes, self.dropout)
177
+ self.head = self.head.to(self.trainer.device)
178
+
179
+ return self.trainer.fit(train_data, val_data=val_data)
180
+
181
+ @torch.inference_mode()
182
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
183
+ """Predict class labels for input time series data.
184
+
185
+ Args:
186
+ x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
187
+ Returns:
188
+ torch.Tensor: Predicted class labels with shape (batch_size,).
189
+ """
190
+ self.eval()
191
+ x = x.to(self.device)
192
+ logits = self.forward(x)
193
+ return torch.argmax(logits, dim=1)
194
+
195
+ @torch.inference_mode()
196
+ def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
197
+ """Predict class probabilities for input time series data.
198
+
199
+ Args:
200
+ x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
201
+ Returns:
202
+ torch.Tensor: Class probabilities with shape (batch_size, num_classes).
203
+ """
204
+ self.eval()
205
+ x = x.to(self.device)
206
+ logits = self.forward(x)
207
+ return torch.softmax(logits, dim=1)
208
+
209
+ def save_model(self, path: str) -> None:
210
+ """Save the trained classification head.
211
+
212
+ This function saves the trained classification head weights (.pt format), embedding configuration,
213
+ model dimensions, and device information. The embedding model itself is not
214
+ saved as it uses a pre-trained backbone that can be reloaded.
215
+
216
+ Args:
217
+ path: File path where the model should be saved (e.g., 'model.pt').
218
+ """
219
+ train_config_dict = asdict(self.trainer.train_config)
220
+ torch.save(
221
+ {
222
+ "head_state_dict": self.head.state_dict(), # need to save only head, embedding is frozen
223
+ "data_augmentation": self.data_augmentation,
224
+ "emb_dim": self.emb_dim,
225
+ "num_classes": self.num_classes,
226
+ "dropout": self.dropout,
227
+ "train_config": train_config_dict,
228
+ },
229
+ path,
230
+ )
231
+
232
+ @classmethod
233
+ def load_model(cls, path: str) -> "TirexClassifierTorch":
234
+ """Load a saved model from file.
235
+
236
+ This reconstructs the model architecture and loads the trained weights from
237
+ a checkpoint file created by save_model().
238
+
239
+ Args:
240
+ path: File path to the saved model checkpoint.
241
+ Returns:
242
+ TirexClassifierTorch: The loaded model with trained weights, ready for inference.
243
+ """
244
+ checkpoint = torch.load(path)
245
+
246
+ # Extract train_config if available, otherwise use defaults
247
+ train_config_dict = checkpoint.get("train_config", {})
248
+
249
+ model = cls(
250
+ data_augmentation=checkpoint["data_augmentation"],
251
+ dropout=checkpoint["dropout"],
252
+ max_epochs=train_config_dict.get("max_epochs", 50),
253
+ lr=train_config_dict.get("lr", 1e-4),
254
+ weight_decay=train_config_dict.get("weight_decay", 0.01),
255
+ batch_size=train_config_dict.get("batch_size", 512),
256
+ val_split_ratio=train_config_dict.get("val_split_ratio", 0.2),
257
+ stratify=train_config_dict.get("stratify", True),
258
+ patience=train_config_dict.get("patience", 7),
259
+ delta=train_config_dict.get("delta", 0.001),
260
+ log_every_n_steps=train_config_dict.get("log_every_n_steps", 5),
261
+ seed=train_config_dict.get("seed", None),
262
+ class_weights=train_config_dict.get("class_weights", None),
263
+ )
264
+
265
+ # Initialize head with dimensions
266
+ model.emb_dim = checkpoint["emb_dim"]
267
+ model.num_classes = checkpoint["num_classes"]
268
+ model.head = model._init_classifier(model.emb_dim, model.num_classes, model.dropout)
269
+
270
+ # Load the trained weights
271
+ model.head.load_state_dict(checkpoint["head_state_dict"])
272
+ model.to(model.device)
273
+
274
+ return model
@@ -0,0 +1,155 @@
1
+ import joblib
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.ensemble import RandomForestClassifier
5
+
6
+ from .embedding import TiRexEmbedding
7
+
8
+
9
+ class TirexRFClassifier:
10
+ """
11
+ A Random Forest classifier that uses time series embeddings as features.
12
+
13
+ This classifier combines a pre-trained embedding model for feature extraction with a scikit-learn
14
+ Random Forest classifier. The embedding model generates fixed-size feature vectors from variable-length
15
+ time series, which are then used to train the Random Forest.
16
+
17
+ Example:
18
+ >>> import numpy as np
19
+ >>> from tirex.models.classification import TirexRFClassifier
20
+ >>>
21
+ >>> # Create model with custom Random Forest parameters
22
+ >>> model = TirexRFClassifier(
23
+ ... data_augmentation=True,
24
+ ... n_estimators=50,
25
+ ... max_depth=10,
26
+ ... random_state=42
27
+ ... )
28
+ >>>
29
+ >>> # Prepare data (can use NumPy arrays or PyTorch tensors)
30
+ >>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
31
+ >>> y_train = torch.randint(0, 3, (100,)) # 3 classes
32
+ >>>
33
+ >>> # Train the model
34
+ >>> model.fit((X_train, y_train))
35
+ >>>
36
+ >>> # Make predictions
37
+ >>> X_test = torch.randn(20, 1, 128)
38
+ >>> predictions = model.predict(X_test)
39
+ >>> probabilities = model.predict_proba(X_test)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ data_augmentation: bool = False,
45
+ device: str | None = None,
46
+ batch_size: int = 512,
47
+ # Random Forest parameters
48
+ **rf_kwargs,
49
+ ) -> None:
50
+ """Initializes Embedding Based Random Forest Classification model.
51
+
52
+ Args:
53
+ data_augmentation : bool
54
+ Whether to use data_augmentation for embeddings (stats and first-order differences of the original data). Default: False
55
+ device : str | None
56
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
57
+ batch_size : int
58
+ Batch size for embedding calculations. Default: 512
59
+ **rf_kwargs
60
+ Additional keyword arguments to pass to sklearn's RandomForestClassifier.
61
+ Common options include n_estimators, max_depth, min_samples_split, random_state, etc.
62
+ """
63
+
64
+ # Set device
65
+ if device is None:
66
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
67
+ self.device = device
68
+
69
+ self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
70
+ self.data_augmentation = data_augmentation
71
+
72
+ self.head = RandomForestClassifier(**rf_kwargs)
73
+
74
+ @torch.inference_mode()
75
+ def fit(self, train_data: tuple[torch.Tensor, torch.Tensor]) -> None:
76
+ """Train the Random Forest classifier on embedded time series data.
77
+
78
+ This method generates embeddings for the training data using the embedding
79
+ model, then trains the Random Forest on these embeddings.
80
+
81
+ Args:
82
+ train_data: Tuple of (X_train, y_train) where X_train is the input time
83
+ series data (torch.Tensor) and y_train is a numpy array
84
+ of class labels.
85
+ """
86
+ X_train, y_train = train_data
87
+
88
+ if isinstance(y_train, torch.Tensor):
89
+ y_train = y_train.detach().cpu().numpy()
90
+
91
+ embeddings = self.emb_model(X_train).cpu().numpy()
92
+ self.head.fit(embeddings, y_train)
93
+
94
+ @torch.inference_mode()
95
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
96
+ """Predict class labels for input time series data.
97
+
98
+ Args:
99
+ x: Input time series data as torch.Tensor or np.ndarray with shape
100
+ (batch_size, num_variates, seq_len).
101
+ Returns:
102
+ torch.Tensor: Predicted class labels with shape (batch_size,).
103
+ """
104
+
105
+ embeddings = self.emb_model(x).cpu().numpy()
106
+ return torch.from_numpy(self.head.predict(embeddings)).long()
107
+
108
+ @torch.inference_mode()
109
+ def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
110
+ """Predict class probabilities for input time series data.
111
+
112
+ Args:
113
+ x: Input time series data as torch.Tensor or np.ndarray with shape
114
+ (batch_size, num_variates, seq_len).
115
+ Returns:
116
+ torch.Tensor: Class probabilities with shape (batch_size, num_classes).
117
+ """
118
+ embeddings = self.emb_model(x).cpu().numpy()
119
+ return torch.from_numpy(self.head.predict_proba(embeddings))
120
+
121
+ def save_model(self, path: str) -> None:
122
+ """This method saves the trained Random Forest classifier head and embedding information in joblib format
123
+
124
+ Args:
125
+ path: File path where the model should be saved (e.g., 'model.joblib').
126
+ """
127
+ payload = {
128
+ "data_augmentation": self.data_augmentation,
129
+ "head": self.head,
130
+ }
131
+ joblib.dump(payload, path)
132
+
133
+ @classmethod
134
+ def load_model(cls, path: str) -> "TirexRFClassifier":
135
+ """Load a saved model from file.
136
+
137
+ This reconstructs the model with the embedding configuration and loads
138
+ the trained Random Forest classifier from a checkpoint file created by save_model().
139
+
140
+ Args:
141
+ path: File path to the saved model checkpoint.
142
+ Returns:
143
+ TirexRFClassifier: The loaded model with trained Random Forest, ready for inference.
144
+ """
145
+ checkpoint = joblib.load(path)
146
+
147
+ # Create new instance with saved configuration
148
+ model = cls(
149
+ data_augmentation=checkpoint["data_augmentation"],
150
+ )
151
+
152
+ # Load the trained Random Forest head
153
+ model.head = checkpoint["head"]
154
+
155
+ return model
@@ -0,0 +1,171 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader, TensorDataset
5
+
6
+ from .utils import EarlyStopping, set_seed, train_val_split
7
+
8
+
9
+ @dataclass
10
+ class TrainConfig:
11
+ # Training loop parameters
12
+ max_epochs: int
13
+ log_every_n_steps: int
14
+ device: str
15
+
16
+ # Optimizer parameters
17
+ lr: float
18
+ weight_decay: float
19
+
20
+ # Loss parameters
21
+ class_weights: torch.Tensor | None
22
+
23
+ # Data loading parameters
24
+ batch_size: int
25
+ val_split_ratio: float
26
+ stratify: bool
27
+
28
+ # Earlystopping parameters
29
+ patience: int
30
+ delta: float
31
+
32
+ # Reproducability
33
+ seed: int | None
34
+
35
+ def __post_init__(self) -> None:
36
+ if self.max_epochs <= 0:
37
+ raise ValueError(f"max_epochs must be positive, got {self.max_epochs}")
38
+
39
+ if self.log_every_n_steps <= 0:
40
+ raise ValueError(f"log_every_n_steps must be positive, got {self.log_every_n_steps}")
41
+
42
+ if self.lr <= 0:
43
+ raise ValueError(f"lr (learning rate) must be positive, got {self.lr}")
44
+
45
+ if self.weight_decay < 0:
46
+ raise ValueError(f"weight_decay must be non-negative, got {self.weight_decay}")
47
+
48
+ if self.batch_size <= 0:
49
+ raise ValueError(f"batch_size must be positive, got {self.batch_size}")
50
+
51
+ if not (0 < self.val_split_ratio < 1):
52
+ raise ValueError(f"val_split_ratio must be in (0, 1), got {self.val_split_ratio}")
53
+
54
+ if self.patience <= 0:
55
+ raise ValueError(f"patience must be positive, got {self.patience}")
56
+
57
+ if self.delta < 0:
58
+ raise ValueError(f"delta must be non-negative, got {self.delta}")
59
+
60
+
61
+ class Trainer:
62
+ def __init__(
63
+ self,
64
+ model: torch.nn.Module,
65
+ train_config: TrainConfig,
66
+ ) -> None:
67
+ self.device = train_config.device
68
+ self.train_config = train_config
69
+
70
+ self.model = model.to(self.device)
71
+ class_weights = (
72
+ self.train_config.class_weights.to(self.device) if self.train_config.class_weights is not None else None
73
+ )
74
+ self.loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights).to(self.device)
75
+
76
+ self.optimizer = None
77
+ self.early_stopper = EarlyStopping(patience=self.train_config.patience, delta=self.train_config.delta)
78
+
79
+ def fit(
80
+ self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None = None
81
+ ) -> dict[str, float]:
82
+ if self.train_config.seed is not None:
83
+ set_seed(self.train_config.seed)
84
+
85
+ self._freeze_embedding()
86
+
87
+ if self.optimizer is None:
88
+ self.optimizer = torch.optim.AdamW(
89
+ self.model.parameters(), lr=self.train_config.lr, weight_decay=self.train_config.weight_decay
90
+ )
91
+
92
+ train_loader, val_loader = self._create_data_loaders(train_data, val_data)
93
+
94
+ for epoch in range(self.train_config.max_epochs):
95
+ train_loss = self._train_epoch(train_loader)
96
+ val_loss = self._validate_epoch(val_loader)
97
+
98
+ self._log_epoch_metrics(epoch, train_loss, val_loss)
99
+
100
+ stop_training = self.early_stopper(epoch=epoch + 1, val_loss=val_loss)
101
+ if stop_training:
102
+ break
103
+
104
+ return {"train_loss": train_loss, "val_loss": val_loss}
105
+
106
+ def _freeze_embedding(self) -> None:
107
+ if hasattr(self.model, "emb_model"):
108
+ for param in self.model.emb_model.parameters():
109
+ param.requires_grad = False
110
+
111
+ def _train_epoch(self, train_loader: DataLoader) -> float:
112
+ train_loss = []
113
+ self.model.train()
114
+ for batch in train_loader:
115
+ x, y = batch
116
+ x = x.to(self.device)
117
+ y = y.to(self.device)
118
+
119
+ self.optimizer.zero_grad()
120
+ y_hat = self.model.head(x) # Only classification head is involved, embeddings are precomputed
121
+ loss = self.loss_fn(y_hat, y)
122
+
123
+ loss.backward()
124
+ self.optimizer.step()
125
+
126
+ train_loss.append(loss.detach())
127
+
128
+ return torch.stack(train_loss).mean().item()
129
+
130
+ @torch.inference_mode()
131
+ def _validate_epoch(self, val_loader: DataLoader) -> float:
132
+ self.model.eval()
133
+ val_loss = []
134
+ for batch in val_loader:
135
+ x, y = batch
136
+ x = x.to(self.device)
137
+ y = y.to(self.device)
138
+
139
+ y_hat = self.model.head(x) # Only classification head is involved, embeddings are precomputed
140
+ loss = self.loss_fn(y_hat, y)
141
+
142
+ val_loss.append(loss.detach())
143
+ return torch.stack(val_loss).mean().item()
144
+
145
+ def _create_data_loaders(
146
+ self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None
147
+ ) -> tuple[DataLoader, DataLoader]:
148
+ if val_data is None:
149
+ train_data, val_data = train_val_split(
150
+ train_data, self.train_config.val_split_ratio, self.train_config.stratify, self.train_config.seed
151
+ )
152
+
153
+ train_embeddings = self.model.emb_model(train_data[0])
154
+ val_embeddings = self.model.emb_model(val_data[0])
155
+
156
+ train_loader = DataLoader(
157
+ TensorDataset(train_embeddings, train_data[1]),
158
+ batch_size=self.train_config.batch_size,
159
+ shuffle=True,
160
+ )
161
+
162
+ val_loader = DataLoader(
163
+ TensorDataset(val_embeddings, val_data[1]),
164
+ batch_size=self.train_config.batch_size,
165
+ shuffle=False,
166
+ )
167
+ return train_loader, val_loader
168
+
169
+ def _log_epoch_metrics(self, epoch: int, train_loss: float, val_loss: float) -> None:
170
+ if epoch % self.train_config.log_every_n_steps == 0:
171
+ print(f"Epoch {epoch + 1}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
@@ -0,0 +1,81 @@
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.model_selection import train_test_split
4
+
5
+
6
+ # Remove after Issue will be solved: https://github.com/pytorch/pytorch/issues/61474
7
+ def nanmax(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
8
+ min_value = torch.finfo(tensor.dtype).min
9
+ output = tensor.nan_to_num(min_value).max(dim=dim, keepdim=keepdim)
10
+ return output.values
11
+
12
+
13
+ def nanmin(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
14
+ max_value = torch.finfo(tensor.dtype).max
15
+ output = tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim)
16
+ return output.values
17
+
18
+
19
+ def nanvar(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
20
+ tensor_mean = tensor.nanmean(dim=dim, keepdim=True)
21
+ output = (tensor - tensor_mean).square().nanmean(dim=dim, keepdim=keepdim)
22
+ return output
23
+
24
+
25
+ def nanstd(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
26
+ output = nanvar(tensor, dim=dim, keepdim=keepdim)
27
+ output = output.sqrt()
28
+ return output
29
+
30
+
31
+ def train_val_split(
32
+ train_data: tuple[torch.Tensor, torch.Tensor],
33
+ val_split_ratio: float,
34
+ stratify: bool,
35
+ seed: int | None,
36
+ ) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
37
+ idx_train, idx_val = train_test_split(
38
+ np.arange(len(train_data[0])),
39
+ test_size=val_split_ratio,
40
+ random_state=seed,
41
+ shuffle=True,
42
+ stratify=train_data[1] if stratify else None,
43
+ )
44
+
45
+ return (
46
+ (train_data[0][idx_train], train_data[1][idx_train]),
47
+ (train_data[0][idx_val], train_data[1][idx_val]),
48
+ )
49
+
50
+
51
+ def set_seed(seed: int) -> None:
52
+ torch.manual_seed(seed)
53
+ if torch.cuda.is_available():
54
+ torch.cuda.manual_seed(seed)
55
+ np.random.seed(seed)
56
+
57
+
58
+ class EarlyStopping:
59
+ def __init__(
60
+ self,
61
+ patience: int = 7,
62
+ delta: float = 0.0001,
63
+ ) -> None:
64
+ self.patience: int = patience
65
+ self.delta: float = delta
66
+
67
+ self.best: float = np.inf
68
+ self.wait_count: int = 0
69
+ self.early_stop: bool = False
70
+
71
+ def __call__(self, epoch: int, val_loss: float) -> bool:
72
+ improved = val_loss < (self.best - self.delta)
73
+ if improved:
74
+ self.best = val_loss
75
+ self.wait_count = 0
76
+ else:
77
+ self.wait_count += 1
78
+ if self.wait_count >= self.patience:
79
+ self.early_stop = True
80
+ print(f"Early stopping triggered at epoch {epoch}.")
81
+ return self.early_stop
tirex/models/tirex.py CHANGED
@@ -58,6 +58,25 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
58
58
  def register_name(cls):
59
59
  return "TiRex"
60
60
 
61
+ def _adjust_context_length(self, context: torch.Tensor, min_context: int, max_context: int):
62
+ pad_len = 0
63
+ if context.shape[-1] > max_context:
64
+ context = context[..., -max_context:]
65
+ if context.shape[-1] < min_context:
66
+ pad_len = min_context - context.shape[-1]
67
+ pad = torch.full(
68
+ (context.shape[0], pad_len),
69
+ fill_value=torch.nan,
70
+ device=context.device,
71
+ dtype=context.dtype,
72
+ )
73
+ context = torch.concat((pad, context), dim=1)
74
+ return context, pad_len
75
+
76
+ # ===============================
77
+ # Forecasting Functions
78
+ # ===============================
79
+
61
80
  @torch.inference_mode()
62
81
  def _forecast_quantiles(
63
82
  self,
@@ -106,18 +125,7 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
106
125
 
107
126
  def _forecast_single_step(self, context: torch.Tensor, new_patch_count: int = 1) -> torch.Tensor:
108
127
  max_context, min_context = self.config.train_ctx_len, self.config.train_ctx_len
109
-
110
- if context.shape[-1] > max_context:
111
- context = context[..., -max_context:]
112
- if context.shape[-1] < min_context:
113
- pad = torch.full(
114
- (context.shape[0], min_context - context.shape[-1]),
115
- fill_value=torch.nan,
116
- device=context.device,
117
- dtype=context.dtype,
118
- )
119
- context = torch.concat((pad, context), dim=1)
120
-
128
+ context, _ = self._adjust_context_length(context, max_context, min_context)
121
129
  input_token, tokenizer_state = self.tokenizer.input_transform(context)
122
130
  prediction = self._forward_model_tokenized(input_token=input_token, new_patch_count=new_patch_count)
123
131
  predicted_token = prediction[:, :, -new_patch_count:, :].to(input_token) # predicted token
@@ -161,16 +169,58 @@ class TiRexZero(nn.Module, PretrainedModel, ForecastModel):
161
169
  # quantile_preds: [batch_size, num_quantiles, num_token, output_patch_size]
162
170
  return quantile_preds
163
171
 
164
- def _forward_model(self, input: torch.Tensor) -> torch.Tensor:
172
+ def _forward_model(self, input: torch.Tensor, return_all_hidden: bool = False) -> torch.Tensor:
165
173
  hidden_states = self.input_patch_embedding(input)
174
+ all_hidden_states = []
166
175
 
167
176
  for block in self.blocks:
168
177
  hidden_states = block(hidden_states)
169
-
178
+ if return_all_hidden:
179
+ all_hidden_states.append(hidden_states)
170
180
  hidden_states = self.out_norm(hidden_states)
171
181
 
182
+ if return_all_hidden:
183
+ return self.output_patch_embedding(hidden_states), torch.stack(all_hidden_states, dim=-2)
184
+
172
185
  return self.output_patch_embedding(hidden_states)
173
186
 
187
+ # ===============================
188
+ # Context Embedding Functions
189
+ # ===============================
190
+ @torch.inference_mode()
191
+ def _embed_context(
192
+ self,
193
+ context: torch.Tensor,
194
+ max_context: int | None = None,
195
+ ) -> torch.Tensor:
196
+ input_embeds, padded_token = self._prepare_context_for_embedding(context, max_context)
197
+ _, hidden_states = self._forward_model(input_embeds, return_all_hidden=True)
198
+ # Shape: [batch_size, num_tokens, num_layers, hidden_dim]
199
+ return hidden_states[:, padded_token:, :, :]
200
+
201
+ def _prepare_context_for_embedding(
202
+ self, context: torch.Tensor, max_context: int | None
203
+ ) -> tuple[torch.Tensor, int]:
204
+ max_context = self.config.train_ctx_len if max_context is None else max_context
205
+ min_context = max(self.config.train_ctx_len, max_context)
206
+
207
+ device = self.input_patch_embedding.hidden_layer.weight.device
208
+ context = context.to(
209
+ device=device,
210
+ dtype=torch.float32,
211
+ )
212
+
213
+ context, pad_len = self._adjust_context_length(context, min_context, max_context)
214
+
215
+ padded_token = pad_len // self.tokenizer.patch_size
216
+ input_token, _ = self.tokenizer.input_transform(context)
217
+
218
+ input_mask = torch.isnan(input_token).logical_not().to(input_token.dtype)
219
+ input_token = torch.nan_to_num(input_token, nan=self.config.nan_mask_value)
220
+ input_embeds = torch.cat((input_token, input_mask), dim=2)
221
+
222
+ return input_embeds, padded_token
223
+
174
224
  def on_load_checkpoint(self, checkpoint: dict) -> None:
175
225
  # rename keys of state_dict, because the block_stack was moved directly into the tirex model
176
226
  checkpoint["state_dict"] = {k.replace("block_stack.", ""): v for k, v in checkpoint["state_dict"].items()}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.26
3
+ Version: 2025.11.29
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -84,6 +84,8 @@ Requires-Dist: datasets; extra == "hfdataset"
84
84
  Provides-Extra: test
85
85
  Requires-Dist: fev>=0.6.0; extra == "test"
86
86
  Requires-Dist: pytest; extra == "test"
87
+ Provides-Extra: classification
88
+ Requires-Dist: scikit-learn; extra == "classification"
87
89
  Provides-Extra: all
88
90
  Requires-Dist: xlstm; extra == "all"
89
91
  Requires-Dist: ninja; extra == "all"
@@ -95,6 +97,7 @@ Requires-Dist: gluonts; extra == "all"
95
97
  Requires-Dist: datasets; extra == "all"
96
98
  Requires-Dist: pytest; extra == "all"
97
99
  Requires-Dist: fev>=0.6.0; extra == "all"
100
+ Requires-Dist: scikit-learn; extra == "all"
98
101
  Dynamic: license-file
99
102
 
100
103
  # tirex-mirror
@@ -8,14 +8,20 @@ tirex/api_adapter/hf_data.py,sha256=TRyys2xKIGZS0Yhq2Eb61lWCMg5CWWn1yRlLIN1mU7o,
8
8
  tirex/api_adapter/standard_adapter.py,sha256=vdlxNs8mTUtPgK_5WMqYqNdMj8W44igqWsAgtggt_xk,2809
9
9
  tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
10
10
  tirex/models/patcher.py,sha256=8T4c3PZnOAsEpahhrjtt7S7405WUjN6g3cV33E55PD4,1911
11
- tirex/models/tirex.py,sha256=URt-MClXu0zdUHACQ96Zu3Ytdb52vbeG_SXnj5C4tI8,7522
11
+ tirex/models/tirex.py,sha256=wtdjrdE1TuoueIvhlKf-deLH3oKyQuGmd2k15My7SWA,9710
12
+ tirex/models/classification/__init__.py,sha256=qVn84uBosWaHm9wr0FoYAXNvmajyyB3_OmpeHNzDH4g,194
13
+ tirex/models/classification/embedding.py,sha256=TJlchUaBhz8Pf1mLpTdDVqmkGVOiaCI55sSoqk2tSXE,5259
14
+ tirex/models/classification/linear_classifier.py,sha256=yE2sekKw2LCLGiITTsTktOJw28RjjOWgsJE2PakAjN8,11142
15
+ tirex/models/classification/rf_classifier.py,sha256=wzMiF5TELGl3V-i94MPOEpNn2yZgrCdPfuzpcl8A18M,5816
16
+ tirex/models/classification/trainer.py,sha256=JzM1XtRoRI3fn6Sbu7V-9IuiKVy454O73uNrMNgCREs,5759
17
+ tirex/models/classification/utils.py,sha256=db9056u6uIVhm0qmHDoOo3K5f-ZiuDytDLcLOg-zFb0,2599
12
18
  tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
13
19
  tirex/models/slstm/cell.py,sha256=Otyil_AjpJbUckkINWGHxlqP14J5epm_J_zdWPzvD2g,7290
14
20
  tirex/models/slstm/layer.py,sha256=hrDydQJIAHf5W0A0Rt0hXG4yKXrOSY-HPL0UbigR6Q8,2867
15
- tirex_mirror-2025.11.26.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
16
- tirex_mirror-2025.11.26.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
17
- tirex_mirror-2025.11.26.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
18
- tirex_mirror-2025.11.26.dist-info/METADATA,sha256=t-Qg4PeJHaY2ZGN6CFBSiwz8dig4Ts1kT_qiLlk4xZE,11494
19
- tirex_mirror-2025.11.26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- tirex_mirror-2025.11.26.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
21
- tirex_mirror-2025.11.26.dist-info/RECORD,,
21
+ tirex_mirror-2025.11.29.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
22
+ tirex_mirror-2025.11.29.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
23
+ tirex_mirror-2025.11.29.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
24
+ tirex_mirror-2025.11.29.dist-info/METADATA,sha256=6MeKMufnn5yIl1BKxzQKW22jeR7ZU1U9BEuFd5fBJAg,11624
25
+ tirex_mirror-2025.11.29.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
26
+ tirex_mirror-2025.11.29.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
27
+ tirex_mirror-2025.11.29.dist-info/RECORD,,