torchtextclassifiers 0.1.0__tar.gz → 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/PKG-INFO +16 -2
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/README.md +15 -1
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/pyproject.toml +12 -8
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/dataset/dataset.py +41 -3
- torchtextclassifiers-1.0.1/torchTextClassifiers/model/components/classification_head.py +61 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/lightning.py +4 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/torchTextClassifiers.py +69 -33
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/utilities/plot_explainability.py +17 -7
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/components/classification_head.py +0 -43
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/dataset/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/components/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/components/attention.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/components/categorical_var_net.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/components/text_embedder.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/model.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/tokenizers/WordPiece.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/tokenizers/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/tokenizers/base.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/tokenizers/ngram.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/utilities/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: torchtextclassifiers
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 1.0.1
|
|
4
4
|
Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
|
|
5
5
|
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
6
|
Author: Cédric Couralet, Meilame Tayebjee
|
|
@@ -26,15 +26,18 @@ Description-Content-Type: text/markdown
|
|
|
26
26
|
|
|
27
27
|
# torchTextClassifiers
|
|
28
28
|
|
|
29
|
+
[](https://inseefrlab.github.io/torchTextClassifiers/)
|
|
30
|
+
|
|
29
31
|
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
30
32
|
|
|
31
33
|
## 🚀 Features
|
|
32
34
|
|
|
33
|
-
- **
|
|
35
|
+
- **Complex input support**: Handle text data alongside categorical variables seamlessly.
|
|
34
36
|
- **Unified yet highly customizable**:
|
|
35
37
|
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
|
|
36
38
|
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
|
|
37
39
|
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
|
|
40
|
+
- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
|
|
38
41
|
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
39
42
|
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
|
|
40
43
|
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
|
|
@@ -55,6 +58,15 @@ uv sync
|
|
|
55
58
|
pip install -e .
|
|
56
59
|
```
|
|
57
60
|
|
|
61
|
+
## 📖 Documentation
|
|
62
|
+
|
|
63
|
+
Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
|
|
64
|
+
The documentation includes:
|
|
65
|
+
- **Getting Started**: Installation and quick start guide
|
|
66
|
+
- **Architecture**: Understanding the 3-layer design
|
|
67
|
+
- **Tutorials**: Step-by-step guides for different use cases
|
|
68
|
+
- **API Reference**: Complete API documentation
|
|
69
|
+
|
|
58
70
|
## 📝 Usage
|
|
59
71
|
|
|
60
72
|
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
|
|
@@ -71,3 +83,5 @@ See the [examples/](examples/) directory for:
|
|
|
71
83
|
## 📄 License
|
|
72
84
|
|
|
73
85
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
86
|
+
|
|
87
|
+
|
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
# torchTextClassifiers
|
|
2
2
|
|
|
3
|
+
[](https://inseefrlab.github.io/torchTextClassifiers/)
|
|
4
|
+
|
|
3
5
|
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
4
6
|
|
|
5
7
|
## 🚀 Features
|
|
6
8
|
|
|
7
|
-
- **
|
|
9
|
+
- **Complex input support**: Handle text data alongside categorical variables seamlessly.
|
|
8
10
|
- **Unified yet highly customizable**:
|
|
9
11
|
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
|
|
10
12
|
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
|
|
11
13
|
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
|
|
14
|
+
- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
|
|
12
15
|
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
13
16
|
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
|
|
14
17
|
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
|
|
@@ -29,6 +32,15 @@ uv sync
|
|
|
29
32
|
pip install -e .
|
|
30
33
|
```
|
|
31
34
|
|
|
35
|
+
## 📖 Documentation
|
|
36
|
+
|
|
37
|
+
Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
|
|
38
|
+
The documentation includes:
|
|
39
|
+
- **Getting Started**: Installation and quick start guide
|
|
40
|
+
- **Architecture**: Understanding the 3-layer design
|
|
41
|
+
- **Tutorials**: Step-by-step guides for different use cases
|
|
42
|
+
- **API Reference**: Complete API documentation
|
|
43
|
+
|
|
32
44
|
## 📝 Usage
|
|
33
45
|
|
|
34
46
|
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
|
|
@@ -45,3 +57,5 @@ See the [examples/](examples/) directory for:
|
|
|
45
57
|
## 📄 License
|
|
46
58
|
|
|
47
59
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
60
|
+
|
|
61
|
+
|
|
@@ -18,12 +18,12 @@ dependencies = [
|
|
|
18
18
|
"pytorch-lightning>=2.4.0",
|
|
19
19
|
]
|
|
20
20
|
requires-python = ">=3.11"
|
|
21
|
-
version="0.1
|
|
21
|
+
version="1.0.1"
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
[dependency-groups]
|
|
25
25
|
dev = [
|
|
26
|
-
"pytest >=
|
|
26
|
+
"pytest >=9.0.1,<10",
|
|
27
27
|
"pandas",
|
|
28
28
|
"scikit-learn",
|
|
29
29
|
"nltk",
|
|
@@ -35,13 +35,17 @@ dev = [
|
|
|
35
35
|
"ipywidgets>=8.1.8",
|
|
36
36
|
]
|
|
37
37
|
docs = [
|
|
38
|
-
"sphinx>=
|
|
39
|
-
"sphinx-
|
|
40
|
-
"sphinx-autodoc-typehints>=
|
|
41
|
-
"sphinxcontrib-napoleon>=0.7",
|
|
38
|
+
"sphinx>=8.1.0",
|
|
39
|
+
"pydata-sphinx-theme>=0.16.0",
|
|
40
|
+
"sphinx-autodoc-typehints>=2.0.0",
|
|
42
41
|
"sphinx-copybutton>=0.5.0",
|
|
43
|
-
"myst-parser>=0.
|
|
44
|
-
"sphinx-design>=0.
|
|
42
|
+
"myst-parser>=4.0.0",
|
|
43
|
+
"sphinx-design>=0.6.0",
|
|
44
|
+
"nbsphinx>=0.9.0",
|
|
45
|
+
"ipython>=8.0.0",
|
|
46
|
+
"pandoc>=2.0.0",
|
|
47
|
+
"linkify-it-py>=2.0.0",
|
|
48
|
+
"sphinxcontrib-images>=1.0.1"
|
|
45
49
|
]
|
|
46
50
|
|
|
47
51
|
[project.optional-dependencies]
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/dataset/dataset.py
RENAMED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
from typing import List, Union
|
|
3
4
|
|
|
@@ -8,6 +9,7 @@ from torch.utils.data import DataLoader, Dataset
|
|
|
8
9
|
from torchTextClassifiers.tokenizers import BaseTokenizer
|
|
9
10
|
|
|
10
11
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class TextClassificationDataset(Dataset):
|
|
@@ -16,7 +18,8 @@ class TextClassificationDataset(Dataset):
|
|
|
16
18
|
texts: List[str],
|
|
17
19
|
categorical_variables: Union[List[List[int]], np.array, None],
|
|
18
20
|
tokenizer: BaseTokenizer,
|
|
19
|
-
labels: Union[List[int], None] = None,
|
|
21
|
+
labels: Union[List[int], List[List[int]], np.array, None] = None,
|
|
22
|
+
ragged_multilabel: bool = False,
|
|
20
23
|
):
|
|
21
24
|
self.categorical_variables = categorical_variables
|
|
22
25
|
|
|
@@ -32,6 +35,23 @@ class TextClassificationDataset(Dataset):
|
|
|
32
35
|
self.texts = texts
|
|
33
36
|
self.tokenizer = tokenizer
|
|
34
37
|
self.labels = labels
|
|
38
|
+
self.ragged_multilabel = ragged_multilabel
|
|
39
|
+
|
|
40
|
+
if self.ragged_multilabel and self.labels is not None:
|
|
41
|
+
max_value = int(max(max(row) for row in labels if row))
|
|
42
|
+
self.num_classes = max_value + 1
|
|
43
|
+
|
|
44
|
+
if max_value == 1:
|
|
45
|
+
try:
|
|
46
|
+
labels = np.array(labels)
|
|
47
|
+
logger.critical(
|
|
48
|
+
"""ragged_multilabel set to True but max label value is 1 and all samples have the same number of labels.
|
|
49
|
+
If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."""
|
|
50
|
+
)
|
|
51
|
+
except ValueError:
|
|
52
|
+
logger.warning(
|
|
53
|
+
"ragged_multilabel set to True but max label value is 1. If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."
|
|
54
|
+
)
|
|
35
55
|
|
|
36
56
|
def __len__(self):
|
|
37
57
|
return len(self.texts)
|
|
@@ -59,10 +79,28 @@ class TextClassificationDataset(Dataset):
|
|
|
59
79
|
)
|
|
60
80
|
|
|
61
81
|
def collate_fn(self, batch):
|
|
62
|
-
text, *categorical_vars,
|
|
82
|
+
text, *categorical_vars, labels = zip(*batch)
|
|
63
83
|
|
|
64
84
|
if self.labels is not None:
|
|
65
|
-
|
|
85
|
+
if self.ragged_multilabel:
|
|
86
|
+
# Pad labels to the max length in the batch
|
|
87
|
+
labels_padded = torch.nn.utils.rnn.pad_sequence(
|
|
88
|
+
[torch.tensor(label) for label in labels],
|
|
89
|
+
batch_first=True,
|
|
90
|
+
padding_value=-1, # use impossible class
|
|
91
|
+
).int()
|
|
92
|
+
|
|
93
|
+
labels_tensor = torch.zeros(labels_padded.size(0), 6).float()
|
|
94
|
+
mask = labels_padded != -1
|
|
95
|
+
|
|
96
|
+
batch_size = labels_padded.size(0)
|
|
97
|
+
rows = torch.arange(batch_size).unsqueeze(1).expand_as(labels_padded)[mask]
|
|
98
|
+
cols = labels_padded[mask]
|
|
99
|
+
|
|
100
|
+
labels_tensor[rows, cols] = 1
|
|
101
|
+
|
|
102
|
+
else:
|
|
103
|
+
labels_tensor = torch.tensor(labels)
|
|
66
104
|
else:
|
|
67
105
|
labels_tensor = None
|
|
68
106
|
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ClassificationHead(nn.Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
input_dim: Optional[int] = None,
|
|
11
|
+
num_classes: Optional[int] = None,
|
|
12
|
+
net: Optional[nn.Module] = None,
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Classification head for text classification tasks.
|
|
16
|
+
It is a nn.Module that can either be a simple Linear layer or a custom neural network module.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
input_dim (int, optional): Dimension of the input features. Required if net is not provided.
|
|
20
|
+
num_classes (int, optional): Number of output classes. Required if net is not provided.
|
|
21
|
+
net (nn.Module, optional): Custom neural network module to be used as the classification head.
|
|
22
|
+
If provided, input_dim and num_classes are inferred from this module.
|
|
23
|
+
Should be either an nn.Sequential with first and last layers being Linears or nn.Linear.
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
if net is not None:
|
|
27
|
+
self.net = net
|
|
28
|
+
|
|
29
|
+
# --- Custom net should either be a Sequential or a Linear ---
|
|
30
|
+
if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)):
|
|
31
|
+
raise ValueError("net must be an nn.Sequential when provided.")
|
|
32
|
+
|
|
33
|
+
# --- If Sequential, Check first and last layers are Linear ---
|
|
34
|
+
|
|
35
|
+
if isinstance(net, nn.Sequential):
|
|
36
|
+
first = net[0]
|
|
37
|
+
last = net[-1]
|
|
38
|
+
|
|
39
|
+
if not isinstance(first, nn.Linear):
|
|
40
|
+
raise TypeError(f"First layer must be nn.Linear, got {type(first).__name__}.")
|
|
41
|
+
|
|
42
|
+
if not isinstance(last, nn.Linear):
|
|
43
|
+
raise TypeError(f"Last layer must be nn.Linear, got {type(last).__name__}.")
|
|
44
|
+
|
|
45
|
+
# --- Extract features ---
|
|
46
|
+
self.input_dim = first.in_features
|
|
47
|
+
self.num_classes = last.out_features
|
|
48
|
+
else: # if not Sequential, it is a Linear
|
|
49
|
+
self.input_dim = net.in_features
|
|
50
|
+
self.num_classes = net.out_features
|
|
51
|
+
|
|
52
|
+
else:
|
|
53
|
+
assert (
|
|
54
|
+
input_dim is not None and num_classes is not None
|
|
55
|
+
), "Either net or both input_dim and num_classes must be provided."
|
|
56
|
+
self.net = nn.Linear(input_dim, num_classes)
|
|
57
|
+
self.input_dim = input_dim
|
|
58
|
+
self.num_classes = num_classes
|
|
59
|
+
|
|
60
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
return self.net(x)
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/lightning.py
RENAMED
|
@@ -76,6 +76,10 @@ class TextClassificationModule(pl.LightningModule):
|
|
|
76
76
|
targets = batch["labels"]
|
|
77
77
|
|
|
78
78
|
outputs = self.forward(batch)
|
|
79
|
+
|
|
80
|
+
if isinstance(self.loss, torch.nn.BCEWithLogitsLoss):
|
|
81
|
+
targets = targets.float()
|
|
82
|
+
|
|
79
83
|
loss = self.loss(outputs, targets)
|
|
80
84
|
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
|
|
81
85
|
accuracy = self.accuracy_fn(outputs, targets)
|
|
@@ -100,6 +100,7 @@ class torchTextClassifiers:
|
|
|
100
100
|
self,
|
|
101
101
|
tokenizer: BaseTokenizer,
|
|
102
102
|
model_config: ModelConfig,
|
|
103
|
+
ragged_multilabel: bool = False,
|
|
103
104
|
):
|
|
104
105
|
"""Initialize the torchTextClassifiers instance.
|
|
105
106
|
|
|
@@ -124,6 +125,7 @@ class torchTextClassifiers:
|
|
|
124
125
|
|
|
125
126
|
self.model_config = model_config
|
|
126
127
|
self.tokenizer = tokenizer
|
|
128
|
+
self.ragged_multilabel = ragged_multilabel
|
|
127
129
|
|
|
128
130
|
if hasattr(self.tokenizer, "trained"):
|
|
129
131
|
if not self.tokenizer.trained:
|
|
@@ -182,9 +184,9 @@ class torchTextClassifiers:
|
|
|
182
184
|
self,
|
|
183
185
|
X_train: np.ndarray,
|
|
184
186
|
y_train: np.ndarray,
|
|
185
|
-
X_val: np.ndarray,
|
|
186
|
-
y_val: np.ndarray,
|
|
187
187
|
training_config: TrainingConfig,
|
|
188
|
+
X_val: Optional[np.ndarray] = None,
|
|
189
|
+
y_val: Optional[np.ndarray] = None,
|
|
188
190
|
verbose: bool = False,
|
|
189
191
|
) -> None:
|
|
190
192
|
"""Train the classifier using PyTorch Lightning.
|
|
@@ -196,6 +198,12 @@ class torchTextClassifiers:
|
|
|
196
198
|
- Model training with early stopping
|
|
197
199
|
- Best model loading after training
|
|
198
200
|
|
|
201
|
+
Note on Checkpoints:
|
|
202
|
+
After training, the best model checkpoint is automatically loaded.
|
|
203
|
+
This checkpoint contains the full training state (model weights,
|
|
204
|
+
optimizer, and scheduler state). Loading uses weights_only=False
|
|
205
|
+
as the checkpoint is self-generated and trusted.
|
|
206
|
+
|
|
199
207
|
Args:
|
|
200
208
|
X_train: Training input data
|
|
201
209
|
y_train: Training labels
|
|
@@ -222,7 +230,14 @@ class torchTextClassifiers:
|
|
|
222
230
|
"""
|
|
223
231
|
# Input validation
|
|
224
232
|
X_train, y_train = self._check_XY(X_train, y_train)
|
|
225
|
-
|
|
233
|
+
|
|
234
|
+
if X_val is not None:
|
|
235
|
+
assert y_val is not None, "y_val must be provided if X_val is provided."
|
|
236
|
+
if y_val is not None:
|
|
237
|
+
assert X_val is not None, "X_val must be provided if y_val is provided."
|
|
238
|
+
|
|
239
|
+
if X_val is not None and y_val is not None:
|
|
240
|
+
X_val, y_val = self._check_XY(X_val, y_val)
|
|
226
241
|
|
|
227
242
|
if (
|
|
228
243
|
X_train["categorical_variables"] is not None
|
|
@@ -249,6 +264,11 @@ class torchTextClassifiers:
|
|
|
249
264
|
if training_config.optimizer_params is not None:
|
|
250
265
|
optimizer_params.update(training_config.optimizer_params)
|
|
251
266
|
|
|
267
|
+
if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel:
|
|
268
|
+
logger.warning(
|
|
269
|
+
"⚠️ You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks."
|
|
270
|
+
)
|
|
271
|
+
|
|
252
272
|
self.lightning_module = TextClassificationModule(
|
|
253
273
|
model=self.pytorch_model,
|
|
254
274
|
loss=training_config.loss,
|
|
@@ -270,38 +290,43 @@ class torchTextClassifiers:
|
|
|
270
290
|
texts=X_train["text"],
|
|
271
291
|
categorical_variables=X_train["categorical_variables"], # None if no cat vars
|
|
272
292
|
tokenizer=self.tokenizer,
|
|
273
|
-
labels=y_train,
|
|
274
|
-
|
|
275
|
-
val_dataset = TextClassificationDataset(
|
|
276
|
-
texts=X_val["text"],
|
|
277
|
-
categorical_variables=X_val["categorical_variables"], # None if no cat vars
|
|
278
|
-
tokenizer=self.tokenizer,
|
|
279
|
-
labels=y_val,
|
|
293
|
+
labels=y_train.tolist(),
|
|
294
|
+
ragged_multilabel=self.ragged_multilabel,
|
|
280
295
|
)
|
|
281
|
-
|
|
282
296
|
train_dataloader = train_dataset.create_dataloader(
|
|
283
297
|
batch_size=training_config.batch_size,
|
|
284
298
|
num_workers=training_config.num_workers,
|
|
285
299
|
shuffle=True,
|
|
286
300
|
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
287
301
|
)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
302
|
+
|
|
303
|
+
if X_val is not None and y_val is not None:
|
|
304
|
+
val_dataset = TextClassificationDataset(
|
|
305
|
+
texts=X_val["text"],
|
|
306
|
+
categorical_variables=X_val["categorical_variables"], # None if no cat vars
|
|
307
|
+
tokenizer=self.tokenizer,
|
|
308
|
+
labels=y_val,
|
|
309
|
+
ragged_multilabel=self.ragged_multilabel,
|
|
310
|
+
)
|
|
311
|
+
val_dataloader = val_dataset.create_dataloader(
|
|
312
|
+
batch_size=training_config.batch_size,
|
|
313
|
+
num_workers=training_config.num_workers,
|
|
314
|
+
shuffle=False,
|
|
315
|
+
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
316
|
+
)
|
|
317
|
+
else:
|
|
318
|
+
val_dataloader = None
|
|
294
319
|
|
|
295
320
|
# Setup trainer
|
|
296
321
|
callbacks = [
|
|
297
322
|
ModelCheckpoint(
|
|
298
|
-
monitor="val_loss",
|
|
323
|
+
monitor="val_loss" if val_dataloader is not None else "train_loss",
|
|
299
324
|
save_top_k=1,
|
|
300
325
|
save_last=False,
|
|
301
326
|
mode="min",
|
|
302
327
|
),
|
|
303
328
|
EarlyStopping(
|
|
304
|
-
monitor="val_loss",
|
|
329
|
+
monitor="val_loss" if val_dataloader is not None else "train_loss",
|
|
305
330
|
patience=training_config.patience_early_stopping,
|
|
306
331
|
mode="min",
|
|
307
332
|
),
|
|
@@ -342,6 +367,7 @@ class torchTextClassifiers:
|
|
|
342
367
|
best_model_path,
|
|
343
368
|
model=self.pytorch_model,
|
|
344
369
|
loss=training_config.loss,
|
|
370
|
+
weights_only=False, # Required: checkpoint contains optimizer/scheduler state
|
|
345
371
|
)
|
|
346
372
|
|
|
347
373
|
self.pytorch_model = self.lightning_module.model.to(self.device)
|
|
@@ -352,7 +378,7 @@ class torchTextClassifiers:
|
|
|
352
378
|
X = self._check_X(X)
|
|
353
379
|
Y = self._check_Y(Y)
|
|
354
380
|
|
|
355
|
-
if X["text"].shape[0] != Y
|
|
381
|
+
if X["text"].shape[0] != len(Y):
|
|
356
382
|
raise ValueError("X_train and y_train must have the same number of observations.")
|
|
357
383
|
|
|
358
384
|
return X, Y
|
|
@@ -422,22 +448,32 @@ class torchTextClassifiers:
|
|
|
422
448
|
return {"text": text, "categorical_variables": categorical_variables}
|
|
423
449
|
|
|
424
450
|
def _check_Y(self, Y):
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
451
|
+
if self.ragged_multilabel:
|
|
452
|
+
assert isinstance(
|
|
453
|
+
Y, list
|
|
454
|
+
), "Y must be a list of lists for ragged multilabel classification."
|
|
455
|
+
for row in Y:
|
|
456
|
+
assert isinstance(row, list), "Each element of Y must be a list of labels."
|
|
429
457
|
|
|
430
|
-
|
|
431
|
-
Y = Y.astype(int)
|
|
432
|
-
except ValueError:
|
|
433
|
-
logger.error("Y must be castable in integer format.")
|
|
458
|
+
return Y
|
|
434
459
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
460
|
+
else:
|
|
461
|
+
assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
|
|
462
|
+
assert (
|
|
463
|
+
len(Y.shape) == 1 or len(Y.shape) == 2
|
|
464
|
+
), "Y must be a numpy array of shape (N,) or (N, num_labels)."
|
|
465
|
+
|
|
466
|
+
try:
|
|
467
|
+
Y = Y.astype(int)
|
|
468
|
+
except ValueError:
|
|
469
|
+
logger.error("Y must be castable in integer format.")
|
|
470
|
+
|
|
471
|
+
if Y.max() >= self.num_classes or Y.min() < 0:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
|
|
474
|
+
)
|
|
439
475
|
|
|
440
|
-
|
|
476
|
+
return Y
|
|
441
477
|
|
|
442
478
|
def predict(
|
|
443
479
|
self,
|
|
@@ -53,8 +53,18 @@ def map_attributions_to_char(attributions, offsets, text):
|
|
|
53
53
|
np.exp(attributions_per_char), axis=1, keepdims=True
|
|
54
54
|
) # softmax normalization
|
|
55
55
|
|
|
56
|
+
def get_id_to_word(text, word_ids, offsets):
|
|
57
|
+
words = {}
|
|
58
|
+
for idx, word_id in enumerate(word_ids):
|
|
59
|
+
if word_id is None:
|
|
60
|
+
continue
|
|
61
|
+
start, end = offsets[idx]
|
|
62
|
+
words[int(word_id)] = text[start:end]
|
|
63
|
+
|
|
64
|
+
return words
|
|
65
|
+
|
|
56
66
|
|
|
57
|
-
def map_attributions_to_word(attributions, word_ids):
|
|
67
|
+
def map_attributions_to_word(attributions, text, word_ids, offsets):
|
|
58
68
|
"""
|
|
59
69
|
Maps token-level attributions to word-level attributions based on word IDs.
|
|
60
70
|
Args:
|
|
@@ -69,8 +79,9 @@ def map_attributions_to_word(attributions, word_ids):
|
|
|
69
79
|
np.ndarray: Array of shape (top_k, num_words) containing word-level attributions.
|
|
70
80
|
num_words is the number of unique words in the original text.
|
|
71
81
|
"""
|
|
72
|
-
|
|
82
|
+
|
|
73
83
|
word_ids = np.array(word_ids)
|
|
84
|
+
words = get_id_to_word(text, word_ids, offsets)
|
|
74
85
|
|
|
75
86
|
# Convert None to -1 for easier processing (PAD tokens)
|
|
76
87
|
word_ids_int = np.array([x if x is not None else -1 for x in word_ids], dtype=int)
|
|
@@ -99,7 +110,7 @@ def map_attributions_to_word(attributions, word_ids):
|
|
|
99
110
|
) # zero-out non-matching tokens and sum attributions for all tokens belonging to the same word
|
|
100
111
|
|
|
101
112
|
# assert word_attributions.sum(axis=1) == attributions.sum(axis=1), "Sum of word attributions per top_k must equal sum of token attributions per top_k."
|
|
102
|
-
return np.exp(word_attributions) / np.sum(
|
|
113
|
+
return words, np.exp(word_attributions) / np.sum(
|
|
103
114
|
np.exp(word_attributions), axis=1, keepdims=True
|
|
104
115
|
) # softmax normalization
|
|
105
116
|
|
|
@@ -131,7 +142,7 @@ def plot_attributions_at_char(
|
|
|
131
142
|
fig, ax = plt.subplots(figsize=figsize)
|
|
132
143
|
ax.bar(range(len(text)), attributions_per_char[i])
|
|
133
144
|
ax.set_xticks(np.arange(len(text)))
|
|
134
|
-
ax.set_xticklabels(list(text), rotation=
|
|
145
|
+
ax.set_xticklabels(list(text), rotation=45)
|
|
135
146
|
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
|
|
136
147
|
ax.set_title(title)
|
|
137
148
|
ax.set_xlabel("Characters in Text")
|
|
@@ -142,7 +153,7 @@ def plot_attributions_at_char(
|
|
|
142
153
|
|
|
143
154
|
|
|
144
155
|
def plot_attributions_at_word(
|
|
145
|
-
text, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
|
|
156
|
+
text, words, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
|
|
146
157
|
):
|
|
147
158
|
"""
|
|
148
159
|
Plots word-level attributions as a heatmap.
|
|
@@ -159,14 +170,13 @@ def plot_attributions_at_word(
|
|
|
159
170
|
"matplotlib is required for plotting. Please install it to use this function."
|
|
160
171
|
)
|
|
161
172
|
|
|
162
|
-
words = text.split()
|
|
163
173
|
top_k = attributions_per_word.shape[0]
|
|
164
174
|
all_plots = []
|
|
165
175
|
for i in range(top_k):
|
|
166
176
|
fig, ax = plt.subplots(figsize=figsize)
|
|
167
177
|
ax.bar(range(len(words)), attributions_per_word[i])
|
|
168
178
|
ax.set_xticks(np.arange(len(words)))
|
|
169
|
-
ax.set_xticklabels(words, rotation=
|
|
179
|
+
ax.set_xticklabels(words, rotation=45)
|
|
170
180
|
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
|
|
171
181
|
ax.set_title(title)
|
|
172
182
|
ax.set_xlabel("Words in Text")
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class ClassificationHead(nn.Module):
|
|
8
|
-
def __init__(
|
|
9
|
-
self,
|
|
10
|
-
input_dim: Optional[int] = None,
|
|
11
|
-
num_classes: Optional[int] = None,
|
|
12
|
-
net: Optional[nn.Module] = None,
|
|
13
|
-
):
|
|
14
|
-
super().__init__()
|
|
15
|
-
if net is not None:
|
|
16
|
-
self.net = net
|
|
17
|
-
self.input_dim = net.in_features
|
|
18
|
-
self.num_classes = net.out_features
|
|
19
|
-
else:
|
|
20
|
-
assert (
|
|
21
|
-
input_dim is not None and num_classes is not None
|
|
22
|
-
), "Either net or both input_dim and num_classes must be provided."
|
|
23
|
-
self.net = nn.Linear(input_dim, num_classes)
|
|
24
|
-
self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net)
|
|
25
|
-
|
|
26
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
27
|
-
return self.net(x)
|
|
28
|
-
|
|
29
|
-
@staticmethod
|
|
30
|
-
def _get_linear_input_output_dims(module: nn.Module):
|
|
31
|
-
"""
|
|
32
|
-
Returns (input_dim, output_dim) for any module containing Linear layers.
|
|
33
|
-
Works for Linear, Sequential, or nested models.
|
|
34
|
-
"""
|
|
35
|
-
# Collect all Linear layers recursively
|
|
36
|
-
linears = [m for m in module.modules() if isinstance(m, nn.Linear)]
|
|
37
|
-
|
|
38
|
-
if not linears:
|
|
39
|
-
raise ValueError("No Linear layers found in the given module.")
|
|
40
|
-
|
|
41
|
-
input_dim = linears[0].in_features
|
|
42
|
-
output_dim = linears[-1].out_features
|
|
43
|
-
return input_dim, output_dim
|
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/dataset/__init__.py
RENAMED
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/model/model.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/tokenizers/base.py
RENAMED
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/tokenizers/ngram.py
RENAMED
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.1}/torchTextClassifiers/utilities/__init__.py
RENAMED
|
File without changes
|