tirex-mirror 2025.11.26__tar.gz → 2025.11.28__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/PKG-INFO +4 -1
  2. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/pyproject.toml +3 -2
  3. tirex_mirror-2025.11.28/src/tirex/models/classification/__init__.py +8 -0
  4. tirex_mirror-2025.11.28/src/tirex/models/classification/embedding.py +125 -0
  5. tirex_mirror-2025.11.28/src/tirex/models/classification/linear_classifier.py +274 -0
  6. tirex_mirror-2025.11.28/src/tirex/models/classification/rf_classifier.py +155 -0
  7. tirex_mirror-2025.11.28/src/tirex/models/classification/trainer.py +171 -0
  8. tirex_mirror-2025.11.28/src/tirex/models/classification/utils.py +81 -0
  9. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/models/tirex.py +64 -14
  10. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex_mirror.egg-info/PKG-INFO +4 -1
  11. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex_mirror.egg-info/SOURCES.txt +9 -0
  12. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex_mirror.egg-info/requires.txt +4 -0
  13. tirex_mirror-2025.11.28/tests/test_embedding.py +76 -0
  14. tirex_mirror-2025.11.28/tests/test_linear_classifier.py +170 -0
  15. tirex_mirror-2025.11.28/tests/test_rf_classifier.py +123 -0
  16. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/LICENSE +0 -0
  17. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/LICENSE_MIRROR.txt +0 -0
  18. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/MANIFEST.in +0 -0
  19. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/NOTICE.txt +0 -0
  20. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/README.md +0 -0
  21. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/setup.cfg +0 -0
  22. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/__init__.py +0 -0
  23. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/api_adapter/__init__.py +0 -0
  24. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/api_adapter/forecast.py +0 -0
  25. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/api_adapter/gluon.py +0 -0
  26. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/api_adapter/hf_data.py +0 -0
  27. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/api_adapter/standard_adapter.py +0 -0
  28. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/base.py +0 -0
  29. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/models/__init__.py +0 -0
  30. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/models/patcher.py +0 -0
  31. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/models/slstm/block.py +0 -0
  32. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/models/slstm/cell.py +0 -0
  33. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/models/slstm/layer.py +0 -0
  34. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex/util.py +0 -0
  35. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex_mirror.egg-info/dependency_links.txt +0 -0
  36. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/src/tirex_mirror.egg-info/top_level.txt +0 -0
  37. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_chronos_zs.py +0 -0
  38. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_compile.py +0 -0
  39. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_forecast.py +0 -0
  40. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_forecast_adapter.py +0 -0
  41. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_load_model.py +0 -0
  42. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_patcher.py +0 -0
  43. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_slstm_torch_vs_cuda.py +0 -0
  44. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_standard_adapter.py +0 -0
  45. {tirex_mirror-2025.11.26 → tirex_mirror-2025.11.28}/tests/test_util_freq.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tirex-mirror
3
- Version: 2025.11.26
3
+ Version: 2025.11.28
4
4
  Summary: Unofficial mirror of NX-AI/tirex for packaging
5
5
  Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
6
6
  License: NXAI COMMUNITY LICENSE AGREEMENT
@@ -84,6 +84,8 @@ Requires-Dist: datasets; extra == "hfdataset"
84
84
  Provides-Extra: test
85
85
  Requires-Dist: fev>=0.6.0; extra == "test"
86
86
  Requires-Dist: pytest; extra == "test"
87
+ Provides-Extra: classification
88
+ Requires-Dist: scikit-learn; extra == "classification"
87
89
  Provides-Extra: all
88
90
  Requires-Dist: xlstm; extra == "all"
89
91
  Requires-Dist: ninja; extra == "all"
@@ -95,6 +97,7 @@ Requires-Dist: gluonts; extra == "all"
95
97
  Requires-Dist: datasets; extra == "all"
96
98
  Requires-Dist: pytest; extra == "all"
97
99
  Requires-Dist: fev>=0.6.0; extra == "all"
100
+ Requires-Dist: scikit-learn; extra == "all"
98
101
  Dynamic: license-file
