torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. torchTextClassifiers/__init__.py +12 -48
  2. torchTextClassifiers/dataset/__init__.py +1 -0
  3. torchTextClassifiers/dataset/dataset.py +114 -0
  4. torchTextClassifiers/model/__init__.py +2 -0
  5. torchTextClassifiers/model/components/__init__.py +12 -0
  6. torchTextClassifiers/model/components/attention.py +126 -0
  7. torchTextClassifiers/model/components/categorical_var_net.py +128 -0
  8. torchTextClassifiers/model/components/classification_head.py +43 -0
  9. torchTextClassifiers/model/components/text_embedder.py +220 -0
  10. torchTextClassifiers/model/lightning.py +166 -0
  11. torchTextClassifiers/model/model.py +151 -0
  12. torchTextClassifiers/tokenizers/WordPiece.py +92 -0
  13. torchTextClassifiers/tokenizers/__init__.py +10 -0
  14. torchTextClassifiers/tokenizers/base.py +205 -0
  15. torchTextClassifiers/tokenizers/ngram.py +472 -0
  16. torchTextClassifiers/torchTextClassifiers.py +463 -405
  17. torchTextClassifiers/utilities/__init__.py +0 -3
  18. torchTextClassifiers/utilities/plot_explainability.py +184 -0
  19. torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
  20. torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
  21. {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
  22. torchTextClassifiers/classifiers/base.py +0 -83
  23. torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
  24. torchTextClassifiers/classifiers/fasttext/core.py +0 -269
  25. torchTextClassifiers/classifiers/fasttext/model.py +0 -752
  26. torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
  27. torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
  28. torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
  29. torchTextClassifiers/factories.py +0 -34
  30. torchTextClassifiers/utilities/checkers.py +0 -108
  31. torchTextClassifiers/utilities/preprocess.py +0 -82
  32. torchTextClassifiers/utilities/utils.py +0 -346
  33. torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
  34. torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
@@ -1,7 +1,15 @@
1
1
  import logging
2
2
  import time
3
- import json
4
- from typing import Optional, Union, Type, List, Dict, Any
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
5
+
6
+ try:
7
+ from captum.attr import LayerIntegratedGradients
8
+
9
+ HAS_CAPTUM = True
10
+ except ImportError:
11
+ HAS_CAPTUM = False
12
+
5
13
 
6
14
  import numpy as np
7
15
  import pytorch_lightning as pl
@@ -12,9 +20,17 @@ from pytorch_lightning.callbacks import (
12
20
  ModelCheckpoint,
13
21
  )
14
22
 
15
- from .utilities.checkers import check_X, check_Y, NumpyJSONEncoder
16
- from .classifiers.base import BaseClassifierConfig, BaseClassifierWrapper
17
-
23
+ from torchTextClassifiers.dataset import TextClassificationDataset
24
+ from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
25
+ from torchTextClassifiers.model.components import (
26
+ AttentionConfig,
27
+ CategoricalForwardType,
28
+ CategoricalVariableNet,
29
+ ClassificationHead,
30
+ TextEmbedder,
31
+ TextEmbedderConfig,
32
+ )
33
+ from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
18
34
 
19
35
  logger = logging.getLogger(__name__)
20
36
 
@@ -26,296 +42,256 @@ logging.basicConfig(
26
42
  )
27
43
 
28
44
 
45
+ @dataclass
46
+ class ModelConfig:
47
+ """Base configuration class for text classifiers."""
48
+
49
+ embedding_dim: int
50
+ categorical_vocabulary_sizes: Optional[List[int]] = None
51
+ categorical_embedding_dims: Optional[Union[List[int], int]] = None
52
+ num_classes: Optional[int] = None
53
+ attention_config: Optional[AttentionConfig] = None
54
+
55
+ def to_dict(self) -> Dict[str, Any]:
56
+ return asdict(self)
57
+
58
+ @classmethod
59
+ def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
60
+ return cls(**data)
61
+
62
+
63
+ @dataclass
64
+ class TrainingConfig:
65
+ num_epochs: int
66
+ batch_size: int
67
+ lr: float
68
+ loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss())
69
+ optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
70
+ scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None
71
+ accelerator: str = "auto"
72
+ num_workers: int = 12
73
+ patience_early_stopping: int = 3
74
+ dataloader_params: Optional[dict] = None
75
+ trainer_params: Optional[dict] = None
76
+ optimizer_params: Optional[dict] = None
77
+ scheduler_params: Optional[dict] = None
78
+
79
+ def to_dict(self) -> Dict[str, Any]:
80
+ data = asdict(self)
81
+ # Serialize loss and scheduler as their class names
82
+ data["loss"] = self.loss.__class__.__name__
83
+ if self.scheduler is not None:
84
+ data["scheduler"] = self.scheduler.__name__
85
+ return data
29
86
 
