torchtextclassifiers 0.0.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. torchTextClassifiers/__init__.py +12 -48
  2. torchTextClassifiers/dataset/__init__.py +1 -0
  3. torchTextClassifiers/dataset/dataset.py +152 -0
  4. torchTextClassifiers/model/__init__.py +2 -0
  5. torchTextClassifiers/model/components/__init__.py +12 -0
  6. torchTextClassifiers/model/components/attention.py +126 -0
  7. torchTextClassifiers/model/components/categorical_var_net.py +128 -0
  8. torchTextClassifiers/model/components/classification_head.py +61 -0
  9. torchTextClassifiers/model/components/text_embedder.py +220 -0
  10. torchTextClassifiers/model/lightning.py +170 -0
  11. torchTextClassifiers/model/model.py +151 -0
  12. torchTextClassifiers/tokenizers/WordPiece.py +92 -0
  13. torchTextClassifiers/tokenizers/__init__.py +10 -0
  14. torchTextClassifiers/tokenizers/base.py +205 -0
  15. torchTextClassifiers/tokenizers/ngram.py +472 -0
  16. torchTextClassifiers/torchTextClassifiers.py +500 -413
  17. torchTextClassifiers/utilities/__init__.py +0 -3
  18. torchTextClassifiers/utilities/plot_explainability.py +184 -0
  19. torchtextclassifiers-1.0.0.dist-info/METADATA +87 -0
  20. torchtextclassifiers-1.0.0.dist-info/RECORD +21 -0
  21. {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-1.0.0.dist-info}/WHEEL +1 -1
  22. torchTextClassifiers/classifiers/base.py +0 -83
  23. torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
  24. torchTextClassifiers/classifiers/fasttext/core.py +0 -269
  25. torchTextClassifiers/classifiers/fasttext/model.py +0 -752
  26. torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
  27. torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
  28. torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
  29. torchTextClassifiers/factories.py +0 -34
  30. torchTextClassifiers/utilities/checkers.py +0 -108
  31. torchTextClassifiers/utilities/preprocess.py +0 -82
  32. torchTextClassifiers/utilities/utils.py +0 -346
  33. torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
  34. torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
@@ -1,7 +1,15 @@
1
1
  import logging
2
2
  import time
3
- import json
4
- from typing import Optional, Union, Type, List, Dict, Any
3
+ from dataclasses import asdict, dataclass, field
4
+ from typing import Any, Dict, List, Optional, Tuple, Type, Union
5
+
6
+ try:
7
+ from captum.attr import LayerIntegratedGradients
8
+
9
+ HAS_CAPTUM = True
10
+ except ImportError:
11
+ HAS_CAPTUM = False
12
+
5
13
 
6
14
  import numpy as np
7
15
  import pytorch_lightning as pl
@@ -12,9 +20,17 @@ from pytorch_lightning.callbacks import (
12
20
  ModelCheckpoint,
13
21
  )
14
22
 
15
- from .utilities.checkers import check_X, check_Y, NumpyJSONEncoder
16
- from .classifiers.base import BaseClassifierConfig, BaseClassifierWrapper
17
-
23
+ from torchTextClassifiers.dataset import TextClassificationDataset
24
+ from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
25
+ from torchTextClassifiers.model.components import (
26
+ AttentionConfig,
27
+ CategoricalForwardType,
28
+ CategoricalVariableNet,
29
+ ClassificationHead,
30
+ TextEmbedder,
31
+ TextEmbedderConfig,
32
+ )
33
+ from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
18
34
 
19
35
  logger = logging.getLogger(__name__)
20
36
 
@@ -26,484 +42,555 @@ logging.basicConfig(
26
42
  )
27
43
 
28
44
 
45
+ @dataclass
46
+ class ModelConfig:
47
+ """Base configuration class for text classifiers."""
48
+
49
+ embedding_dim: int
50
+ categorical_vocabulary_sizes: Optional[List[int]] = None
51
+ categorical_embedding_dims: Optional[Union[List[int], int]] = None
52
+ num_classes: Optional[int] = None
53
+ attention_config: Optional[AttentionConfig] = None
54
+
55
+ def to_dict(self) -> Dict[str, Any]:
56
+ return asdict(self)
57
+
58
+ @classmethod
59
+ def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
60
+ return cls(**data)
61
+
62
+
63
+ @dataclass
64
+ class TrainingConfig:
65
+ num_epochs: int
66
+ batch_size: int
67
+ lr: float
68
+ loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss())
69
+ optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
70
+ scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None
71
+ accelerator: str = "auto"
72
+ num_workers: int = 12
73
+ patience_early_stopping: int = 3
74
+ dataloader_params: Optional[dict] = None
75
+ trainer_params: Optional[dict] = None
76
+ optimizer_params: Optional[dict] = None
77
+ scheduler_params: Optional[dict] = None
78
+
79
+ def to_dict(self) -> Dict[str, Any]:
80
+ data = asdict(self)
81
+ # Serialize loss and scheduler as their class names
82
+ data["loss"] = self.loss.__class__.__name__
83
+ if self.scheduler is not None:
84
+ data["scheduler"] = self.scheduler.__name__
85
+ return data
29
86
 
