torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +114 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +43 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +166 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +463 -405
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
- torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import time
|
|
3
|
-
import
|
|
4
|
-
from typing import
|
|
3
|
+
from dataclasses import asdict, dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from captum.attr import LayerIntegratedGradients
|
|
8
|
+
|
|
9
|
+
HAS_CAPTUM = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
HAS_CAPTUM = False
|
|
12
|
+
|
|
5
13
|
|
|
6
14
|
import numpy as np
|
|
7
15
|
import pytorch_lightning as pl
|
|
@@ -12,9 +20,17 @@ from pytorch_lightning.callbacks import (
|
|
|
12
20
|
ModelCheckpoint,
|
|
13
21
|
)
|
|
14
22
|
|
|
15
|
-
from .
|
|
16
|
-
from .
|
|
17
|
-
|
|
23
|
+
from torchTextClassifiers.dataset import TextClassificationDataset
|
|
24
|
+
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
|
|
25
|
+
from torchTextClassifiers.model.components import (
|
|
26
|
+
AttentionConfig,
|
|
27
|
+
CategoricalForwardType,
|
|
28
|
+
CategoricalVariableNet,
|
|
29
|
+
ClassificationHead,
|
|
30
|
+
TextEmbedder,
|
|
31
|
+
TextEmbedderConfig,
|
|
32
|
+
)
|
|
33
|
+
from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
|
|
18
34
|
|
|
19
35
|
logger = logging.getLogger(__name__)
|
|
20
36
|
|
|
@@ -26,296 +42,256 @@ logging.basicConfig(
|
|
|
26
42
|
)
|
|
27
43
|
|
|
28
44
|
|
|
45
|
+
@dataclass
|
|
46
|
+
class ModelConfig:
|
|
47
|
+
"""Base configuration class for text classifiers."""
|
|
48
|
+
|
|
49
|
+
embedding_dim: int
|
|
50
|
+
categorical_vocabulary_sizes: Optional[List[int]] = None
|
|
51
|
+
categorical_embedding_dims: Optional[Union[List[int], int]] = None
|
|
52
|
+
num_classes: Optional[int] = None
|
|
53
|
+
attention_config: Optional[AttentionConfig] = None
|
|
54
|
+
|
|
55
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
56
|
+
return asdict(self)
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
|
|
60
|
+
return cls(**data)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class TrainingConfig:
|
|
65
|
+
num_epochs: int
|
|
66
|
+
batch_size: int
|
|
67
|
+
lr: float
|
|
68
|
+
loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss())
|
|
69
|
+
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
70
|
+
scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None
|
|
71
|
+
accelerator: str = "auto"
|
|
72
|
+
num_workers: int = 12
|
|
73
|
+
patience_early_stopping: int = 3
|
|
74
|
+
dataloader_params: Optional[dict] = None
|
|
75
|
+
trainer_params: Optional[dict] = None
|
|
76
|
+
optimizer_params: Optional[dict] = None
|
|
77
|
+
scheduler_params: Optional[dict] = None
|
|
78
|
+
|
|
79
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
80
|
+
data = asdict(self)
|
|
81
|
+
# Serialize loss and scheduler as their class names
|
|
82
|
+
data["loss"] = self.loss.__class__.__name__
|
|
83
|
+
if self.scheduler is not None:
|
|
84
|
+
data["scheduler"] = self.scheduler.__name__
|
|
85
|
+
return data
|
|
29
86
|
|
|
30
87
|
|
|
31
88
|
class torchTextClassifiers:
|
|
32
89
|
"""Generic text classifier framework supporting multiple architectures.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
The
|
|
39
|
-
|
|
40
|
-
- Model training with validation
|
|
41
|
-
- Prediction and evaluation
|
|
42
|
-
- Model serialization and loading
|
|
43
|
-
|
|
44
|
-
Attributes:
|
|
45
|
-
config: Configuration object specific to the classifier type
|
|
46
|
-
classifier: The underlying classifier implementation
|
|
47
|
-
|
|
48
|
-
Example:
|
|
49
|
-
>>> from torchTextClassifiers import torchTextClassifiers
|
|
50
|
-
>>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
|
|
51
|
-
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
52
|
-
>>>
|
|
53
|
-
>>> # Create configuration
|
|
54
|
-
>>> config = FastTextConfig(
|
|
55
|
-
... embedding_dim=100,
|
|
56
|
-
... num_tokens=10000,
|
|
57
|
-
... min_count=1,
|
|
58
|
-
... min_n=3,
|
|
59
|
-
... max_n=6,
|
|
60
|
-
... len_word_ngrams=2,
|
|
61
|
-
... num_classes=2
|
|
62
|
-
... )
|
|
63
|
-
>>>
|
|
64
|
-
>>> # Initialize classifier with wrapper
|
|
65
|
-
>>> wrapper = FastTextWrapper(config)
|
|
66
|
-
>>> classifier = torchTextClassifiers(wrapper)
|
|
67
|
-
>>>
|
|
68
|
-
>>> # Build and train
|
|
69
|
-
>>> classifier.build(X_train, y_train)
|
|
70
|
-
>>> classifier.train(X_train, y_train, X_val, y_val, num_epochs=10, batch_size=32)
|
|
71
|
-
>>>
|
|
72
|
-
>>> # Predict
|
|
73
|
-
>>> predictions = classifier.predict(X_test)
|
|
90
|
+
|
|
91
|
+
Given a tokenizer and model configuration, this class initializes:
|
|
92
|
+
- Text embedding layer (if needed)
|
|
93
|
+
- Categorical variable embedding network (if categorical variables are provided)
|
|
94
|
+
- Classification head
|
|
95
|
+
The resulting model can be trained using PyTorch Lightning and used for predictions.
|
|
96
|
+
|
|
74
97
|
"""
|
|
75
|
-
|
|
76
|
-
def __init__(
|
|
77
|
-
"""Initialize the torchTextClassifiers instance.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
classifier: An instance of a classifier wrapper that implements BaseClassifierWrapper
|
|
81
|
-
|
|
82
|
-
Example:
|
|
83
|
-
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
84
|
-
>>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
|
|
85
|
-
>>> config = FastTextConfig(embedding_dim=50, num_tokens=5000)
|
|
86
|
-
>>> wrapper = FastTextWrapper(config)
|
|
87
|
-
>>> classifier = torchTextClassifiers(wrapper)
|
|
88
|
-
"""
|
|
89
|
-
self.classifier = classifier
|
|
90
|
-
self.config = classifier.config
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def build_tokenizer(self, training_text: np.ndarray) -> None:
|
|
94
|
-
"""Build tokenizer from training text data.
|
|
95
|
-
|
|
96
|
-
This method is kept for backward compatibility. It delegates to
|
|
97
|
-
prepare_text_features which handles the actual text preprocessing.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
training_text: Array of text strings to build the tokenizer from
|
|
101
|
-
|
|
102
|
-
Example:
|
|
103
|
-
>>> import numpy as np
|
|
104
|
-
>>> texts = np.array(["Hello world", "This is a test", "Another example"])
|
|
105
|
-
>>> classifier.build_tokenizer(texts)
|
|
106
|
-
"""
|
|
107
|
-
self.classifier.prepare_text_features(training_text)
|
|
108
|
-
|
|
109
|
-
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
110
|
-
"""Prepare text features for the classifier.
|
|
111
|
-
|
|
112
|
-
This method handles text preprocessing which could involve tokenization,
|
|
113
|
-
vectorization, or other approaches depending on the classifier type.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
training_text: Array of text strings to prepare features from
|
|
117
|
-
|
|
118
|
-
Example:
|
|
119
|
-
>>> import numpy as np
|
|
120
|
-
>>> texts = np.array(["Hello world", "This is a test", "Another example"])
|
|
121
|
-
>>> classifier.prepare_text_features(texts)
|
|
122
|
-
"""
|
|
123
|
-
self.classifier.prepare_text_features(training_text)
|
|
124
|
-
|
|
125
|
-
def build(
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
126
100
|
self,
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
"""Build the complete classifier from training data.
|
|
133
|
-
|
|
134
|
-
This method handles the full model building process including:
|
|
135
|
-
- Input validation and preprocessing
|
|
136
|
-
- Tokenizer creation from training text
|
|
137
|
-
- Model architecture initialization
|
|
138
|
-
- Lightning module setup (if enabled)
|
|
139
|
-
|
|
101
|
+
tokenizer: BaseTokenizer,
|
|
102
|
+
model_config: ModelConfig,
|
|
103
|
+
):
|
|
104
|
+
"""Initialize the torchTextClassifiers instance.
|
|
105
|
+
|
|
140
106
|
Args:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
**kwargs: Additional arguments passed to Lightning initialization
|
|
145
|
-
|
|
146
|
-
Raises:
|
|
147
|
-
ValueError: If y_train is None and num_classes is not set in config
|
|
148
|
-
ValueError: If label values are outside expected range
|
|
149
|
-
|
|
107
|
+
tokenizer: A tokenizer instance for text preprocessing
|
|
108
|
+
model_config: Configuration parameters for the text classification model
|
|
109
|
+
|
|
150
110
|
Example:
|
|
151
|
-
>>>
|
|
152
|
-
>>>
|
|
153
|
-
>>>
|
|
111
|
+
>>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
|
|
112
|
+
>>> # Assume tokenizer is a trained BaseTokenizer instance
|
|
113
|
+
>>> model_config = ModelConfig(
|
|
114
|
+
... embedding_dim=10,
|
|
115
|
+
... categorical_vocabulary_sizes=[30, 25],
|
|
116
|
+
... categorical_embedding_dims=[10, 5],
|
|
117
|
+
... num_classes=10,
|
|
118
|
+
... )
|
|
119
|
+
>>> ttc = torchTextClassifiers(
|
|
120
|
+
... tokenizer=tokenizer,
|
|
121
|
+
... model_config=model_config,
|
|
122
|
+
... )
|
|
154
123
|
"""
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
y_train = check_Y(y_train)
|
|
165
|
-
self.config.num_classes = len(np.unique(y_train))
|
|
166
|
-
|
|
167
|
-
if np.max(y_train) >= self.config.num_classes:
|
|
168
|
-
raise ValueError(
|
|
169
|
-
"y_train must contain values between 0 and num_classes-1"
|
|
124
|
+
|
|
125
|
+
self.model_config = model_config
|
|
126
|
+
self.tokenizer = tokenizer
|
|
127
|
+
|
|
128
|
+
if hasattr(self.tokenizer, "trained"):
|
|
129
|
+
if not self.tokenizer.trained:
|
|
130
|
+
raise RuntimeError(
|
|
131
|
+
f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier."
|
|
170
132
|
)
|
|
133
|
+
|
|
134
|
+
self.vocab_size = tokenizer.vocab_size
|
|
135
|
+
self.embedding_dim = model_config.embedding_dim
|
|
136
|
+
self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
|
|
137
|
+
self.num_classes = model_config.num_classes
|
|
138
|
+
|
|
139
|
+
if self.tokenizer.output_vectorized:
|
|
140
|
+
self.text_embedder = None
|
|
141
|
+
logger.info(
|
|
142
|
+
"Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization."
|
|
143
|
+
)
|
|
144
|
+
self.embedding_dim = self.tokenizer.output_dim
|
|
171
145
|
else:
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
if
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
self.
|
|
203
|
-
|
|
146
|
+
text_embedder_config = TextEmbedderConfig(
|
|
147
|
+
vocab_size=self.vocab_size,
|
|
148
|
+
embedding_dim=self.embedding_dim,
|
|
149
|
+
padding_idx=tokenizer.padding_idx,
|
|
150
|
+
attention_config=model_config.attention_config,
|
|
151
|
+
)
|
|
152
|
+
self.text_embedder = TextEmbedder(
|
|
153
|
+
text_embedder_config=text_embedder_config,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
classif_head_input_dim = self.embedding_dim
|
|
157
|
+
if self.categorical_vocabulary_sizes:
|
|
158
|
+
self.categorical_var_net = CategoricalVariableNet(
|
|
159
|
+
categorical_vocabulary_sizes=self.categorical_vocabulary_sizes,
|
|
160
|
+
categorical_embedding_dims=model_config.categorical_embedding_dims,
|
|
161
|
+
text_embedding_dim=self.embedding_dim,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT:
|
|
165
|
+
classif_head_input_dim += self.categorical_var_net.output_dim
|
|
166
|
+
|
|
167
|
+
else:
|
|
168
|
+
self.categorical_var_net = None
|
|
169
|
+
|
|
170
|
+
self.classification_head = ClassificationHead(
|
|
171
|
+
input_dim=classif_head_input_dim,
|
|
172
|
+
num_classes=model_config.num_classes,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
self.pytorch_model = TextClassificationModel(
|
|
176
|
+
text_embedder=self.text_embedder,
|
|
177
|
+
categorical_variable_net=self.categorical_var_net,
|
|
178
|
+
classification_head=self.classification_head,
|
|
179
|
+
)
|
|
180
|
+
|
|
204
181
|
def train(
|
|
205
182
|
self,
|
|
206
183
|
X_train: np.ndarray,
|
|
207
184
|
y_train: np.ndarray,
|
|
208
185
|
X_val: np.ndarray,
|
|
209
186
|
y_val: np.ndarray,
|
|
210
|
-
|
|
211
|
-
batch_size: int,
|
|
212
|
-
cpu_run: bool = False,
|
|
213
|
-
num_workers: int = 12,
|
|
214
|
-
patience_train: int = 3,
|
|
187
|
+
training_config: TrainingConfig,
|
|
215
188
|
verbose: bool = False,
|
|
216
|
-
trainer_params: Optional[dict] = None,
|
|
217
|
-
**kwargs
|
|
218
189
|
) -> None:
|
|
219
190
|
"""Train the classifier using PyTorch Lightning.
|
|
220
|
-
|
|
191
|
+
|
|
221
192
|
This method handles the complete training process including:
|
|
222
193
|
- Data validation and preprocessing
|
|
223
194
|
- Dataset and DataLoader creation
|
|
224
195
|
- PyTorch Lightning trainer setup with callbacks
|
|
225
196
|
- Model training with early stopping
|
|
226
197
|
- Best model loading after training
|
|
227
|
-
|
|
198
|
+
|
|
228
199
|
Args:
|
|
229
200
|
X_train: Training input data
|
|
230
201
|
y_train: Training labels
|
|
231
202
|
X_val: Validation input data
|
|
232
203
|
y_val: Validation labels
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
patience_train: Number of epochs to wait for improvement before early stopping
|
|
238
|
-
verbose: If True, print detailed training progress
|
|
239
|
-
trainer_params: Additional parameters to pass to PyTorch Lightning Trainer
|
|
240
|
-
**kwargs: Additional arguments passed to the build method
|
|
241
|
-
|
|
204
|
+
training_config: Configuration parameters for training
|
|
205
|
+
verbose: Whether to print training progress information
|
|
206
|
+
|
|
207
|
+
|
|
242
208
|
Example:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
209
|
+
|
|
210
|
+
>>> training_config = TrainingConfig(
|
|
211
|
+
... lr=1e-3,
|
|
212
|
+
... batch_size=4,
|
|
213
|
+
... num_epochs=1,
|
|
214
|
+
... )
|
|
215
|
+
>>> ttc.train(
|
|
216
|
+
... X_train=X,
|
|
217
|
+
... y_train=Y,
|
|
218
|
+
... X_val=X,
|
|
219
|
+
... y_val=Y,
|
|
220
|
+
... training_config=training_config,
|
|
221
|
+
... )
|
|
250
222
|
"""
|
|
251
223
|
# Input validation
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
)
|
|
267
|
-
|
|
224
|
+
X_train, y_train = self._check_XY(X_train, y_train)
|
|
225
|
+
X_val, y_val = self._check_XY(X_val, y_val)
|
|
226
|
+
|
|
227
|
+
if (
|
|
228
|
+
X_train["categorical_variables"] is not None
|
|
229
|
+
and X_val["categorical_variables"] is not None
|
|
230
|
+
):
|
|
231
|
+
assert (
|
|
232
|
+
X_train["categorical_variables"].ndim > 1
|
|
233
|
+
and X_train["categorical_variables"].shape[1]
|
|
234
|
+
== X_val["categorical_variables"].shape[1]
|
|
235
|
+
or X_val["categorical_variables"].ndim == 1
|
|
236
|
+
), "X_train and X_val must have the same number of columns."
|
|
237
|
+
|
|
268
238
|
if verbose:
|
|
269
239
|
logger.info("Starting training process...")
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
if cpu_run:
|
|
273
|
-
device = torch.device("cpu")
|
|
274
|
-
else:
|
|
240
|
+
|
|
241
|
+
if training_config.accelerator == "auto":
|
|
275
242
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
243
|
+
else:
|
|
244
|
+
device = torch.device(training_config.accelerator)
|
|
245
|
+
|
|
246
|
+
self.device = device
|
|
247
|
+
|
|
248
|
+
optimizer_params = {"lr": training_config.lr}
|
|
249
|
+
if training_config.optimizer_params is not None:
|
|
250
|
+
optimizer_params.update(training_config.optimizer_params)
|
|
251
|
+
|
|
252
|
+
self.lightning_module = TextClassificationModule(
|
|
253
|
+
model=self.pytorch_model,
|
|
254
|
+
loss=training_config.loss,
|
|
255
|
+
optimizer=training_config.optimizer,
|
|
256
|
+
optimizer_params=optimizer_params,
|
|
257
|
+
scheduler=training_config.scheduler,
|
|
258
|
+
scheduler_params=training_config.scheduler_params
|
|
259
|
+
if training_config.scheduler_params
|
|
260
|
+
else {},
|
|
261
|
+
scheduler_interval="epoch",
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
self.pytorch_model.to(self.device)
|
|
265
|
+
|
|
279
266
|
if verbose:
|
|
280
267
|
logger.info(f"Running on: {device}")
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
if
|
|
285
|
-
|
|
286
|
-
logger.info("Building the model...")
|
|
287
|
-
self.build(X_train, y_train, **kwargs)
|
|
288
|
-
if verbose:
|
|
289
|
-
end = time.time()
|
|
290
|
-
logger.info(f"Model built in {end - start:.2f} seconds.")
|
|
291
|
-
|
|
292
|
-
self.classifier.pytorch_model = self.classifier.pytorch_model.to(device)
|
|
293
|
-
|
|
294
|
-
# Create datasets and dataloaders using wrapper methods
|
|
295
|
-
train_dataset = self.classifier.create_dataset(
|
|
296
|
-
texts=training_text,
|
|
268
|
+
|
|
269
|
+
train_dataset = TextClassificationDataset(
|
|
270
|
+
texts=X_train["text"],
|
|
271
|
+
categorical_variables=X_train["categorical_variables"], # None if no cat vars
|
|
272
|
+
tokenizer=self.tokenizer,
|
|
297
273
|
labels=y_train,
|
|
298
|
-
categorical_variables=train_categorical_variables,
|
|
299
274
|
)
|
|
300
|
-
val_dataset =
|
|
301
|
-
texts=
|
|
275
|
+
val_dataset = TextClassificationDataset(
|
|
276
|
+
texts=X_val["text"],
|
|
277
|
+
categorical_variables=X_val["categorical_variables"], # None if no cat vars
|
|
278
|
+
tokenizer=self.tokenizer,
|
|
302
279
|
labels=y_val,
|
|
303
|
-
categorical_variables=val_categorical_variables,
|
|
304
280
|
)
|
|
305
|
-
|
|
306
|
-
train_dataloader =
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
281
|
+
|
|
282
|
+
train_dataloader = train_dataset.create_dataloader(
|
|
283
|
+
batch_size=training_config.batch_size,
|
|
284
|
+
num_workers=training_config.num_workers,
|
|
285
|
+
shuffle=True,
|
|
286
|
+
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
311
287
|
)
|
|
312
|
-
val_dataloader =
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
288
|
+
val_dataloader = val_dataset.create_dataloader(
|
|
289
|
+
batch_size=training_config.batch_size,
|
|
290
|
+
num_workers=training_config.num_workers,
|
|
291
|
+
shuffle=False,
|
|
292
|
+
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
317
293
|
)
|
|
318
|
-
|
|
294
|
+
|
|
319
295
|
# Setup trainer
|
|
320
296
|
callbacks = [
|
|
321
297
|
ModelCheckpoint(
|
|
@@ -326,184 +302,266 @@ class torchTextClassifiers:
|
|
|
326
302
|
),
|
|
327
303
|
EarlyStopping(
|
|
328
304
|
monitor="val_loss",
|
|
329
|
-
patience=
|
|
305
|
+
patience=training_config.patience_early_stopping,
|
|
330
306
|
mode="min",
|
|
331
307
|
),
|
|
332
308
|
LearningRateMonitor(logging_interval="step"),
|
|
333
309
|
]
|
|
334
|
-
|
|
335
|
-
|
|
310
|
+
|
|
311
|
+
trainer_params = {
|
|
312
|
+
"accelerator": training_config.accelerator,
|
|
336
313
|
"callbacks": callbacks,
|
|
337
|
-
"max_epochs": num_epochs,
|
|
314
|
+
"max_epochs": training_config.num_epochs,
|
|
338
315
|
"num_sanity_val_steps": 2,
|
|
339
316
|
"strategy": "auto",
|
|
340
317
|
"log_every_n_steps": 1,
|
|
341
318
|
"enable_progress_bar": True,
|
|
342
319
|
}
|
|
343
|
-
|
|
344
|
-
if trainer_params is not None:
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
trainer = pl.Trainer(**
|
|
348
|
-
|
|
320
|
+
|
|
321
|
+
if training_config.trainer_params is not None:
|
|
322
|
+
trainer_params.update(training_config.trainer_params)
|
|
323
|
+
|
|
324
|
+
trainer = pl.Trainer(**trainer_params)
|
|
325
|
+
|
|
349
326
|
torch.cuda.empty_cache()
|
|
350
327
|
torch.set_float32_matmul_precision("medium")
|
|
351
|
-
|
|
328
|
+
|
|
352
329
|
if verbose:
|
|
353
330
|
logger.info("Launching training...")
|
|
354
331
|
start = time.time()
|
|
355
|
-
|
|
356
|
-
trainer.fit(self.
|
|
357
|
-
|
|
332
|
+
|
|
333
|
+
trainer.fit(self.lightning_module, train_dataloader, val_dataloader)
|
|
334
|
+
|
|
358
335
|
if verbose:
|
|
359
336
|
end = time.time()
|
|
360
337
|
logger.info(f"Training completed in {end - start:.2f} seconds.")
|
|
361
|
-
|
|
362
|
-
# Load best model using wrapper method
|
|
338
|
+
|
|
363
339
|
best_model_path = trainer.checkpoint_callback.best_model_path
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
X
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
return
|
|
399
|
-
|
|
400
|
-
def
|
|
401
|
-
"""
|
|
402
|
-
|
|
403
|
-
This method provides both predictions and explanations for the model's
|
|
404
|
-
decisions. Availability depends on the specific classifier implementation.
|
|
405
|
-
|
|
340
|
+
|
|
341
|
+
self.lightning_module = TextClassificationModule.load_from_checkpoint(
|
|
342
|
+
best_model_path,
|
|
343
|
+
model=self.pytorch_model,
|
|
344
|
+
loss=training_config.loss,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
self.pytorch_model = self.lightning_module.model.to(self.device)
|
|
348
|
+
|
|
349
|
+
self.lightning_module.eval()
|
|
350
|
+
|
|
351
|
+
def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
352
|
+
X = self._check_X(X)
|
|
353
|
+
Y = self._check_Y(Y)
|
|
354
|
+
|
|
355
|
+
if X["text"].shape[0] != Y.shape[0]:
|
|
356
|
+
raise ValueError("X_train and y_train must have the same number of observations.")
|
|
357
|
+
|
|
358
|
+
return X, Y
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def _check_text_col(X):
|
|
362
|
+
assert isinstance(
|
|
363
|
+
X, np.ndarray
|
|
364
|
+
), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables."
|
|
365
|
+
|
|
366
|
+
try:
|
|
367
|
+
if X.ndim > 1:
|
|
368
|
+
text = X[:, 0].astype(str)
|
|
369
|
+
else:
|
|
370
|
+
text = X[:].astype(str)
|
|
371
|
+
except ValueError:
|
|
372
|
+
logger.error("The first column of X must be castable in string format.")
|
|
373
|
+
|
|
374
|
+
return text
|
|
375
|
+
|
|
376
|
+
def _check_categorical_variables(self, X: np.ndarray) -> None:
|
|
377
|
+
"""Check if categorical variables in X match training configuration.
|
|
378
|
+
|
|
406
379
|
Args:
|
|
407
|
-
X: Input data
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
Returns:
|
|
411
|
-
tuple: (predictions, explanations) where explanations format depends
|
|
412
|
-
on the classifier type
|
|
413
|
-
|
|
380
|
+
X: Input data to check
|
|
381
|
+
|
|
414
382
|
Raises:
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
Example:
|
|
418
|
-
>>> predictions, explanations = classifier.predict_and_explain(X_test)
|
|
419
|
-
>>> print(f"Predictions: {predictions}")
|
|
420
|
-
>>> print(f"Explanations: {explanations}")
|
|
383
|
+
ValueError: If the number of categorical variables does not match
|
|
384
|
+
the training configuration
|
|
421
385
|
"""
|
|
422
|
-
|
|
423
|
-
|
|
386
|
+
|
|
387
|
+
assert self.categorical_var_net is not None
|
|
388
|
+
|
|
389
|
+
if X.ndim > 1:
|
|
390
|
+
num_cat_vars = X.shape[1] - 1
|
|
424
391
|
else:
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
392
|
+
num_cat_vars = 0
|
|
393
|
+
|
|
394
|
+
if num_cat_vars != self.categorical_var_net.num_categorical_features:
|
|
395
|
+
raise ValueError(
|
|
396
|
+
f"X must have the same number of categorical variables as the number of embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})."
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
try:
|
|
400
|
+
categorical_variables = X[:, 1:].astype(int)
|
|
401
|
+
except ValueError:
|
|
402
|
+
logger.error(
|
|
403
|
+
f"Columns {1} to {X.shape[1] - 1} of X_train must be castable in integer format."
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
for j in range(X.shape[1] - 1):
|
|
407
|
+
max_cat_value = categorical_variables[:, j].max()
|
|
408
|
+
if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]:
|
|
409
|
+
raise ValueError(
|
|
410
|
+
f"Categorical variable at index {j} has value {max_cat_value} which exceeds the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}."
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
return categorical_variables
|
|
414
|
+
|
|
415
|
+
def _check_X(self, X: np.ndarray) -> np.ndarray:
|
|
416
|
+
text = self._check_text_col(X)
|
|
417
|
+
|
|
418
|
+
categorical_variables = None
|
|
419
|
+
if self.categorical_var_net is not None:
|
|
420
|
+
categorical_variables = self._check_categorical_variables(X)
|
|
421
|
+
|
|
422
|
+
return {"text": text, "categorical_variables": categorical_variables}
|
|
423
|
+
|
|
424
|
+
def _check_Y(self, Y):
|
|
425
|
+
assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
|
|
426
|
+
assert len(Y.shape) == 1 or (
|
|
427
|
+
len(Y.shape) == 2 and Y.shape[1] == 1
|
|
428
|
+
), "Y must be a numpy array of shape (N,) or (N,1)."
|
|
429
|
+
|
|
430
|
+
try:
|
|
431
|
+
Y = Y.astype(int)
|
|
432
|
+
except ValueError:
|
|
433
|
+
logger.error("Y must be castable in integer format.")
|
|
434
|
+
|
|
435
|
+
if Y.max() >= self.num_classes or Y.min() < 0:
|
|
436
|
+
raise ValueError(
|
|
437
|
+
f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
return Y
|
|
441
|
+
|
|
442
|
+
def predict(
|
|
443
|
+
self,
|
|
444
|
+
X_test: np.ndarray,
|
|
445
|
+
top_k=1,
|
|
446
|
+
explain=False,
|
|
447
|
+
):
|
|
440
448
|
"""
|
|
441
|
-
with open(filepath, "w") as f:
|
|
442
|
-
data = {
|
|
443
|
-
"config": self.config.to_dict(),
|
|
444
|
-
}
|
|
445
|
-
|
|
446
|
-
# Try to get wrapper class info for reconstruction
|
|
447
|
-
if hasattr(self.classifier.__class__, 'get_wrapper_class_info'):
|
|
448
|
-
data["wrapper_class_info"] = self.classifier.__class__.get_wrapper_class_info()
|
|
449
|
-
else:
|
|
450
|
-
# Fallback: store module and class name
|
|
451
|
-
data["wrapper_class_info"] = {
|
|
452
|
-
"module": self.classifier.__class__.__module__,
|
|
453
|
-
"class_name": self.classifier.__class__.__name__
|
|
454
|
-
}
|
|
455
|
-
|
|
456
|
-
json.dump(data, f, cls=NumpyJSONEncoder, indent=4)
|
|
457
|
-
|
|
458
|
-
@classmethod
|
|
459
|
-
def from_json(cls, filepath: str, wrapper_class: Optional[Type[BaseClassifierWrapper]] = None) -> "torchTextClassifiers":
|
|
460
|
-
"""Load classifier configuration from JSON file.
|
|
461
|
-
|
|
462
|
-
This method creates a new classifier instance from a previously saved
|
|
463
|
-
configuration file. The classifier will need to be built and trained again.
|
|
464
|
-
|
|
465
449
|
Args:
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
Returns:
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
FileNotFoundError: If the configuration file doesn't exist
|
|
476
|
-
|
|
477
|
-
Example:
|
|
478
|
-
>>> # Using saved wrapper class info
|
|
479
|
-
>>> classifier = torchTextClassifiers.from_json('my_classifier_config.json')
|
|
480
|
-
>>>
|
|
481
|
-
>>> # Or providing wrapper class explicitly
|
|
482
|
-
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
483
|
-
>>> classifier = torchTextClassifiers.from_json('config.json', FastTextWrapper)
|
|
450
|
+
X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
|
|
451
|
+
top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
|
|
452
|
+
explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
|
|
453
|
+
|
|
454
|
+
Returns: A dictionary containing the following fields:
|
|
455
|
+
- predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
|
|
456
|
+
- confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores.
|
|
457
|
+
- if explain is True:
|
|
458
|
+
- attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
|
|
484
459
|
"""
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
460
|
+
|
|
461
|
+
if explain:
|
|
462
|
+
return_offsets_mapping = True # to be passed to the tokenizer
|
|
463
|
+
return_word_ids = True
|
|
464
|
+
if self.pytorch_model.text_embedder is None:
|
|
465
|
+
raise RuntimeError(
|
|
466
|
+
"Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
|
|
467
|
+
)
|
|
468
|
+
else:
|
|
469
|
+
if not HAS_CAPTUM:
|
|
470
|
+
raise ImportError(
|
|
471
|
+
"Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
|
|
472
|
+
)
|
|
473
|
+
lig = LayerIntegratedGradients(
|
|
474
|
+
self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
|
|
475
|
+
) # initialize a Captum layer gradient integrator
|
|
476
|
+
else:
|
|
477
|
+
return_offsets_mapping = False
|
|
478
|
+
return_word_ids = False
|
|
479
|
+
|
|
480
|
+
X_test = self._check_X(X_test)
|
|
481
|
+
text = X_test["text"]
|
|
482
|
+
categorical_variables = X_test["categorical_variables"]
|
|
483
|
+
|
|
484
|
+
self.pytorch_model.eval().cpu()
|
|
485
|
+
|
|
486
|
+
tokenize_output = self.tokenizer.tokenize(
|
|
487
|
+
text.tolist(),
|
|
488
|
+
return_offsets_mapping=return_offsets_mapping,
|
|
489
|
+
return_word_ids=return_word_ids,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
if not isinstance(tokenize_output, TokenizerOutput):
|
|
493
|
+
raise TypeError(
|
|
494
|
+
f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method."
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
encoded_text = tokenize_output.input_ids # (batch_size, seq_len)
|
|
498
|
+
attention_mask = tokenize_output.attention_mask # (batch_size, seq_len)
|
|
499
|
+
|
|
500
|
+
if categorical_variables is not None:
|
|
501
|
+
categorical_vars = torch.tensor(
|
|
502
|
+
categorical_variables, dtype=torch.float32
|
|
503
|
+
) # (batch_size, num_categorical_features)
|
|
504
|
+
else:
|
|
505
|
+
categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32)
|
|
506
|
+
|
|
507
|
+
pred = self.pytorch_model(
|
|
508
|
+
encoded_text, attention_mask, categorical_vars
|
|
509
|
+
) # forward pass, contains the prediction scores (len(text), num_classes)
|
|
510
|
+
|
|
511
|
+
label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities
|
|
512
|
+
|
|
513
|
+
label_scores_topk = torch.topk(label_scores, k=top_k, dim=1)
|
|
514
|
+
|
|
515
|
+
predictions = label_scores_topk.indices # get the top_k most likely predictions
|
|
516
|
+
confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
|
|
517
|
+
|
|
518
|
+
if explain:
|
|
519
|
+
all_attributions = []
|
|
520
|
+
for k in range(top_k):
|
|
521
|
+
attributions = lig.attribute(
|
|
522
|
+
(encoded_text, attention_mask, categorical_vars),
|
|
523
|
+
target=torch.Tensor(predictions[:, k]).long(),
|
|
524
|
+
) # (batch_size, seq_len)
|
|
525
|
+
attributions = attributions.sum(dim=-1)
|
|
526
|
+
all_attributions.append(attributions.detach().cpu())
|
|
527
|
+
|
|
528
|
+
all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
|
|
529
|
+
|
|
530
|
+
return {
|
|
531
|
+
"prediction": predictions,
|
|
532
|
+
"confidence": confidence,
|
|
533
|
+
"attributions": all_attributions,
|
|
534
|
+
"offset_mapping": tokenize_output.offset_mapping,
|
|
535
|
+
"word_ids": tokenize_output.word_ids,
|
|
536
|
+
}
|
|
537
|
+
else:
|
|
538
|
+
return {
|
|
539
|
+
"prediction": predictions,
|
|
540
|
+
"confidence": confidence,
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
def __repr__(self):
|
|
544
|
+
model_type = (
|
|
545
|
+
self.lightning_module.__repr__()
|
|
546
|
+
if hasattr(self, "lightning_module")
|
|
547
|
+
else self.pytorch_model.__repr__()
|
|
548
|
+
)
|
|
549
|
+
|
|
550
|
+
tokenizer_info = self.tokenizer.__repr__()
|
|
551
|
+
|
|
552
|
+
cat_forward_type = (
|
|
553
|
+
self.categorical_var_net.forward_type.name
|
|
554
|
+
if self.categorical_var_net is not None
|
|
555
|
+
else "None"
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
lines = [
|
|
559
|
+
"torchTextClassifiers(",
|
|
560
|
+
f" tokenizer = {tokenizer_info},",
|
|
561
|
+
f" model = {model_type},",
|
|
562
|
+
f" categorical_forward_type = {cat_forward_type},",
|
|
563
|
+
f" num_classes = {self.model_config.num_classes},",
|
|
564
|
+
f" embedding_dim = {self.embedding_dim},",
|
|
565
|
+
")",
|
|
566
|
+
]
|
|
567
|
+
return "\n".join(lines)
|