tirex-mirror 2025.12.2__py3-none-any.whl → 2025.12.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tirex/models/base/base_classifier.py +36 -0
- tirex/models/base/base_regressor.py +23 -0
- tirex/models/{classification/heads/base_classifier.py → base/base_tirex.py} +42 -26
- tirex/models/classification/__init__.py +4 -4
- tirex/models/classification/{heads/gbm_classifier.py → gbm_classifier.py} +4 -26
- tirex/models/classification/{heads/linear_classifier.py → linear_classifier.py} +11 -12
- tirex/models/classification/{heads/rf_classifier.py → rf_classifier.py} +3 -3
- tirex/models/{classification/embedding.py → embedding.py} +2 -3
- tirex/models/regression/__init__.py +8 -0
- tirex/models/regression/gbm_regressor.py +181 -0
- tirex/models/regression/linear_regressor.py +250 -0
- tirex/models/regression/rf_regressor.py +130 -0
- tirex/models/{classification/trainer.py → trainer.py} +29 -10
- tirex/util.py +82 -0
- {tirex_mirror-2025.12.2.dist-info → tirex_mirror-2025.12.16.dist-info}/METADATA +6 -3
- tirex_mirror-2025.12.16.dist-info/RECORD +34 -0
- tirex/models/classification/utils.py +0 -84
- tirex_mirror-2025.12.2.dist-info/RECORD +0 -29
- {tirex_mirror-2025.12.2.dist-info → tirex_mirror-2025.12.16.dist-info}/WHEEL +0 -0
- {tirex_mirror-2025.12.2.dist-info → tirex_mirror-2025.12.16.dist-info}/licenses/LICENSE +0 -0
- {tirex_mirror-2025.12.2.dist-info → tirex_mirror-2025.12.16.dist-info}/licenses/LICENSE_MIRROR.txt +0 -0
- {tirex_mirror-2025.12.2.dist-info → tirex_mirror-2025.12.16.dist-info}/licenses/NOTICE.txt +0 -0
- {tirex_mirror-2025.12.2.dist-info → tirex_mirror-2025.12.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .base_tirex import BaseTirexEmbeddingModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseTirexClassifier(BaseTirexEmbeddingModel):
|
|
10
|
+
"""Abstract base class for TiRex classification models."""
|
|
11
|
+
|
|
12
|
+
@torch.inference_mode()
|
|
13
|
+
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
14
|
+
"""Predict class labels for input time series data.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x: Input time series data as torch.Tensor with shape
|
|
18
|
+
(batch_size, num_variates, seq_len).
|
|
19
|
+
Returns:
|
|
20
|
+
torch.Tensor: Predicted class labels with shape (batch_size,).
|
|
21
|
+
"""
|
|
22
|
+
emb = self._compute_embeddings(x)
|
|
23
|
+
return torch.from_numpy(self.head.predict(emb)).long()
|
|
24
|
+
|
|
25
|
+
@torch.inference_mode()
|
|
26
|
+
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
|
|
27
|
+
"""Predict class probabilities for input time series data.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
x: Input time series data as torch.Tensor with shape
|
|
31
|
+
(batch_size, num_variates, seq_len).
|
|
32
|
+
Returns:
|
|
33
|
+
torch.Tensor: Class probabilities with shape (batch_size, num_classes).
|
|
34
|
+
"""
|
|
35
|
+
emb = self._compute_embeddings(x)
|
|
36
|
+
return torch.from_numpy(self.head.predict_proba(emb))
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from .base_tirex import BaseTirexEmbeddingModel
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseTirexRegressor(BaseTirexEmbeddingModel):
|
|
10
|
+
"""Abstract base class for TiRex regression models."""
|
|
11
|
+
|
|
12
|
+
@torch.inference_mode()
|
|
13
|
+
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
14
|
+
"""Predict values for input time series data.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
x: Input time series data as torch.Tensor with shape
|
|
18
|
+
(batch_size, num_variates, seq_len).
|
|
19
|
+
Returns:
|
|
20
|
+
torch.Tensor: Predicted values with shape (batch_size,).
|
|
21
|
+
"""
|
|
22
|
+
emb = self._compute_embeddings(x)
|
|
23
|
+
return torch.from_numpy(self.head.predict(emb)).float()
|
|
@@ -3,15 +3,18 @@
|
|
|
3
3
|
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
|
|
6
|
+
import numpy as np
|
|
6
7
|
import torch
|
|
7
8
|
|
|
9
|
+
from tirex.util import train_val_split
|
|
10
|
+
|
|
8
11
|
from ..embedding import TiRexEmbedding
|
|
9
12
|
|
|
10
13
|
|
|
11
|
-
class
|
|
12
|
-
"""Abstract base class for TiRex
|
|
14
|
+
class BaseTirexEmbeddingModel(ABC):
|
|
15
|
+
"""Abstract base class for TiRex models.
|
|
13
16
|
|
|
14
|
-
This base class provides common functionality for all TiRex
|
|
17
|
+
This base class provides common functionality for all TiRex classifier and regression models,
|
|
15
18
|
including embedding model initialization and a consistent interface.
|
|
16
19
|
|
|
17
20
|
"""
|
|
@@ -19,7 +22,10 @@ class BaseTirexClassifier(ABC):
|
|
|
19
22
|
def __init__(
|
|
20
23
|
self, data_augmentation: bool = False, device: str | None = None, compile: bool = False, batch_size: int = 512
|
|
21
24
|
) -> None:
|
|
22
|
-
"""Initializes a TiRex
|
|
25
|
+
"""Initializes a base TiRex model.
|
|
26
|
+
|
|
27
|
+
This base class initializes the embedding model and common configuration
|
|
28
|
+
used by both classification and regression models.
|
|
23
29
|
|
|
24
30
|
Args:
|
|
25
31
|
data_augmentation : bool
|
|
@@ -49,45 +55,55 @@ class BaseTirexClassifier(ABC):
|
|
|
49
55
|
|
|
50
56
|
@abstractmethod
|
|
51
57
|
def fit(self, train_data: tuple[torch.Tensor, torch.Tensor]) -> None:
|
|
58
|
+
"""Abstract method for model training"""
|
|
52
59
|
pass
|
|
53
60
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
"""Predict class labels for input time series data.
|
|
61
|
+
def _compute_embeddings(self, x: torch.Tensor) -> np.ndarray:
|
|
62
|
+
"""Compute embeddings for input time series data.
|
|
57
63
|
|
|
58
64
|
Args:
|
|
59
65
|
x: Input time series data as torch.Tensor with shape
|
|
60
66
|
(batch_size, num_variates, seq_len).
|
|
61
|
-
Returns:
|
|
62
|
-
torch.Tensor: Predicted class labels with shape (batch_size,).
|
|
63
|
-
"""
|
|
64
|
-
self.emb_model.eval()
|
|
65
|
-
x = x.to(self.device)
|
|
66
|
-
embeddings = self.emb_model(x).cpu().numpy()
|
|
67
|
-
return torch.from_numpy(self.head.predict(embeddings)).long()
|
|
68
|
-
|
|
69
|
-
@torch.inference_mode()
|
|
70
|
-
def predict_proba(self, x: torch.Tensor) -> torch.Tensor:
|
|
71
|
-
"""Predict class probabilities for input time series data.
|
|
72
67
|
|
|
73
|
-
Args:
|
|
74
|
-
x: Input time series data as torch.Tensor with shape
|
|
75
|
-
(batch_size, num_variates, seq_len).
|
|
76
68
|
Returns:
|
|
77
|
-
|
|
69
|
+
np.ndarray: Embeddings with shape (batch_size, embedding_dim).
|
|
78
70
|
"""
|
|
79
71
|
self.emb_model.eval()
|
|
80
72
|
x = x.to(self.device)
|
|
81
|
-
|
|
82
|
-
|
|
73
|
+
return self.emb_model(x).cpu().numpy()
|
|
74
|
+
|
|
75
|
+
def _create_train_val_datasets(
|
|
76
|
+
self,
|
|
77
|
+
train_data: tuple[torch.Tensor, torch.Tensor],
|
|
78
|
+
val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
79
|
+
val_split_ratio: float = 0.2,
|
|
80
|
+
stratify: bool = False,
|
|
81
|
+
seed: int | None = None,
|
|
82
|
+
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
83
|
+
if val_data is None:
|
|
84
|
+
train_data, val_data = train_val_split(
|
|
85
|
+
train_data=train_data, val_split_ratio=val_split_ratio, stratify=stratify, seed=seed
|
|
86
|
+
)
|
|
87
|
+
return train_data, val_data
|
|
83
88
|
|
|
84
89
|
@abstractmethod
|
|
85
90
|
def save_model(self, path: str) -> None:
|
|
86
|
-
"""
|
|
91
|
+
"""Save model to file.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
path: File path where the model should be saved.
|
|
95
|
+
"""
|
|
87
96
|
pass
|
|
88
97
|
|
|
89
98
|
@classmethod
|
|
90
99
|
@abstractmethod
|
|
91
100
|
def load_model(cls, path: str):
|
|
92
|
-
"""
|
|
101
|
+
"""Load model from file.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
path: File path to the saved model checkpoint.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Instance of the model class with loaded weights and configuration.
|
|
108
|
+
"""
|
|
93
109
|
pass
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
# Copyright (c) NXAI GmbH.
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
|
-
from .
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
4
|
+
from .gbm_classifier import TirexGBMClassifier
|
|
5
|
+
from .linear_classifier import TirexLinearClassifier
|
|
6
|
+
from .rf_classifier import TirexRFClassifier
|
|
7
7
|
|
|
8
|
-
__all__ = ["
|
|
8
|
+
__all__ = ["TirexLinearClassifier", "TirexRFClassifier", "TirexGBMClassifier"]
|
|
@@ -5,8 +5,7 @@ import joblib
|
|
|
5
5
|
import torch
|
|
6
6
|
from lightgbm import LGBMClassifier, early_stopping
|
|
7
7
|
|
|
8
|
-
from ..
|
|
9
|
-
from .base_classifier import BaseTirexClassifier
|
|
8
|
+
from ..base.base_classifier import BaseTirexClassifier
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
class TirexGBMClassifier(BaseTirexClassifier):
|
|
@@ -118,11 +117,11 @@ class TirexGBMClassifier(BaseTirexClassifier):
|
|
|
118
117
|
seed=self.random_state,
|
|
119
118
|
)
|
|
120
119
|
|
|
121
|
-
self.emb_model.eval()
|
|
122
120
|
X_train = X_train.to(self.device)
|
|
123
121
|
X_val = X_val.to(self.device)
|
|
124
|
-
|
|
125
|
-
|
|
122
|
+
|
|
123
|
+
embeddings_train = self._compute_embeddings(X_train)
|
|
124
|
+
embeddings_val = self._compute_embeddings(X_val)
|
|
126
125
|
|
|
127
126
|
y_train = y_train.detach().cpu().numpy() if isinstance(y_train, torch.Tensor) else y_train
|
|
128
127
|
y_val = y_val.detach().cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val
|
|
@@ -187,24 +186,3 @@ class TirexGBMClassifier(BaseTirexClassifier):
|
|
|
187
186
|
model.random_state = getattr(model.head, "random_state", None)
|
|
188
187
|
|
|
189
188
|
return model
|
|
190
|
-
|
|
191
|
-
def _create_train_val_datasets(
|
|
192
|
-
self,
|
|
193
|
-
train_data: tuple[torch.Tensor, torch.Tensor],
|
|
194
|
-
val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
195
|
-
val_split_ratio: float = 0.2,
|
|
196
|
-
stratify: bool = True,
|
|
197
|
-
seed: int | None = None,
|
|
198
|
-
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
199
|
-
X_train, y_train = train_data
|
|
200
|
-
|
|
201
|
-
if val_data is None:
|
|
202
|
-
train_data, val_data = train_val_split(
|
|
203
|
-
train_data=train_data, val_split_ratio=val_split_ratio, stratify=stratify, seed=seed
|
|
204
|
-
)
|
|
205
|
-
X_train, y_train = train_data
|
|
206
|
-
X_val, y_val = val_data
|
|
207
|
-
else:
|
|
208
|
-
X_val, y_val = val_data
|
|
209
|
-
|
|
210
|
-
return (X_train, y_train), (X_val, y_val)
|
|
@@ -5,11 +5,11 @@ from dataclasses import asdict
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
-
from ..
|
|
9
|
-
from
|
|
8
|
+
from ..base.base_classifier import BaseTirexClassifier
|
|
9
|
+
from ..trainer import TrainConfig, Trainer, TrainingMetrics
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class
|
|
12
|
+
class TirexLinearClassifier(BaseTirexClassifier, torch.nn.Module):
|
|
13
13
|
"""
|
|
14
14
|
A PyTorch classifier that combines time series embeddings with a linear classification head.
|
|
15
15
|
|
|
@@ -19,10 +19,10 @@ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
|
19
19
|
|
|
20
20
|
Example:
|
|
21
21
|
>>> import torch
|
|
22
|
-
>>> from tirex.models.classification import
|
|
22
|
+
>>> from tirex.models.classification import TirexLinearClassifier
|
|
23
23
|
>>>
|
|
24
24
|
>>> # Create model with TiRex embeddings
|
|
25
|
-
>>> model =
|
|
25
|
+
>>> model = TirexLinearClassifier(
|
|
26
26
|
... data_augmentation=True,
|
|
27
27
|
... max_epochs=2,
|
|
28
28
|
... lr=1e-4,
|
|
@@ -72,7 +72,7 @@ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
|
72
72
|
compile: bool
|
|
73
73
|
Whether to compile the frozen embedding model. Default: False
|
|
74
74
|
max_epochs : int
|
|
75
|
-
Maximum number of training epochs. Default:
|
|
75
|
+
Maximum number of training epochs. Default: 10
|
|
76
76
|
lr : float
|
|
77
77
|
Learning rate for the optimizer. Default: 1e-4
|
|
78
78
|
weight_decay : float
|
|
@@ -115,6 +115,7 @@ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
|
115
115
|
lr=lr,
|
|
116
116
|
weight_decay=weight_decay,
|
|
117
117
|
class_weights=class_weights,
|
|
118
|
+
task_type="classification",
|
|
118
119
|
batch_size=batch_size,
|
|
119
120
|
val_split_ratio=val_split_ratio,
|
|
120
121
|
stratify=stratify,
|
|
@@ -132,9 +133,7 @@ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
|
132
133
|
|
|
133
134
|
@torch.inference_mode()
|
|
134
135
|
def _identify_head_dims(self, x: torch.Tensor, y: torch.Tensor) -> None:
|
|
135
|
-
self.
|
|
136
|
-
sample_emb = self.emb_model(x[:1])
|
|
137
|
-
self.emb_dim = sample_emb.shape[-1]
|
|
136
|
+
self.emb_dim = self._compute_embeddings(x[:1]).shape[-1]
|
|
138
137
|
self.num_classes = len(torch.unique(y))
|
|
139
138
|
|
|
140
139
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -155,7 +154,7 @@ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
|
155
154
|
|
|
156
155
|
def fit(
|
|
157
156
|
self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None = None
|
|
158
|
-
) ->
|
|
157
|
+
) -> TrainingMetrics:
|
|
159
158
|
"""Train the classification head on the provided data.
|
|
160
159
|
|
|
161
160
|
This method initializes the classification head based on the data dimensions,
|
|
@@ -231,7 +230,7 @@ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
|
231
230
|
)
|
|
232
231
|
|
|
233
232
|
@classmethod
|
|
234
|
-
def load_model(cls, path: str) -> "
|
|
233
|
+
def load_model(cls, path: str) -> "TirexLinearClassifier":
|
|
235
234
|
"""Load a saved model from file.
|
|
236
235
|
|
|
237
236
|
This reconstructs the model architecture and loads the trained weights from
|
|
@@ -240,7 +239,7 @@ class TirexClassifierTorch(BaseTirexClassifier, torch.nn.Module):
|
|
|
240
239
|
Args:
|
|
241
240
|
path: File path to the saved model checkpoint.
|
|
242
241
|
Returns:
|
|
243
|
-
|
|
242
|
+
TirexLinearClassifier: The loaded model with trained weights, ready for inference.
|
|
244
243
|
"""
|
|
245
244
|
checkpoint = torch.load(path)
|
|
246
245
|
|
|
@@ -5,7 +5,7 @@ import joblib
|
|
|
5
5
|
import torch
|
|
6
6
|
from sklearn.ensemble import RandomForestClassifier
|
|
7
7
|
|
|
8
|
-
from .base_classifier import BaseTirexClassifier
|
|
8
|
+
from ..base.base_classifier import BaseTirexClassifier
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class TirexRFClassifier(BaseTirexClassifier):
|
|
@@ -84,9 +84,9 @@ class TirexRFClassifier(BaseTirexClassifier):
|
|
|
84
84
|
if isinstance(y_train, torch.Tensor):
|
|
85
85
|
y_train = y_train.detach().cpu().numpy()
|
|
86
86
|
|
|
87
|
-
self.emb_model.eval()
|
|
88
87
|
X_train = X_train.to(self.device)
|
|
89
|
-
embeddings = self.
|
|
88
|
+
embeddings = self._compute_embeddings(X_train)
|
|
89
|
+
|
|
90
90
|
self.head.fit(embeddings, y_train)
|
|
91
91
|
|
|
92
92
|
def save_model(self, path: str) -> None:
|
|
@@ -6,8 +6,7 @@ import torch.nn as nn
|
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
|
|
8
8
|
from tirex import load_model
|
|
9
|
-
|
|
10
|
-
from .utils import nanmax, nanmin, nanstd
|
|
9
|
+
from tirex.util import nanmax, nanmin, nanstd
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
class TiRexEmbedding(nn.Module):
|
|
@@ -20,7 +19,7 @@ class TiRexEmbedding(nn.Module):
|
|
|
20
19
|
self.batch_size = batch_size
|
|
21
20
|
|
|
22
21
|
if device is None:
|
|
23
|
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
22
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
24
23
|
self.device = device
|
|
25
24
|
self._compile = compile
|
|
26
25
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
from .gbm_regressor import TirexGBMRegressor
|
|
5
|
+
from .linear_regressor import TirexLinearRegressor
|
|
6
|
+
from .rf_regressor import TirexRFRegressor
|
|
7
|
+
|
|
8
|
+
__all__ = ["TirexLinearRegressor", "TirexRFRegressor", "TirexGBMRegressor"]
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import joblib
|
|
5
|
+
import torch
|
|
6
|
+
from lightgbm import LGBMRegressor, early_stopping
|
|
7
|
+
|
|
8
|
+
from ..base.base_regressor import BaseTirexRegressor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TirexGBMRegressor(BaseTirexRegressor):
|
|
12
|
+
"""
|
|
13
|
+
A Gradient Boosting regressor that uses time series embeddings as features.
|
|
14
|
+
|
|
15
|
+
This regressor combines a pre-trained embedding model for feature extraction with a
|
|
16
|
+
Gradient Boosting regressor.
|
|
17
|
+
|
|
18
|
+
Example:
|
|
19
|
+
>>> import torch
|
|
20
|
+
>>> from tirex.models.regression import TirexGBMRegressor
|
|
21
|
+
>>>
|
|
22
|
+
>>> # Create model with custom LightGBM parameters
|
|
23
|
+
>>> model = TirexGBMRegressor(
|
|
24
|
+
... data_augmentation=True,
|
|
25
|
+
... n_estimators=50,
|
|
26
|
+
... random_state=42
|
|
27
|
+
... )
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Prepare data (can use NumPy arrays or PyTorch tensors)
|
|
30
|
+
>>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
|
|
31
|
+
>>> y_train = torch.randn(100,) # target values
|
|
32
|
+
>>>
|
|
33
|
+
>>> # Train the model
|
|
34
|
+
>>> model.fit((X_train, y_train))
|
|
35
|
+
>>>
|
|
36
|
+
>>> # Make predictions
|
|
37
|
+
>>> X_test = torch.randn(20, 1, 128)
|
|
38
|
+
>>> predictions = model.predict(X_test)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
data_augmentation: bool = False,
|
|
44
|
+
device: str | None = None,
|
|
45
|
+
compile: bool = False,
|
|
46
|
+
batch_size: int = 512,
|
|
47
|
+
early_stopping_rounds: int | None = 10,
|
|
48
|
+
min_delta: float = 0.0,
|
|
49
|
+
val_split_ratio: float = 0.2,
|
|
50
|
+
# LightGBM kwargs
|
|
51
|
+
**lgbm_kwargs,
|
|
52
|
+
) -> None:
|
|
53
|
+
"""Initializes Embedding Based Gradient Boosting Regression model.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
data_augmentation : bool
|
|
57
|
+
Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
|
|
58
|
+
device : str | None
|
|
59
|
+
Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
|
|
60
|
+
compile: bool
|
|
61
|
+
Whether to compile the frozen embedding model. Default: False
|
|
62
|
+
batch_size : int
|
|
63
|
+
Batch size for embedding calculations. Default: 512
|
|
64
|
+
early_stopping_rounds: int | None
|
|
65
|
+
Number of rounds without improvement of all metrics for Early Stopping. Default: 10
|
|
66
|
+
min_delta: float
|
|
67
|
+
Minimum improvement in score to keep training. Default 0.0
|
|
68
|
+
val_split_ratio : float
|
|
69
|
+
Proportion of training data to use for validation, if validation data are not provided. Default: 0.2
|
|
70
|
+
**lgbm_kwargs
|
|
71
|
+
Additional keyword arguments to pass to LightGBM's LGBMRegressor.
|
|
72
|
+
Common options include n_estimators, max_depth, learning_rate, random_state, etc.
|
|
73
|
+
"""
|
|
74
|
+
super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
|
|
75
|
+
|
|
76
|
+
# Early Stopping callback
|
|
77
|
+
self.early_stopping_rounds = early_stopping_rounds
|
|
78
|
+
self.min_delta = min_delta
|
|
79
|
+
|
|
80
|
+
# Data split parameters:
|
|
81
|
+
self.val_split_ratio = val_split_ratio
|
|
82
|
+
|
|
83
|
+
# Extract random_state for train_val_split if provided
|
|
84
|
+
self.random_state = lgbm_kwargs.get("random_state", None)
|
|
85
|
+
|
|
86
|
+
self.head = LGBMRegressor(**lgbm_kwargs)
|
|
87
|
+
|
|
88
|
+
@torch.inference_mode()
|
|
89
|
+
def fit(
|
|
90
|
+
self,
|
|
91
|
+
train_data: tuple[torch.Tensor, torch.Tensor],
|
|
92
|
+
val_data: tuple[torch.Tensor, torch.Tensor] | None = None,
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Train the LightGBM regressor on embedded time series data.
|
|
95
|
+
|
|
96
|
+
This method generates embeddings for the training data using the embedding
|
|
97
|
+
model, then trains the LightGBM regressor on these embeddings.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
train_data: Tuple of (X_train, y_train) where X_train is the input time
|
|
101
|
+
series data (torch.Tensor) and y_train is a torch.Tensor
|
|
102
|
+
of target values.
|
|
103
|
+
val_data: Optional tuple of (X_val, y_val) for validation where X_val is the input time
|
|
104
|
+
series data (torch.Tensor) and y_val is a torch.Tensor
|
|
105
|
+
of target values. If None, validation data will be automatically split from train_data (20% split).
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
(X_train, y_train), (X_val, y_val) = self._create_train_val_datasets(
|
|
109
|
+
train_data=train_data,
|
|
110
|
+
val_data=val_data,
|
|
111
|
+
val_split_ratio=self.val_split_ratio,
|
|
112
|
+
seed=self.random_state,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
X_train = X_train.to(self.device)
|
|
116
|
+
X_val = X_val.to(self.device)
|
|
117
|
+
|
|
118
|
+
embeddings_train = self._compute_embeddings(X_train)
|
|
119
|
+
embeddings_val = self._compute_embeddings(X_val)
|
|
120
|
+
|
|
121
|
+
y_train = y_train.detach().cpu().numpy() if isinstance(y_train, torch.Tensor) else y_train
|
|
122
|
+
y_val = y_val.detach().cpu().numpy() if isinstance(y_val, torch.Tensor) else y_val
|
|
123
|
+
|
|
124
|
+
self.head.fit(
|
|
125
|
+
embeddings_train,
|
|
126
|
+
y_train,
|
|
127
|
+
eval_set=[(embeddings_val, y_val)],
|
|
128
|
+
callbacks=[early_stopping(stopping_rounds=self.early_stopping_rounds, min_delta=self.min_delta)]
|
|
129
|
+
if self.early_stopping_rounds is not None
|
|
130
|
+
else None,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def save_model(self, path: str) -> None:
|
|
134
|
+
"""This method saves the trained LightGBM regressor head (joblib format) and embedding information.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
path: File path where the model should be saved (e.g., 'model.joblib').
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
payload = {
|
|
141
|
+
"data_augmentation": self.data_augmentation,
|
|
142
|
+
"compile": self._compile,
|
|
143
|
+
"batch_size": self.batch_size,
|
|
144
|
+
"early_stopping_rounds": self.early_stopping_rounds,
|
|
145
|
+
"min_delta": self.min_delta,
|
|
146
|
+
"val_split_ratio": self.val_split_ratio,
|
|
147
|
+
"head": self.head,
|
|
148
|
+
}
|
|
149
|
+
joblib.dump(payload, path)
|
|
150
|
+
|
|
151
|
+
@classmethod
|
|
152
|
+
def load_model(cls, path: str) -> "TirexGBMRegressor":
|
|
153
|
+
"""Load a saved model from file.
|
|
154
|
+
|
|
155
|
+
This reconstructs the model with the embedding configuration and loads
|
|
156
|
+
the trained LightGBM regressor from a checkpoint file created by save_model().
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
path: File path to the saved model checkpoint.
|
|
160
|
+
Returns:
|
|
161
|
+
TirexGBMRegressor: The loaded model with trained Gradient Boosting regressor, ready for inference.
|
|
162
|
+
"""
|
|
163
|
+
checkpoint = joblib.load(path)
|
|
164
|
+
|
|
165
|
+
# Create new instance with saved configuration
|
|
166
|
+
model = cls(
|
|
167
|
+
data_augmentation=checkpoint["data_augmentation"],
|
|
168
|
+
compile=checkpoint["compile"],
|
|
169
|
+
batch_size=checkpoint["batch_size"],
|
|
170
|
+
early_stopping_rounds=checkpoint["early_stopping_rounds"],
|
|
171
|
+
min_delta=checkpoint["min_delta"],
|
|
172
|
+
val_split_ratio=checkpoint["val_split_ratio"],
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Load the trained LightGBM head
|
|
176
|
+
model.head = checkpoint["head"]
|
|
177
|
+
|
|
178
|
+
# Extract random_state from the loaded head if available
|
|
179
|
+
model.random_state = getattr(model.head, "random_state", None)
|
|
180
|
+
|
|
181
|
+
return model
|
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
from dataclasses import asdict
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from ..base.base_regressor import BaseTirexRegressor
|
|
9
|
+
from ..trainer import TrainConfig, Trainer, TrainingMetrics
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TirexLinearRegressor(BaseTirexRegressor, torch.nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
A PyTorch regressor that combines time series embeddings with a linear regression head.
|
|
15
|
+
|
|
16
|
+
This model uses a pre-trained TiRex embedding model to generate feature representations from time series
|
|
17
|
+
data, followed by a linear layer (with optional dropout) for regression. The embedding backbone
|
|
18
|
+
is frozen during training, and only the regression head is trained.
|
|
19
|
+
|
|
20
|
+
Example:
|
|
21
|
+
>>> import torch
|
|
22
|
+
>>> from tirex.models.regression import TirexLinearRegressor
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Create model with TiRex embeddings
|
|
25
|
+
>>> model = TirexLinearRegressor(
|
|
26
|
+
... data_augmentation=True,
|
|
27
|
+
... max_epochs=2,
|
|
28
|
+
... lr=1e-4,
|
|
29
|
+
... batch_size=32
|
|
30
|
+
... )
|
|
31
|
+
>>>
|
|
32
|
+
>>> # Prepare data
|
|
33
|
+
>>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
|
|
34
|
+
>>> y_train = torch.randn(100, 1) # target values
|
|
35
|
+
>>>
|
|
36
|
+
>>> # Train the model
|
|
37
|
+
>>> metrics = model.fit((X_train, y_train)) # doctest: +ELLIPSIS
|
|
38
|
+
Epoch 1, Train Loss: ...
|
|
39
|
+
>>> # Make predictions
|
|
40
|
+
>>> X_test = torch.randn(20, 1, 128)
|
|
41
|
+
>>> predictions = model.predict(X_test)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
data_augmentation: bool = False,
|
|
47
|
+
device: str | None = None,
|
|
48
|
+
compile: bool = False,
|
|
49
|
+
# Training parameters
|
|
50
|
+
max_epochs: int = 10,
|
|
51
|
+
lr: float = 1e-4,
|
|
52
|
+
weight_decay: float = 0.01,
|
|
53
|
+
batch_size: int = 512,
|
|
54
|
+
val_split_ratio: float = 0.2,
|
|
55
|
+
patience: int = 7,
|
|
56
|
+
delta: float = 0.001,
|
|
57
|
+
log_every_n_steps: int = 5,
|
|
58
|
+
seed: int | None = None,
|
|
59
|
+
# Head parameters
|
|
60
|
+
dropout: float | None = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
"""Initializes Embedding Based Linear Regression model.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
data_augmentation : bool | None
|
|
66
|
+
Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
|
|
67
|
+
device : str | None
|
|
68
|
+
Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
|
|
69
|
+
compile: bool
|
|
70
|
+
Whether to compile the frozen embedding model. Default: False
|
|
71
|
+
max_epochs : int
|
|
72
|
+
Maximum number of training epochs. Default: 10
|
|
73
|
+
lr : float
|
|
74
|
+
Learning rate for the optimizer. Default: 1e-4
|
|
75
|
+
weight_decay : float
|
|
76
|
+
Weight decay coefficient. Default: 0.01
|
|
77
|
+
batch_size : int
|
|
78
|
+
Batch size for training and embedding calculations. Default: 512
|
|
79
|
+
val_split_ratio : float
|
|
80
|
+
Proportion of training data to use for validation, if validation data are not provided. Default: 0.2
|
|
81
|
+
patience : int
|
|
82
|
+
Number of epochs to wait for improvement before early stopping. Default: 7
|
|
83
|
+
delta : float
|
|
84
|
+
Minimum change in validation loss to qualify as an improvement. Default: 0.001
|
|
85
|
+
log_every_n_steps : int
|
|
86
|
+
Frequency of logging during training. Default: 5
|
|
87
|
+
seed : int | None
|
|
88
|
+
Random seed for reproducibility. If None, no seed is set. Default: None
|
|
89
|
+
dropout : float | None
|
|
90
|
+
Dropout probability for the regression head. If None, no dropout is used. Default: None
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
torch.nn.Module.__init__(self)
|
|
94
|
+
|
|
95
|
+
super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
|
|
96
|
+
|
|
97
|
+
# Head parameters
|
|
98
|
+
self.dropout = dropout
|
|
99
|
+
self.head = None
|
|
100
|
+
self.emb_dim = None
|
|
101
|
+
self.output_dim = None
|
|
102
|
+
|
|
103
|
+
# Train config
|
|
104
|
+
train_config = TrainConfig(
|
|
105
|
+
max_epochs=max_epochs,
|
|
106
|
+
log_every_n_steps=log_every_n_steps,
|
|
107
|
+
device=self.device,
|
|
108
|
+
lr=lr,
|
|
109
|
+
weight_decay=weight_decay,
|
|
110
|
+
class_weights=None,
|
|
111
|
+
task_type="regression",
|
|
112
|
+
batch_size=batch_size,
|
|
113
|
+
val_split_ratio=val_split_ratio,
|
|
114
|
+
patience=patience,
|
|
115
|
+
delta=delta,
|
|
116
|
+
seed=seed,
|
|
117
|
+
)
|
|
118
|
+
self.trainer = Trainer(self, train_config=train_config)
|
|
119
|
+
|
|
120
|
+
def _init_regressor(self, emb_dim: int, output_dim: int, dropout: float | None) -> torch.nn.Module:
|
|
121
|
+
if dropout:
|
|
122
|
+
return torch.nn.Sequential(torch.nn.Dropout(p=dropout), torch.nn.Linear(emb_dim, output_dim))
|
|
123
|
+
else:
|
|
124
|
+
return torch.nn.Linear(emb_dim, output_dim)
|
|
125
|
+
|
|
126
|
+
@torch.inference_mode()
|
|
127
|
+
def _identify_head_dims(self, x: torch.Tensor, y: torch.Tensor) -> None:
|
|
128
|
+
self.emb_dim = self._compute_embeddings(x[:1]).shape[-1]
|
|
129
|
+
self.output_dim = 1
|
|
130
|
+
|
|
131
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
132
|
+
"""Forward pass through the embedding model and regression head.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
|
|
136
|
+
Returns:
|
|
137
|
+
torch.Tensor: Predicted values with shape (batch_size, 1).
|
|
138
|
+
Raises:
|
|
139
|
+
RuntimeError: If the regression head has not been initialized via fit().
|
|
140
|
+
"""
|
|
141
|
+
if self.head is None:
|
|
142
|
+
raise RuntimeError("Head not initialized. Call fit() first to automatically build the head.")
|
|
143
|
+
|
|
144
|
+
embedding = self.emb_model(x).to(self.device)
|
|
145
|
+
return self.head(embedding)
|
|
146
|
+
|
|
147
|
+
def fit(
|
|
148
|
+
self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None = None
|
|
149
|
+
) -> TrainingMetrics:
|
|
150
|
+
"""Train the regression head on the provided data.
|
|
151
|
+
|
|
152
|
+
This method initializes the regression head based on the data dimensions,
|
|
153
|
+
then trains it on provided data. The embedding model remains frozen.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
train_data: Tuple of (X_train, y_train) where X_train is the input time series
|
|
157
|
+
data and y_train are the corresponding target values.
|
|
158
|
+
val_data: Optional tuple of (X_val, y_val) for validation. If None and
|
|
159
|
+
val_split_ratio > 0, validation data will be split from train_data.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
dict[str, float]: Dictionary containing final training and validation losses.
|
|
163
|
+
"""
|
|
164
|
+
X_train, y_train = train_data
|
|
165
|
+
|
|
166
|
+
self._identify_head_dims(X_train, y_train)
|
|
167
|
+
self.head = self._init_regressor(self.emb_dim, self.output_dim, self.dropout)
|
|
168
|
+
self.head = self.head.to(self.trainer.device)
|
|
169
|
+
|
|
170
|
+
return self.trainer.fit(train_data, val_data=val_data)
|
|
171
|
+
|
|
172
|
+
@torch.inference_mode()
|
|
173
|
+
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
174
|
+
"""Predict values for input time series data.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
x: Input tensor of time series data with shape (batch_size, num_variates, seq_len).
|
|
178
|
+
Returns:
|
|
179
|
+
torch.Tensor: Predicted values with shape (batch_size, 1).
|
|
180
|
+
"""
|
|
181
|
+
self.eval()
|
|
182
|
+
x = x.to(self.device)
|
|
183
|
+
return self.forward(x)
|
|
184
|
+
|
|
185
|
+
def save_model(self, path: str) -> None:
|
|
186
|
+
"""Save the trained regression head.
|
|
187
|
+
|
|
188
|
+
This function saves the trained regression head weights (.pt format), embedding configuration,
|
|
189
|
+
model dimensions, and device information. The embedding model itself is not
|
|
190
|
+
saved as it uses a pre-trained backbone that can be reloaded.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
path: File path where the model should be saved (e.g., 'model.pt').
|
|
194
|
+
"""
|
|
195
|
+
train_config_dict = asdict(self.trainer.train_config)
|
|
196
|
+
torch.save(
|
|
197
|
+
{
|
|
198
|
+
"head_state_dict": self.head.state_dict(), # need to save only head, embedding is frozen
|
|
199
|
+
"data_augmentation": self.data_augmentation,
|
|
200
|
+
"compile": self._compile,
|
|
201
|
+
"emb_dim": self.emb_dim,
|
|
202
|
+
"output_dim": self.output_dim,
|
|
203
|
+
"dropout": self.dropout,
|
|
204
|
+
"train_config": train_config_dict,
|
|
205
|
+
},
|
|
206
|
+
path,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
@classmethod
|
|
210
|
+
def load_model(cls, path: str) -> "TirexLinearRegressor":
|
|
211
|
+
"""Load a saved model from file.
|
|
212
|
+
|
|
213
|
+
This reconstructs the model architecture and loads the trained weights from
|
|
214
|
+
a checkpoint file created by save_model().
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
path: File path to the saved model checkpoint.
|
|
218
|
+
Returns:
|
|
219
|
+
TirexLinearRegressor: The loaded model with trained weights, ready for inference.
|
|
220
|
+
"""
|
|
221
|
+
checkpoint = torch.load(path)
|
|
222
|
+
|
|
223
|
+
# Extract train_config if available, otherwise use defaults
|
|
224
|
+
train_config_dict = checkpoint.get("train_config", {})
|
|
225
|
+
|
|
226
|
+
model = cls(
|
|
227
|
+
data_augmentation=checkpoint["data_augmentation"],
|
|
228
|
+
compile=checkpoint["compile"],
|
|
229
|
+
dropout=checkpoint["dropout"],
|
|
230
|
+
max_epochs=train_config_dict.get("max_epochs", 50),
|
|
231
|
+
lr=train_config_dict.get("lr", 1e-4),
|
|
232
|
+
weight_decay=train_config_dict.get("weight_decay", 0.01),
|
|
233
|
+
batch_size=train_config_dict.get("batch_size", 512),
|
|
234
|
+
val_split_ratio=train_config_dict.get("val_split_ratio", 0.2),
|
|
235
|
+
patience=train_config_dict.get("patience", 7),
|
|
236
|
+
delta=train_config_dict.get("delta", 0.001),
|
|
237
|
+
log_every_n_steps=train_config_dict.get("log_every_n_steps", 5),
|
|
238
|
+
seed=train_config_dict.get("seed", None),
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Initialize head with dimensions
|
|
242
|
+
model.emb_dim = checkpoint["emb_dim"]
|
|
243
|
+
model.output_dim = checkpoint.get("output_dim", checkpoint.get("num_classes", 1)) # Backward compatibility
|
|
244
|
+
model.head = model._init_regressor(model.emb_dim, model.output_dim, model.dropout)
|
|
245
|
+
|
|
246
|
+
# Load the trained weights
|
|
247
|
+
model.head.load_state_dict(checkpoint["head_state_dict"])
|
|
248
|
+
model.to(model.device)
|
|
249
|
+
|
|
250
|
+
return model
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright (c) NXAI GmbH.
|
|
2
|
+
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
+
|
|
4
|
+
import joblib
|
|
5
|
+
import torch
|
|
6
|
+
from sklearn.ensemble import RandomForestRegressor
|
|
7
|
+
|
|
8
|
+
from ..base.base_regressor import BaseTirexRegressor
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TirexRFRegressor(BaseTirexRegressor):
|
|
12
|
+
"""
|
|
13
|
+
A Random Forest regressor that uses time series embeddings as features.
|
|
14
|
+
|
|
15
|
+
This regressor combines a pre-trained embedding model for feature extraction with a scikit-learn
|
|
16
|
+
Random Forest regressor. The embedding model generates fixed-size feature vectors from variable-length
|
|
17
|
+
time series, which are then used to train the Random Forest.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> import torch
|
|
21
|
+
>>> from tirex.models.regression import TirexRFRegressor
|
|
22
|
+
>>>
|
|
23
|
+
>>> # Create model with custom Random Forest parameters
|
|
24
|
+
>>> model = TirexRFRegressor(
|
|
25
|
+
... data_augmentation=True,
|
|
26
|
+
... n_estimators=50,
|
|
27
|
+
... max_depth=10,
|
|
28
|
+
... random_state=42
|
|
29
|
+
... )
|
|
30
|
+
>>>
|
|
31
|
+
>>> # Prepare data (can use NumPy arrays or PyTorch tensors)
|
|
32
|
+
>>> X_train = torch.randn(100, 1, 128) # 100 samples, 1 number of variates, 128 sequence length
|
|
33
|
+
>>> y_train = torch.randn(100,) # target values
|
|
34
|
+
>>>
|
|
35
|
+
>>> # Train the model
|
|
36
|
+
>>> model.fit((X_train, y_train))
|
|
37
|
+
>>>
|
|
38
|
+
>>> # Make predictions
|
|
39
|
+
>>> X_test = torch.randn(20, 1, 128)
|
|
40
|
+
>>> predictions = model.predict(X_test)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
data_augmentation: bool = False,
|
|
46
|
+
device: str | None = None,
|
|
47
|
+
compile: bool = False,
|
|
48
|
+
batch_size: int = 512,
|
|
49
|
+
# Random Forest parameters
|
|
50
|
+
**rf_kwargs,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""Initializes Embedding Based Random Forest Regression model.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
data_augmentation : bool
|
|
56
|
+
Whether to use data_augmentation for embeddings (sample statistics and first-order differences of the original data). Default: False
|
|
57
|
+
device : str | None
|
|
58
|
+
Device to run the embedding model on. If None, uses CUDA if available, else CPU. Default: None
|
|
59
|
+
compile: bool
|
|
60
|
+
Whether to compile the frozen embedding model. Default: False
|
|
61
|
+
batch_size : int
|
|
62
|
+
Batch size for embedding calculations. Default: 512
|
|
63
|
+
**rf_kwargs
|
|
64
|
+
Additional keyword arguments to pass to sklearn's RandomForestRegressor.
|
|
65
|
+
Common options include n_estimators, max_depth, min_samples_split, random_state, etc.
|
|
66
|
+
"""
|
|
67
|
+
super().__init__(data_augmentation=data_augmentation, device=device, compile=compile, batch_size=batch_size)
|
|
68
|
+
self.head = RandomForestRegressor(**rf_kwargs)
|
|
69
|
+
|
|
70
|
+
@torch.inference_mode()
|
|
71
|
+
def fit(self, train_data: tuple[torch.Tensor, torch.Tensor]) -> None:
|
|
72
|
+
"""Train the Random Forest regressor on embedded time series data.
|
|
73
|
+
|
|
74
|
+
This method generates embeddings for the training data using the embedding
|
|
75
|
+
model, then trains the Random Forest on these embeddings.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
train_data: Tuple of (X_train, y_train) where X_train is the input time
|
|
79
|
+
series data (torch.Tensor) and y_train is a torch.Tensor
|
|
80
|
+
of target values.
|
|
81
|
+
"""
|
|
82
|
+
X_train, y_train = train_data
|
|
83
|
+
|
|
84
|
+
if isinstance(y_train, torch.Tensor):
|
|
85
|
+
y_train = y_train.detach().cpu().numpy()
|
|
86
|
+
|
|
87
|
+
X_train = X_train.to(self.device)
|
|
88
|
+
embeddings = self._compute_embeddings(X_train)
|
|
89
|
+
|
|
90
|
+
self.head.fit(embeddings, y_train)
|
|
91
|
+
|
|
92
|
+
def save_model(self, path: str) -> None:
|
|
93
|
+
"""This method saves the trained Random Forest regressor head and embedding information in joblib format
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
path: File path where the model should be saved (e.g., 'model.joblib').
|
|
97
|
+
"""
|
|
98
|
+
payload = {
|
|
99
|
+
"data_augmentation": self.data_augmentation,
|
|
100
|
+
"compile": self._compile,
|
|
101
|
+
"batch_size": self.batch_size,
|
|
102
|
+
"head": self.head,
|
|
103
|
+
}
|
|
104
|
+
joblib.dump(payload, path)
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def load_model(cls, path: str) -> "TirexRFRegressor":
|
|
108
|
+
"""Load a saved model from file.
|
|
109
|
+
|
|
110
|
+
This reconstructs the model with the embedding configuration and loads
|
|
111
|
+
the trained Random Forest regressor from a checkpoint file created by save_model().
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
path: File path to the saved model checkpoint.
|
|
115
|
+
Returns:
|
|
116
|
+
TirexRFRegressor: The loaded model with trained Random Forest regressor, ready for inference.
|
|
117
|
+
"""
|
|
118
|
+
checkpoint = joblib.load(path)
|
|
119
|
+
|
|
120
|
+
# Create new instance with saved configuration
|
|
121
|
+
model = cls(
|
|
122
|
+
data_augmentation=checkpoint["data_augmentation"],
|
|
123
|
+
compile=checkpoint["compile"],
|
|
124
|
+
batch_size=checkpoint["batch_size"],
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Load the trained Random Forest head
|
|
128
|
+
model.head = checkpoint["head"]
|
|
129
|
+
|
|
130
|
+
return model
|
|
@@ -2,11 +2,18 @@
|
|
|
2
2
|
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
3
|
|
|
4
4
|
from dataclasses import dataclass
|
|
5
|
+
from typing import TypedDict
|
|
5
6
|
|
|
6
7
|
import torch
|
|
8
|
+
from torch.optim import Optimizer
|
|
7
9
|
from torch.utils.data import DataLoader, TensorDataset
|
|
8
10
|
|
|
9
|
-
from
|
|
11
|
+
from ..util import EarlyStopping, set_seed, train_val_split
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TrainingMetrics(TypedDict):
|
|
15
|
+
train_loss: float
|
|
16
|
+
val_loss: float
|
|
10
17
|
|
|
11
18
|
|
|
12
19
|
@dataclass
|
|
@@ -22,11 +29,7 @@ class TrainConfig:
|
|
|
22
29
|
|
|
23
30
|
# Loss parameters
|
|
24
31
|
class_weights: torch.Tensor | None
|
|
25
|
-
|
|
26
|
-
# Data loading parameters
|
|
27
|
-
batch_size: int
|
|
28
|
-
val_split_ratio: float
|
|
29
|
-
stratify: bool
|
|
32
|
+
task_type: str
|
|
30
33
|
|
|
31
34
|
# Earlystopping parameters
|
|
32
35
|
patience: int
|
|
@@ -35,6 +38,11 @@ class TrainConfig:
|
|
|
35
38
|
# Reproducability
|
|
36
39
|
seed: int | None
|
|
37
40
|
|
|
41
|
+
# Data loading parameters
|
|
42
|
+
batch_size: int
|
|
43
|
+
val_split_ratio: float
|
|
44
|
+
stratify: bool = False
|
|
45
|
+
|
|
38
46
|
def __post_init__(self) -> None:
|
|
39
47
|
if self.max_epochs <= 0:
|
|
40
48
|
raise ValueError(f"max_epochs must be positive, got {self.max_epochs}")
|
|
@@ -60,6 +68,12 @@ class TrainConfig:
|
|
|
60
68
|
if self.delta < 0:
|
|
61
69
|
raise ValueError(f"delta must be non-negative, got {self.delta}")
|
|
62
70
|
|
|
71
|
+
if self.task_type not in ["classification", "regression"]:
|
|
72
|
+
raise ValueError(f"task_type must be 'classification' or 'regression', got {self.task_type}")
|
|
73
|
+
|
|
74
|
+
if self.stratify and self.task_type == "regression":
|
|
75
|
+
raise ValueError("stratify=True is not valid for regression tasks")
|
|
76
|
+
|
|
63
77
|
|
|
64
78
|
class Trainer:
|
|
65
79
|
def __init__(
|
|
@@ -74,14 +88,19 @@ class Trainer:
|
|
|
74
88
|
class_weights = (
|
|
75
89
|
self.train_config.class_weights.to(self.device) if self.train_config.class_weights is not None else None
|
|
76
90
|
)
|
|
77
|
-
self.
|
|
78
|
-
|
|
79
|
-
self.
|
|
91
|
+
if self.train_config.task_type == "classification":
|
|
92
|
+
self.loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights).to(self.device)
|
|
93
|
+
elif self.train_config.task_type == "regression":
|
|
94
|
+
self.loss_fn = torch.nn.MSELoss().to(self.device)
|
|
95
|
+
else:
|
|
96
|
+
raise ValueError(f"Unsupported task_type: {self.train_config.task_type}")
|
|
97
|
+
|
|
98
|
+
self.optimizer: Optimizer | None = None
|
|
80
99
|
self.early_stopper = EarlyStopping(patience=self.train_config.patience, delta=self.train_config.delta)
|
|
81
100
|
|
|
82
101
|
def fit(
|
|
83
102
|
self, train_data: tuple[torch.Tensor, torch.Tensor], val_data: tuple[torch.Tensor, torch.Tensor] | None = None
|
|
84
|
-
) ->
|
|
103
|
+
) -> TrainingMetrics:
|
|
85
104
|
if self.train_config.seed is not None:
|
|
86
105
|
set_seed(self.train_config.seed)
|
|
87
106
|
|
tirex/util.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Literal, Optional
|
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import torch
|
|
12
|
+
from sklearn.model_selection import train_test_split
|
|
12
13
|
|
|
13
14
|
COLOR_CONTEXT = "#4a90d0"
|
|
14
15
|
COLOR_FORECAST = "#d94e4e"
|
|
@@ -845,3 +846,84 @@ def plot_forecast(
|
|
|
845
846
|
ax.grid()
|
|
846
847
|
|
|
847
848
|
return ax
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
# ==== Classification and Regression Utilities ====
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
# Remove after Issue will be solved: https://github.com/pytorch/pytorch/issues/61474
|
|
855
|
+
def nanmax(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
856
|
+
min_value = torch.finfo(tensor.dtype).min
|
|
857
|
+
output = tensor.nan_to_num(min_value).max(dim=dim, keepdim=keepdim)
|
|
858
|
+
return output.values
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
def nanmin(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
862
|
+
max_value = torch.finfo(tensor.dtype).max
|
|
863
|
+
output = tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim)
|
|
864
|
+
return output.values
|
|
865
|
+
|
|
866
|
+
|
|
867
|
+
def nanvar(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
868
|
+
tensor_mean = tensor.nanmean(dim=dim, keepdim=True)
|
|
869
|
+
output = (tensor - tensor_mean).square().nanmean(dim=dim, keepdim=keepdim)
|
|
870
|
+
return output
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
def nanstd(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
874
|
+
output = nanvar(tensor, dim=dim, keepdim=keepdim)
|
|
875
|
+
output = output.sqrt()
|
|
876
|
+
return output
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
def train_val_split(
|
|
880
|
+
train_data: tuple[torch.Tensor, torch.Tensor],
|
|
881
|
+
val_split_ratio: float,
|
|
882
|
+
stratify: bool = False,
|
|
883
|
+
seed: int | None = None,
|
|
884
|
+
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
885
|
+
idx_train, idx_val = train_test_split(
|
|
886
|
+
np.arange(len(train_data[0])),
|
|
887
|
+
test_size=val_split_ratio,
|
|
888
|
+
random_state=seed,
|
|
889
|
+
shuffle=True,
|
|
890
|
+
stratify=train_data[1] if stratify else None,
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
return (
|
|
894
|
+
(train_data[0][idx_train], train_data[1][idx_train]),
|
|
895
|
+
(train_data[0][idx_val], train_data[1][idx_val]),
|
|
896
|
+
)
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
def set_seed(seed: int) -> None:
|
|
900
|
+
torch.manual_seed(seed)
|
|
901
|
+
if torch.cuda.is_available():
|
|
902
|
+
torch.cuda.manual_seed(seed)
|
|
903
|
+
np.random.seed(seed)
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
class EarlyStopping:
|
|
907
|
+
def __init__(
|
|
908
|
+
self,
|
|
909
|
+
patience: int = 7,
|
|
910
|
+
delta: float = 0.0001,
|
|
911
|
+
) -> None:
|
|
912
|
+
self.patience: int = patience
|
|
913
|
+
self.delta: float = delta
|
|
914
|
+
|
|
915
|
+
self.best: float = np.inf
|
|
916
|
+
self.wait_count: int = 0
|
|
917
|
+
self.early_stop: bool = False
|
|
918
|
+
|
|
919
|
+
def __call__(self, epoch: int, val_loss: float) -> bool:
|
|
920
|
+
improved = val_loss < (self.best - self.delta)
|
|
921
|
+
if improved:
|
|
922
|
+
self.best = val_loss
|
|
923
|
+
self.wait_count = 0
|
|
924
|
+
else:
|
|
925
|
+
self.wait_count += 1
|
|
926
|
+
if self.wait_count >= self.patience:
|
|
927
|
+
self.early_stop = True
|
|
928
|
+
print(f"Early stopping triggered at epoch {epoch}.")
|
|
929
|
+
return self.early_stop
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: tirex-mirror
|
|
3
|
-
Version: 2025.12.
|
|
3
|
+
Version: 2025.12.16
|
|
4
4
|
Summary: Unofficial mirror of NX-AI/tirex for packaging
|
|
5
5
|
Author-email: Arpad Rozsas <rozsasarpi@gmail.com>
|
|
6
6
|
License: NXAI COMMUNITY LICENSE AGREEMENT
|
|
@@ -55,7 +55,7 @@ License: NXAI COMMUNITY LICENSE AGREEMENT
|
|
|
55
55
|
|
|
56
56
|
Project-URL: Repository, https://github.com/rozsasarpi/tirex-mirror
|
|
57
57
|
Project-URL: Issues, https://github.com/rozsasarpi/tirex-mirror/issues
|
|
58
|
-
Keywords: TiRex,xLSTM,
|
|
58
|
+
Keywords: TiRex,xLSTM,Timeseries,Zero-shot,Deep Learning,Timeseries-Forecasting,Timeseries-Classification,Timeseries-Regression
|
|
59
59
|
Classifier: Programming Language :: Python :: 3
|
|
60
60
|
Classifier: Operating System :: OS Independent
|
|
61
61
|
Requires-Python: >=3.11
|
|
@@ -66,6 +66,7 @@ License-File: NOTICE.txt
|
|
|
66
66
|
Requires-Dist: torch
|
|
67
67
|
Requires-Dist: huggingface-hub
|
|
68
68
|
Requires-Dist: numpy
|
|
69
|
+
Requires-Dist: scikit-learn
|
|
69
70
|
Provides-Extra: cuda
|
|
70
71
|
Requires-Dist: xlstm; extra == "cuda"
|
|
71
72
|
Requires-Dist: ninja; extra == "cuda"
|
|
@@ -84,9 +85,11 @@ Requires-Dist: datasets; extra == "hfdataset"
|
|
|
84
85
|
Provides-Extra: test
|
|
85
86
|
Requires-Dist: fev>=0.6.0; extra == "test"
|
|
86
87
|
Requires-Dist: pytest; extra == "test"
|
|
88
|
+
Requires-Dist: aeon; extra == "test"
|
|
87
89
|
Provides-Extra: classification
|
|
88
|
-
Requires-Dist: scikit-learn; extra == "classification"
|
|
89
90
|
Requires-Dist: lightgbm[scikit-learn]; extra == "classification"
|
|
91
|
+
Provides-Extra: regression
|
|
92
|
+
Requires-Dist: lightgbm[scikit-learn]; extra == "regression"
|
|
90
93
|
Provides-Extra: all
|
|
91
94
|
Requires-Dist: xlstm; extra == "all"
|
|
92
95
|
Requires-Dist: ninja; extra == "all"
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
+
tirex/base.py,sha256=pGShyI3LwPAwkl0rIYqik5qb8uBAWqbQrVoVjq7UH_8,5096
|
|
3
|
+
tirex/util.py,sha256=D1HaHiAfZAEiOA_q2saOpugn0gozepBcwXwvQHS85cg,34308
|
|
4
|
+
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
|
+
tirex/api_adapter/forecast.py,sha256=FnRgdI_4vJ7iFTqyWwxPE_C7MBB3nhLTnVMzTr5pihc,17581
|
|
6
|
+
tirex/api_adapter/gluon.py,sha256=0KfKX7dGSooVIUumGtjK6Va2YgI6Isa_8kz9j0oPXUM,1883
|
|
7
|
+
tirex/api_adapter/hf_data.py,sha256=TRyys2xKIGZS0Yhq2Eb61lWCMg5CWWn1yRlLIN1mU7o,1369
|
|
8
|
+
tirex/api_adapter/standard_adapter.py,sha256=vdlxNs8mTUtPgK_5WMqYqNdMj8W44igqWsAgtggt_xk,2809
|
|
9
|
+
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
|
+
tirex/models/embedding.py,sha256=quDT_xKaaMmhXvZxz7LV2r4ICou2pzcjam7YnfOzxIs,5457
|
|
11
|
+
tirex/models/patcher.py,sha256=8T4c3PZnOAsEpahhrjtt7S7405WUjN6g3cV33E55PD4,1911
|
|
12
|
+
tirex/models/tirex.py,sha256=wtdjrdE1TuoueIvhlKf-deLH3oKyQuGmd2k15My7SWA,9710
|
|
13
|
+
tirex/models/trainer.py,sha256=lkNSt5aHzjoAQTRVPj6QlFVSP81jWbkMcKl-2FsPNhQ,6680
|
|
14
|
+
tirex/models/base/base_classifier.py,sha256=QU_eiwTe8B3QRT6Xs83utnnvlo6iaTzIv8VCId_S900,1310
|
|
15
|
+
tirex/models/base/base_regressor.py,sha256=RA9FfRgJeAuE_pbaP_4pGIiYWRmMs4zgkb9ajqAYfsg,786
|
|
16
|
+
tirex/models/base/base_tirex.py,sha256=JS82k0qzi19LNl3aW45ty4AQAIZktvZZCuT9lhMlpcg,3646
|
|
17
|
+
tirex/models/classification/__init__.py,sha256=uBoN5BvZV_bHDFFQpOQ8E0CKAyU7BYkmkQndb48aqTk,361
|
|
18
|
+
tirex/models/classification/gbm_classifier.py,sha256=ylcMJKC63ps4h-qoVFB7JaEzyuS8X3Mqt2B-gMYzN4I,7523
|
|
19
|
+
tirex/models/classification/linear_classifier.py,sha256=I7etrhctJgjTRJISIXVs2v_efXkJvTz3G9YOuAIcL5E,11366
|
|
20
|
+
tirex/models/classification/rf_classifier.py,sha256=jo9otjwhMhNkKCmVUdDbkvb2gbCLt5eTYMKclWyeQks,5036
|
|
21
|
+
tirex/models/regression/__init__.py,sha256=GjIcpJXA1n-RzP3JjwLknIu1FgEQX6hh4C6U1fHLoL0,352
|
|
22
|
+
tirex/models/regression/gbm_regressor.py,sha256=8RVk4gmQ3o0N58zDE11CnNZIAWZa1HQrw234-2nHu5o,7162
|
|
23
|
+
tirex/models/regression/linear_regressor.py,sha256=n4Nr6HtTQVQ-T45yMCcOuk9Nj3kGt5V9vyX5q2UkQ_k,10152
|
|
24
|
+
tirex/models/regression/rf_regressor.py,sha256=03c39NCDKmHZTB7yfytOXREqF1Lfdcc0sWV2FJQ0Ouw,4985
|
|
25
|
+
tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
|
|
26
|
+
tirex/models/slstm/cell.py,sha256=Otyil_AjpJbUckkINWGHxlqP14J5epm_J_zdWPzvD2g,7290
|
|
27
|
+
tirex/models/slstm/layer.py,sha256=hrDydQJIAHf5W0A0Rt0hXG4yKXrOSY-HPL0UbigR6Q8,2867
|
|
28
|
+
tirex_mirror-2025.12.16.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
29
|
+
tirex_mirror-2025.12.16.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
30
|
+
tirex_mirror-2025.12.16.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
31
|
+
tirex_mirror-2025.12.16.dist-info/METADATA,sha256=z_2g2gtkAJMFpzCnIxTVA1R0k10-Fnq20VpocikYpBM,11911
|
|
32
|
+
tirex_mirror-2025.12.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
33
|
+
tirex_mirror-2025.12.16.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
34
|
+
tirex_mirror-2025.12.16.dist-info/RECORD,,
|
|
@@ -1,84 +0,0 @@
|
|
|
1
|
-
# Copyright (c) NXAI GmbH.
|
|
2
|
-
# This software may be used and distributed according to the terms of the NXAI Community License Agreement.
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import torch
|
|
6
|
-
from sklearn.model_selection import train_test_split
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
# Remove after Issue will be solved: https://github.com/pytorch/pytorch/issues/61474
|
|
10
|
-
def nanmax(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
11
|
-
min_value = torch.finfo(tensor.dtype).min
|
|
12
|
-
output = tensor.nan_to_num(min_value).max(dim=dim, keepdim=keepdim)
|
|
13
|
-
return output.values
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def nanmin(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
17
|
-
max_value = torch.finfo(tensor.dtype).max
|
|
18
|
-
output = tensor.nan_to_num(max_value).min(dim=dim, keepdim=keepdim)
|
|
19
|
-
return output.values
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def nanvar(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
23
|
-
tensor_mean = tensor.nanmean(dim=dim, keepdim=True)
|
|
24
|
-
output = (tensor - tensor_mean).square().nanmean(dim=dim, keepdim=keepdim)
|
|
25
|
-
return output
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def nanstd(tensor: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor:
|
|
29
|
-
output = nanvar(tensor, dim=dim, keepdim=keepdim)
|
|
30
|
-
output = output.sqrt()
|
|
31
|
-
return output
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def train_val_split(
|
|
35
|
-
train_data: tuple[torch.Tensor, torch.Tensor],
|
|
36
|
-
val_split_ratio: float,
|
|
37
|
-
stratify: bool,
|
|
38
|
-
seed: int | None,
|
|
39
|
-
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
40
|
-
idx_train, idx_val = train_test_split(
|
|
41
|
-
np.arange(len(train_data[0])),
|
|
42
|
-
test_size=val_split_ratio,
|
|
43
|
-
random_state=seed,
|
|
44
|
-
shuffle=True,
|
|
45
|
-
stratify=train_data[1] if stratify else None,
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
return (
|
|
49
|
-
(train_data[0][idx_train], train_data[1][idx_train]),
|
|
50
|
-
(train_data[0][idx_val], train_data[1][idx_val]),
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def set_seed(seed: int) -> None:
|
|
55
|
-
torch.manual_seed(seed)
|
|
56
|
-
if torch.cuda.is_available():
|
|
57
|
-
torch.cuda.manual_seed(seed)
|
|
58
|
-
np.random.seed(seed)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
class EarlyStopping:
|
|
62
|
-
def __init__(
|
|
63
|
-
self,
|
|
64
|
-
patience: int = 7,
|
|
65
|
-
delta: float = 0.0001,
|
|
66
|
-
) -> None:
|
|
67
|
-
self.patience: int = patience
|
|
68
|
-
self.delta: float = delta
|
|
69
|
-
|
|
70
|
-
self.best: float = np.inf
|
|
71
|
-
self.wait_count: int = 0
|
|
72
|
-
self.early_stop: bool = False
|
|
73
|
-
|
|
74
|
-
def __call__(self, epoch: int, val_loss: float) -> bool:
|
|
75
|
-
improved = val_loss < (self.best - self.delta)
|
|
76
|
-
if improved:
|
|
77
|
-
self.best = val_loss
|
|
78
|
-
self.wait_count = 0
|
|
79
|
-
else:
|
|
80
|
-
self.wait_count += 1
|
|
81
|
-
if self.wait_count >= self.patience:
|
|
82
|
-
self.early_stop = True
|
|
83
|
-
print(f"Early stopping triggered at epoch {epoch}.")
|
|
84
|
-
return self.early_stop
|
|
@@ -1,29 +0,0 @@
|
|
|
1
|
-
tirex/__init__.py,sha256=rfsOeCJ7eRqU3K3TOhfN5-4XUuZFqt11wBRxk5SoAWA,292
|
|
2
|
-
tirex/base.py,sha256=pGShyI3LwPAwkl0rIYqik5qb8uBAWqbQrVoVjq7UH_8,5096
|
|
3
|
-
tirex/util.py,sha256=ggfETirJ589Dr5o3QThxTnjhgNnMCk10bNoJghnoeoA,31672
|
|
4
|
-
tirex/api_adapter/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
5
|
-
tirex/api_adapter/forecast.py,sha256=FnRgdI_4vJ7iFTqyWwxPE_C7MBB3nhLTnVMzTr5pihc,17581
|
|
6
|
-
tirex/api_adapter/gluon.py,sha256=0KfKX7dGSooVIUumGtjK6Va2YgI6Isa_8kz9j0oPXUM,1883
|
|
7
|
-
tirex/api_adapter/hf_data.py,sha256=TRyys2xKIGZS0Yhq2Eb61lWCMg5CWWn1yRlLIN1mU7o,1369
|
|
8
|
-
tirex/api_adapter/standard_adapter.py,sha256=vdlxNs8mTUtPgK_5WMqYqNdMj8W44igqWsAgtggt_xk,2809
|
|
9
|
-
tirex/models/__init__.py,sha256=YnTtPf5jGqvhfqoX8Ku7Yd0xohy0MmocE2ryrXVnQ1Q,135
|
|
10
|
-
tirex/models/patcher.py,sha256=8T4c3PZnOAsEpahhrjtt7S7405WUjN6g3cV33E55PD4,1911
|
|
11
|
-
tirex/models/tirex.py,sha256=wtdjrdE1TuoueIvhlKf-deLH3oKyQuGmd2k15My7SWA,9710
|
|
12
|
-
tirex/models/classification/__init__.py,sha256=0PKH8Cr-VWT4x4SXzGAyj90jMJbxzlIyrBIiV1ZXfm4,377
|
|
13
|
-
tirex/models/classification/embedding.py,sha256=wxIefIpccHe64Cb2JlmL0l0PhqYB_zdBHSh5auPqdlY,5452
|
|
14
|
-
tirex/models/classification/trainer.py,sha256=CbYfdTNdj5iucchxNwgg35EF-MzrdjrMm9auNOLae4U,5895
|
|
15
|
-
tirex/models/classification/utils.py,sha256=JbloQu4KlBHfdlH1NYYg7_VIoH-t2HBo5ovdyAEoCGs,2735
|
|
16
|
-
tirex/models/classification/heads/base_classifier.py,sha256=mBJV7jShBAUQgMpidow8lVfqhufNHvEIwDnidtxA1Hw,3236
|
|
17
|
-
tirex/models/classification/heads/gbm_classifier.py,sha256=V1CzfmJF6Tp2j14F3dAmRZ4wxBsuKoK1jQu4HbjttCc,8365
|
|
18
|
-
tirex/models/classification/heads/linear_classifier.py,sha256=2BI2GqQvBiHjFJ7_dS8rGQdQ-pnpF7-9sZv_YRUL-8w,11351
|
|
19
|
-
tirex/models/classification/heads/rf_classifier.py,sha256=Nsw3mxToa804txYzkvkRQwq34whEb5x-9_6ywskylBk,5063
|
|
20
|
-
tirex/models/slstm/block.py,sha256=V91Amgz8WAOOHo4fK1UZxd4Dgbx4-X6kUBS6X4m0tKQ,2006
|
|
21
|
-
tirex/models/slstm/cell.py,sha256=Otyil_AjpJbUckkINWGHxlqP14J5epm_J_zdWPzvD2g,7290
|
|
22
|
-
tirex/models/slstm/layer.py,sha256=hrDydQJIAHf5W0A0Rt0hXG4yKXrOSY-HPL0UbigR6Q8,2867
|
|
23
|
-
tirex_mirror-2025.12.2.dist-info/licenses/LICENSE,sha256=HlwHKnGTlE2oNm6734V-Vy62zlkWohnuZpYXSdkqDk4,7362
|
|
24
|
-
tirex_mirror-2025.12.2.dist-info/licenses/LICENSE_MIRROR.txt,sha256=ulPZMcOZdN7JvISjiID3KUwovTjrPwiMv5ku9dM7nls,496
|
|
25
|
-
tirex_mirror-2025.12.2.dist-info/licenses/NOTICE.txt,sha256=rcgDscFHb-uuZO3L0_vIxYhTYl-a2Rm0lBpp3_kKdFQ,147
|
|
26
|
-
tirex_mirror-2025.12.2.dist-info/METADATA,sha256=XU9yHQPFUYg5qiqMSNoOJRGmSE9dY3GUhGfn8m4LouM,11783
|
|
27
|
-
tirex_mirror-2025.12.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
28
|
-
tirex_mirror-2025.12.2.dist-info/top_level.txt,sha256=AOLDhfv0F_7nn3pFq0Kapg6Ky_28I_cGDXzQX3w9eO4,6
|
|
29
|
-
tirex_mirror-2025.12.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
{tirex_mirror-2025.12.2.dist-info → tirex_mirror-2025.12.16.dist-info}/licenses/LICENSE_MIRROR.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|