torchtextclassifiers 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  import os
2
3
  from typing import List, Union
3
4
 
@@ -8,6 +9,7 @@ from torch.utils.data import DataLoader, Dataset
8
9
  from torchTextClassifiers.tokenizers import BaseTokenizer
9
10
 
10
11
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+ logger = logging.getLogger(__name__)
11
13
 
12
14
 
13
15
  class TextClassificationDataset(Dataset):
@@ -16,7 +18,8 @@ class TextClassificationDataset(Dataset):
16
18
  texts: List[str],
17
19
  categorical_variables: Union[List[List[int]], np.array, None],
18
20
  tokenizer: BaseTokenizer,
19
- labels: Union[List[int], None] = None,
21
+ labels: Union[List[int], List[List[int]], np.array, None] = None,
22
+ ragged_multilabel: bool = False,
20
23
  ):
21
24
  self.categorical_variables = categorical_variables
22
25
 
@@ -32,6 +35,23 @@ class TextClassificationDataset(Dataset):
32
35
  self.texts = texts
33
36
  self.tokenizer = tokenizer
34
37
  self.labels = labels
38
+ self.ragged_multilabel = ragged_multilabel
39
+
40
+ if self.ragged_multilabel and self.labels is not None:
41
+ max_value = int(max(max(row) for row in labels if row))
42
+ self.num_classes = max_value + 1
43
+
44
+ if max_value == 1:
45
+ try:
46
+ labels = np.array(labels)
47
+ logger.critical(
48
+ """ragged_multilabel set to True but max label value is 1 and all samples have the same number of labels.
49
+ If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."""
50
+ )
51
+ except ValueError:
52
+ logger.warning(
53
+ "ragged_multilabel set to True but max label value is 1. If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."
54
+ )
35
55
 
36
56
  def __len__(self):
37
57
  return len(self.texts)
@@ -59,10 +79,28 @@ class TextClassificationDataset(Dataset):
59
79
  )
60
80
 
61
81
  def collate_fn(self, batch):
62
- text, *categorical_vars, y = zip(*batch)
82
+ text, *categorical_vars, labels = zip(*batch)
63
83
 
64
84
  if self.labels is not None:
65
- labels_tensor = torch.tensor(y, dtype=torch.long)
85
+ if self.ragged_multilabel:
86
+ # Pad labels to the max length in the batch
87
+ labels_padded = torch.nn.utils.rnn.pad_sequence(
88
+ [torch.tensor(label) for label in labels],
89
+ batch_first=True,
90
+ padding_value=-1, # use impossible class
91
+ ).int()
92
+
93
+ labels_tensor = torch.zeros(labels_padded.size(0), 6).float()
94
+ mask = labels_padded != -1
95
+
96
+ batch_size = labels_padded.size(0)
97
+ rows = torch.arange(batch_size).unsqueeze(1).expand_as(labels_padded)[mask]
98
+ cols = labels_padded[mask]
99
+
100
+ labels_tensor[rows, cols] = 1
101
+
102
+ else:
103
+ labels_tensor = torch.tensor(labels)
66
104
  else:
67
105
  labels_tensor = None
68
106
 
@@ -11,33 +11,51 @@ class ClassificationHead(nn.Module):
11
11
  num_classes: Optional[int] = None,
12
12
  net: Optional[nn.Module] = None,
13
13
  ):
14
+ """
15
+ Classification head for text classification tasks.
16
+ It is a nn.Module that can either be a simple Linear layer or a custom neural network module.
17
+
18
+ Args:
19
+ input_dim (int, optional): Dimension of the input features. Required if net is not provided.
20
+ num_classes (int, optional): Number of output classes. Required if net is not provided.
21
+ net (nn.Module, optional): Custom neural network module to be used as the classification head.
22
+ If provided, input_dim and num_classes are inferred from this module.
23
+ Should be either an nn.Sequential with first and last layers being Linears or nn.Linear.
24
+ """
14
25
  super().__init__()
15
26
  if net is not None:
16
27
  self.net = net
