torchtextclassifiers 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +68 -0
- torchTextClassifiers/classifiers/base.py +83 -0
- torchTextClassifiers/classifiers/fasttext/__init__.py +25 -0
- torchTextClassifiers/classifiers/fasttext/core.py +269 -0
- torchTextClassifiers/classifiers/fasttext/model.py +752 -0
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +346 -0
- torchTextClassifiers/classifiers/fasttext/wrapper.py +216 -0
- torchTextClassifiers/classifiers/simple_text_classifier.py +191 -0
- torchTextClassifiers/factories.py +34 -0
- torchTextClassifiers/torchTextClassifiers.py +509 -0
- torchTextClassifiers/utilities/__init__.py +3 -0
- torchTextClassifiers/utilities/checkers.py +108 -0
- torchTextClassifiers/utilities/preprocess.py +82 -0
- torchTextClassifiers/utilities/utils.py +346 -0
- torchtextclassifiers-0.0.1.dist-info/METADATA +187 -0
- torchtextclassifiers-0.0.1.dist-info/RECORD +17 -0
- torchtextclassifiers-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
import json
|
|
4
|
+
from typing import Optional, Union, Type, List, Dict, Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytorch_lightning as pl
|
|
8
|
+
import torch
|
|
9
|
+
from pytorch_lightning.callbacks import (
|
|
10
|
+
EarlyStopping,
|
|
11
|
+
LearningRateMonitor,
|
|
12
|
+
ModelCheckpoint,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from .utilities.checkers import check_X, check_Y, NumpyJSONEncoder
|
|
16
|
+
from .classifiers.base import BaseClassifierConfig, BaseClassifierWrapper
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
logging.basicConfig(
|
|
22
|
+
level=logging.INFO,
|
|
23
|
+
format="%(asctime)s - %(name)s - %(message)s",
|
|
24
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
25
|
+
handlers=[logging.StreamHandler()],
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class torchTextClassifiers:
|
|
32
|
+
"""Generic text classifier framework supporting multiple architectures.
|
|
33
|
+
|
|
34
|
+
This is the main class that provides a unified interface for different types
|
|
35
|
+
of text classifiers. It acts as a high-level wrapper that delegates operations
|
|
36
|
+
to specific classifier implementations while providing a consistent API.
|
|
37
|
+
|
|
38
|
+
The class supports the full machine learning workflow including:
|
|
39
|
+
- Building tokenizers from training data
|
|
40
|
+
- Model training with validation
|
|
41
|
+
- Prediction and evaluation
|
|
42
|
+
- Model serialization and loading
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
config: Configuration object specific to the classifier type
|
|
46
|
+
classifier: The underlying classifier implementation
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
>>> from torchTextClassifiers import torchTextClassifiers
|
|
50
|
+
>>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
|
|
51
|
+
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
52
|
+
>>>
|
|
53
|
+
>>> # Create configuration
|
|
54
|
+
>>> config = FastTextConfig(
|
|
55
|
+
... embedding_dim=100,
|
|
56
|
+
... num_tokens=10000,
|
|
57
|
+
... min_count=1,
|
|
58
|
+
... min_n=3,
|
|
59
|
+
... max_n=6,
|
|
60
|
+
... len_word_ngrams=2,
|
|
61
|
+
... num_classes=2
|
|
62
|
+
... )
|
|
63
|
+
>>>
|
|
64
|
+
>>> # Initialize classifier with wrapper
|
|
65
|
+
>>> wrapper = FastTextWrapper(config)
|
|
66
|
+
>>> classifier = torchTextClassifiers(wrapper)
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Build and train
|
|
69
|
+
>>> classifier.build(X_train, y_train)
|
|
70
|
+
>>> classifier.train(X_train, y_train, X_val, y_val, num_epochs=10, batch_size=32)
|
|
71
|
+
>>>
|
|
72
|
+
>>> # Predict
|
|
73
|
+
>>> predictions = classifier.predict(X_test)
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(self, classifier: BaseClassifierWrapper):
|
|
77
|
+
"""Initialize the torchTextClassifiers instance.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
classifier: An instance of a classifier wrapper that implements BaseClassifierWrapper
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
84
|
+
>>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
|
|
85
|
+
>>> config = FastTextConfig(embedding_dim=50, num_tokens=5000)
|
|
86
|
+
>>> wrapper = FastTextWrapper(config)
|
|
87
|
+
>>> classifier = torchTextClassifiers(wrapper)
|
|
88
|
+
"""
|
|
89
|
+
self.classifier = classifier
|
|
90
|
+
self.config = classifier.config
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def build_tokenizer(self, training_text: np.ndarray) -> None:
|
|
94
|
+
"""Build tokenizer from training text data.
|
|
95
|
+
|
|
96
|
+
This method is kept for backward compatibility. It delegates to
|
|
97
|
+
prepare_text_features which handles the actual text preprocessing.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
training_text: Array of text strings to build the tokenizer from
|
|
101
|
+
|
|
102
|
+
Example:
|
|
103
|
+
>>> import numpy as np
|
|
104
|
+
>>> texts = np.array(["Hello world", "This is a test", "Another example"])
|
|
105
|
+
>>> classifier.build_tokenizer(texts)
|
|
106
|
+
"""
|
|
107
|
+
self.classifier.prepare_text_features(training_text)
|
|
108
|
+
|
|
109
|
+
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
110
|
+
"""Prepare text features for the classifier.
|
|
111
|
+
|
|
112
|
+
This method handles text preprocessing which could involve tokenization,
|
|
113
|
+
vectorization, or other approaches depending on the classifier type.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
training_text: Array of text strings to prepare features from
|
|
117
|
+
|
|
118
|
+
Example:
|
|
119
|
+
>>> import numpy as np
|
|
120
|
+
>>> texts = np.array(["Hello world", "This is a test", "Another example"])
|
|
121
|
+
>>> classifier.prepare_text_features(texts)
|
|
122
|
+
"""
|
|
123
|
+
self.classifier.prepare_text_features(training_text)
|
|
124
|
+
|
|
125
|
+
def build(
|
|
126
|
+
self,
|
|
127
|
+
X_train: np.ndarray,
|
|
128
|
+
y_train: np.ndarray = None,
|
|
129
|
+
lightning=True,
|
|
130
|
+
**kwargs
|
|
131
|
+
) -> None:
|
|
132
|
+
"""Build the complete classifier from training data.
|
|
133
|
+
|
|
134
|
+
This method handles the full model building process including:
|
|
135
|
+
- Input validation and preprocessing
|
|
136
|
+
- Tokenizer creation from training text
|
|
137
|
+
- Model architecture initialization
|
|
138
|
+
- Lightning module setup (if enabled)
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
X_train: Training input data (text and optional categorical features)
|
|
142
|
+
y_train: Training labels (optional, can be inferred if num_classes is set)
|
|
143
|
+
lightning: Whether to initialize PyTorch Lightning components
|
|
144
|
+
**kwargs: Additional arguments passed to Lightning initialization
|
|
145
|
+
|
|
146
|
+
Raises:
|
|
147
|
+
ValueError: If y_train is None and num_classes is not set in config
|
|
148
|
+
ValueError: If label values are outside expected range
|
|
149
|
+
|
|
150
|
+
Example:
|
|
151
|
+
>>> X_train = np.array(["text sample 1", "text sample 2"])
|
|
152
|
+
>>> y_train = np.array([0, 1])
|
|
153
|
+
>>> classifier.build(X_train, y_train)
|
|
154
|
+
"""
|
|
155
|
+
training_text, categorical_variables, no_cat_var = check_X(X_train)
|
|
156
|
+
|
|
157
|
+
if y_train is not None:
|
|
158
|
+
if self.config.num_classes is not None:
|
|
159
|
+
if self.config.num_classes != len(np.unique(y_train)):
|
|
160
|
+
logger.warning(
|
|
161
|
+
f"Updating num_classes from {self.config.num_classes} to {len(np.unique(y_train))}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
y_train = check_Y(y_train)
|
|
165
|
+
self.config.num_classes = len(np.unique(y_train))
|
|
166
|
+
|
|
167
|
+
if np.max(y_train) >= self.config.num_classes:
|
|
168
|
+
raise ValueError(
|
|
169
|
+
"y_train must contain values between 0 and num_classes-1"
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
if self.config.num_classes is None:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
"Either num_classes must be provided at init or y_train must be provided here."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Handle categorical variables
|
|
178
|
+
if not no_cat_var:
|
|
179
|
+
if hasattr(self.config, 'num_categorical_features') and self.config.num_categorical_features is not None:
|
|
180
|
+
if self.config.num_categorical_features != categorical_variables.shape[1]:
|
|
181
|
+
logger.warning(
|
|
182
|
+
f"Updating num_categorical_features from {self.config.num_categorical_features} to {categorical_variables.shape[1]}"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
if hasattr(self.config, 'num_categorical_features'):
|
|
186
|
+
self.config.num_categorical_features = categorical_variables.shape[1]
|
|
187
|
+
|
|
188
|
+
categorical_vocabulary_sizes = np.max(categorical_variables, axis=0) + 1
|
|
189
|
+
|
|
190
|
+
if hasattr(self.config, 'categorical_vocabulary_sizes') and self.config.categorical_vocabulary_sizes is not None:
|
|
191
|
+
if self.config.categorical_vocabulary_sizes != list(categorical_vocabulary_sizes):
|
|
192
|
+
logger.warning(
|
|
193
|
+
"Overwriting categorical_vocabulary_sizes with values from training data."
|
|
194
|
+
)
|
|
195
|
+
if hasattr(self.config, 'categorical_vocabulary_sizes'):
|
|
196
|
+
self.config.categorical_vocabulary_sizes = list(categorical_vocabulary_sizes)
|
|
197
|
+
|
|
198
|
+
self.classifier.prepare_text_features(training_text)
|
|
199
|
+
self.classifier._build_pytorch_model()
|
|
200
|
+
|
|
201
|
+
if lightning:
|
|
202
|
+
self.classifier._check_and_init_lightning(**kwargs)
|
|
203
|
+
|
|
204
|
+
def train(
|
|
205
|
+
self,
|
|
206
|
+
X_train: np.ndarray,
|
|
207
|
+
y_train: np.ndarray,
|
|
208
|
+
X_val: np.ndarray,
|
|
209
|
+
y_val: np.ndarray,
|
|
210
|
+
num_epochs: int,
|
|
211
|
+
batch_size: int,
|
|
212
|
+
cpu_run: bool = False,
|
|
213
|
+
num_workers: int = 12,
|
|
214
|
+
patience_train: int = 3,
|
|
215
|
+
verbose: bool = False,
|
|
216
|
+
trainer_params: Optional[dict] = None,
|
|
217
|
+
**kwargs
|
|
218
|
+
) -> None:
|
|
219
|
+
"""Train the classifier using PyTorch Lightning.
|
|
220
|
+
|
|
221
|
+
This method handles the complete training process including:
|
|
222
|
+
- Data validation and preprocessing
|
|
223
|
+
- Dataset and DataLoader creation
|
|
224
|
+
- PyTorch Lightning trainer setup with callbacks
|
|
225
|
+
- Model training with early stopping
|
|
226
|
+
- Best model loading after training
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
X_train: Training input data
|
|
230
|
+
y_train: Training labels
|
|
231
|
+
X_val: Validation input data
|
|
232
|
+
y_val: Validation labels
|
|
233
|
+
num_epochs: Maximum number of training epochs
|
|
234
|
+
batch_size: Batch size for training and validation
|
|
235
|
+
cpu_run: If True, force training on CPU instead of GPU
|
|
236
|
+
num_workers: Number of worker processes for data loading
|
|
237
|
+
patience_train: Number of epochs to wait for improvement before early stopping
|
|
238
|
+
verbose: If True, print detailed training progress
|
|
239
|
+
trainer_params: Additional parameters to pass to PyTorch Lightning Trainer
|
|
240
|
+
**kwargs: Additional arguments passed to the build method
|
|
241
|
+
|
|
242
|
+
Example:
|
|
243
|
+
>>> classifier.train(
|
|
244
|
+
... X_train, y_train, X_val, y_val,
|
|
245
|
+
... num_epochs=50,
|
|
246
|
+
... batch_size=32,
|
|
247
|
+
... patience_train=5,
|
|
248
|
+
... verbose=True
|
|
249
|
+
... )
|
|
250
|
+
"""
|
|
251
|
+
# Input validation
|
|
252
|
+
training_text, train_categorical_variables, train_no_cat_var = check_X(X_train)
|
|
253
|
+
val_text, val_categorical_variables, val_no_cat_var = check_X(X_val)
|
|
254
|
+
y_train = check_Y(y_train)
|
|
255
|
+
y_val = check_Y(y_val)
|
|
256
|
+
|
|
257
|
+
# Consistency checks
|
|
258
|
+
assert train_no_cat_var == val_no_cat_var, (
|
|
259
|
+
"X_train and X_val must have the same number of categorical variables."
|
|
260
|
+
)
|
|
261
|
+
assert X_train.shape[0] == y_train.shape[0], (
|
|
262
|
+
"X_train and y_train must have the same number of observations."
|
|
263
|
+
)
|
|
264
|
+
assert X_train.ndim > 1 and X_train.shape[1] == X_val.shape[1] or X_val.ndim == 1, (
|
|
265
|
+
"X_train and X_val must have the same number of columns."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if verbose:
|
|
269
|
+
logger.info("Starting training process...")
|
|
270
|
+
|
|
271
|
+
# Device setup
|
|
272
|
+
if cpu_run:
|
|
273
|
+
device = torch.device("cpu")
|
|
274
|
+
else:
|
|
275
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
276
|
+
|
|
277
|
+
self.classifier.device = device
|
|
278
|
+
|
|
279
|
+
if verbose:
|
|
280
|
+
logger.info(f"Running on: {device}")
|
|
281
|
+
|
|
282
|
+
# Build model if not already built
|
|
283
|
+
if self.classifier.pytorch_model is None:
|
|
284
|
+
if verbose:
|
|
285
|
+
start = time.time()
|
|
286
|
+
logger.info("Building the model...")
|
|
287
|
+
self.build(X_train, y_train, **kwargs)
|
|
288
|
+
if verbose:
|
|
289
|
+
end = time.time()
|
|
290
|
+
logger.info(f"Model built in {end - start:.2f} seconds.")
|
|
291
|
+
|
|
292
|
+
self.classifier.pytorch_model = self.classifier.pytorch_model.to(device)
|
|
293
|
+
|
|
294
|
+
# Create datasets and dataloaders using wrapper methods
|
|
295
|
+
train_dataset = self.classifier.create_dataset(
|
|
296
|
+
texts=training_text,
|
|
297
|
+
labels=y_train,
|
|
298
|
+
categorical_variables=train_categorical_variables,
|
|
299
|
+
)
|
|
300
|
+
val_dataset = self.classifier.create_dataset(
|
|
301
|
+
texts=val_text,
|
|
302
|
+
labels=y_val,
|
|
303
|
+
categorical_variables=val_categorical_variables,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
train_dataloader = self.classifier.create_dataloader(
|
|
307
|
+
dataset=train_dataset,
|
|
308
|
+
batch_size=batch_size,
|
|
309
|
+
num_workers=num_workers,
|
|
310
|
+
shuffle=True
|
|
311
|
+
)
|
|
312
|
+
val_dataloader = self.classifier.create_dataloader(
|
|
313
|
+
dataset=val_dataset,
|
|
314
|
+
batch_size=batch_size,
|
|
315
|
+
num_workers=num_workers,
|
|
316
|
+
shuffle=False
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Setup trainer
|
|
320
|
+
callbacks = [
|
|
321
|
+
ModelCheckpoint(
|
|
322
|
+
monitor="val_loss",
|
|
323
|
+
save_top_k=1,
|
|
324
|
+
save_last=False,
|
|
325
|
+
mode="min",
|
|
326
|
+
),
|
|
327
|
+
EarlyStopping(
|
|
328
|
+
monitor="val_loss",
|
|
329
|
+
patience=patience_train,
|
|
330
|
+
mode="min",
|
|
331
|
+
),
|
|
332
|
+
LearningRateMonitor(logging_interval="step"),
|
|
333
|
+
]
|
|
334
|
+
|
|
335
|
+
train_params = {
|
|
336
|
+
"callbacks": callbacks,
|
|
337
|
+
"max_epochs": num_epochs,
|
|
338
|
+
"num_sanity_val_steps": 2,
|
|
339
|
+
"strategy": "auto",
|
|
340
|
+
"log_every_n_steps": 1,
|
|
341
|
+
"enable_progress_bar": True,
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
if trainer_params is not None:
|
|
345
|
+
train_params.update(trainer_params)
|
|
346
|
+
|
|
347
|
+
trainer = pl.Trainer(**train_params)
|
|
348
|
+
|
|
349
|
+
torch.cuda.empty_cache()
|
|
350
|
+
torch.set_float32_matmul_precision("medium")
|
|
351
|
+
|
|
352
|
+
if verbose:
|
|
353
|
+
logger.info("Launching training...")
|
|
354
|
+
start = time.time()
|
|
355
|
+
|
|
356
|
+
trainer.fit(self.classifier.lightning_module, train_dataloader, val_dataloader)
|
|
357
|
+
|
|
358
|
+
if verbose:
|
|
359
|
+
end = time.time()
|
|
360
|
+
logger.info(f"Training completed in {end - start:.2f} seconds.")
|
|
361
|
+
|
|
362
|
+
# Load best model using wrapper method
|
|
363
|
+
best_model_path = trainer.checkpoint_callback.best_model_path
|
|
364
|
+
self.classifier.load_best_model(best_model_path)
|
|
365
|
+
|
|
366
|
+
def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
|
|
367
|
+
"""Make predictions on input data.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
X: Input data for prediction (text and optional categorical features)
|
|
371
|
+
**kwargs: Additional arguments passed to the underlying predictor
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
np.ndarray: Predicted class labels
|
|
375
|
+
|
|
376
|
+
Example:
|
|
377
|
+
>>> X_test = np.array(["new text sample", "another sample"])
|
|
378
|
+
>>> predictions = classifier.predict(X_test)
|
|
379
|
+
>>> print(predictions) # [0, 1]
|
|
380
|
+
"""
|
|
381
|
+
return self.classifier.predict(X, **kwargs)
|
|
382
|
+
|
|
383
|
+
def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
|
384
|
+
"""Validate the model on test data.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
X: Input data for validation
|
|
388
|
+
Y: True labels for validation
|
|
389
|
+
**kwargs: Additional arguments passed to the validator
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
float: Validation accuracy score
|
|
393
|
+
|
|
394
|
+
Example:
|
|
395
|
+
>>> accuracy = classifier.validate(X_test, y_test)
|
|
396
|
+
>>> print(f"Accuracy: {accuracy:.3f}")
|
|
397
|
+
"""
|
|
398
|
+
return self.classifier.validate(X, Y, **kwargs)
|
|
399
|
+
|
|
400
|
+
def predict_and_explain(self, X: np.ndarray, **kwargs):
|
|
401
|
+
"""Make predictions with explanations (if supported).
|
|
402
|
+
|
|
403
|
+
This method provides both predictions and explanations for the model's
|
|
404
|
+
decisions. Availability depends on the specific classifier implementation.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
X: Input data for prediction
|
|
408
|
+
**kwargs: Additional arguments passed to the explainer
|
|
409
|
+
|
|
410
|
+
Returns:
|
|
411
|
+
tuple: (predictions, explanations) where explanations format depends
|
|
412
|
+
on the classifier type
|
|
413
|
+
|
|
414
|
+
Raises:
|
|
415
|
+
NotImplementedError: If the classifier doesn't support explanations
|
|
416
|
+
|
|
417
|
+
Example:
|
|
418
|
+
>>> predictions, explanations = classifier.predict_and_explain(X_test)
|
|
419
|
+
>>> print(f"Predictions: {predictions}")
|
|
420
|
+
>>> print(f"Explanations: {explanations}")
|
|
421
|
+
"""
|
|
422
|
+
if hasattr(self.classifier, 'predict_and_explain'):
|
|
423
|
+
return self.classifier.predict_and_explain(X, **kwargs)
|
|
424
|
+
else:
|
|
425
|
+
raise NotImplementedError(f"Explanation not supported for {type(self.classifier).__name__}")
|
|
426
|
+
|
|
427
|
+
def to_json(self, filepath: str) -> None:
|
|
428
|
+
"""Save classifier configuration to JSON file.
|
|
429
|
+
|
|
430
|
+
This method serializes the classifier configuration to a JSON
|
|
431
|
+
file. Note: This only saves configuration, not trained model weights.
|
|
432
|
+
Custom classifier wrappers should implement a class method `get_wrapper_class_info()`
|
|
433
|
+
that returns a dict with 'module' and 'class_name' keys for proper reconstruction.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
filepath: Path where to save the JSON configuration file
|
|
437
|
+
|
|
438
|
+
Example:
|
|
439
|
+
>>> classifier.to_json('my_classifier_config.json')
|
|
440
|
+
"""
|
|
441
|
+
with open(filepath, "w") as f:
|
|
442
|
+
data = {
|
|
443
|
+
"config": self.config.to_dict(),
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
# Try to get wrapper class info for reconstruction
|
|
447
|
+
if hasattr(self.classifier.__class__, 'get_wrapper_class_info'):
|
|
448
|
+
data["wrapper_class_info"] = self.classifier.__class__.get_wrapper_class_info()
|
|
449
|
+
else:
|
|
450
|
+
# Fallback: store module and class name
|
|
451
|
+
data["wrapper_class_info"] = {
|
|
452
|
+
"module": self.classifier.__class__.__module__,
|
|
453
|
+
"class_name": self.classifier.__class__.__name__
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
json.dump(data, f, cls=NumpyJSONEncoder, indent=4)
|
|
457
|
+
|
|
458
|
+
@classmethod
|
|
459
|
+
def from_json(cls, filepath: str, wrapper_class: Optional[Type[BaseClassifierWrapper]] = None) -> "torchTextClassifiers":
|
|
460
|
+
"""Load classifier configuration from JSON file.
|
|
461
|
+
|
|
462
|
+
This method creates a new classifier instance from a previously saved
|
|
463
|
+
configuration file. The classifier will need to be built and trained again.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
filepath: Path to the JSON configuration file
|
|
467
|
+
wrapper_class: Optional wrapper class to use. If not provided, will try to
|
|
468
|
+
reconstruct from saved wrapper_class_info
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
torchTextClassifiers: New classifier instance with loaded configuration
|
|
472
|
+
|
|
473
|
+
Raises:
|
|
474
|
+
ImportError: If the wrapper class cannot be imported
|
|
475
|
+
FileNotFoundError: If the configuration file doesn't exist
|
|
476
|
+
|
|
477
|
+
Example:
|
|
478
|
+
>>> # Using saved wrapper class info
|
|
479
|
+
>>> classifier = torchTextClassifiers.from_json('my_classifier_config.json')
|
|
480
|
+
>>>
|
|
481
|
+
>>> # Or providing wrapper class explicitly
|
|
482
|
+
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
483
|
+
>>> classifier = torchTextClassifiers.from_json('config.json', FastTextWrapper)
|
|
484
|
+
"""
|
|
485
|
+
with open(filepath, "r") as f:
|
|
486
|
+
data = json.load(f)
|
|
487
|
+
|
|
488
|
+
if wrapper_class is None:
|
|
489
|
+
# Try to reconstruct wrapper class from saved info
|
|
490
|
+
if "wrapper_class_info" not in data:
|
|
491
|
+
raise ValueError("No wrapper_class_info found in config file and no wrapper_class provided")
|
|
492
|
+
|
|
493
|
+
wrapper_info = data["wrapper_class_info"]
|
|
494
|
+
module_name = wrapper_info["module"]
|
|
495
|
+
class_name = wrapper_info["class_name"]
|
|
496
|
+
|
|
497
|
+
# Dynamically import the wrapper class
|
|
498
|
+
import importlib
|
|
499
|
+
module = importlib.import_module(module_name)
|
|
500
|
+
wrapper_class = getattr(module, class_name)
|
|
501
|
+
|
|
502
|
+
# Reconstruct config using wrapper class's config class
|
|
503
|
+
config_class = wrapper_class.get_config_class()
|
|
504
|
+
config = config_class.from_dict(data["config"])
|
|
505
|
+
|
|
506
|
+
# Create wrapper instance
|
|
507
|
+
wrapper = wrapper_class(config)
|
|
508
|
+
|
|
509
|
+
return cls(wrapper)
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import json
|
|
3
|
+
from typing import Optional, Union, Type, List
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def check_X(X):
|
|
11
|
+
assert isinstance(X, np.ndarray), (
|
|
12
|
+
"X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables."
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
if X.ndim > 1:
|
|
17
|
+
text = X[:, 0].astype(str)
|
|
18
|
+
else:
|
|
19
|
+
text = X[:].astype(str)
|
|
20
|
+
except ValueError:
|
|
21
|
+
logger.error("The first column of X must be castable in string format.")
|
|
22
|
+
|
|
23
|
+
if len(X.shape) == 1 or (len(X.shape) == 2 and X.shape[1] == 1):
|
|
24
|
+
no_cat_var = True
|
|
25
|
+
else:
|
|
26
|
+
no_cat_var = False
|
|
27
|
+
|
|
28
|
+
if not no_cat_var:
|
|
29
|
+
try:
|
|
30
|
+
categorical_variables = X[:, 1:].astype(int)
|
|
31
|
+
except ValueError:
|
|
32
|
+
logger.error(
|
|
33
|
+
f"Columns {1} to {X.shape[1] - 1} of X_train must be castable in integer format."
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
categorical_variables = None
|
|
37
|
+
|
|
38
|
+
return text, categorical_variables, no_cat_var
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def check_Y(Y):
|
|
42
|
+
assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
|
|
43
|
+
assert len(Y.shape) == 1 or (len(Y.shape) == 2 and Y.shape[1] == 1), (
|
|
44
|
+
"Y must be a numpy array of shape (N,) or (N,1)."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
try:
|
|
48
|
+
Y = Y.astype(int)
|
|
49
|
+
except ValueError:
|
|
50
|
+
logger.error("Y must be castable in integer format.")
|
|
51
|
+
|
|
52
|
+
return Y
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def validate_categorical_inputs(
|
|
56
|
+
categorical_vocabulary_sizes: List[int],
|
|
57
|
+
categorical_embedding_dims: Union[List[int], int],
|
|
58
|
+
num_categorical_features: int = None,
|
|
59
|
+
):
|
|
60
|
+
if categorical_vocabulary_sizes is None:
|
|
61
|
+
logger.warning("No categorical_vocabulary_sizes. It will be inferred later.")
|
|
62
|
+
return None, None, None
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
if not isinstance(categorical_vocabulary_sizes, list):
|
|
66
|
+
raise TypeError("categorical_vocabulary_sizes must be a list of int")
|
|
67
|
+
|
|
68
|
+
if isinstance(categorical_embedding_dims, list):
|
|
69
|
+
if len(categorical_vocabulary_sizes) != len(categorical_embedding_dims):
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"Categorical vocabulary sizes and their embedding dimensions must have the same length"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if num_categorical_features is not None:
|
|
75
|
+
if len(categorical_vocabulary_sizes) != num_categorical_features:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"len(categorical_vocabulary_sizes) must be equal to num_categorical_features"
|
|
78
|
+
)
|
|
79
|
+
else:
|
|
80
|
+
num_categorical_features = len(categorical_vocabulary_sizes)
|
|
81
|
+
|
|
82
|
+
assert num_categorical_features is not None, (
|
|
83
|
+
"num_categorical_features should be inferred at this point."
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# "Transform" embedding dims into a suitable list, or stay None
|
|
87
|
+
if categorical_embedding_dims is not None:
|
|
88
|
+
if isinstance(categorical_embedding_dims, int):
|
|
89
|
+
categorical_embedding_dims = [categorical_embedding_dims] * num_categorical_features
|
|
90
|
+
elif not isinstance(categorical_embedding_dims, list):
|
|
91
|
+
raise TypeError("categorical_embedding_dims must be an int or a list of int")
|
|
92
|
+
|
|
93
|
+
assert isinstance(categorical_embedding_dims, list) or categorical_embedding_dims is None, (
|
|
94
|
+
"categorical_embedding_dims must be a list of int at this point"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
return categorical_vocabulary_sizes, categorical_embedding_dims, num_categorical_features
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class NumpyJSONEncoder(json.JSONEncoder):
|
|
101
|
+
def default(self, obj):
|
|
102
|
+
if isinstance(obj, np.integer):
|
|
103
|
+
return int(obj)
|
|
104
|
+
if isinstance(obj, np.floating):
|
|
105
|
+
return float(obj)
|
|
106
|
+
if isinstance(obj, np.ndarray):
|
|
107
|
+
return obj.tolist()
|
|
108
|
+
return super().default(obj)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Processing fns.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import string
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import nltk
|
|
11
|
+
from nltk.corpus import stopwords as ntlk_stopwords
|
|
12
|
+
from nltk.stem.snowball import SnowballStemmer
|
|
13
|
+
|
|
14
|
+
HAS_NLTK = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
HAS_NLTK = False
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import unidecode
|
|
20
|
+
|
|
21
|
+
HAS_UNIDECODE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
HAS_UNIDECODE = False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def clean_text_feature(text: list[str], remove_stop_words=True):
|
|
27
|
+
"""
|
|
28
|
+
Cleans a text feature.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
text (list[str]): List of text descriptions.
|
|
32
|
+
remove_stop_words (bool): If True, remove stopwords.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
list[str]: List of cleaned text descriptions.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
if not HAS_NLTK:
|
|
39
|
+
raise ImportError(
|
|
40
|
+
"nltk is not installed and is required for preprocessing. Run 'pip install torchFastText[preprocess]'."
|
|
41
|
+
)
|
|
42
|
+
if not HAS_UNIDECODE:
|
|
43
|
+
raise ImportError(
|
|
44
|
+
"unidecode is not installed and is required for preprocessing. Run 'pip install torchFastText[preprocess]'."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Define stopwords and stemmer
|
|
48
|
+
|
|
49
|
+
nltk.download("stopwords", quiet=True)
|
|
50
|
+
stopwords = tuple(ntlk_stopwords.words("french")) + tuple(string.ascii_lowercase)
|
|
51
|
+
stemmer = SnowballStemmer(language="french")
|
|
52
|
+
|
|
53
|
+
# Remove of accented characters
|
|
54
|
+
text = np.vectorize(unidecode.unidecode)(np.array(text))
|
|
55
|
+
|
|
56
|
+
# To lowercase
|
|
57
|
+
text = np.char.lower(text)
|
|
58
|
+
|
|
59
|
+
# Remove one letter words
|
|
60
|
+
def mylambda(x):
|
|
61
|
+
return " ".join([w for w in x.split() if len(w) > 1])
|
|
62
|
+
|
|
63
|
+
text = np.vectorize(mylambda)(text)
|
|
64
|
+
|
|
65
|
+
# Remove duplicate words and stopwords in texts
|
|
66
|
+
# Stem words
|
|
67
|
+
libs_token = [lib.split() for lib in text.tolist()]
|
|
68
|
+
libs_token = [
|
|
69
|
+
sorted(set(libs_token[i]), key=libs_token[i].index) for i in range(len(libs_token))
|
|
70
|
+
]
|
|
71
|
+
if remove_stop_words:
|
|
72
|
+
text = [
|
|
73
|
+
" ".join([stemmer.stem(word) for word in libs_token[i] if word not in stopwords])
|
|
74
|
+
for i in range(len(libs_token))
|
|
75
|
+
]
|
|
76
|
+
else:
|
|
77
|
+
text = [
|
|
78
|
+
" ".join([stemmer.stem(word) for word in libs_token[i]]) for i in range(len(libs_token))
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
# Return clean DataFrame
|
|
82
|
+
return text
|