30
87
 
31
88
  class torchTextClassifiers:
32
89
  """Generic text classifier framework supporting multiple architectures.
33
-
34
- This is the main class that provides a unified interface for different types
35
- of text classifiers. It acts as a high-level wrapper that delegates operations
36
- to specific classifier implementations while providing a consistent API.
37
-
38
- The class supports the full machine learning workflow including:
39
- - Building tokenizers from training data
40
- - Model training with validation
41
- - Prediction and evaluation
42
- - Model serialization and loading
43
-
44
- Attributes:
45
- config: Configuration object specific to the classifier type
46
- classifier: The underlying classifier implementation
47
-
48
- Example:
49
- >>> from torchTextClassifiers import torchTextClassifiers
50
- >>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
51
- >>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
52
- >>>
53
- >>> # Create configuration
54
- >>> config = FastTextConfig(
55
- ... embedding_dim=100,
56
- ... num_tokens=10000,
57
- ... min_count=1,
58
- ... min_n=3,
59
- ... max_n=6,
60
- ... len_word_ngrams=2,
61
- ... num_classes=2
62
- ... )
63
- >>>
64
- >>> # Initialize classifier with wrapper
65
- >>> wrapper = FastTextWrapper(config)
66
- >>> classifier = torchTextClassifiers(wrapper)
67
- >>>
68
- >>> # Build and train
69
- >>> classifier.build(X_train, y_train)
70
- >>> classifier.train(X_train, y_train, X_val, y_val, num_epochs=10, batch_size=32)
71
- >>>
72
- >>> # Predict
73
- >>> predictions = classifier.predict(X_test)
90
+
91
+ Given a tokenizer and model configuration, this class initializes:
92
+ - Text embedding layer (if needed)
93
+ - Categorical variable embedding network (if categorical variables are provided)
94
+ - Classification head
95
+ The resulting model can be trained using PyTorch Lightning and used for predictions.
96
+
74
97
  """