17
- self.input_dim = net.in_features
18
- self.num_classes = net.out_features
28
+
29
+ # --- Custom net should either be a Sequential or a Linear ---
30
+ if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)):
31
+ raise ValueError("net must be an nn.Sequential when provided.")
32
+
33
+ # --- If Sequential, Check first and last layers are Linear ---
34
+
35
+ if isinstance(net, nn.Sequential):
36
+ first = net[0]
37
+ last = net[-1]
38
+
39
+ if not isinstance(first, nn.Linear):
40
+ raise TypeError(f"First layer must be nn.Linear, got {type(first).__name__}.")
41
+
42
+ if not isinstance(last, nn.Linear):
43
+ raise TypeError(f"Last layer must be nn.Linear, got {type(last).__name__}.")
44
+
45
+ # --- Extract features ---
46
+ self.input_dim = first.in_features
47
+ self.num_classes = last.out_features
48
+ else: # if not Sequential, it is a Linear
49
+ self.input_dim = net.in_features
50
+ self.num_classes = net.out_features
51
+
19
52
  else:
20
53
  assert (
21
54
  input_dim is not None and num_classes is not None
22
55
  ), "Either net or both input_dim and num_classes must be provided."
23
56
  self.net = nn.Linear(input_dim, num_classes)
24
- self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net)
57
+ self.input_dim = input_dim
58
+ self.num_classes = num_classes
25
59
 
26
60
  def forward(self, x: torch.Tensor) -> torch.Tensor:
27
61
  return self.net(x)
28
-
29
- @staticmethod
30
- def _get_linear_input_output_dims(module: nn.Module):
31
- """
32
- Returns (input_dim, output_dim) for any module containing Linear layers.
33
- Works for Linear, Sequential, or nested models.
34
- """
35
- # Collect all Linear layers recursively
36
- linears = [m for m in module.modules() if isinstance(m, nn.Linear)]
37
-
38
- if not linears:
39
- raise ValueError("No Linear layers found in the given module.")
40
-
41
- input_dim = linears[0].in_features
42
- output_dim = linears[-1].out_features
43
- return input_dim, output_dim
@@ -76,6 +76,10 @@ class TextClassificationModule(pl.LightningModule):
76
76
  targets = batch["labels"]
77
77
 
78
78
  outputs = self.forward(batch)
79
+
80
+ if isinstance(self.loss, torch.nn.BCEWithLogitsLoss):
81
+ targets = targets.float()
82
+
79
83
  loss = self.loss(outputs, targets)
80
84
  self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
81
85
  accuracy = self.accuracy_fn(outputs, targets)
@@ -100,6 +100,7 @@ class torchTextClassifiers:
100
100
  self,
101
101
  tokenizer: BaseTokenizer,
102
102
  model_config: ModelConfig,
103
+ ragged_multilabel: bool = False,
103
104
  ):
104
105
  """Initialize the torchTextClassifiers instance.
105
106
 
@@ -124,6 +125,7 @@ class torchTextClassifiers:
124
125
 
125
126
  self.model_config = model_config
126
127
  self.tokenizer = tokenizer
128
+ self.ragged_multilabel = ragged_multilabel
127
129
 
128
130
  if hasattr(self.tokenizer, "trained"):
129
131
  if not self.tokenizer.trained:
@@ -182,9 +184,9 @@ class torchTextClassifiers:
182
184
  self,
183
185
  X_train: np.ndarray,
184
186
  y_train: np.ndarray,
185
- X_val: np.ndarray,
186
- y_val: np.ndarray,
187
187
  training_config: TrainingConfig,
188
+ X_val: Optional[np.ndarray] = None,
189
+ y_val: Optional[np.ndarray] = None,
188
190
  verbose: bool = False,
189
191
  ) -> None:
190
192
  """Train the classifier using PyTorch Lightning.