30
87
 
31
88
  class torchTextClassifiers:
32
89
  """Generic text classifier framework supporting multiple architectures.
33
-
34
- This is the main class that provides a unified interface for different types
35
- of text classifiers. It acts as a high-level wrapper that delegates operations
36
- to specific classifier implementations while providing a consistent API.
37
-
38
- The class supports the full machine learning workflow including:
39
- - Building tokenizers from training data
40
- - Model training with validation
41
- - Prediction and evaluation
42
- - Model serialization and loading
43
-
44
- Attributes:
45
- config: Configuration object specific to the classifier type
46
- classifier: The underlying classifier implementation
47
-
48
- Example:
49
- >>> from torchTextClassifiers import torchTextClassifiers
50
- >>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
51
- >>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
52
- >>>
53
- >>> # Create configuration
54
- >>> config = FastTextConfig(
55
- ... embedding_dim=100,
56
- ... num_tokens=10000,
57
- ... min_count=1,
58
- ... min_n=3,
59
- ... max_n=6,
60
- ... len_word_ngrams=2,
61
- ... num_classes=2
62
- ... )
63
- >>>
64
- >>> # Initialize classifier with wrapper
65
- >>> wrapper = FastTextWrapper(config)
66
- >>> classifier = torchTextClassifiers(wrapper)
67
- >>>
68
- >>> # Build and train
69
- >>> classifier.build(X_train, y_train)
70
- >>> classifier.train(X_train, y_train, X_val, y_val, num_epochs=10, batch_size=32)
71
- >>>
72
- >>> # Predict
73
- >>> predictions = classifier.predict(X_test)
90
+
91
+ Given a tokenizer and model configuration, this class initializes:
92
+ - Text embedding layer (if needed)
93
+ - Categorical variable embedding network (if categorical variables are provided)
94
+ - Classification head
95
+ The resulting model can be trained using PyTorch Lightning and used for predictions.
96
+
74
97
  """