99
102
 
100
103
  # tirex-mirror
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "tirex-mirror"
3
- version = "2025.11.26"
3
+ version = "2025.11.28"
4
4
  description = "Unofficial mirror of NX-AI/tirex for packaging"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.11"
@@ -29,7 +29,8 @@ plotting = [ "matplotlib",]
29
29
  gluonts = [ "gluonts", "pandas",]
30
30
  hfdataset = [ "datasets",]
31
31
  test = [ "fev>=0.6.0", "pytest",]
32
- all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "pandas", "python-dotenv", "gluonts", "datasets", "pytest", "fev>=0.6.0",]
32
+ classification = [ "scikit-learn",]
33
+ all = [ "xlstm", "ninja", "ipykernel", "matplotlib", "pandas", "python-dotenv", "gluonts", "datasets", "pytest", "fev>=0.6.0", "scikit-learn",]
33
34
 
34
35
  [tool.docformatter]
35
36
  diff = false
@@ -0,0 +1,8 @@
1
+ # classification/__init__.py
2
+ from .linear_classifier import TirexClassifierTorch
3
+ from .rf_classifier import TirexRFClassifier
4
+
5
+ __all__ = [
6
+ "TirexClassifierTorch",
7
+ "TirexRFClassifier",
8
+ ]
@@ -0,0 +1,125 @@
1
+ import inspect
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from tirex import load_model
9
+
10
+ from .utils import nanmax, nanmin, nanstd
11
+
12
+
13
+ class TiRexEmbedding(nn.Module):
14
+ def __init__(self, device: str | None = None, data_augmentation: bool = False, batch_size: int = 512) -> None:
15
+ super().__init__()
16
+ self.data_augmentation = data_augmentation
17
+ self.number_of_patches = 8
18
+ self.batch_size = batch_size
19
+
20
+ if device is None:
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ self.device = device
23
+
24
+ self.model = load_model(path="NX-AI/TiRex", device=self.device)
25
+
26
+ def _gen_emb_batched(self, data: torch.Tensor) -> torch.Tensor:
27
+ batches = list(torch.split(data, self.batch_size))
28
+ embedding_list = []
29
+ for batch in batches:
30
+ embedding = self.model._embed_context(batch)
31
+ embedding_list.append(embedding.cpu())
32
+ return torch.cat(embedding_list, dim=0)
33
+
34
+ def _calculate_n_patches(self, data: torch.Tensor) -> int:
35
+ _, _, n_steps = data.shape
36
+ n_patches = -(-n_steps // self.model.config.input_patch_size)
37
+ return n_patches
38
+
39
+ def forward(self, data: torch.Tensor) -> torch.Tensor:
40
+ n_patches = self._calculate_n_patches(data)
41
+
42
+ embedding = torch.stack(
43
+ [self._gen_emb_batched(var_slice) for var_slice in torch.unbind(data, dim=1)], dim=1
44
+ ) # Stack in case of multivar
45
+ embedding = self.process_embedding(embedding, n_patches)
46
+
47
+ if self.data_augmentation:
48
+ # Difference Embedding
49
+ diff_data = torch.diff(data, dim=-1, prepend=data[..., :0])
50
+ n_patches = self._calculate_n_patches(diff_data)
51
+
52
+ diff_embedding = torch.stack(
53
+ [self._gen_emb_batched(var_slice) for var_slice in torch.unbind(diff_data, dim=1)], dim=1
54
+ )
55
+ diff_embedding = self.process_embedding(diff_embedding, n_patches)
56
+ embedding = torch.cat((diff_embedding, embedding), dim=-1)
57
+
58
+ # Stats Embedding
59
+ stat_features = self._generate_stats_features(data)
60
+ normalized_stats = self._normalize_stats(stat_features)
61
+ normalized_stats = normalized_stats.to(embedding.device)
62
+
63
+ # Concat all together
64
+ embedding = torch.cat((embedding, normalized_stats), dim=-1)
65
+
66
+ return embedding
67
+
68
+ def process_embedding(self, embedding: torch.Tensor, n_patches: int) -> torch.Tensor:
69
+ # embedding shape: (bs, var_dim, n_patches, n_layer, emb_dim)
70
+ embedding = embedding[:, :, -n_patches:, :, :]
71
+ embedding = torch.mean(embedding, dim=2) # sequence
72
+ embedding = torch.nn.functional.normalize(embedding, p=2, dim=-1)
73
+ embedding = torch.transpose(embedding, 1, -2).flatten(start_dim=-2) # var
74
+ embedding = torch.transpose(embedding, 1, -2).flatten(start_dim=-2) # layer
75
+ embedding = F.layer_norm(embedding, (embedding.shape[-1],))
76
+ return embedding
77
+
78
+ def _normalize_stats(self, stat_features: torch.Tensor) -> torch.Tensor:
79
+ dataset_mean = torch.nanmean(stat_features, dim=0, keepdim=True)
80
+ dataset_std = nanstd(stat_features, dim=0, keepdim=True)
81
+ stat_features = (stat_features - dataset_mean) / (dataset_std + 1e-8)
82
+ stat_features = torch.nan_to_num(stat_features, nan=0.0)
83
+
84
+ stat_features = (stat_features - stat_features.nanmean(dim=-1, keepdim=True)) / (
85
+ stat_features.std(dim=-1, keepdim=True) + 1e-8
86
+ )
87
+ return stat_features
88
+
89
+ def _generate_stats_features(self, data: torch.Tensor) -> torch.Tensor:
90
+ bs, variates, n_steps = data.shape
91
+
92
+ patch_size = max(1, n_steps // self.number_of_patches)
93
+ n_full_patches = n_steps // patch_size
94
+ n_remain = n_steps % patch_size
95
+
96
+ # [batch, variates, n_patches, patch_size]
97
+ patches = data[..., : n_full_patches * patch_size].unfold(-1, patch_size, patch_size)
98
+
99
+ # Stats for full patches
100
+ patch_means = torch.nanmean(patches, dim=-1)
101
+ patch_stds = nanstd(patches, dim=-1)
102
+ patch_maxes = nanmax(patches, dim=-1)
103
+ patch_mins = nanmin(patches, dim=-1)
104
+
105
+ stats = [patch_means, patch_stds, patch_maxes, patch_mins]
106
+
107
+ # Handle last smaller patch if needed
108
+ if n_remain > 0:
109
+ self._handle_remaining_patch(data, stats, n_full_patches * patch_size)
110
+
111
+ stats = torch.stack(stats, dim=-1) # [batch, variates, n_patches(+1), 4]
112
+ return stats.flatten(start_dim=1) # [batch, variates * n_patches * 4]
113
+
114
+ def _handle_remaining_patch(self, data: torch.Tensor, stats: list[torch.Tensor], full_patch_length: int) -> None:
115
+ last_patch = data[..., full_patch_length:]
116
+
117
+ mean_last = last_patch.mean(dim=-1, keepdim=True)
118
+ std_last = last_patch.std(dim=-1, keepdim=True)
119
+ max_last = last_patch.max(dim=-1, keepdim=True)
120
+ min_last = last_patch.min(dim=-1, keepdim=True)
121
+
122
+ stats[0] = torch.cat([stats[0], mean_last], dim=-1)
123
+ stats[1] = torch.cat([stats[1], std_last], dim=-1)
124
+ stats[2] = torch.cat([stats[2], max_last], dim=-1)
125
+ stats[3] = torch.cat([stats[3], min_last], dim=-1)
@@ -0,0 +1,274 @@
1
+ from dataclasses import asdict
2
+
3
+ import torch
4
+
5
+ from .embedding import TiRexEmbedding
6
+ from .trainer import TrainConfig, Trainer
7
+
8
+
9
+ class TirexClassifierTorch(torch.nn.Module):
10
+ """
11
+ A PyTorch classifier that combines time series embeddings with a linear classification head.
12
+
13
+ This model uses a pre-trained TiRex embedding model to generate feature representations from time series
14
+ data, followed by a linear layer (with optional dropout) for classification. The embedding backbone
15
+ is frozen during training, and only the classification head is trained.
16
+
17
+ Example:
18
+ >>> import torch
19
+ >>> from tirex.models.classification import TirexClassifierTorch
20
+ >>>
21
+ >>> # Create model with TIREX embeddings
22
+ >>> model = TirexClassifierTorch(
23
+ ... data_augmentation=True,
24
+ ... max_epochs=2,
25
+ ... lr=1e-4,
26
+ ... batch_size=32
27
+ ... )
28
+ >>>
29
+ >>> # Prepare data
30
+ >>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
31
+ >>> y_train = torch.randint(0, 3, (100,)) # 3 classes
32
+ >>>
33
+ >>> # Train the model
34
+ >>> metrics = model.fit((X_train, y_train)) # doctest: +ELLIPSIS
35
+ Epoch 1, Train Loss: ...
36
+ >>> # Make predictions
37
+ >>> X_test = torch.randn(20, 1, 128)
38
+ >>> predictions = model.predict(X_test)
39
+ >>> probabilities = model.predict_proba(X_test)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ data_augmentation: bool = False,
45
+ device: str | None = None,
46
+ # Training parameters
47
+ max_epochs: int = 50,
48
+ lr: float = 1e-4,
49
+ weight_decay: float = 0.01,
50
+ batch_size: int = 512,
51
+ val_split_ratio: float = 0.2,
52
+ stratify: bool = True,
53
+ patience: int = 7,
54
+ delta: float = 0.001,
55
+ log_every_n_steps: int = 5,
56
+ seed: int | None = None,
57
+ class_weights: torch.Tensor | None = None,
58
+ # Head parameters
59
+ dropout: float | None = None,
60
+ ) -> None:
61
+ """Initializes Embedding Based Linear Classification model.
62
+
63
+ Args:
64
+ data_augmentation : bool | None
65
+ Whether to use data_augmentation for embeddings (stats and first-order differences of the original data). Default: False
66
+ device : str | None
67
+ Device to run the model on. If None, uses CUDA if available, else CPU. Default: None
68
+ max_epochs : int
69
+ Maximum number of training epochs. Default: 50
70
+ lr : float
71
+ Learning rate for the optimizer. Default: 1e-4
72
+ weight_decay : float
73
+ Weight decay coefficient. Default: 0.01
74
+ batch_size : int
75
+ Batch size for training and embedding calculations. Default: 512
76
+ val_split_ratio : float
77
+ Proportion of training data to use for validation, if validation data are not provided. Default: 0.2
78
+ stratify : bool
79
+ Whether to stratify the train/validation split by class labels. Default: True
80
+ patience : int
81
+ Number of epochs to wait for improvement before early stopping. Default: 7
82
+ delta : float
83
+ Minimum change in validation loss to qualify as an improvement. Default: 0.001
84
+ log_every_n_steps : int
85
+ Frequency of logging during training. Default: 5
86
+ seed : int | None
87
+ Random seed for reproducibility. If None, no seed is set. Default: None
88
+ class_weights: torch.Tensor | None
89
+ Weight classes according to given values, has to be a Tensor of size number of categories. Default: None
90
+ dropout : float | None
91
+ Dropout probability for the classification head. If None, no dropout is used. Default: None
92
+ """
93
+
94
+ super().__init__()
95
+
96
+ if device is None:
97
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
98
+ self.device = device
99
+
100
+ # Create embedding model
101
+ self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
102
+ self.data_augmentation = data_augmentation
103
+
104
+ # Head parameters
105
+ self.dropout = dropout
106
+ self.head = None
107
+ self.emb_dim = None
108
+ self.num_classes = None
109
+
110
+ # Train config
111
+ train_config = TrainConfig(
112
+ max_epochs=max_epochs,
113
+ log_every_n_steps=log_every_n_steps,
114
+ device=self.device,
115
+ lr=lr,
116
+ weight_decay=weight_decay,
117
+ class_weights=class_weights,
118
+ batch_size=batch_size,
119
+ val_split_ratio=val_split_ratio,
120
+ stratify=stratify,
121
+ patience=patience,
122
+ delta=delta,
123
+ seed=seed,
124
+ )
125
+ self.trainer = Trainer(self, train_config=train_config)
126
+
127
+ def _init_classifier(self, emb_dim: int, num_classes: int, dropout: float | None) -> torch.nn.Module:
128
+ if dropout:
129
+ return torch.nn.Sequential(torch.nn.Dropout(p=dropout), torch.nn.Linear(emb_dim, num_classes))
130
+ else:
131
+ return torch.nn.Linear(emb_dim, num_classes)
132
+
133
+ @torch.inference_mode()
134
+ def _identify_head_dims(self, x: torch.Tensor, y: torch.Tensor) -> None:
135
+ self.emb_model.eval()
136
+ sample_emb = self.emb_model(x[:1])
137
+ self.emb_dim = sample_emb.shape[-1]
138
+ self.num_classes = len(torch.unique(y))
139
+
140
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
141
+ """Forward pass through the embedding model and classification head.
142
+
143
+ Args:
144
+ x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
145
+ Returns:
146
+ torch.Tensor: Logits for each class with shape (batch_size, num_classes).
147
+ Raises:
148
+ RuntimeError: If the classification head has not been initialized via fit().
149
+ """
150
+ if self.head is None:
151
+ raise RuntimeError("Head not initialized. Call fit() first to automatically build the head.")
152
+
153
+ embedding = self.emb_model(x).to(self.device)
154
+ return self.head(embedding)
155
+
156
+ def fit(
157
+ self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None = None
158
+ ) -> dict[str, float]:
159
+ """Train the classification head on the provided data.
160
+
161
+ This method initializes the classification head based on the data dimensions,
162
+ then trains it on provided data. The embedding model remains frozen.
163
+
164
+ Args:
165
+ train_data: Tuple of (X_train, y_train) where X_train is the input time series
166
+ data and y_train are the corresponding class labels.
167
+ val_data: Optional tuple of (X_val, y_val) for validation. If None and
168
+ val_split_ratio > 0, validation data will be split from train_data.
169
+
170
+ Returns:
171
+ dict[str, float]: Dictionary containing final training and validation losses.
172
+ """
173
+ X_train, y_train = train_data
174
+
175
+ self._identify_head_dims(X_train, y_train)
176
+ self.head = self._init_classifier(self.emb_dim, self.num_classes, self.dropout)
177
+ self.head = self.head.to(self.trainer.device)
178
+
179
+ return self.trainer.fit(train_data, val_data=val_data)
180
+
181
+ @torch.inference_mode()
182
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
183
+ """Predict class labels for input time series data.
184
+
185
+ Args:
186
+ x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
187
+ Returns:
188
+ torch.Tensor: Predicted class labels with shape (batch_size,).
189
+ """
190
+ self.eval()
191
+ x = x.to(self.device)
192
+ logits = self.forward(x)
193
+ return torch.argmax(logits, dim=1)
194
+
195
+ @torch.inference_mode()
196
+ def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
197
+ """Predict class probabilities for input time series data.
198
+
199
+ Args:
200
+ x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
201
+ Returns:
202
+ torch.Tensor: Class probabilities with shape (batch_size, num_classes).
203
+ """
204
+ self.eval()
205
+ x = x.to(self.device)
206
+ logits = self.forward(x)
207
+ return torch.softmax(logits, dim=1)
208
+
209
+ def save_model(self, path: str) -> None:
210
+ """Save the trained classification head.
211
+
212
+ This function saves the trained classification head weights (.pt format), embedding configuration,
213
+ model dimensions, and device information. The embedding model itself is not
214
+ saved as it uses a pre-trained backbone that can be reloaded.
215
+
216
+ Args:
217
+ path: File path where the model should be saved (e.g., 'model.pt').
218
+ """
219
+ train_config_dict = asdict(self.trainer.train_config)
220
+ torch.save(
221
+ {
222
+ "head_state_dict": self.head.state_dict(), # need to save only head, embedding is frozen
223
+ "data_augmentation": self.data_augmentation,
224
+ "emb_dim": self.emb_dim,
225
+ "num_classes": self.num_classes,
226
+ "dropout": self.dropout,
227
+ "train_config": train_config_dict,
228
+ },
229
+ path,
230
+ )
231
+
232
+ @classmethod
233
+ def load_model(cls, path: str) -> "TirexClassifierTorch":
234
+ """Load a saved model from file.
235
+
236
+ This reconstructs the model architecture and loads the trained weights from
237
+ a checkpoint file created by save_model().
238
+
239
+ Args:
240
+ path: File path to the saved model checkpoint.
241
+ Returns:
242
+ TirexClassifierTorch: The loaded model with trained weights, ready for inference.
243
+ """
244
+ checkpoint = torch.load(path)
245
+
246
+ # Extract train_config if available, otherwise use defaults
247
+ train_config_dict = checkpoint.get("train_config", {})
248
+
249
+ model = cls(
250
+ data_augmentation=checkpoint["data_augmentation"],
251
+ dropout=checkpoint["dropout"],
252
+ max_epochs=train_config_dict.get("max_epochs", 50),
253
+ lr=train_config_dict.get("lr", 1e-4),
254
+ weight_decay=train_config_dict.get("weight_decay", 0.01),
255
+ batch_size=train_config_dict.get("batch_size", 512),
256
+ val_split_ratio=train_config_dict.get("val_split_ratio", 0.2),
257
+ stratify=train_config_dict.get("stratify", True),
258
+ patience=train_config_dict.get("patience", 7),
259
+ delta=train_config_dict.get("delta", 0.001),
260
+ log_every_n_steps=train_config_dict.get("log_every_n_steps", 5),
261
+ seed=train_config_dict.get("seed", None),
262
+ class_weights=train_config_dict.get("class_weights", None),
263
+ )
264
+
265
+ # Initialize head with dimensions
266
+ model.emb_dim = checkpoint["emb_dim"]
267
+ model.num_classes = checkpoint["num_classes"]
268
+ model.head = model._init_classifier(model.emb_dim, model.num_classes, model.dropout)
269
+
270
+ # Load the trained weights
271
+ model.head.load_state_dict(checkpoint["head_state_dict"])
272
+ model.to(model.device)
273
+
274
+ return model
@@ -0,0 +1,155 @@
1
+ import joblib
2
+ import numpy as np
3
+ import torch
4
+ from sklearn.ensemble import RandomForestClassifier
5
+
6
+ from .embedding import TiRexEmbedding
7
+
8
+
9
+ class TirexRFClassifier:
10
+ """
11
+ A Random Forest classifier that uses time series embeddings as features.
12
+
13
+ This classifier combines a pre-trained embedding model for feature extraction with a scikit-learn
14
+ Random Forest classifier. The embedding model generates fixed-size feature vectors from variable-length
15
+ time series, which are then used to train the Random Forest.
16
+
17
+ Example:
18
+ >>> import numpy as np
19
+ >>> from tirex.models.classification import TirexRFClassifier
20
+ >>>
21
+ >>> # Create model with custom Random Forest parameters
22
+ >>> model = TirexRFClassifier(
23
+ ... data_augmentation=True,
24
+ ... n_estimators=50,
25
+ ... max_depth=10,
26
+ ... random_state=42
27
+ ... )
28
+ >>>
29
+ >>> # Prepare data (can use NumPy arrays or PyTorch tensors)
30
+ >>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
31
+ >>> y_train = torch.randint(0, 3, (100,)) # 3 classes
32
+ >>>
33
+ >>> # Train the model
34
+ >>> model.fit((X_train, y_train))
35
+ >>>
36
+ >>> # Make predictions
37
+ >>> X_test = torch.randn(20, 1, 128)
38
+ >>> predictions = model.predict(X_test)
39
+ >>> probabilities = model.predict_proba(X_test)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ data_augmentation: bool = False,
45
+ device: str | None = None,
46
+ batch_size: int = 512,
47
+ # Random Forest parameters
48
+ **rf_kwargs,
49
+ ) -> None:
50
+ """Initializes Embedding Based Random Forest Classification model.
51
+
52
+ Args:
53
+ data_augmentation : bool
54
+ Whether to use data_augmentation for embeddings (stats and first-order differences of the original data). Default: False
55
+ device : str | None
56
+ Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
57
+ batch_size : int
58
+ Batch size for embedding calculations. Default: 512
59
+ **rf_kwargs
60
+ Additional keyword arguments to pass to sklearn's RandomForestClassifier.
61
+ Common options include n_estimators, max_depth, min_samples_split, random_state, etc.
62
+ """
63
+
64
+ # Set device
65
+ if device is None:
66
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
67
+ self.device = device
68
+
69
+ self.emb_model = TiRexEmbedding(device=self.device, data_augmentation=data_augmentation, batch_size=batch_size)
70
+ self.data_augmentation = data_augmentation
71
+
72
+ self.head = RandomForestClassifier(**rf_kwargs)
73
+
74
+ @torch.inference_mode()
75
+ def fit(self, train_data: tuple[torch.Tensor, torch.Tensor]) -> None:
76
+ """Train the Random Forest classifier on embedded time series data.
77
+
78
+ This method generates embeddings for the training data using the embedding
79
+ model, then trains the Random Forest on these embeddings.
80
+
81
+ Args:
82
+ train_data: Tuple of (X_train, y_train) where X_train is the input time
83
+ series data (torch.Tensor) and y_train is a numpy array
84
+ of class labels.
85
+ """
86
+ X_train, y_train = train_data
87
+
88
+ if isinstance(y_train, torch.Tensor):
89
+ y_train = y_train.detach().cpu().numpy()
90
+
91
+ embeddings = self.emb_model(X_train).cpu().numpy()
92
+ self.head.fit(embeddings, y_train)
93
+
94
+ @torch.inference_mode()
95
+ def predict(self, x: torch.Tensor) -> torch.Tensor:
96
+ """Predict class labels for input time series data.
97
+
98
+ Args:
99
+ x: Input time series data as torch.Tensor or np.ndarray with shape
100
+ (batch_size, num_variates, seq_len).
101
+ Returns:
102
+ torch.Tensor: Predicted class labels with shape (batch_size,).
103
+ """
104
+
105
+ embeddings = self.emb_model(x).cpu().numpy()
106
+ return torch.from_numpy(self.head.predict(embeddings)).long()
107
+
108
+ @torch.inference_mode()
109
+ def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
110
+ """Predict class probabilities for input time series data.
111
+
112
+ Args:
113
+ x: Input time series data as torch.Tensor or np.ndarray with shape
114
+ (batch_size, num_variates, seq_len).
115
+ Returns:
116
+ torch.Tensor: Class probabilities with shape (batch_size, num_classes).
117
+ """
118
+ embeddings = self.emb_model(x).cpu().numpy()
119
+ return torch.from_numpy(self.head.predict_proba(embeddings))
120
+
121
+ def save_model(self, path: str) -> None:
122
+ """This method saves the trained Random Forest classifier head and embedding information in joblib format
123
+
124
+ Args:
125
+ path: File path where the model should be saved (e.g., 'model.joblib').
126
+ """
127
+ payload = {
128
+ "data_augmentation": self.data_augmentation,
129
+ "head": self.head,
130
+ }
131
+ joblib.dump(payload, path)
132
+
133
+ @classmethod
134
+ def load_model(cls, path: str) -> "TirexRFClassifier":
135
+ """Load a saved model from file.
136
+
137
+ This reconstructs the model with the embedding configuration and loads
138
+ the trained Random Forest classifier from a checkpoint file created by save_model().
139
+
140
+ Args:
141
+ path: File path to the saved model checkpoint.
142
+ Returns:
143
+ TirexRFClassifier: The loaded model with trained Random Forest, ready for inference.
144
+ """
145
+ checkpoint = joblib.load(path)
146
+
147
+ # Create new instance with saved configuration
148
+ model = cls(
149
+ data_augmentation=checkpoint["data_augmentation"],
150
+ )
151
+
152
+ # Load the trained Random Forest head
153
+ model.head = checkpoint["head"]
154
+
155
+ return model