@@ -222,7 +224,14 @@ class torchTextClassifiers:
222
224
  """
223
225
  # Input validation
224
226
  X_train, y_train = self._check_XY(X_train, y_train)
225
- X_val, y_val = self._check_XY(X_val, y_val)
227
+
228
+ if X_val is not None:
229
+ assert y_val is not None, "y_val must be provided if X_val is provided."
230
+ if y_val is not None:
231
+ assert X_val is not None, "X_val must be provided if y_val is provided."
232
+
233
+ if X_val is not None and y_val is not None:
234
+ X_val, y_val = self._check_XY(X_val, y_val)
226
235
 
227
236
  if (
228
237
  X_train["categorical_variables"] is not None
@@ -249,6 +258,11 @@ class torchTextClassifiers:
249
258
  if training_config.optimizer_params is not None:
250
259
  optimizer_params.update(training_config.optimizer_params)
251
260
 
261
+ if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel:
262
+ logger.warning(
263
+ "⚠️ You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks."
264
+ )
265
+
252
266
  self.lightning_module = TextClassificationModule(
253
267
  model=self.pytorch_model,
254
268
  loss=training_config.loss,
@@ -270,38 +284,43 @@ class torchTextClassifiers:
270
284
  texts=X_train["text"],
271
285
  categorical_variables=X_train["categorical_variables"], # None if no cat vars
272
286
  tokenizer=self.tokenizer,
273
- labels=y_train,
287
+ labels=y_train.tolist(),
288
+ ragged_multilabel=self.ragged_multilabel,
274
289
  )
275
- val_dataset = TextClassificationDataset(
276
- texts=X_val["text"],
277
- categorical_variables=X_val["categorical_variables"], # None if no cat vars
278
- tokenizer=self.tokenizer,
279
- labels=y_val,
280
- )
281
-
282
290
  train_dataloader = train_dataset.create_dataloader(
283
291
  batch_size=training_config.batch_size,
284
292
  num_workers=training_config.num_workers,
285
293
  shuffle=True,
286
294
  **training_config.dataloader_params if training_config.dataloader_params else {},
287
295
  )
288
- val_dataloader = val_dataset.create_dataloader(
289
- batch_size=training_config.batch_size,
290
- num_workers=training_config.num_workers,
291
- shuffle=False,
292
- **training_config.dataloader_params if training_config.dataloader_params else {},
293
- )
296
+
297
+ if X_val is not None and y_val is not None:
298
+ val_dataset = TextClassificationDataset(
299
+ texts=X_val["text"],
300
+ categorical_variables=X_val["categorical_variables"], # None if no cat vars
301
+ tokenizer=self.tokenizer,
302
+ labels=y_val,
303
+ ragged_multilabel=self.ragged_multilabel,
304
+ )
305
+ val_dataloader = val_dataset.create_dataloader(
306
+ batch_size=training_config.batch_size,
307
+ num_workers=training_config.num_workers,
308
+ shuffle=False,
309
+ **training_config.dataloader_params if training_config.dataloader_params else {},
310
+ )
311
+ else:
312
+ val_dataloader = None
294
313
 
295
314
  # Setup trainer
296
315
  callbacks = [
297
316
  ModelCheckpoint(
298
- monitor="val_loss",
317
+ monitor="val_loss" if val_dataloader is not None else "train_loss",
299
318
  save_top_k=1,
300
319
  save_last=False,
301
320
  mode="min",
302
321
  ),
303
322
  EarlyStopping(
304
- monitor="val_loss",
323
+ monitor="val_loss" if val_dataloader is not None else "train_loss",
305
324
  patience=training_config.patience_early_stopping,
306
325
  mode="min",
307
326
  ),
@@ -352,7 +371,7 @@ class torchTextClassifiers:
352
371
  X = self._check_X(X)
353
372
  Y = self._check_Y(Y)
354
373
 
355
- if X["text"].shape[0] != Y.shape[0]:
374
+ if X["text"].shape[0] != len(Y):
356
375
  raise ValueError("X_train and y_train must have the same number of observations.")
357
376
 
358
377
  return X, Y
@@ -422,22 +441,32 @@ class torchTextClassifiers:
422
441
  return {"text": text, "categorical_variables": categorical_variables}
423
442
 
424
443
  def _check_Y(self, Y):
425
- assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
426
- assert len(Y.shape) == 1 or (
427
- len(Y.shape) == 2 and Y.shape[1] == 1
428
- ), "Y must be a numpy array of shape (N,) or (N,1)."
444
+ if self.ragged_multilabel:
445
+ assert isinstance(
446
+ Y, list
447
+ ), "Y must be a list of lists for ragged multilabel classification."
448
+ for row in Y:
449
+ assert isinstance(row, list), "Each element of Y must be a list of labels."
429
450
 
430
- try:
431
- Y = Y.astype(int)
432
- except ValueError:
433
- logger.error("Y must be castable in integer format.")
451
+ return Y
434
452
 
435
- if Y.max() >= self.num_classes or Y.min() < 0:
436
- raise ValueError(
437
- f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
438
- )
453
+ else:
454
+ assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
455
+ assert (
456
+ len(Y.shape) == 1 or len(Y.shape) == 2
457
+ ), "Y must be a numpy array of shape (N,) or (N, num_labels)."
458
+
459
+ try:
460
+ Y = Y.astype(int)
461
+ except ValueError:
462
+ logger.error("Y must be castable in integer format.")
463
+
464
+ if Y.max() >= self.num_classes or Y.min() < 0:
465
+ raise ValueError(
466
+ f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
467
+ )
439
468
 
440
- return Y
469
+ return Y
441
470
 
442
471
  def predict(
443
472
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: torchtextclassifiers
3
- Version: 0.1.0
3
+ Version: 1.0.0
4
4
  Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
5
5
  Keywords: fastText,text classification,NLP,automatic coding,deep learning
6
6
  Author: Cédric Couralet, Meilame Tayebjee
@@ -26,15 +26,18 @@ Description-Content-Type: text/markdown
26
26
 
27
27
  # torchTextClassifiers
28
28
 
29
+ [![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://inseefrlab.github.io/torchTextClassifiers/)
30
+
29
31
  A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
30
32
 
31
33
  ## 🚀 Features
32
34
 
33
- - **Mixed input support**: Handle text data alongside categorical variables seamlessly.
35
+ - **Complex input support**: Handle text data alongside categorical variables seamlessly.
34
36
  - **Unified yet highly customizable**:
35
37
  - Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
36
38
  - Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
37
39
  - The `TextClassificationModel` class combines these components and can be extended for custom behavior.
40
+ - **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
38
41
  - **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
39
42
  - **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
40
43
  - The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
@@ -55,6 +58,15 @@ uv sync
55
58
  pip install -e .
56
59
  ```
57
60
 
61
+ ## 📖 Documentation
62
+
63
+ Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
64
+ The documentation includes:
65
+ - **Getting Started**: Installation and quick start guide
66
+ - **Architecture**: Understanding the 3-layer design
67
+ - **Tutorials**: Step-by-step guides for different use cases
68
+ - **API Reference**: Complete API documentation
69
+
58
70
  ## 📝 Usage
59
71
 
60
72
  Checkout the [notebook](notebooks/example.ipynb) for a quick start.
@@ -71,3 +83,5 @@ See the [examples/](examples/) directory for:
71
83
  ## 📄 License
72
84
 
73
85
  This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
86
+
87
+
@@ -1,21 +1,21 @@
1
1
  torchTextClassifiers/__init__.py,sha256=TM2AjZ4KDqpgwMKiT0X5daNZvLDj6WECz_OFf8M4lgA,906
2
2
  torchTextClassifiers/dataset/__init__.py,sha256=dyCz48pO6zRC-2qh4753Hj70W2MZGXdX3RbgutvyOng,76
3
- torchTextClassifiers/dataset/dataset.py,sha256=n7V4JNtcuqb2ugx7hxkAohEPHqEuxv46jYU47KiUbno,3295
3
+ torchTextClassifiers/dataset/dataset.py,sha256=sQm6msPinr8pbyO1yKXFVVmfSIStrWe0bBOflwWO4iE,5101
4
4
  torchTextClassifiers/model/__init__.py,sha256=lFY1Mb1J0tFhe4_PsDOEHhnVl3dXj59K4Zxnwy2KkS4,146
5
5
  torchTextClassifiers/model/components/__init__.py,sha256=-IT_6fCHZkRw6Hu7GdVeCt685P4PuGaY6VdYQV5M8mE,447
6
6
  torchTextClassifiers/model/components/attention.py,sha256=hhSMh_CvpR-hiP8hoCg4Fr_TovGlJpC_RHs3iW-Pnpc,4199
7
7
  torchTextClassifiers/model/components/categorical_var_net.py,sha256=no0QDidKCw1rlbJzD7S-Srhzn5P6vETGRT5Er-gzMnM,5699
8
- torchTextClassifiers/model/components/classification_head.py,sha256=lPndu5FPC-bOZ2H4Yq0EnzWrOzPFJdBb_KUx5wyZBb4,1445
8
+ torchTextClassifiers/model/components/classification_head.py,sha256=myuEc5wFQ5gw_f519cUZ1Z7AMuQF7Vshq_B3aRt5xRE,2501
9
9
  torchTextClassifiers/model/components/text_embedder.py,sha256=tY2pXAt4IvayyvRpjiKGg5vGz_Q2-p_TOL6Jg2p8hYE,9058
10
- torchTextClassifiers/model/lightning.py,sha256=z5mq10_hNp-UK66Aqpcablg3BDYnjF9Gch0HaGoJ6cM,5265
10
+ torchTextClassifiers/model/lightning.py,sha256=dOJzyGbqwFxriAtrIjC14E1f107YMtpiR65-OJy_Pc4,5367
11
11
  torchTextClassifiers/model/model.py,sha256=jjGjvK7C2Wly0e4S6gTC8Ty8y-o8reU-aniBqYS73Cc,6100
12
12
  torchTextClassifiers/tokenizers/WordPiece.py,sha256=HMHYV2SiwShlhWMQ6LXH4MtZE5GSsaNA2DlD340ABGE,3289
13
13
  torchTextClassifiers/tokenizers/__init__.py,sha256=I8IQ2-t85RVlZFwLjDFF_Te2S9uiwlymQDWx-3GeF-Y,334
14
14
  torchTextClassifiers/tokenizers/base.py,sha256=OY6GIhI4KTdvvKq3VZowf64H7lAmdQyq4scZ10HxP3A,7570
15
15
  torchTextClassifiers/tokenizers/ngram.py,sha256=lHI8dtuCGWh0o7V58TJx_mTVIHm8udl6XuWccxgJPew,16375
16
- torchTextClassifiers/torchTextClassifiers.py,sha256=E2XVGAky_SMAw6BAMswA3c08rKyOpGEW_dv1BqQlJrU,21141
16
+ torchTextClassifiers/torchTextClassifiers.py,sha256=ZbFWiobiqNFYCXj3Ht8jLQJDIozApZ5n__E7V64P5q4,22631
17
17
  torchTextClassifiers/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  torchTextClassifiers/utilities/plot_explainability.py,sha256=8YhyiMupdiIZp4jT7uvlcJNf69Fyr9HXfjUiNyMSYxE,6931
19
- torchtextclassifiers-0.1.0.dist-info/WHEEL,sha256=ELhySV62sOro8I5wRaLaF3TWxhBpkcDkdZUdAYLy_Hk,78
20
- torchtextclassifiers-0.1.0.dist-info/METADATA,sha256=fvPTUIS-M4LgURVzC1CUTb8IrKyZiBzWRAE1heTafEE,2988
21
- torchtextclassifiers-0.1.0.dist-info/RECORD,,
19
+ torchtextclassifiers-1.0.0.dist-info/WHEEL,sha256=AaqHSNJgTyoT6I9ETCXrbV_7cVSjA_q07lkDGeNjGdQ,79
20
+ torchtextclassifiers-1.0.0.dist-info/METADATA,sha256=4loF6DgPJCHVIcBu8gAzvYQonhoZhTzmuMk0yqUAPrc,3666
21
+ torchtextclassifiers-1.0.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.3
2
+ Generator: uv 0.9.12
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any