torchtextclassifiers 0.1.0__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/PKG-INFO +16 -2
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/README.md +15 -1
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/pyproject.toml +12 -8
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/dataset/dataset.py +41 -3
- torchtextclassifiers-1.0.0/torchTextClassifiers/model/components/classification_head.py +61 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/lightning.py +4 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/torchTextClassifiers.py +62 -33
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/components/classification_head.py +0 -43
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/dataset/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/components/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/components/attention.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/components/categorical_var_net.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/components/text_embedder.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/model.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/tokenizers/WordPiece.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/tokenizers/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/tokenizers/base.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/tokenizers/ngram.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/utilities/__init__.py +0 -0
- {torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/utilities/plot_explainability.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: torchtextclassifiers
|
|
3
|
-
Version:
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
|
|
5
5
|
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
6
|
Author: Cédric Couralet, Meilame Tayebjee
|
|
@@ -26,15 +26,18 @@ Description-Content-Type: text/markdown
|
|
|
26
26
|
|
|
27
27
|
# torchTextClassifiers
|
|
28
28
|
|
|
29
|
+
[](https://inseefrlab.github.io/torchTextClassifiers/)
|
|
30
|
+
|
|
29
31
|
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
30
32
|
|
|
31
33
|
## 🚀 Features
|
|
32
34
|
|
|
33
|
-
- **
|
|
35
|
+
- **Complex input support**: Handle text data alongside categorical variables seamlessly.
|
|
34
36
|
- **Unified yet highly customizable**:
|
|
35
37
|
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
|
|
36
38
|
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
|
|
37
39
|
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
|
|
40
|
+
- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
|
|
38
41
|
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
39
42
|
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
|
|
40
43
|
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
|
|
@@ -55,6 +58,15 @@ uv sync
|
|
|
55
58
|
pip install -e .
|
|
56
59
|
```
|
|
57
60
|
|
|
61
|
+
## 📖 Documentation
|
|
62
|
+
|
|
63
|
+
Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
|
|
64
|
+
The documentation includes:
|
|
65
|
+
- **Getting Started**: Installation and quick start guide
|
|
66
|
+
- **Architecture**: Understanding the 3-layer design
|
|
67
|
+
- **Tutorials**: Step-by-step guides for different use cases
|
|
68
|
+
- **API Reference**: Complete API documentation
|
|
69
|
+
|
|
58
70
|
## 📝 Usage
|
|
59
71
|
|
|
60
72
|
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
|
|
@@ -71,3 +83,5 @@ See the [examples/](examples/) directory for:
|
|
|
71
83
|
## 📄 License
|
|
72
84
|
|
|
73
85
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
86
|
+
|
|
87
|
+
|
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
# torchTextClassifiers
|
|
2
2
|
|
|
3
|
+
[](https://inseefrlab.github.io/torchTextClassifiers/)
|
|
4
|
+
|
|
3
5
|
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
4
6
|
|
|
5
7
|
## 🚀 Features
|
|
6
8
|
|
|
7
|
-
- **
|
|
9
|
+
- **Complex input support**: Handle text data alongside categorical variables seamlessly.
|
|
8
10
|
- **Unified yet highly customizable**:
|
|
9
11
|
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
|
|
10
12
|
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
|
|
11
13
|
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
|
|
14
|
+
- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
|
|
12
15
|
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
13
16
|
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
|
|
14
17
|
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
|
|
@@ -29,6 +32,15 @@ uv sync
|
|
|
29
32
|
pip install -e .
|
|
30
33
|
```
|
|
31
34
|
|
|
35
|
+
## 📖 Documentation
|
|
36
|
+
|
|
37
|
+
Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
|
|
38
|
+
The documentation includes:
|
|
39
|
+
- **Getting Started**: Installation and quick start guide
|
|
40
|
+
- **Architecture**: Understanding the 3-layer design
|
|
41
|
+
- **Tutorials**: Step-by-step guides for different use cases
|
|
42
|
+
- **API Reference**: Complete API documentation
|
|
43
|
+
|
|
32
44
|
## 📝 Usage
|
|
33
45
|
|
|
34
46
|
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
|
|
@@ -45,3 +57,5 @@ See the [examples/](examples/) directory for:
|
|
|
45
57
|
## 📄 License
|
|
46
58
|
|
|
47
59
|
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
60
|
+
|
|
61
|
+
|
|
@@ -18,12 +18,12 @@ dependencies = [
|
|
|
18
18
|
"pytorch-lightning>=2.4.0",
|
|
19
19
|
]
|
|
20
20
|
requires-python = ">=3.11"
|
|
21
|
-
version="
|
|
21
|
+
version="1.0.0"
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
[dependency-groups]
|
|
25
25
|
dev = [
|
|
26
|
-
"pytest >=
|
|
26
|
+
"pytest >=9.0.1,<10",
|
|
27
27
|
"pandas",
|
|
28
28
|
"scikit-learn",
|
|
29
29
|
"nltk",
|
|
@@ -35,13 +35,17 @@ dev = [
|
|
|
35
35
|
"ipywidgets>=8.1.8",
|
|
36
36
|
]
|
|
37
37
|
docs = [
|
|
38
|
-
"sphinx>=
|
|
39
|
-
"sphinx-
|
|
40
|
-
"sphinx-autodoc-typehints>=
|
|
41
|
-
"sphinxcontrib-napoleon>=0.7",
|
|
38
|
+
"sphinx>=8.1.0",
|
|
39
|
+
"pydata-sphinx-theme>=0.16.0",
|
|
40
|
+
"sphinx-autodoc-typehints>=2.0.0",
|
|
42
41
|
"sphinx-copybutton>=0.5.0",
|
|
43
|
-
"myst-parser>=0.
|
|
44
|
-
"sphinx-design>=0.
|
|
42
|
+
"myst-parser>=4.0.0",
|
|
43
|
+
"sphinx-design>=0.6.0",
|
|
44
|
+
"nbsphinx>=0.9.0",
|
|
45
|
+
"ipython>=8.0.0",
|
|
46
|
+
"pandoc>=2.0.0",
|
|
47
|
+
"linkify-it-py>=2.0.0",
|
|
48
|
+
"sphinxcontrib-images>=1.0.1"
|
|
45
49
|
]
|
|
46
50
|
|
|
47
51
|
[project.optional-dependencies]
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/dataset/dataset.py
RENAMED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
from typing import List, Union
|
|
3
4
|
|
|
@@ -8,6 +9,7 @@ from torch.utils.data import DataLoader, Dataset
|
|
|
8
9
|
from torchTextClassifiers.tokenizers import BaseTokenizer
|
|
9
10
|
|
|
10
11
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
11
13
|
|
|
12
14
|
|
|
13
15
|
class TextClassificationDataset(Dataset):
|
|
@@ -16,7 +18,8 @@ class TextClassificationDataset(Dataset):
|
|
|
16
18
|
texts: List[str],
|
|
17
19
|
categorical_variables: Union[List[List[int]], np.array, None],
|
|
18
20
|
tokenizer: BaseTokenizer,
|
|
19
|
-
labels: Union[List[int], None] = None,
|
|
21
|
+
labels: Union[List[int], List[List[int]], np.array, None] = None,
|
|
22
|
+
ragged_multilabel: bool = False,
|
|
20
23
|
):
|
|
21
24
|
self.categorical_variables = categorical_variables
|
|
22
25
|
|
|
@@ -32,6 +35,23 @@ class TextClassificationDataset(Dataset):
|
|
|
32
35
|
self.texts = texts
|
|
33
36
|
self.tokenizer = tokenizer
|
|
34
37
|
self.labels = labels
|
|
38
|
+
self.ragged_multilabel = ragged_multilabel
|
|
39
|
+
|
|
40
|
+
if self.ragged_multilabel and self.labels is not None:
|
|
41
|
+
max_value = int(max(max(row) for row in labels if row))
|
|
42
|
+
self.num_classes = max_value + 1
|
|
43
|
+
|
|
44
|
+
if max_value == 1:
|
|
45
|
+
try:
|
|
46
|
+
labels = np.array(labels)
|
|
47
|
+
logger.critical(
|
|
48
|
+
"""ragged_multilabel set to True but max label value is 1 and all samples have the same number of labels.
|
|
49
|
+
If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."""
|
|
50
|
+
)
|
|
51
|
+
except ValueError:
|
|
52
|
+
logger.warning(
|
|
53
|
+
"ragged_multilabel set to True but max label value is 1. If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."
|
|
54
|
+
)
|
|
35
55
|
|
|
36
56
|
def __len__(self):
|
|
37
57
|
return len(self.texts)
|
|
@@ -59,10 +79,28 @@ class TextClassificationDataset(Dataset):
|
|
|
59
79
|
)
|
|
60
80
|
|
|
61
81
|
def collate_fn(self, batch):
|
|
62
|
-
text, *categorical_vars,
|
|
82
|
+
text, *categorical_vars, labels = zip(*batch)
|
|
63
83
|
|
|
64
84
|
if self.labels is not None:
|
|
65
|
-
|
|
85
|
+
if self.ragged_multilabel:
|
|
86
|
+
# Pad labels to the max length in the batch
|
|
87
|
+
labels_padded = torch.nn.utils.rnn.pad_sequence(
|
|
88
|
+
[torch.tensor(label) for label in labels],
|
|
89
|
+
batch_first=True,
|
|
90
|
+
padding_value=-1, # use impossible class
|
|
91
|
+
).int()
|
|
92
|
+
|
|
93
|
+
labels_tensor = torch.zeros(labels_padded.size(0), 6).float()
|
|
94
|
+
mask = labels_padded != -1
|
|
95
|
+
|
|
96
|
+
batch_size = labels_padded.size(0)
|
|
97
|
+
rows = torch.arange(batch_size).unsqueeze(1).expand_as(labels_padded)[mask]
|
|
98
|
+
cols = labels_padded[mask]
|
|
99
|
+
|
|
100
|
+
labels_tensor[rows, cols] = 1
|
|
101
|
+
|
|
102
|
+
else:
|
|
103
|
+
labels_tensor = torch.tensor(labels)
|
|
66
104
|
else:
|
|
67
105
|
labels_tensor = None
|
|
68
106
|
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ClassificationHead(nn.Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
input_dim: Optional[int] = None,
|
|
11
|
+
num_classes: Optional[int] = None,
|
|
12
|
+
net: Optional[nn.Module] = None,
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Classification head for text classification tasks.
|
|
16
|
+
It is a nn.Module that can either be a simple Linear layer or a custom neural network module.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
input_dim (int, optional): Dimension of the input features. Required if net is not provided.
|
|
20
|
+
num_classes (int, optional): Number of output classes. Required if net is not provided.
|
|
21
|
+
net (nn.Module, optional): Custom neural network module to be used as the classification head.
|
|
22
|
+
If provided, input_dim and num_classes are inferred from this module.
|
|
23
|
+
Should be either an nn.Sequential with first and last layers being Linears or nn.Linear.
|
|
24
|
+
"""
|
|
25
|
+
super().__init__()
|
|
26
|
+
if net is not None:
|
|
27
|
+
self.net = net
|
|
28
|
+
|
|
29
|
+
# --- Custom net should either be a Sequential or a Linear ---
|
|
30
|
+
if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)):
|
|
31
|
+
raise ValueError("net must be an nn.Sequential when provided.")
|
|
32
|
+
|
|
33
|
+
# --- If Sequential, Check first and last layers are Linear ---
|
|
34
|
+
|
|
35
|
+
if isinstance(net, nn.Sequential):
|
|
36
|
+
first = net[0]
|
|
37
|
+
last = net[-1]
|
|
38
|
+
|
|
39
|
+
if not isinstance(first, nn.Linear):
|
|
40
|
+
raise TypeError(f"First layer must be nn.Linear, got {type(first).__name__}.")
|
|
41
|
+
|
|
42
|
+
if not isinstance(last, nn.Linear):
|
|
43
|
+
raise TypeError(f"Last layer must be nn.Linear, got {type(last).__name__}.")
|
|
44
|
+
|
|
45
|
+
# --- Extract features ---
|
|
46
|
+
self.input_dim = first.in_features
|
|
47
|
+
self.num_classes = last.out_features
|
|
48
|
+
else: # if not Sequential, it is a Linear
|
|
49
|
+
self.input_dim = net.in_features
|
|
50
|
+
self.num_classes = net.out_features
|
|
51
|
+
|
|
52
|
+
else:
|
|
53
|
+
assert (
|
|
54
|
+
input_dim is not None and num_classes is not None
|
|
55
|
+
), "Either net or both input_dim and num_classes must be provided."
|
|
56
|
+
self.net = nn.Linear(input_dim, num_classes)
|
|
57
|
+
self.input_dim = input_dim
|
|
58
|
+
self.num_classes = num_classes
|
|
59
|
+
|
|
60
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
+
return self.net(x)
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/lightning.py
RENAMED
|
@@ -76,6 +76,10 @@ class TextClassificationModule(pl.LightningModule):
|
|
|
76
76
|
targets = batch["labels"]
|
|
77
77
|
|
|
78
78
|
outputs = self.forward(batch)
|
|
79
|
+
|
|
80
|
+
if isinstance(self.loss, torch.nn.BCEWithLogitsLoss):
|
|
81
|
+
targets = targets.float()
|
|
82
|
+
|
|
79
83
|
loss = self.loss(outputs, targets)
|
|
80
84
|
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
|
|
81
85
|
accuracy = self.accuracy_fn(outputs, targets)
|
|
@@ -100,6 +100,7 @@ class torchTextClassifiers:
|
|
|
100
100
|
self,
|
|
101
101
|
tokenizer: BaseTokenizer,
|
|
102
102
|
model_config: ModelConfig,
|
|
103
|
+
ragged_multilabel: bool = False,
|
|
103
104
|
):
|
|
104
105
|
"""Initialize the torchTextClassifiers instance.
|
|
105
106
|
|
|
@@ -124,6 +125,7 @@ class torchTextClassifiers:
|
|
|
124
125
|
|
|
125
126
|
self.model_config = model_config
|
|
126
127
|
self.tokenizer = tokenizer
|
|
128
|
+
self.ragged_multilabel = ragged_multilabel
|
|
127
129
|
|
|
128
130
|
if hasattr(self.tokenizer, "trained"):
|
|
129
131
|
if not self.tokenizer.trained:
|
|
@@ -182,9 +184,9 @@ class torchTextClassifiers:
|
|
|
182
184
|
self,
|
|
183
185
|
X_train: np.ndarray,
|
|
184
186
|
y_train: np.ndarray,
|
|
185
|
-
X_val: np.ndarray,
|
|
186
|
-
y_val: np.ndarray,
|
|
187
187
|
training_config: TrainingConfig,
|
|
188
|
+
X_val: Optional[np.ndarray] = None,
|
|
189
|
+
y_val: Optional[np.ndarray] = None,
|
|
188
190
|
verbose: bool = False,
|
|
189
191
|
) -> None:
|
|
190
192
|
"""Train the classifier using PyTorch Lightning.
|
|
@@ -222,7 +224,14 @@ class torchTextClassifiers:
|
|
|
222
224
|
"""
|
|
223
225
|
# Input validation
|
|
224
226
|
X_train, y_train = self._check_XY(X_train, y_train)
|
|
225
|
-
|
|
227
|
+
|
|
228
|
+
if X_val is not None:
|
|
229
|
+
assert y_val is not None, "y_val must be provided if X_val is provided."
|
|
230
|
+
if y_val is not None:
|
|
231
|
+
assert X_val is not None, "X_val must be provided if y_val is provided."
|
|
232
|
+
|
|
233
|
+
if X_val is not None and y_val is not None:
|
|
234
|
+
X_val, y_val = self._check_XY(X_val, y_val)
|
|
226
235
|
|
|
227
236
|
if (
|
|
228
237
|
X_train["categorical_variables"] is not None
|
|
@@ -249,6 +258,11 @@ class torchTextClassifiers:
|
|
|
249
258
|
if training_config.optimizer_params is not None:
|
|
250
259
|
optimizer_params.update(training_config.optimizer_params)
|
|
251
260
|
|
|
261
|
+
if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel:
|
|
262
|
+
logger.warning(
|
|
263
|
+
"⚠️ You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks."
|
|
264
|
+
)
|
|
265
|
+
|
|
252
266
|
self.lightning_module = TextClassificationModule(
|
|
253
267
|
model=self.pytorch_model,
|
|
254
268
|
loss=training_config.loss,
|
|
@@ -270,38 +284,43 @@ class torchTextClassifiers:
|
|
|
270
284
|
texts=X_train["text"],
|
|
271
285
|
categorical_variables=X_train["categorical_variables"], # None if no cat vars
|
|
272
286
|
tokenizer=self.tokenizer,
|
|
273
|
-
labels=y_train,
|
|
287
|
+
labels=y_train.tolist(),
|
|
288
|
+
ragged_multilabel=self.ragged_multilabel,
|
|
274
289
|
)
|
|
275
|
-
val_dataset = TextClassificationDataset(
|
|
276
|
-
texts=X_val["text"],
|
|
277
|
-
categorical_variables=X_val["categorical_variables"], # None if no cat vars
|
|
278
|
-
tokenizer=self.tokenizer,
|
|
279
|
-
labels=y_val,
|
|
280
|
-
)
|
|
281
|
-
|
|
282
290
|
train_dataloader = train_dataset.create_dataloader(
|
|
283
291
|
batch_size=training_config.batch_size,
|
|
284
292
|
num_workers=training_config.num_workers,
|
|
285
293
|
shuffle=True,
|
|
286
294
|
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
287
295
|
)
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
296
|
+
|
|
297
|
+
if X_val is not None and y_val is not None:
|
|
298
|
+
val_dataset = TextClassificationDataset(
|
|
299
|
+
texts=X_val["text"],
|
|
300
|
+
categorical_variables=X_val["categorical_variables"], # None if no cat vars
|
|
301
|
+
tokenizer=self.tokenizer,
|
|
302
|
+
labels=y_val,
|
|
303
|
+
ragged_multilabel=self.ragged_multilabel,
|
|
304
|
+
)
|
|
305
|
+
val_dataloader = val_dataset.create_dataloader(
|
|
306
|
+
batch_size=training_config.batch_size,
|
|
307
|
+
num_workers=training_config.num_workers,
|
|
308
|
+
shuffle=False,
|
|
309
|
+
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
val_dataloader = None
|
|
294
313
|
|
|
295
314
|
# Setup trainer
|
|
296
315
|
callbacks = [
|
|
297
316
|
ModelCheckpoint(
|
|
298
|
-
monitor="val_loss",
|
|
317
|
+
monitor="val_loss" if val_dataloader is not None else "train_loss",
|
|
299
318
|
save_top_k=1,
|
|
300
319
|
save_last=False,
|
|
301
320
|
mode="min",
|
|
302
321
|
),
|
|
303
322
|
EarlyStopping(
|
|
304
|
-
monitor="val_loss",
|
|
323
|
+
monitor="val_loss" if val_dataloader is not None else "train_loss",
|
|
305
324
|
patience=training_config.patience_early_stopping,
|
|
306
325
|
mode="min",
|
|
307
326
|
),
|
|
@@ -352,7 +371,7 @@ class torchTextClassifiers:
|
|
|
352
371
|
X = self._check_X(X)
|
|
353
372
|
Y = self._check_Y(Y)
|
|
354
373
|
|
|
355
|
-
if X["text"].shape[0] != Y
|
|
374
|
+
if X["text"].shape[0] != len(Y):
|
|
356
375
|
raise ValueError("X_train and y_train must have the same number of observations.")
|
|
357
376
|
|
|
358
377
|
return X, Y
|
|
@@ -422,22 +441,32 @@ class torchTextClassifiers:
|
|
|
422
441
|
return {"text": text, "categorical_variables": categorical_variables}
|
|
423
442
|
|
|
424
443
|
def _check_Y(self, Y):
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
444
|
+
if self.ragged_multilabel:
|
|
445
|
+
assert isinstance(
|
|
446
|
+
Y, list
|
|
447
|
+
), "Y must be a list of lists for ragged multilabel classification."
|
|
448
|
+
for row in Y:
|
|
449
|
+
assert isinstance(row, list), "Each element of Y must be a list of labels."
|
|
429
450
|
|
|
430
|
-
|
|
431
|
-
Y = Y.astype(int)
|
|
432
|
-
except ValueError:
|
|
433
|
-
logger.error("Y must be castable in integer format.")
|
|
451
|
+
return Y
|
|
434
452
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
453
|
+
else:
|
|
454
|
+
assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
|
|
455
|
+
assert (
|
|
456
|
+
len(Y.shape) == 1 or len(Y.shape) == 2
|
|
457
|
+
), "Y must be a numpy array of shape (N,) or (N, num_labels)."
|
|
458
|
+
|
|
459
|
+
try:
|
|
460
|
+
Y = Y.astype(int)
|
|
461
|
+
except ValueError:
|
|
462
|
+
logger.error("Y must be castable in integer format.")
|
|
463
|
+
|
|
464
|
+
if Y.max() >= self.num_classes or Y.min() < 0:
|
|
465
|
+
raise ValueError(
|
|
466
|
+
f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
|
|
467
|
+
)
|
|
439
468
|
|
|
440
|
-
|
|
469
|
+
return Y
|
|
441
470
|
|
|
442
471
|
def predict(
|
|
443
472
|
self,
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import nn
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class ClassificationHead(nn.Module):
|
|
8
|
-
def __init__(
|
|
9
|
-
self,
|
|
10
|
-
input_dim: Optional[int] = None,
|
|
11
|
-
num_classes: Optional[int] = None,
|
|
12
|
-
net: Optional[nn.Module] = None,
|
|
13
|
-
):
|
|
14
|
-
super().__init__()
|
|
15
|
-
if net is not None:
|
|
16
|
-
self.net = net
|
|
17
|
-
self.input_dim = net.in_features
|
|
18
|
-
self.num_classes = net.out_features
|
|
19
|
-
else:
|
|
20
|
-
assert (
|
|
21
|
-
input_dim is not None and num_classes is not None
|
|
22
|
-
), "Either net or both input_dim and num_classes must be provided."
|
|
23
|
-
self.net = nn.Linear(input_dim, num_classes)
|
|
24
|
-
self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net)
|
|
25
|
-
|
|
26
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
27
|
-
return self.net(x)
|
|
28
|
-
|
|
29
|
-
@staticmethod
|
|
30
|
-
def _get_linear_input_output_dims(module: nn.Module):
|
|
31
|
-
"""
|
|
32
|
-
Returns (input_dim, output_dim) for any module containing Linear layers.
|
|
33
|
-
Works for Linear, Sequential, or nested models.
|
|
34
|
-
"""
|
|
35
|
-
# Collect all Linear layers recursively
|
|
36
|
-
linears = [m for m in module.modules() if isinstance(m, nn.Linear)]
|
|
37
|
-
|
|
38
|
-
if not linears:
|
|
39
|
-
raise ValueError("No Linear layers found in the given module.")
|
|
40
|
-
|
|
41
|
-
input_dim = linears[0].in_features
|
|
42
|
-
output_dim = linears[-1].out_features
|
|
43
|
-
return input_dim, output_dim
|
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/dataset/__init__.py
RENAMED
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/model/model.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/tokenizers/base.py
RENAMED
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/tokenizers/ngram.py
RENAMED
|
File without changes
|
{torchtextclassifiers-0.1.0 → torchtextclassifiers-1.0.0}/torchTextClassifiers/utilities/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|