75
-
76
- def __init__(self, classifier: BaseClassifierWrapper):
77
- """Initialize the torchTextClassifiers instance.
78
-
79
- Args:
80
- classifier: An instance of a classifier wrapper that implements BaseClassifierWrapper
81
-
82
- Example:
83
- >>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
84
- >>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
85
- >>> config = FastTextConfig(embedding_dim=50, num_tokens=5000)
86
- >>> wrapper = FastTextWrapper(config)
87
- >>> classifier = torchTextClassifiers(wrapper)
88
- """
89
- self.classifier = classifier
90
- self.config = classifier.config
91
-
92
-
93
- def build_tokenizer(self, training_text: np.ndarray) -> None:
94
- """Build tokenizer from training text data.
95
-
96
- This method is kept for backward compatibility. It delegates to
97
- prepare_text_features which handles the actual text preprocessing.
98
-
99
- Args:
100
- training_text: Array of text strings to build the tokenizer from
101
-
102
- Example:
103
- >>> import numpy as np
104
- >>> texts = np.array(["Hello world", "This is a test", "Another example"])
105
- >>> classifier.build_tokenizer(texts)
106
- """
107
- self.classifier.prepare_text_features(training_text)
108
-
109
- def prepare_text_features(self, training_text: np.ndarray) -> None:
110
- """Prepare text features for the classifier.
111
-
112
- This method handles text preprocessing which could involve tokenization,
113
- vectorization, or other approaches depending on the classifier type.
114
-
115
- Args:
116
- training_text: Array of text strings to prepare features from
117
-
118
- Example:
119
- >>> import numpy as np
120
- >>> texts = np.array(["Hello world", "This is a test", "Another example"])
121
- >>> classifier.prepare_text_features(texts)
122
- """
123
- self.classifier.prepare_text_features(training_text)
124
-
125
- def build(
98
+
99
+ def __init__(
126
100
  self,
127
- X_train: np.ndarray,
128
- y_train: np.ndarray = None,
129
- lightning=True,
130
- **kwargs
131
- ) -> None:
132
- """Build the complete classifier from training data.
133
-
134
- This method handles the full model building process including:
135
- - Input validation and preprocessing
136
- - Tokenizer creation from training text
137
- - Model architecture initialization
138
- - Lightning module setup (if enabled)
139
-
101
+ tokenizer: BaseTokenizer,
102
+ model_config: ModelConfig,
103
+ ):
104
+ """Initialize the torchTextClassifiers instance.
105
+
140
106
  Args:
141
- X_train: Training input data (text and optional categorical features)
142
- y_train: Training labels (optional, can be inferred if num_classes is set)
143
- lightning: Whether to initialize PyTorch Lightning components
144
- **kwargs: Additional arguments passed to Lightning initialization
145
-
146
- Raises:
147
- ValueError: If y_train is None and num_classes is not set in config
148
- ValueError: If label values are outside expected range
149
-
107
+ tokenizer: A tokenizer instance for text preprocessing
108
+ model_config: Configuration parameters for the text classification model
109
+
150
110
  Example:
151
- >>> X_train = np.array(["text sample 1", "text sample 2"])
152
- >>> y_train = np.array([0, 1])
153
- >>> classifier.build(X_train, y_train)
111
+ >>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
112
+ >>> # Assume tokenizer is a trained BaseTokenizer instance
113
+ >>> model_config = ModelConfig(
114
+ ... embedding_dim=10,
115
+ ... categorical_vocabulary_sizes=[30, 25],
116
+ ... categorical_embedding_dims=[10, 5],
117
+ ... num_classes=10,
118
+ ... )
119
+ >>> ttc = torchTextClassifiers(
120
+ ... tokenizer=tokenizer,
121
+ ... model_config=model_config,
122
+ ... )
154
123
  """
155
- training_text, categorical_variables, no_cat_var = check_X(X_train)
156
-
157
- if y_train is not None:
158
- if self.config.num_classes is not None:
159
- if self.config.num_classes != len(np.unique(y_train)):
160
- logger.warning(
161
- f"Updating num_classes from {self.config.num_classes} to {len(np.unique(y_train))}"
162
- )
163
-
164
- y_train = check_Y(y_train)
165
- self.config.num_classes = len(np.unique(y_train))
166
-
167
- if np.max(y_train) >= self.config.num_classes:
168
- raise ValueError(
169
- "y_train must contain values between 0 and num_classes-1"
124
+
125
+ self.model_config = model_config
126
+ self.tokenizer = tokenizer
127
+
128
+ if hasattr(self.tokenizer, "trained"):
129
+ if not self.tokenizer.trained:
130
+ raise RuntimeError(
131
+ f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier."
170
132
  )
133
+
134
+ self.vocab_size = tokenizer.vocab_size
135
+ self.embedding_dim = model_config.embedding_dim
136
+ self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
137
+ self.num_classes = model_config.num_classes
138
+
139
+ if self.tokenizer.output_vectorized:
140
+ self.text_embedder = None
141
+ logger.info(
142
+ "Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization."
143
+ )
144
+ self.embedding_dim = self.tokenizer.output_dim
171
145
  else:
172
- if self.config.num_classes is None:
173
- raise ValueError(
174
- "Either num_classes must be provided at init or y_train must be provided here."
175
- )
176
-
177
- # Handle categorical variables
178
- if not no_cat_var:
179
- if hasattr(self.config, 'num_categorical_features') and self.config.num_categorical_features is not None:
180
- if self.config.num_categorical_features != categorical_variables.shape[1]:
181
- logger.warning(
182
- f"Updating num_categorical_features from {self.config.num_categorical_features} to {categorical_variables.shape[1]}"
183
- )
184
-
185
- if hasattr(self.config, 'num_categorical_features'):
186
- self.config.num_categorical_features = categorical_variables.shape[1]
187
-
188
- categorical_vocabulary_sizes = np.max(categorical_variables, axis=0) + 1
189
-
190
- if hasattr(self.config, 'categorical_vocabulary_sizes') and self.config.categorical_vocabulary_sizes is not None:
191
- if self.config.categorical_vocabulary_sizes != list(categorical_vocabulary_sizes):
192
- logger.warning(
193
- "Overwriting categorical_vocabulary_sizes with values from training data."
194
- )
195
- if hasattr(self.config, 'categorical_vocabulary_sizes'):
196
- self.config.categorical_vocabulary_sizes = list(categorical_vocabulary_sizes)
197
-
198
- self.classifier.prepare_text_features(training_text)
199
- self.classifier._build_pytorch_model()
200
-
201
- if lightning:
202
- self.classifier._check_and_init_lightning(**kwargs)
203
-
146
+ text_embedder_config = TextEmbedderConfig(
147
+ vocab_size=self.vocab_size,
148
+ embedding_dim=self.embedding_dim,
149
+ padding_idx=tokenizer.padding_idx,
150
+ attention_config=model_config.attention_config,
151
+ )
152
+ self.text_embedder = TextEmbedder(
153
+ text_embedder_config=text_embedder_config,
154
+ )
155
+
156
+ classif_head_input_dim = self.embedding_dim
157
+ if self.categorical_vocabulary_sizes:
158
+ self.categorical_var_net = CategoricalVariableNet(
159
+ categorical_vocabulary_sizes=self.categorical_vocabulary_sizes,
160
+ categorical_embedding_dims=model_config.categorical_embedding_dims,
161
+ text_embedding_dim=self.embedding_dim,
162
+ )
163
+
164
+ if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT:
165
+ classif_head_input_dim += self.categorical_var_net.output_dim
166
+
167
+ else:
168
+ self.categorical_var_net = None
169
+
170
+ self.classification_head = ClassificationHead(
171
+ input_dim=classif_head_input_dim,
172
+ num_classes=model_config.num_classes,
173
+ )
174
+
175
+ self.pytorch_model = TextClassificationModel(
176
+ text_embedder=self.text_embedder,
177
+ categorical_variable_net=self.categorical_var_net,
178
+ classification_head=self.classification_head,
179
+ )
180
+
204
181
  def train(
205
182
  self,
206
183
  X_train: np.ndarray,
207
184
  y_train: np.ndarray,
208
185
  X_val: np.ndarray,
209
186
  y_val: np.ndarray,
210
- num_epochs: int,
211
- batch_size: int,
212
- cpu_run: bool = False,
213
- num_workers: int = 12,
214
- patience_train: int = 3,
187
+ training_config: TrainingConfig,
215
188
  verbose: bool = False,
216
- trainer_params: Optional[dict] = None,
217
- **kwargs
218
189
  ) -> None:
219
190
  """Train the classifier using PyTorch Lightning.
220
-
191
+
221
192
  This method handles the complete training process including:
222
193
  - Data validation and preprocessing
223
194
  - Dataset and DataLoader creation
224
195
  - PyTorch Lightning trainer setup with callbacks
225
196
  - Model training with early stopping
226
197
  - Best model loading after training
227
-
198
+
228
199
  Args:
229
200
  X_train: Training input data
230
201
  y_train: Training labels
231
202
  X_val: Validation input data
232
203
  y_val: Validation labels
233
- num_epochs: Maximum number of training epochs
234
- batch_size: Batch size for training and validation
235
- cpu_run: If True, force training on CPU instead of GPU
236
- num_workers: Number of worker processes for data loading
237
- patience_train: Number of epochs to wait for improvement before early stopping
238
- verbose: If True, print detailed training progress
239
- trainer_params: Additional parameters to pass to PyTorch Lightning Trainer
240
- **kwargs: Additional arguments passed to the build method
241
-
204
+ training_config: Configuration parameters for training
205
+ verbose: Whether to print training progress information
206
+
207
+
242
208
  Example:
243
- >>> classifier.train(
244
- ... X_train, y_train, X_val, y_val,
245
- ... num_epochs=50,
246
- ... batch_size=32,
247
- ... patience_train=5,
248
- ... verbose=True
249
- ... )
209
+
210
+ >>> training_config = TrainingConfig(
211
+ ... lr=1e-3,
212
+ ... batch_size=4,
213
+ ... num_epochs=1,
214
+ ... )
215
+ >>> ttc.train(
216
+ ... X_train=X,
217
+ ... y_train=Y,
218
+ ... X_val=X,
219
+ ... y_val=Y,
220
+ ... training_config=training_config,
221
+ ... )
250
222
  """
251
223
  # Input validation
252
- training_text, train_categorical_variables, train_no_cat_var = check_X(X_train)
253
- val_text, val_categorical_variables, val_no_cat_var = check_X(X_val)
254
- y_train = check_Y(y_train)
255
- y_val = check_Y(y_val)
256
-
257
- # Consistency checks
258
- assert train_no_cat_var == val_no_cat_var, (
259
- "X_train and X_val must have the same number of categorical variables."
260
- )
261
- assert X_train.shape[0] == y_train.shape[0], (
262
- "X_train and y_train must have the same number of observations."
263
- )
264
- assert X_train.ndim > 1 and X_train.shape[1] == X_val.shape[1] or X_val.ndim == 1, (
265
- "X_train and X_val must have the same number of columns."
266
- )
267
-
224
+ X_train, y_train = self._check_XY(X_train, y_train)
225
+ X_val, y_val = self._check_XY(X_val, y_val)
226
+
227
+ if (
228
+ X_train["categorical_variables"] is not None
229
+ and X_val["categorical_variables"] is not None
230
+ ):
231
+ assert (
232
+ X_train["categorical_variables"].ndim > 1
233
+ and X_train["categorical_variables"].shape[1]
234
+ == X_val["categorical_variables"].shape[1]
235
+ or X_val["categorical_variables"].ndim == 1
236
+ ), "X_train and X_val must have the same number of columns."
237
+
268
238
  if verbose:
269
239
  logger.info("Starting training process...")
270
-
271
- # Device setup
272
- if cpu_run:
273
- device = torch.device("cpu")
274
- else:
240
+
241
+ if training_config.accelerator == "auto":
275
242
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
276
-
277
- self.classifier.device = device
278
-
243
+ else:
244
+ device = torch.device(training_config.accelerator)
245
+
246
+ self.device = device
247
+
248
+ optimizer_params = {"lr": training_config.lr}
249
+ if training_config.optimizer_params is not None:
250
+ optimizer_params.update(training_config.optimizer_params)
251
+
252
+ self.lightning_module = TextClassificationModule(
253
+ model=self.pytorch_model,
254
+ loss=training_config.loss,
255
+ optimizer=training_config.optimizer,
256
+ optimizer_params=optimizer_params,
257
+ scheduler=training_config.scheduler,
258
+ scheduler_params=training_config.scheduler_params
259
+ if training_config.scheduler_params
260
+ else {},
261
+ scheduler_interval="epoch",
262
+ )
263
+
264
+ self.pytorch_model.to(self.device)
265
+
279
266
  if verbose:
280
267
  logger.info(f"Running on: {device}")
281
-
282
- # Build model if not already built
283
- if self.classifier.pytorch_model is None:
284
- if verbose:
285
- start = time.time()
286
- logger.info("Building the model...")
287
- self.build(X_train, y_train, **kwargs)
288
- if verbose:
289
- end = time.time()
290
- logger.info(f"Model built in {end - start:.2f} seconds.")
291
-
292
- self.classifier.pytorch_model = self.classifier.pytorch_model.to(device)
293
-
294
- # Create datasets and dataloaders using wrapper methods
295
- train_dataset = self.classifier.create_dataset(
296
- texts=training_text,
268
+
269
+ train_dataset = TextClassificationDataset(
270
+ texts=X_train["text"],
271
+ categorical_variables=X_train["categorical_variables"], # None if no cat vars
272
+ tokenizer=self.tokenizer,
297
273
  labels=y_train,
298
- categorical_variables=train_categorical_variables,
299
274
  )
300
- val_dataset = self.classifier.create_dataset(
301
- texts=val_text,
275
+ val_dataset = TextClassificationDataset(
276
+ texts=X_val["text"],
277
+ categorical_variables=X_val["categorical_variables"], # None if no cat vars
278
+ tokenizer=self.tokenizer,
302
279
  labels=y_val,
303
- categorical_variables=val_categorical_variables,
304
280
  )
305
-
306
- train_dataloader = self.classifier.create_dataloader(
307
- dataset=train_dataset,
308
- batch_size=batch_size,
309
- num_workers=num_workers,
310
- shuffle=True
281
+
282
+ train_dataloader = train_dataset.create_dataloader(
283
+ batch_size=training_config.batch_size,
284
+ num_workers=training_config.num_workers,
285
+ shuffle=True,
286
+ **training_config.dataloader_params if training_config.dataloader_params else {},
311
287
  )
312
- val_dataloader = self.classifier.create_dataloader(
313
- dataset=val_dataset,
314
- batch_size=batch_size,
315
- num_workers=num_workers,
316
- shuffle=False
288
+ val_dataloader = val_dataset.create_dataloader(
289
+ batch_size=training_config.batch_size,
290
+ num_workers=training_config.num_workers,
291
+ shuffle=False,
292
+ **training_config.dataloader_params if training_config.dataloader_params else {},
317
293
  )
318
-
294
+
319
295
  # Setup trainer
320
296
  callbacks = [
321
297
  ModelCheckpoint(
@@ -326,184 +302,266 @@ class torchTextClassifiers:
326
302
  ),
327
303
  EarlyStopping(
328
304
  monitor="val_loss",
329
- patience=patience_train,
305
+ patience=training_config.patience_early_stopping,
330
306
  mode="min",
331
307
  ),
332
308
  LearningRateMonitor(logging_interval="step"),
333
309
  ]
334
-
335
- train_params = {
310
+
311
+ trainer_params = {
312
+ "accelerator": training_config.accelerator,
336
313
  "callbacks": callbacks,
337
- "max_epochs": num_epochs,
314
+ "max_epochs": training_config.num_epochs,
338
315
  "num_sanity_val_steps": 2,
339
316
  "strategy": "auto",
340
317
  "log_every_n_steps": 1,
341
318
  "enable_progress_bar": True,
342
319
  }
343
-
344
- if trainer_params is not None:
345
- train_params.update(trainer_params)
346
-
347
- trainer = pl.Trainer(**train_params)
348
-
320
+
321
+ if training_config.trainer_params is not None:
322
+ trainer_params.update(training_config.trainer_params)
323
+
324
+ trainer = pl.Trainer(**trainer_params)
325
+
349
326
  torch.cuda.empty_cache()
350
327
  torch.set_float32_matmul_precision("medium")
351
-
328
+
352
329
  if verbose:
353
330
  logger.info("Launching training...")
354
331
  start = time.time()
355
-
356
- trainer.fit(self.classifier.lightning_module, train_dataloader, val_dataloader)
357
-
332
+
333
+ trainer.fit(self.lightning_module, train_dataloader, val_dataloader)
334
+
358
335
  if verbose:
359
336
  end = time.time()
360
337
  logger.info(f"Training completed in {end - start:.2f} seconds.")
361
-
362
- # Load best model using wrapper method
338
+
363
339
  best_model_path = trainer.checkpoint_callback.best_model_path
364
- self.classifier.load_best_model(best_model_path)
365
-
366
- def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
367
- """Make predictions on input data.
368
-
369
- Args:
370
- X: Input data for prediction (text and optional categorical features)
371
- **kwargs: Additional arguments passed to the underlying predictor
372
-
373
- Returns:
374
- np.ndarray: Predicted class labels
375
-
376
- Example:
377
- >>> X_test = np.array(["new text sample", "another sample"])
378
- >>> predictions = classifier.predict(X_test)
379
- >>> print(predictions) # [0, 1]
380
- """
381
- return self.classifier.predict(X, **kwargs)
382
-
383
- def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
384
- """Validate the model on test data.
385
-
386
- Args:
387
- X: Input data for validation
388
- Y: True labels for validation
389
- **kwargs: Additional arguments passed to the validator
390
-
391
- Returns:
392
- float: Validation accuracy score
393
-
394
- Example:
395
- >>> accuracy = classifier.validate(X_test, y_test)
396
- >>> print(f"Accuracy: {accuracy:.3f}")
397
- """
398
- return self.classifier.validate(X, Y, **kwargs)
399
-
400
- def predict_and_explain(self, X: np.ndarray, **kwargs):
401
- """Make predictions with explanations (if supported).
402
-
403
- This method provides both predictions and explanations for the model's
404
- decisions. Availability depends on the specific classifier implementation.
405
-
340
+
341
+ self.lightning_module = TextClassificationModule.load_from_checkpoint(
342
+ best_model_path,
343
+ model=self.pytorch_model,
344
+ loss=training_config.loss,
345
+ )
346
+
347
+ self.pytorch_model = self.lightning_module.model.to(self.device)
348
+
349
+ self.lightning_module.eval()
350
+
351
+ def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
352
+ X = self._check_X(X)
353
+ Y = self._check_Y(Y)
354
+
355
+ if X["text"].shape[0] != Y.shape[0]:
356
+ raise ValueError("X_train and y_train must have the same number of observations.")
357
+
358
+ return X, Y
359
+
360
+ @staticmethod
361
+ def _check_text_col(X):
362
+ assert isinstance(
363
+ X, np.ndarray
364
+ ), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables."
365
+
366
+ try:
367
+ if X.ndim > 1:
368
+ text = X[:, 0].astype(str)
369
+ else:
370
+ text = X[:].astype(str)
371
+ except ValueError:
372
+ logger.error("The first column of X must be castable in string format.")
373
+
374
+ return text
375
+
376
+ def _check_categorical_variables(self, X: np.ndarray) -> None:
377
+ """Check if categorical variables in X match training configuration.
378
+
406
379
  Args:
407
- X: Input data for prediction
408
- **kwargs: Additional arguments passed to the explainer
409
-
410
- Returns:
411
- tuple: (predictions, explanations) where explanations format depends
412
- on the classifier type
413
-
380
+ X: Input data to check
381
+
414
382
  Raises:
415
- NotImplementedError: If the classifier doesn't support explanations
416
-
417
- Example:
418
- >>> predictions, explanations = classifier.predict_and_explain(X_test)
419
- >>> print(f"Predictions: {predictions}")
420
- >>> print(f"Explanations: {explanations}")
383
+ ValueError: If the number of categorical variables does not match
384
+ the training configuration
421
385
  """
422
- if hasattr(self.classifier, 'predict_and_explain'):
423
- return self.classifier.predict_and_explain(X, **kwargs)
386
+
387
+ assert self.categorical_var_net is not None
388
+
389
+ if X.ndim > 1:
390
+ num_cat_vars = X.shape[1] - 1
424
391
  else:
425
- raise NotImplementedError(f"Explanation not supported for {type(self.classifier).__name__}")
426
-
427
- def to_json(self, filepath: str) -> None:
428
- """Save classifier configuration to JSON file.
429
-
430
- This method serializes the classifier configuration to a JSON
431
- file. Note: This only saves configuration, not trained model weights.
432
- Custom classifier wrappers should implement a class method `get_wrapper_class_info()`
433
- that returns a dict with 'module' and 'class_name' keys for proper reconstruction.
434
-
435
- Args:
436
- filepath: Path where to save the JSON configuration file
437
-
438
- Example:
439
- >>> classifier.to_json('my_classifier_config.json')
392
+ num_cat_vars = 0
393
+
394
+ if num_cat_vars != self.categorical_var_net.num_categorical_features:
395
+ raise ValueError(
396
+ f"X must have the same number of categorical variables as the number of embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})."
397
+ )
398
+
399
+ try:
400
+ categorical_variables = X[:, 1:].astype(int)
401
+ except ValueError:
402
+ logger.error(
403
+ f"Columns {1} to {X.shape[1] - 1} of X_train must be castable in integer format."
404
+ )
405
+
406
+ for j in range(X.shape[1] - 1):
407
+ max_cat_value = categorical_variables[:, j].max()
408
+ if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]:
409
+ raise ValueError(
410
+ f"Categorical variable at index {j} has value {max_cat_value} which exceeds the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}."
411
+ )
412
+
413
+ return categorical_variables
414
+
415
+ def _check_X(self, X: np.ndarray) -> np.ndarray:
416
+ text = self._check_text_col(X)
417
+
418
+ categorical_variables = None
419
+ if self.categorical_var_net is not None:
420
+ categorical_variables = self._check_categorical_variables(X)
421
+
422
+ return {"text": text, "categorical_variables": categorical_variables}
423
+
424
+ def _check_Y(self, Y):
425
+ assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
426
+ assert len(Y.shape) == 1 or (
427
+ len(Y.shape) == 2 and Y.shape[1] == 1
428
+ ), "Y must be a numpy array of shape (N,) or (N,1)."
429
+
430
+ try:
431
+ Y = Y.astype(int)
432
+ except ValueError:
433
+ logger.error("Y must be castable in integer format.")
434
+
435
+ if Y.max() >= self.num_classes or Y.min() < 0:
436
+ raise ValueError(
437
+ f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
438
+ )
439
+
440
+ return Y
441
+
442
+ def predict(
443
+ self,
444
+ X_test: np.ndarray,
445
+ top_k=1,
446
+ explain=False,
447
+ ):
440
448
  """
441
- with open(filepath, "w") as f:
442
- data = {
443
- "config": self.config.to_dict(),
444
- }
445
-
446
- # Try to get wrapper class info for reconstruction
447
- if hasattr(self.classifier.__class__, 'get_wrapper_class_info'):
448
- data["wrapper_class_info"] = self.classifier.__class__.get_wrapper_class_info()
449
- else:
450
- # Fallback: store module and class name
451
- data["wrapper_class_info"] = {
452
- "module": self.classifier.__class__.__module__,
453
- "class_name": self.classifier.__class__.__name__
454
- }
455
-
456
- json.dump(data, f, cls=NumpyJSONEncoder, indent=4)
457
-
458
- @classmethod
459
- def from_json(cls, filepath: str, wrapper_class: Optional[Type[BaseClassifierWrapper]] = None) -> "torchTextClassifiers":
460
- """Load classifier configuration from JSON file.
461
-
462
- This method creates a new classifier instance from a previously saved
463
- configuration file. The classifier will need to be built and trained again.
464
-
465
449
  Args:
466
- filepath: Path to the JSON configuration file
467
- wrapper_class: Optional wrapper class to use. If not provided, will try to
468
- reconstruct from saved wrapper_class_info
469
-
470
- Returns:
471
- torchTextClassifiers: New classifier instance with loaded configuration
472
-
473
- Raises:
474
- ImportError: If the wrapper class cannot be imported
475
- FileNotFoundError: If the configuration file doesn't exist
476
-
477
- Example:
478
- >>> # Using saved wrapper class info
479
- >>> classifier = torchTextClassifiers.from_json('my_classifier_config.json')
480
- >>>
481
- >>> # Or providing wrapper class explicitly
482
- >>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
483
- >>> classifier = torchTextClassifiers.from_json('config.json', FastTextWrapper)
450
+ X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
451
+ top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
452
+ explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
453
+
454
+ Returns: A dictionary containing the following fields:
455
+ - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
456
+ - confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores.
457
+ - if explain is True:
458
+ - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
484
459
  """
485
- with open(filepath, "r") as f:
486
- data = json.load(f)
487
-
488
- if wrapper_class is None:
489
- # Try to reconstruct wrapper class from saved info
490
- if "wrapper_class_info" not in data:
491
- raise ValueError("No wrapper_class_info found in config file and no wrapper_class provided")
492
-
493
- wrapper_info = data["wrapper_class_info"]
494
- module_name = wrapper_info["module"]
495
- class_name = wrapper_info["class_name"]
496
-
497
- # Dynamically import the wrapper class
498
- import importlib
499
- module = importlib.import_module(module_name)
500
- wrapper_class = getattr(module, class_name)
501
-
502
- # Reconstruct config using wrapper class's config class
503
- config_class = wrapper_class.get_config_class()
504
- config = config_class.from_dict(data["config"])
505
-
506
- # Create wrapper instance
507
- wrapper = wrapper_class(config)
508
-
509
- return cls(wrapper)
460
+
461
+ if explain:
462
+ return_offsets_mapping = True # to be passed to the tokenizer
463
+ return_word_ids = True
464
+ if self.pytorch_model.text_embedder is None:
465
+ raise RuntimeError(
466
+ "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
467
+ )
468
+ else:
469
+ if not HAS_CAPTUM:
470
+ raise ImportError(
471
+ "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
472
+ )
473
+ lig = LayerIntegratedGradients(
474
+ self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
475
+ ) # initialize a Captum layer gradient integrator
476
+ else:
477
+ return_offsets_mapping = False
478
+ return_word_ids = False
479
+
480
+ X_test = self._check_X(X_test)
481
+ text = X_test["text"]
482
+ categorical_variables = X_test["categorical_variables"]
483
+
484
+ self.pytorch_model.eval().cpu()
485
+
486
+ tokenize_output = self.tokenizer.tokenize(
487
+ text.tolist(),
488
+ return_offsets_mapping=return_offsets_mapping,
489
+ return_word_ids=return_word_ids,
490
+ )
491
+
492
+ if not isinstance(tokenize_output, TokenizerOutput):
493
+ raise TypeError(
494
+ f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method."
495
+ )
496
+
497
+ encoded_text = tokenize_output.input_ids # (batch_size, seq_len)
498
+ attention_mask = tokenize_output.attention_mask # (batch_size, seq_len)
499
+
500
+ if categorical_variables is not None:
501
+ categorical_vars = torch.tensor(
502
+ categorical_variables, dtype=torch.float32
503
+ ) # (batch_size, num_categorical_features)
504
+ else:
505
+ categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32)
506
+
507
+ pred = self.pytorch_model(
508
+ encoded_text, attention_mask, categorical_vars
509
+ ) # forward pass, contains the prediction scores (len(text), num_classes)
510
+
511
+ label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities
512
+
513
+ label_scores_topk = torch.topk(label_scores, k=top_k, dim=1)
514
+
515
+ predictions = label_scores_topk.indices # get the top_k most likely predictions
516
+ confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
517
+
518
+ if explain:
519
+ all_attributions = []
520
+ for k in range(top_k):
521
+ attributions = lig.attribute(
522
+ (encoded_text, attention_mask, categorical_vars),
523
+ target=torch.Tensor(predictions[:, k]).long(),
524
+ ) # (batch_size, seq_len)
525
+ attributions = attributions.sum(dim=-1)
526
+ all_attributions.append(attributions.detach().cpu())
527
+
528
+ all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
529
+
530
+ return {
531
+ "prediction": predictions,
532
+ "confidence": confidence,
533
+ "attributions": all_attributions,
534
+ "offset_mapping": tokenize_output.offset_mapping,
535
+ "word_ids": tokenize_output.word_ids,
536
+ }
537
+ else:
538
+ return {
539
+ "prediction": predictions,
540
+ "confidence": confidence,
541
+ }
542
+
543
+ def __repr__(self):
544
+ model_type = (
545
+ self.lightning_module.__repr__()
546
+ if hasattr(self, "lightning_module")
547
+ else self.pytorch_model.__repr__()
548
+ )
549
+
550
+ tokenizer_info = self.tokenizer.__repr__()
551
+
552
+ cat_forward_type = (
553
+ self.categorical_var_net.forward_type.name
554
+ if self.categorical_var_net is not None
555
+ else "None"
556
+ )
557
+
558
+ lines = [
559
+ "torchTextClassifiers(",
560
+ f" tokenizer = {tokenizer_info},",
561
+ f" model = {model_type},",
562
+ f" categorical_forward_type = {cat_forward_type},",
563
+ f" num_classes = {self.model_config.num_classes},",
564
+ f" embedding_dim = {self.embedding_dim},",
565
+ ")",
566
+ ]
567
+ return "\n".join(lines)