75
-
76
- def __init__(self, classifier: BaseClassifierWrapper):
77
- """Initialize the torchTextClassifiers instance.
78
-
79
- Args:
80
- classifier: An instance of a classifier wrapper that implements BaseClassifierWrapper
81
-
82
- Example:
83
- >>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
84
- >>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
85
- >>> config = FastTextConfig(embedding_dim=50, num_tokens=5000)
86
- >>> wrapper = FastTextWrapper(config)
87
- >>> classifier = torchTextClassifiers(wrapper)
88
- """
89
- self.classifier = classifier
90
- self.config = classifier.config
91
-
92
-
93
- def build_tokenizer(self, training_text: np.ndarray) -> None:
94
- """Build tokenizer from training text data.
95
-
96
- This method is kept for backward compatibility. It delegates to
97
- prepare_text_features which handles the actual text preprocessing.
98
-
99
- Args:
100
- training_text: Array of text strings to build the tokenizer from
101
-
102
- Example:
103
- >>> import numpy as np
104
- >>> texts = np.array(["Hello world", "This is a test", "Another example"])
105
- >>> classifier.build_tokenizer(texts)
106
- """
107
- self.classifier.prepare_text_features(training_text)
108
-
109
- def prepare_text_features(self, training_text: np.ndarray) -> None:
110
- """Prepare text features for the classifier.
111
-
112
- This method handles text preprocessing which could involve tokenization,
113
- vectorization, or other approaches depending on the classifier type.
114
-
115
- Args:
116
- training_text: Array of text strings to prepare features from
117
-
118
- Example:
119
- >>> import numpy as np
120
- >>> texts = np.array(["Hello world", "This is a test", "Another example"])
121
- >>> classifier.prepare_text_features(texts)
122
- """
123
- self.classifier.prepare_text_features(training_text)
124
-
125
- def build(
98
+
99
+ def __init__(
126
100
  self,
127
- X_train: np.ndarray,
128
- y_train: np.ndarray = None,
129
- lightning=True,
130
- **kwargs
131
- ) -> None:
132
- """Build the complete classifier from training data.
133
-
134
- This method handles the full model building process including:
135
- - Input validation and preprocessing
136
- - Tokenizer creation from training text
137
- - Model architecture initialization
138
- - Lightning module setup (if enabled)
139
-
101
+ tokenizer: BaseTokenizer,
102
+ model_config: ModelConfig,
103
+ ragged_multilabel: bool = False,
104
+ ):
105
+ """Initialize the torchTextClassifiers instance.
106
+
140
107
  Args:
141
- X_train: Training input data (text and optional categorical features)
142
- y_train: Training labels (optional, can be inferred if num_classes is set)
143
- lightning: Whether to initialize PyTorch Lightning components
144
- **kwargs: Additional arguments passed to Lightning initialization
145
-
146
- Raises:
147
- ValueError: If y_train is None and num_classes is not set in config
148
- ValueError: If label values are outside expected range
149
-
108
+ tokenizer: A tokenizer instance for text preprocessing
109
+ model_config: Configuration parameters for the text classification model
110
+
150
111
  Example:
151
- >>> X_train = np.array(["text sample 1", "text sample 2"])
152
- >>> y_train = np.array([0, 1])
153
- >>> classifier.build(X_train, y_train)
112
+ >>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
113
+ >>> # Assume tokenizer is a trained BaseTokenizer instance
114
+ >>> model_config = ModelConfig(
115
+ ... embedding_dim=10,
116
+ ... categorical_vocabulary_sizes=[30, 25],
117
+ ... categorical_embedding_dims=[10, 5],
118
+ ... num_classes=10,
119
+ ... )
120
+ >>> ttc = torchTextClassifiers(
121
+ ... tokenizer=tokenizer,
122
+ ... model_config=model_config,
123
+ ... )
154
124
  """
155
- training_text, categorical_variables, no_cat_var = check_X(X_train)
156
-
157
- if y_train is not None:
158
- if self.config.num_classes is not None:
159
- if self.config.num_classes != len(np.unique(y_train)):
160
- logger.warning(
161
- f"Updating num_classes from {self.config.num_classes} to {len(np.unique(y_train))}"
162
- )
163
-
164
- y_train = check_Y(y_train)
165
- self.config.num_classes = len(np.unique(y_train))
166
-
167
- if np.max(y_train) >= self.config.num_classes:
168
- raise ValueError(
169
- "y_train must contain values between 0 and num_classes-1"
125
+
126
+ self.model_config = model_config
127
+ self.tokenizer = tokenizer
128
+ self.ragged_multilabel = ragged_multilabel
129
+
130
+ if hasattr(self.tokenizer, "trained"):
131
+ if not self.tokenizer.trained:
132
+ raise RuntimeError(
133
+ f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier."
170
134
  )
135
+
136
+ self.vocab_size = tokenizer.vocab_size
137
+ self.embedding_dim = model_config.embedding_dim
138
+ self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
139
+ self.num_classes = model_config.num_classes
140
+
141
+ if self.tokenizer.output_vectorized:
142
+ self.text_embedder = None
143
+ logger.info(
144
+ "Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization."
145
+ )
146
+ self.embedding_dim = self.tokenizer.output_dim
171
147
  else:
172
- if self.config.num_classes is None:
173
- raise ValueError(
174
- "Either num_classes must be provided at init or y_train must be provided here."
175
- )
176
-
177
- # Handle categorical variables
178
- if not no_cat_var:
179
- if hasattr(self.config, 'num_categorical_features') and self.config.num_categorical_features is not None:
180
- if self.config.num_categorical_features != categorical_variables.shape[1]:
181
- logger.warning(
182
- f"Updating num_categorical_features from {self.config.num_categorical_features} to {categorical_variables.shape[1]}"
183
- )
184
-
185
- if hasattr(self.config, 'num_categorical_features'):
186
- self.config.num_categorical_features = categorical_variables.shape[1]
187
-
188
- categorical_vocabulary_sizes = np.max(categorical_variables, axis=0) + 1
189
-
190
- if hasattr(self.config, 'categorical_vocabulary_sizes') and self.config.categorical_vocabulary_sizes is not None:
191
- if self.config.categorical_vocabulary_sizes != list(categorical_vocabulary_sizes):
192
- logger.warning(
193
- "Overwriting categorical_vocabulary_sizes with values from training data."
194
- )
195
- if hasattr(self.config, 'categorical_vocabulary_sizes'):
196
- self.config.categorical_vocabulary_sizes = list(categorical_vocabulary_sizes)
197
-
198
- self.classifier.prepare_text_features(training_text)
199
- self.classifier._build_pytorch_model()
200
-
201
- if lightning:
202
- self.classifier._check_and_init_lightning(**kwargs)
203
-
148
+ text_embedder_config = TextEmbedderConfig(
149
+ vocab_size=self.vocab_size,
150
+ embedding_dim=self.embedding_dim,
151
+ padding_idx=tokenizer.padding_idx,
152
+ attention_config=model_config.attention_config,
153
+ )
154
+ self.text_embedder = TextEmbedder(
155
+ text_embedder_config=text_embedder_config,
156
+ )
157
+
158
+ classif_head_input_dim = self.embedding_dim
159
+ if self.categorical_vocabulary_sizes:
160
+ self.categorical_var_net = CategoricalVariableNet(
161
+ categorical_vocabulary_sizes=self.categorical_vocabulary_sizes,
162
+ categorical_embedding_dims=model_config.categorical_embedding_dims,
163
+ text_embedding_dim=self.embedding_dim,
164
+ )
165
+
166
+ if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT:
167
+ classif_head_input_dim += self.categorical_var_net.output_dim
168
+
169
+ else:
170
+ self.categorical_var_net = None
171
+
172
+ self.classification_head = ClassificationHead(
173
+ input_dim=classif_head_input_dim,
174
+ num_classes=model_config.num_classes,
175
+ )
176
+
177
+ self.pytorch_model = TextClassificationModel(
178
+ text_embedder=self.text_embedder,
179
+ categorical_variable_net=self.categorical_var_net,
180
+ classification_head=self.classification_head,
181
+ )
182
+
204
183
  def train(
205
184
  self,
206
185
  X_train: np.ndarray,
207
186
  y_train: np.ndarray,
208
- X_val: np.ndarray,
209
- y_val: np.ndarray,
210
- num_epochs: int,
211
- batch_size: int,
212
- cpu_run: bool = False,
213
- num_workers: int = 12,
214
- patience_train: int = 3,
187
+ training_config: TrainingConfig,
188
+ X_val: Optional[np.ndarray] = None,
189
+ y_val: Optional[np.ndarray] = None,
215
190
  verbose: bool = False,
216
- trainer_params: Optional[dict] = None,
217
- **kwargs
218
191
  ) -> None:
219
192
  """Train the classifier using PyTorch Lightning.
220
-
193
+
221
194
  This method handles the complete training process including:
222
195
  - Data validation and preprocessing
223
196
  - Dataset and DataLoader creation
224
197
  - PyTorch Lightning trainer setup with callbacks
225
198
  - Model training with early stopping
226
199
  - Best model loading after training
227
-
200
+
228
201
  Args:
229
202
  X_train: Training input data
230
203
  y_train: Training labels
231
204
  X_val: Validation input data
232
205
  y_val: Validation labels
233
- num_epochs: Maximum number of training epochs
234
- batch_size: Batch size for training and validation
235
- cpu_run: If True, force training on CPU instead of GPU
236
- num_workers: Number of worker processes for data loading
237
- patience_train: Number of epochs to wait for improvement before early stopping
238
- verbose: If True, print detailed training progress
239
- trainer_params: Additional parameters to pass to PyTorch Lightning Trainer
240
- **kwargs: Additional arguments passed to the build method
241
-
206
+ training_config: Configuration parameters for training
207
+ verbose: Whether to print training progress information
208
+
209
+
242
210
  Example:
243
- >>> classifier.train(
244
- ... X_train, y_train, X_val, y_val,
245
- ... num_epochs=50,
246
- ... batch_size=32,
247
- ... patience_train=5,
248
- ... verbose=True
249
- ... )
211
+
212
+ >>> training_config = TrainingConfig(
213
+ ... lr=1e-3,
214
+ ... batch_size=4,
215
+ ... num_epochs=1,
216
+ ... )
217
+ >>> ttc.train(
218
+ ... X_train=X,
219
+ ... y_train=Y,
220
+ ... X_val=X,
221
+ ... y_val=Y,
222
+ ... training_config=training_config,
223
+ ... )
250
224
  """
251
225
  # Input validation
252
- training_text, train_categorical_variables, train_no_cat_var = check_X(X_train)
253
- val_text, val_categorical_variables, val_no_cat_var = check_X(X_val)
254
- y_train = check_Y(y_train)
255
- y_val = check_Y(y_val)
256
-
257
- # Consistency checks
258
- assert train_no_cat_var == val_no_cat_var, (
259
- "X_train and X_val must have the same number of categorical variables."
260
- )
261
- assert X_train.shape[0] == y_train.shape[0], (
262
- "X_train and y_train must have the same number of observations."
263
- )
264
- assert X_train.ndim > 1 and X_train.shape[1] == X_val.shape[1] or X_val.ndim == 1, (
265
- "X_train and X_val must have the same number of columns."
266
- )
267
-
226
+ X_train, y_train = self._check_XY(X_train, y_train)
227
+
228
+ if X_val is not None:
229
+ assert y_val is not None, "y_val must be provided if X_val is provided."
230
+ if y_val is not None:
231
+ assert X_val is not None, "X_val must be provided if y_val is provided."
232
+
233
+ if X_val is not None and y_val is not None:
234
+ X_val, y_val = self._check_XY(X_val, y_val)
235
+
236
+ if (
237
+ X_train["categorical_variables"] is not None
238
+ and X_val["categorical_variables"] is not None
239
+ ):
240
+ assert (
241
+ X_train["categorical_variables"].ndim > 1
242
+ and X_train["categorical_variables"].shape[1]
243
+ == X_val["categorical_variables"].shape[1]
244
+ or X_val["categorical_variables"].ndim == 1
245
+ ), "X_train and X_val must have the same number of columns."
246
+
268
247
  if verbose:
269
248
  logger.info("Starting training process...")
270
-
271
- # Device setup
272
- if cpu_run:
273
- device = torch.device("cpu")
274
- else:
249
+
250
+ if training_config.accelerator == "auto":
275
251
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
276
-
277
- self.classifier.device = device
278
-
252
+ else:
253
+ device = torch.device(training_config.accelerator)
254
+
255
+ self.device = device
256
+
257
+ optimizer_params = {"lr": training_config.lr}
258
+ if training_config.optimizer_params is not None:
259
+ optimizer_params.update(training_config.optimizer_params)
260
+
261
+ if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel:
262
+ logger.warning(
263
+ "⚠️ You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks."
264
+ )
265
+
266
+ self.lightning_module = TextClassificationModule(
267
+ model=self.pytorch_model,
268
+ loss=training_config.loss,
269
+ optimizer=training_config.optimizer,
270
+ optimizer_params=optimizer_params,
271
+ scheduler=training_config.scheduler,
272
+ scheduler_params=training_config.scheduler_params
273
+ if training_config.scheduler_params
274
+ else {},
275
+ scheduler_interval="epoch",
276
+ )
277
+
278
+ self.pytorch_model.to(self.device)
279
+
279
280
  if verbose:
280
281
  logger.info(f"Running on: {device}")
281
-
282
- # Build model if not already built
283
- if self.classifier.pytorch_model is None:
284
- if verbose:
285
- start = time.time()
286
- logger.info("Building the model...")
287
- self.build(X_train, y_train, **kwargs)
288
- if verbose:
289
- end = time.time()
290
- logger.info(f"Model built in {end - start:.2f} seconds.")
291
-
292
- self.classifier.pytorch_model = self.classifier.pytorch_model.to(device)
293
-
294
- # Create datasets and dataloaders using wrapper methods
295
- train_dataset = self.classifier.create_dataset(
296
- texts=training_text,
297
- labels=y_train,
298
- categorical_variables=train_categorical_variables,
299
- )
300
- val_dataset = self.classifier.create_dataset(
301
- texts=val_text,
302
- labels=y_val,
303
- categorical_variables=val_categorical_variables,
304
- )
305
-
306
- train_dataloader = self.classifier.create_dataloader(
307
- dataset=train_dataset,
308
- batch_size=batch_size,
309
- num_workers=num_workers,
310
- shuffle=True
282
+
283
+ train_dataset = TextClassificationDataset(
284
+ texts=X_train["text"],
285
+ categorical_variables=X_train["categorical_variables"], # None if no cat vars
286
+ tokenizer=self.tokenizer,
287
+ labels=y_train.tolist(),
288
+ ragged_multilabel=self.ragged_multilabel,
311
289
  )
312
- val_dataloader = self.classifier.create_dataloader(
313
- dataset=val_dataset,
314
- batch_size=batch_size,
315
- num_workers=num_workers,
316
- shuffle=False
290
+ train_dataloader = train_dataset.create_dataloader(
291
+ batch_size=training_config.batch_size,
292
+ num_workers=training_config.num_workers,
293
+ shuffle=True,
294
+ **training_config.dataloader_params if training_config.dataloader_params else {},
317
295
  )
318
-
296
+
297
+ if X_val is not None and y_val is not None:
298
+ val_dataset = TextClassificationDataset(
299
+ texts=X_val["text"],
300
+ categorical_variables=X_val["categorical_variables"], # None if no cat vars
301
+ tokenizer=self.tokenizer,
302
+ labels=y_val,
303
+ ragged_multilabel=self.ragged_multilabel,
304
+ )
305
+ val_dataloader = val_dataset.create_dataloader(
306
+ batch_size=training_config.batch_size,
307
+ num_workers=training_config.num_workers,
308
+ shuffle=False,
309
+ **training_config.dataloader_params if training_config.dataloader_params else {},
310
+ )
311
+ else:
312
+ val_dataloader = None
313
+
319
314
  # Setup trainer
320
315
  callbacks = [
321
316
  ModelCheckpoint(
322
- monitor="val_loss",
317
+ monitor="val_loss" if val_dataloader is not None else "train_loss",
323
318
  save_top_k=1,
324
319
  save_last=False,
325
320
  mode="min",
326
321
  ),
327
322
  EarlyStopping(
328
- monitor="val_loss",
329
- patience=patience_train,
323
+ monitor="val_loss" if val_dataloader is not None else "train_loss",
324
+ patience=training_config.patience_early_stopping,
330
325
  mode="min",
331
326
  ),
332
327
  LearningRateMonitor(logging_interval="step"),
333
328
  ]
334
-
335
- train_params = {
329
+
330
+ trainer_params = {
331
+ "accelerator": training_config.accelerator,
336
332
  "callbacks": callbacks,
337
- "max_epochs": num_epochs,
333
+ "max_epochs": training_config.num_epochs,
338
334
  "num_sanity_val_steps": 2,
339
335
  "strategy": "auto",
340
336
  "log_every_n_steps": 1,
341
337
  "enable_progress_bar": True,
342
338
  }
343
-
344
- if trainer_params is not None:
345
- train_params.update(trainer_params)
346
-
347
- trainer = pl.Trainer(**train_params)
348
-
339
+
340
+ if training_config.trainer_params is not None:
341
+ trainer_params.update(training_config.trainer_params)
342
+
343
+ trainer = pl.Trainer(**trainer_params)
344
+
349
345
  torch.cuda.empty_cache()
350
346
  torch.set_float32_matmul_precision("medium")
351
-
347
+
352
348
  if verbose:
353
349
  logger.info("Launching training...")
354
350
  start = time.time()
355
-
356
- trainer.fit(self.classifier.lightning_module, train_dataloader, val_dataloader)
357
-
351
+
352
+ trainer.fit(self.lightning_module, train_dataloader, val_dataloader)
353
+
358
354
  if verbose:
359
355
  end = time.time()
360
356
  logger.info(f"Training completed in {end - start:.2f} seconds.")
361
-
362
- # Load best model using wrapper method
357
+
363
358
  best_model_path = trainer.checkpoint_callback.best_model_path
364
- self.classifier.load_best_model(best_model_path)
365
-
366
- def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
367
- """Make predictions on input data.
368
-
369
- Args:
370
- X: Input data for prediction (text and optional categorical features)
371
- **kwargs: Additional arguments passed to the underlying predictor
372
-
373
- Returns:
374
- np.ndarray: Predicted class labels
375
-
376
- Example:
377
- >>> X_test = np.array(["new text sample", "another sample"])
378
- >>> predictions = classifier.predict(X_test)
379
- >>> print(predictions) # [0, 1]
380
- """
381
- return self.classifier.predict(X, **kwargs)
382
-
383
- def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
384
- """Validate the model on test data.
385
-
386
- Args:
387
- X: Input data for validation
388
- Y: True labels for validation
389
- **kwargs: Additional arguments passed to the validator
390
-
391
- Returns:
392
- float: Validation accuracy score
393
-
394
- Example:
395
- >>> accuracy = classifier.validate(X_test, y_test)
396
- >>> print(f"Accuracy: {accuracy:.3f}")
397
- """
398
- return self.classifier.validate(X, Y, **kwargs)
399
-
400
- def predict_and_explain(self, X: np.ndarray, **kwargs):
401
- """Make predictions with explanations (if supported).
402
-
403
- This method provides both predictions and explanations for the model's
404
- decisions. Availability depends on the specific classifier implementation.
405
-
359
+
360
+ self.lightning_module = TextClassificationModule.load_from_checkpoint(
361
+ best_model_path,
362
+ model=self.pytorch_model,
363
+ loss=training_config.loss,
364
+ )
365
+
366
+ self.pytorch_model = self.lightning_module.model.to(self.device)
367
+
368
+ self.lightning_module.eval()
369
+
370
+ def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
371
+ X = self._check_X(X)
372
+ Y = self._check_Y(Y)
373
+
374
+ if X["text"].shape[0] != len(Y):
375
+ raise ValueError("X_train and y_train must have the same number of observations.")
376
+
377
+ return X, Y
378
+
379
+ @staticmethod
380
+ def _check_text_col(X):
381
+ assert isinstance(
382
+ X, np.ndarray
383
+ ), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables."
384
+
385
+ try:
386
+ if X.ndim > 1:
387
+ text = X[:, 0].astype(str)
388
+ else:
389
+ text = X[:].astype(str)
390
+ except ValueError:
391
+ logger.error("The first column of X must be castable in string format.")
392
+
393
+ return text
394
+
395
+ def _check_categorical_variables(self, X: np.ndarray) -> None:
396
+ """Check if categorical variables in X match training configuration.
397
+
406
398
  Args:
407
- X: Input data for prediction
408
- **kwargs: Additional arguments passed to the explainer
409
-
410
- Returns:
411
- tuple: (predictions, explanations) where explanations format depends
412
- on the classifier type
413
-
399
+ X: Input data to check
400
+
414
401
  Raises:
415
- NotImplementedError: If the classifier doesn't support explanations
416
-
417
- Example:
418
- >>> predictions, explanations = classifier.predict_and_explain(X_test)
419
- >>> print(f"Predictions: {predictions}")
420
- >>> print(f"Explanations: {explanations}")
402
+ ValueError: If the number of categorical variables does not match
403
+ the training configuration
421
404
  """
422
- if hasattr(self.classifier, 'predict_and_explain'):
423
- return self.classifier.predict_and_explain(X, **kwargs)
405
+
406
+ assert self.categorical_var_net is not None
407
+
408
+ if X.ndim > 1:
409
+ num_cat_vars = X.shape[1] - 1
424
410
  else:
425
- raise NotImplementedError(f"Explanation not supported for {type(self.classifier).__name__}")
426
-
427
- def to_json(self, filepath: str) -> None:
428
- """Save classifier configuration to JSON file.
429
-
430
- This method serializes the classifier configuration to a JSON
431
- file. Note: This only saves configuration, not trained model weights.
432
- Custom classifier wrappers should implement a class method `get_wrapper_class_info()`
433
- that returns a dict with 'module' and 'class_name' keys for proper reconstruction.
434
-
435
- Args:
436
- filepath: Path where to save the JSON configuration file
437
-
438
- Example:
439
- >>> classifier.to_json('my_classifier_config.json')
411
+ num_cat_vars = 0
412
+
413
+ if num_cat_vars != self.categorical_var_net.num_categorical_features:
414
+ raise ValueError(
415
+ f"X must have the same number of categorical variables as the number of embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})."
416
+ )
417
+
418
+ try:
419
+ categorical_variables = X[:, 1:].astype(int)
420
+ except ValueError:
421
+ logger.error(
422
+ f"Columns {1} to {X.shape[1] - 1} of X_train must be castable in integer format."
423
+ )
424
+
425
+ for j in range(X.shape[1] - 1):
426
+ max_cat_value = categorical_variables[:, j].max()
427
+ if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]:
428
+ raise ValueError(
429
+ f"Categorical variable at index {j} has value {max_cat_value} which exceeds the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}."
430
+ )
431
+
432
+ return categorical_variables
433
+
434
+ def _check_X(self, X: np.ndarray) -> np.ndarray:
435
+ text = self._check_text_col(X)
436
+
437
+ categorical_variables = None
438
+ if self.categorical_var_net is not None:
439
+ categorical_variables = self._check_categorical_variables(X)
440
+
441
+ return {"text": text, "categorical_variables": categorical_variables}
442
+
443
+ def _check_Y(self, Y):
444
+ if self.ragged_multilabel:
445
+ assert isinstance(
446
+ Y, list
447
+ ), "Y must be a list of lists for ragged multilabel classification."
448
+ for row in Y:
449
+ assert isinstance(row, list), "Each element of Y must be a list of labels."
450
+
451
+ return Y
452
+
453
+ else:
454
+ assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
455
+ assert (
456
+ len(Y.shape) == 1 or len(Y.shape) == 2
457
+ ), "Y must be a numpy array of shape (N,) or (N, num_labels)."
458
+
459
+ try:
460
+ Y = Y.astype(int)
461
+ except ValueError:
462
+ logger.error("Y must be castable in integer format.")
463
+
464
+ if Y.max() >= self.num_classes or Y.min() < 0:
465
+ raise ValueError(
466
+ f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
467
+ )
468
+
469
+ return Y
470
+
471
+ def predict(
472
+ self,
473
+ X_test: np.ndarray,
474
+ top_k=1,
475
+ explain=False,
476
+ ):
440
477
  """
441
- with open(filepath, "w") as f:
442
- data = {
443
- "config": self.config.to_dict(),
444
- }
445
-
446
- # Try to get wrapper class info for reconstruction
447
- if hasattr(self.classifier.__class__, 'get_wrapper_class_info'):
448
- data["wrapper_class_info"] = self.classifier.__class__.get_wrapper_class_info()
449
- else:
450
- # Fallback: store module and class name
451
- data["wrapper_class_info"] = {
452
- "module": self.classifier.__class__.__module__,
453
- "class_name": self.classifier.__class__.__name__
454
- }
455
-
456
- json.dump(data, f, cls=NumpyJSONEncoder, indent=4)
457
-
458
- @classmethod
459
- def from_json(cls, filepath: str, wrapper_class: Optional[Type[BaseClassifierWrapper]] = None) -> "torchTextClassifiers":
460
- """Load classifier configuration from JSON file.
461
-
462
- This method creates a new classifier instance from a previously saved
463
- configuration file. The classifier will need to be built and trained again.
464
-
465
478
  Args:
466
- filepath: Path to the JSON configuration file
467
- wrapper_class: Optional wrapper class to use. If not provided, will try to
468
- reconstruct from saved wrapper_class_info
469
-
470
- Returns:
471
- torchTextClassifiers: New classifier instance with loaded configuration
472
-
473
- Raises:
474
- ImportError: If the wrapper class cannot be imported
475
- FileNotFoundError: If the configuration file doesn't exist
476
-
477
- Example:
478
- >>> # Using saved wrapper class info
479
- >>> classifier = torchTextClassifiers.from_json('my_classifier_config.json')
480
- >>>
481
- >>> # Or providing wrapper class explicitly
482
- >>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
483
- >>> classifier = torchTextClassifiers.from_json('config.json', FastTextWrapper)
479
+ X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
480
+ top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
481
+ explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
482
+
483
+ Returns: A dictionary containing the following fields:
484
+ - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
485
+ - confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores.
486
+ - if explain is True:
487
+ - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
484
488
  """
485
- with open(filepath, "r") as f:
486
- data = json.load(f)
487
-
488
- if wrapper_class is None:
489
- # Try to reconstruct wrapper class from saved info
490
- if "wrapper_class_info" not in data:
491
- raise ValueError("No wrapper_class_info found in config file and no wrapper_class provided")
492
-
493
- wrapper_info = data["wrapper_class_info"]
494
- module_name = wrapper_info["module"]
495
- class_name = wrapper_info["class_name"]
496
-
497
- # Dynamically import the wrapper class
498
- import importlib
499
- module = importlib.import_module(module_name)
500
- wrapper_class = getattr(module, class_name)
501
-
502
- # Reconstruct config using wrapper class's config class
503
- config_class = wrapper_class.get_config_class()
504
- config = config_class.from_dict(data["config"])
505
-
506
- # Create wrapper instance
507
- wrapper = wrapper_class(config)
508
-
509
- return cls(wrapper)
489
+
490
+ if explain:
491
+ return_offsets_mapping = True # to be passed to the tokenizer
492
+ return_word_ids = True
493
+ if self.pytorch_model.text_embedder is None:
494
+ raise RuntimeError(
495
+ "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
496
+ )
497
+ else:
498
+ if not HAS_CAPTUM:
499
+ raise ImportError(
500
+ "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
501
+ )
502
+ lig = LayerIntegratedGradients(
503
+ self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
504
+ ) # initialize a Captum layer gradient integrator
505
+ else:
506
+ return_offsets_mapping = False
507
+ return_word_ids = False
508
+
509
+ X_test = self._check_X(X_test)
510
+ text = X_test["text"]
511
+ categorical_variables = X_test["categorical_variables"]
512
+
513
+ self.pytorch_model.eval().cpu()
514
+
515
+ tokenize_output = self.tokenizer.tokenize(
516
+ text.tolist(),
517
+ return_offsets_mapping=return_offsets_mapping,
518
+ return_word_ids=return_word_ids,
519
+ )
520
+
521
+ if not isinstance(tokenize_output, TokenizerOutput):
522
+ raise TypeError(
523
+ f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method."
524
+ )
525
+
526
+ encoded_text = tokenize_output.input_ids # (batch_size, seq_len)
527
+ attention_mask = tokenize_output.attention_mask # (batch_size, seq_len)
528
+
529
+ if categorical_variables is not None:
530
+ categorical_vars = torch.tensor(
531
+ categorical_variables, dtype=torch.float32
532
+ ) # (batch_size, num_categorical_features)
533
+ else:
534
+ categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32)
535
+
536
+ pred = self.pytorch_model(
537
+ encoded_text, attention_mask, categorical_vars
538
+ ) # forward pass, contains the prediction scores (len(text), num_classes)
539
+
540
+ label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities
541
+
542
+ label_scores_topk = torch.topk(label_scores, k=top_k, dim=1)
543
+
544
+ predictions = label_scores_topk.indices # get the top_k most likely predictions
545
+ confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
546
+
547
+ if explain:
548
+ all_attributions = []
549
+ for k in range(top_k):
550
+ attributions = lig.attribute(
551
+ (encoded_text, attention_mask, categorical_vars),
552
+ target=torch.Tensor(predictions[:, k]).long(),
553
+ ) # (batch_size, seq_len)
554
+ attributions = attributions.sum(dim=-1)
555
+ all_attributions.append(attributions.detach().cpu())
556
+
557
+ all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
558
+
559
+ return {
560
+ "prediction": predictions,
561
+ "confidence": confidence,
562
+ "attributions": all_attributions,
563
+ "offset_mapping": tokenize_output.offset_mapping,
564
+ "word_ids": tokenize_output.word_ids,
565
+ }
566
+ else:
567
+ return {
568
+ "prediction": predictions,
569
+ "confidence": confidence,
570
+ }
571
+
572
+ def __repr__(self):
573
+ model_type = (
574
+ self.lightning_module.__repr__()
575
+ if hasattr(self, "lightning_module")
576
+ else self.pytorch_model.__repr__()
577
+ )
578
+
579
+ tokenizer_info = self.tokenizer.__repr__()
580
+
581
+ cat_forward_type = (
582
+ self.categorical_var_net.forward_type.name
583
+ if self.categorical_var_net is not None
584
+ else "None"
585
+ )
586
+
587
+ lines = [
588
+ "torchTextClassifiers(",
589
+ f" tokenizer = {tokenizer_info},",
590
+ f" model = {model_type},",
591
+ f" categorical_forward_type = {cat_forward_type},",
592
+ f" num_classes = {self.model_config.num_classes},",
593
+ f" embedding_dim = {self.embedding_dim},",
594
+ ")",
595
+ ]
596
+ return "\n".join(lines)