torchtextclassifiers 0.0.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +152 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +61 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +170 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +500 -413
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-1.0.0.dist-info/METADATA +87 -0
- torchtextclassifiers-1.0.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-1.0.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -1,191 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Simple text classifier example that doesn't require a tokenizer.
|
|
3
|
-
|
|
4
|
-
This demonstrates how to create a classifier wrapper that uses
|
|
5
|
-
different text preprocessing approaches.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
from typing import Optional, Dict, Any
|
|
9
|
-
from dataclasses import dataclass, asdict
|
|
10
|
-
import numpy as np
|
|
11
|
-
import torch
|
|
12
|
-
import torch.nn as nn
|
|
13
|
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
14
|
-
from torch.utils.data import Dataset, DataLoader
|
|
15
|
-
import pytorch_lightning as pl
|
|
16
|
-
from torch.optim import Adam
|
|
17
|
-
|
|
18
|
-
from .base import BaseClassifierWrapper, BaseClassifierConfig
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@dataclass
|
|
22
|
-
class SimpleTextConfig(BaseClassifierConfig):
|
|
23
|
-
"""Configuration for simple text classifier using TF-IDF."""
|
|
24
|
-
|
|
25
|
-
hidden_dim: int = 128
|
|
26
|
-
num_classes: Optional[int] = None
|
|
27
|
-
max_features: int = 10000
|
|
28
|
-
learning_rate: float = 1e-3
|
|
29
|
-
dropout_rate: float = 0.1
|
|
30
|
-
|
|
31
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
32
|
-
return asdict(self)
|
|
33
|
-
|
|
34
|
-
@classmethod
|
|
35
|
-
def from_dict(cls, data: Dict[str, Any]) -> "SimpleTextConfig":
|
|
36
|
-
return cls(**data)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class SimpleTextDataset(Dataset):
|
|
40
|
-
"""Dataset for simple text classifier."""
|
|
41
|
-
|
|
42
|
-
def __init__(self, features: np.ndarray, labels: np.ndarray):
|
|
43
|
-
self.features = torch.FloatTensor(features)
|
|
44
|
-
self.labels = torch.LongTensor(labels)
|
|
45
|
-
|
|
46
|
-
def __len__(self):
|
|
47
|
-
return len(self.features)
|
|
48
|
-
|
|
49
|
-
def __getitem__(self, idx):
|
|
50
|
-
return self.features[idx], self.labels[idx]
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class SimpleTextModel(nn.Module):
|
|
54
|
-
"""Simple neural network for text classification using TF-IDF features."""
|
|
55
|
-
|
|
56
|
-
def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout_rate: float = 0.1):
|
|
57
|
-
super().__init__()
|
|
58
|
-
|
|
59
|
-
self.network = nn.Sequential(
|
|
60
|
-
nn.Linear(input_dim, hidden_dim),
|
|
61
|
-
nn.ReLU(),
|
|
62
|
-
nn.Dropout(dropout_rate),
|
|
63
|
-
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
64
|
-
nn.ReLU(),
|
|
65
|
-
nn.Dropout(dropout_rate),
|
|
66
|
-
nn.Linear(hidden_dim // 2, num_classes)
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
def forward(self, x):
|
|
70
|
-
return self.network(x)
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
class SimpleTextModule(pl.LightningModule):
|
|
74
|
-
"""Lightning module for simple text classifier."""
|
|
75
|
-
|
|
76
|
-
def __init__(self, model: nn.Module, learning_rate: float = 1e-3):
|
|
77
|
-
super().__init__()
|
|
78
|
-
self.model = model
|
|
79
|
-
self.learning_rate = learning_rate
|
|
80
|
-
self.loss_fn = nn.CrossEntropyLoss()
|
|
81
|
-
|
|
82
|
-
def forward(self, x):
|
|
83
|
-
return self.model(x)
|
|
84
|
-
|
|
85
|
-
def training_step(self, batch, batch_idx):
|
|
86
|
-
features, labels = batch
|
|
87
|
-
logits = self(features)
|
|
88
|
-
loss = self.loss_fn(logits, labels)
|
|
89
|
-
self.log('train_loss', loss)
|
|
90
|
-
return loss
|
|
91
|
-
|
|
92
|
-
def validation_step(self, batch, batch_idx):
|
|
93
|
-
features, labels = batch
|
|
94
|
-
logits = self(features)
|
|
95
|
-
loss = self.loss_fn(logits, labels)
|
|
96
|
-
self.log('val_loss', loss)
|
|
97
|
-
return loss
|
|
98
|
-
|
|
99
|
-
def configure_optimizers(self):
|
|
100
|
-
return Adam(self.parameters(), lr=self.learning_rate)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class SimpleTextWrapper(BaseClassifierWrapper):
|
|
104
|
-
"""Wrapper for simple text classifier that uses TF-IDF instead of tokenization."""
|
|
105
|
-
|
|
106
|
-
def __init__(self, config: SimpleTextConfig):
|
|
107
|
-
super().__init__(config)
|
|
108
|
-
self.config: SimpleTextConfig = config
|
|
109
|
-
self.vectorizer: Optional[TfidfVectorizer] = None
|
|
110
|
-
|
|
111
|
-
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
112
|
-
"""Prepare TF-IDF vectorizer instead of tokenizer."""
|
|
113
|
-
self.vectorizer = TfidfVectorizer(
|
|
114
|
-
max_features=self.config.max_features,
|
|
115
|
-
lowercase=True,
|
|
116
|
-
stop_words='english'
|
|
117
|
-
)
|
|
118
|
-
# Fit the vectorizer on training text
|
|
119
|
-
self.vectorizer.fit(training_text)
|
|
120
|
-
|
|
121
|
-
def _build_pytorch_model(self) -> None:
|
|
122
|
-
"""Build the PyTorch model."""
|
|
123
|
-
if self.vectorizer is None:
|
|
124
|
-
raise ValueError("Must call prepare_text_features first")
|
|
125
|
-
|
|
126
|
-
input_dim = len(self.vectorizer.get_feature_names_out())
|
|
127
|
-
|
|
128
|
-
self.pytorch_model = SimpleTextModel(
|
|
129
|
-
input_dim=input_dim,
|
|
130
|
-
hidden_dim=self.config.hidden_dim,
|
|
131
|
-
num_classes=self.config.num_classes,
|
|
132
|
-
dropout_rate=self.config.dropout_rate
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
def _check_and_init_lightning(self, **kwargs) -> None:
|
|
136
|
-
"""Initialize Lightning module."""
|
|
137
|
-
self.lightning_module = SimpleTextModule(
|
|
138
|
-
model=self.pytorch_model,
|
|
139
|
-
learning_rate=self.config.learning_rate
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
|
|
143
|
-
"""Make predictions."""
|
|
144
|
-
if not self.trained:
|
|
145
|
-
raise Exception("Model must be trained first.")
|
|
146
|
-
|
|
147
|
-
# Extract text from X (assuming first column is text)
|
|
148
|
-
text_data = X[:, 0] if X.ndim > 1 else X
|
|
149
|
-
|
|
150
|
-
# Transform text to TF-IDF features
|
|
151
|
-
features = self.vectorizer.transform(text_data).toarray()
|
|
152
|
-
features_tensor = torch.FloatTensor(features)
|
|
153
|
-
|
|
154
|
-
self.pytorch_model.eval()
|
|
155
|
-
with torch.no_grad():
|
|
156
|
-
logits = self.pytorch_model(features_tensor)
|
|
157
|
-
predictions = torch.argmax(logits, dim=1)
|
|
158
|
-
|
|
159
|
-
return predictions.numpy()
|
|
160
|
-
|
|
161
|
-
def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
|
162
|
-
"""Validate the model."""
|
|
163
|
-
predictions = self.predict(X)
|
|
164
|
-
accuracy = (predictions == Y).mean()
|
|
165
|
-
return float(accuracy)
|
|
166
|
-
|
|
167
|
-
def create_dataset(self, texts: np.ndarray, labels: np.ndarray, categorical_variables: Optional[np.ndarray] = None):
|
|
168
|
-
"""Create dataset."""
|
|
169
|
-
# Transform text to TF-IDF features
|
|
170
|
-
features = self.vectorizer.transform(texts).toarray()
|
|
171
|
-
return SimpleTextDataset(features, labels)
|
|
172
|
-
|
|
173
|
-
def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuffle: bool = True):
|
|
174
|
-
"""Create dataloader."""
|
|
175
|
-
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
|
|
176
|
-
|
|
177
|
-
def load_best_model(self, checkpoint_path: str) -> None:
|
|
178
|
-
"""Load best model from checkpoint."""
|
|
179
|
-
self.lightning_module = SimpleTextModule.load_from_checkpoint(
|
|
180
|
-
checkpoint_path,
|
|
181
|
-
model=self.pytorch_model,
|
|
182
|
-
learning_rate=self.config.learning_rate
|
|
183
|
-
)
|
|
184
|
-
self.pytorch_model = self.lightning_module.model
|
|
185
|
-
self.trained = True
|
|
186
|
-
self.pytorch_model.eval()
|
|
187
|
-
|
|
188
|
-
@classmethod
|
|
189
|
-
def get_config_class(cls):
|
|
190
|
-
"""Return the configuration class."""
|
|
191
|
-
return SimpleTextConfig
|
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
"""Generic factories for different classifier types."""
|
|
2
|
-
|
|
3
|
-
from typing import Dict, Any, Optional, Type, Callable
|
|
4
|
-
from .classifiers.base import BaseClassifierConfig
|
|
5
|
-
|
|
6
|
-
# Registry of config factories for different classifier types
|
|
7
|
-
CONFIG_FACTORIES: Dict[str, Callable[[dict], BaseClassifierConfig]] = {}
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def register_config_factory(classifier_type: str, factory_func: Callable[[dict], BaseClassifierConfig]):
|
|
11
|
-
"""Register a config factory for a classifier type."""
|
|
12
|
-
CONFIG_FACTORIES[classifier_type] = factory_func
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def create_config_from_dict(classifier_type: str, config_dict: dict) -> BaseClassifierConfig:
|
|
16
|
-
"""Create a config object from dictionary based on classifier type."""
|
|
17
|
-
if classifier_type not in CONFIG_FACTORIES:
|
|
18
|
-
raise ValueError(f"Unsupported classifier type: {classifier_type}")
|
|
19
|
-
|
|
20
|
-
return CONFIG_FACTORIES[classifier_type](config_dict)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# Register FastText factory
|
|
24
|
-
def _register_fasttext_factory():
|
|
25
|
-
"""Register FastText config factory."""
|
|
26
|
-
try:
|
|
27
|
-
from .classifiers.fasttext.core import FastTextFactory
|
|
28
|
-
register_config_factory("fasttext", FastTextFactory.from_dict)
|
|
29
|
-
except ImportError:
|
|
30
|
-
pass # FastText module not available
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
# Auto-register available factories
|
|
34
|
-
_register_fasttext_factory()
|
|
@@ -1,108 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import json
|
|
3
|
-
from typing import Optional, Union, Type, List
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
|
|
7
|
-
logger = logging.getLogger(__name__)
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def check_X(X):
|
|
11
|
-
assert isinstance(X, np.ndarray), (
|
|
12
|
-
"X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables."
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
try:
|
|
16
|
-
if X.ndim > 1:
|
|
17
|
-
text = X[:, 0].astype(str)
|
|
18
|
-
else:
|
|
19
|
-
text = X[:].astype(str)
|
|
20
|
-
except ValueError:
|
|
21
|
-
logger.error("The first column of X must be castable in string format.")
|
|
22
|
-
|
|
23
|
-
if len(X.shape) == 1 or (len(X.shape) == 2 and X.shape[1] == 1):
|
|
24
|
-
no_cat_var = True
|
|
25
|
-
else:
|
|
26
|
-
no_cat_var = False
|
|
27
|
-
|
|
28
|
-
if not no_cat_var:
|
|
29
|
-
try:
|
|
30
|
-
categorical_variables = X[:, 1:].astype(int)
|
|
31
|
-
except ValueError:
|
|
32
|
-
logger.error(
|
|
33
|
-
f"Columns {1} to {X.shape[1] - 1} of X_train must be castable in integer format."
|
|
34
|
-
)
|
|
35
|
-
else:
|
|
36
|
-
categorical_variables = None
|
|
37
|
-
|
|
38
|
-
return text, categorical_variables, no_cat_var
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def check_Y(Y):
|
|
42
|
-
assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
|
|
43
|
-
assert len(Y.shape) == 1 or (len(Y.shape) == 2 and Y.shape[1] == 1), (
|
|
44
|
-
"Y must be a numpy array of shape (N,) or (N,1)."
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
try:
|
|
48
|
-
Y = Y.astype(int)
|
|
49
|
-
except ValueError:
|
|
50
|
-
logger.error("Y must be castable in integer format.")
|
|
51
|
-
|
|
52
|
-
return Y
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def validate_categorical_inputs(
|
|
56
|
-
categorical_vocabulary_sizes: List[int],
|
|
57
|
-
categorical_embedding_dims: Union[List[int], int],
|
|
58
|
-
num_categorical_features: int = None,
|
|
59
|
-
):
|
|
60
|
-
if categorical_vocabulary_sizes is None:
|
|
61
|
-
logger.warning("No categorical_vocabulary_sizes. It will be inferred later.")
|
|
62
|
-
return None, None, None
|
|
63
|
-
|
|
64
|
-
else:
|
|
65
|
-
if not isinstance(categorical_vocabulary_sizes, list):
|
|
66
|
-
raise TypeError("categorical_vocabulary_sizes must be a list of int")
|
|
67
|
-
|
|
68
|
-
if isinstance(categorical_embedding_dims, list):
|
|
69
|
-
if len(categorical_vocabulary_sizes) != len(categorical_embedding_dims):
|
|
70
|
-
raise ValueError(
|
|
71
|
-
"Categorical vocabulary sizes and their embedding dimensions must have the same length"
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
if num_categorical_features is not None:
|
|
75
|
-
if len(categorical_vocabulary_sizes) != num_categorical_features:
|
|
76
|
-
raise ValueError(
|
|
77
|
-
"len(categorical_vocabulary_sizes) must be equal to num_categorical_features"
|
|
78
|
-
)
|
|
79
|
-
else:
|
|
80
|
-
num_categorical_features = len(categorical_vocabulary_sizes)
|
|
81
|
-
|
|
82
|
-
assert num_categorical_features is not None, (
|
|
83
|
-
"num_categorical_features should be inferred at this point."
|
|
84
|
-
)
|
|
85
|
-
|
|
86
|
-
# "Transform" embedding dims into a suitable list, or stay None
|
|
87
|
-
if categorical_embedding_dims is not None:
|
|
88
|
-
if isinstance(categorical_embedding_dims, int):
|
|
89
|
-
categorical_embedding_dims = [categorical_embedding_dims] * num_categorical_features
|
|
90
|
-
elif not isinstance(categorical_embedding_dims, list):
|
|
91
|
-
raise TypeError("categorical_embedding_dims must be an int or a list of int")
|
|
92
|
-
|
|
93
|
-
assert isinstance(categorical_embedding_dims, list) or categorical_embedding_dims is None, (
|
|
94
|
-
"categorical_embedding_dims must be a list of int at this point"
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
return categorical_vocabulary_sizes, categorical_embedding_dims, num_categorical_features
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
class NumpyJSONEncoder(json.JSONEncoder):
|
|
101
|
-
def default(self, obj):
|
|
102
|
-
if isinstance(obj, np.integer):
|
|
103
|
-
return int(obj)
|
|
104
|
-
if isinstance(obj, np.floating):
|
|
105
|
-
return float(obj)
|
|
106
|
-
if isinstance(obj, np.ndarray):
|
|
107
|
-
return obj.tolist()
|
|
108
|
-
return super().default(obj)
|
|
@@ -1,82 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Processing fns.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import string
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
try:
|
|
10
|
-
import nltk
|
|
11
|
-
from nltk.corpus import stopwords as ntlk_stopwords
|
|
12
|
-
from nltk.stem.snowball import SnowballStemmer
|
|
13
|
-
|
|
14
|
-
HAS_NLTK = True
|
|
15
|
-
except ImportError:
|
|
16
|
-
HAS_NLTK = False
|
|
17
|
-
|
|
18
|
-
try:
|
|
19
|
-
import unidecode
|
|
20
|
-
|
|
21
|
-
HAS_UNIDECODE = True
|
|
22
|
-
except ImportError:
|
|
23
|
-
HAS_UNIDECODE = False
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def clean_text_feature(text: list[str], remove_stop_words=True):
|
|
27
|
-
"""
|
|
28
|
-
Cleans a text feature.
|
|
29
|
-
|
|
30
|
-
Args:
|
|
31
|
-
text (list[str]): List of text descriptions.
|
|
32
|
-
remove_stop_words (bool): If True, remove stopwords.
|
|
33
|
-
|
|
34
|
-
Returns:
|
|
35
|
-
list[str]: List of cleaned text descriptions.
|
|
36
|
-
|
|
37
|
-
"""
|
|
38
|
-
if not HAS_NLTK:
|
|
39
|
-
raise ImportError(
|
|
40
|
-
"nltk is not installed and is required for preprocessing. Run 'pip install torchFastText[preprocess]'."
|
|
41
|
-
)
|
|
42
|
-
if not HAS_UNIDECODE:
|
|
43
|
-
raise ImportError(
|
|
44
|
-
"unidecode is not installed and is required for preprocessing. Run 'pip install torchFastText[preprocess]'."
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
# Define stopwords and stemmer
|
|
48
|
-
|
|
49
|
-
nltk.download("stopwords", quiet=True)
|
|
50
|
-
stopwords = tuple(ntlk_stopwords.words("french")) + tuple(string.ascii_lowercase)
|
|
51
|
-
stemmer = SnowballStemmer(language="french")
|
|
52
|
-
|
|
53
|
-
# Remove of accented characters
|
|
54
|
-
text = np.vectorize(unidecode.unidecode)(np.array(text))
|
|
55
|
-
|
|
56
|
-
# To lowercase
|
|
57
|
-
text = np.char.lower(text)
|
|
58
|
-
|
|
59
|
-
# Remove one letter words
|
|
60
|
-
def mylambda(x):
|
|
61
|
-
return " ".join([w for w in x.split() if len(w) > 1])
|
|
62
|
-
|
|
63
|
-
text = np.vectorize(mylambda)(text)
|
|
64
|
-
|
|
65
|
-
# Remove duplicate words and stopwords in texts
|
|
66
|
-
# Stem words
|
|
67
|
-
libs_token = [lib.split() for lib in text.tolist()]
|
|
68
|
-
libs_token = [
|
|
69
|
-
sorted(set(libs_token[i]), key=libs_token[i].index) for i in range(len(libs_token))
|
|
70
|
-
]
|
|
71
|
-
if remove_stop_words:
|
|
72
|
-
text = [
|
|
73
|
-
" ".join([stemmer.stem(word) for word in libs_token[i] if word not in stopwords])
|
|
74
|
-
for i in range(len(libs_token))
|
|
75
|
-
]
|
|
76
|
-
else:
|
|
77
|
-
text = [
|
|
78
|
-
" ".join([stemmer.stem(word) for word in libs_token[i]]) for i in range(len(libs_token))
|
|
79
|
-
]
|
|
80
|
-
|
|
81
|
-
# Return clean DataFrame
|
|
82
|
-
return text
|