torchtextclassifiers 0.0.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +152 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +61 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +170 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +500 -413
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-1.0.0.dist-info/METADATA +87 -0
- torchtextclassifiers-1.0.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-1.0.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -1,752 +0,0 @@
|
|
|
1
|
-
"""FastText model components.
|
|
2
|
-
|
|
3
|
-
This module contains the PyTorch model, Lightning module, and dataset classes
|
|
4
|
-
for FastText classification. Consolidates what was previously in pytorch_model.py,
|
|
5
|
-
lightning_module.py, and dataset.py.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
import os
|
|
9
|
-
import logging
|
|
10
|
-
from typing import List, Union
|
|
11
|
-
import torch
|
|
12
|
-
import pytorch_lightning as pl
|
|
13
|
-
from torch import nn
|
|
14
|
-
from torchmetrics import Accuracy
|
|
15
|
-
|
|
16
|
-
try:
|
|
17
|
-
from captum.attr import LayerIntegratedGradients
|
|
18
|
-
HAS_CAPTUM = True
|
|
19
|
-
except ImportError:
|
|
20
|
-
HAS_CAPTUM = False
|
|
21
|
-
|
|
22
|
-
from ...utilities.utils import (
|
|
23
|
-
compute_preprocessed_word_score,
|
|
24
|
-
compute_word_score,
|
|
25
|
-
explain_continuous,
|
|
26
|
-
)
|
|
27
|
-
from ...utilities.checkers import validate_categorical_inputs
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger(__name__)
|
|
30
|
-
|
|
31
|
-
logging.basicConfig(
|
|
32
|
-
level=logging.INFO,
|
|
33
|
-
format="%(asctime)s - %(name)s - %(message)s",
|
|
34
|
-
datefmt="%Y-%m-%d %H:%M:%S",
|
|
35
|
-
handlers=[logging.StreamHandler()],
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
# ============================================================================
|
|
40
|
-
# PyTorch Model
|
|
41
|
-
# ============================================================================
|
|
42
|
-
|
|
43
|
-
class FastTextModel(nn.Module):
|
|
44
|
-
"""FastText Pytorch Model."""
|
|
45
|
-
|
|
46
|
-
def __init__(
|
|
47
|
-
self,
|
|
48
|
-
embedding_dim: int,
|
|
49
|
-
num_classes: int,
|
|
50
|
-
tokenizer=None,
|
|
51
|
-
num_rows: int = None,
|
|
52
|
-
categorical_vocabulary_sizes: List[int] = None,
|
|
53
|
-
categorical_embedding_dims: Union[List[int], int] = None,
|
|
54
|
-
padding_idx: int = 0,
|
|
55
|
-
sparse: bool = True,
|
|
56
|
-
direct_bagging: bool = False,
|
|
57
|
-
):
|
|
58
|
-
"""
|
|
59
|
-
Constructor for the FastTextModel class.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
embedding_dim (int): Dimension of the text embedding space.
|
|
63
|
-
buckets (int): Number of rows in the embedding matrix.
|
|
64
|
-
num_classes (int): Number of classes.
|
|
65
|
-
categorical_vocabulary_sizes (List[int]): List of the number of
|
|
66
|
-
modalities for additional categorical features.
|
|
67
|
-
padding_idx (int, optional): Padding index for the text
|
|
68
|
-
descriptions. Defaults to 0.
|
|
69
|
-
sparse (bool): Indicates if Embedding layer is sparse.
|
|
70
|
-
direct_bagging (bool): Use EmbeddingBag instead of Embedding for the text embedding.
|
|
71
|
-
"""
|
|
72
|
-
super(FastTextModel, self).__init__()
|
|
73
|
-
|
|
74
|
-
if isinstance(categorical_embedding_dims, int):
|
|
75
|
-
self.average_cat_embed = True # if provided categorical embedding dims is an int, average the categorical embeddings before concatenating to sentence embedding
|
|
76
|
-
else:
|
|
77
|
-
self.average_cat_embed = False
|
|
78
|
-
|
|
79
|
-
categorical_vocabulary_sizes, categorical_embedding_dims, num_categorical_features = (
|
|
80
|
-
validate_categorical_inputs(
|
|
81
|
-
categorical_vocabulary_sizes,
|
|
82
|
-
categorical_embedding_dims,
|
|
83
|
-
num_categorical_features=None,
|
|
84
|
-
)
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
assert isinstance(categorical_embedding_dims, list) or categorical_embedding_dims is None, (
|
|
88
|
-
"categorical_embedding_dims must be a list of int at this stage"
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
if categorical_embedding_dims is None:
|
|
92
|
-
self.average_cat_embed = False
|
|
93
|
-
|
|
94
|
-
if tokenizer is None:
|
|
95
|
-
if num_rows is None:
|
|
96
|
-
raise ValueError(
|
|
97
|
-
"Either tokenizer or num_rows must be provided (number of rows in the embedding matrix)."
|
|
98
|
-
)
|
|
99
|
-
else:
|
|
100
|
-
if num_rows is not None:
|
|
101
|
-
if num_rows != tokenizer.num_tokens:
|
|
102
|
-
logger.warning(
|
|
103
|
-
"num_rows is different from the number of tokens in the tokenizer. Using provided num_rows."
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
self.num_rows = num_rows
|
|
107
|
-
|
|
108
|
-
self.num_classes = num_classes
|
|
109
|
-
self.padding_idx = padding_idx
|
|
110
|
-
self.tokenizer = tokenizer
|
|
111
|
-
self.embedding_dim = embedding_dim
|
|
112
|
-
self.direct_bagging = direct_bagging
|
|
113
|
-
self.sparse = sparse
|
|
114
|
-
|
|
115
|
-
self.categorical_embedding_dims = categorical_embedding_dims
|
|
116
|
-
|
|
117
|
-
self.embeddings = (
|
|
118
|
-
nn.Embedding(
|
|
119
|
-
embedding_dim=embedding_dim,
|
|
120
|
-
num_embeddings=num_rows,
|
|
121
|
-
padding_idx=padding_idx,
|
|
122
|
-
sparse=sparse,
|
|
123
|
-
)
|
|
124
|
-
if not direct_bagging
|
|
125
|
-
else nn.EmbeddingBag(
|
|
126
|
-
embedding_dim=embedding_dim,
|
|
127
|
-
num_embeddings=num_rows,
|
|
128
|
-
padding_idx=padding_idx,
|
|
129
|
-
sparse=sparse,
|
|
130
|
-
mode="mean",
|
|
131
|
-
)
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
self.categorical_embedding_layers = {}
|
|
135
|
-
|
|
136
|
-
# Entry dim for the last layer:
|
|
137
|
-
# 1. embedding_dim if no categorical variables or summing the categrical embeddings to sentence embedding
|
|
138
|
-
# 2. embedding_dim + cat_embedding_dim if averaging the categorical embeddings before concatenating to sentence embedding (categorical_embedding_dims is a int)
|
|
139
|
-
# 3. embedding_dim + sum(categorical_embedding_dims) if concatenating individually the categorical embeddings to sentence embedding (no averaging, categorical_embedding_dims is a list)
|
|
140
|
-
dim_in_last_layer = embedding_dim
|
|
141
|
-
if self.average_cat_embed:
|
|
142
|
-
dim_in_last_layer += categorical_embedding_dims[0]
|
|
143
|
-
|
|
144
|
-
if categorical_vocabulary_sizes is not None:
|
|
145
|
-
self.no_cat_var = False
|
|
146
|
-
for var_idx, num_rows in enumerate(categorical_vocabulary_sizes):
|
|
147
|
-
if categorical_embedding_dims is not None:
|
|
148
|
-
emb = nn.Embedding(
|
|
149
|
-
embedding_dim=categorical_embedding_dims[var_idx], num_embeddings=num_rows
|
|
150
|
-
) # concatenate to sentence embedding
|
|
151
|
-
if not self.average_cat_embed:
|
|
152
|
-
dim_in_last_layer += categorical_embedding_dims[var_idx]
|
|
153
|
-
else:
|
|
154
|
-
emb = nn.Embedding(
|
|
155
|
-
embedding_dim=embedding_dim, num_embeddings=num_rows
|
|
156
|
-
) # sum to sentence embedding
|
|
157
|
-
self.categorical_embedding_layers[var_idx] = emb
|
|
158
|
-
setattr(self, "emb_{}".format(var_idx), emb)
|
|
159
|
-
else:
|
|
160
|
-
self.no_cat_var = True
|
|
161
|
-
|
|
162
|
-
self.fc = nn.Linear(dim_in_last_layer, num_classes)
|
|
163
|
-
|
|
164
|
-
def forward(self, encoded_text: torch.Tensor, additional_inputs: torch.Tensor) -> torch.Tensor:
|
|
165
|
-
"""
|
|
166
|
-
Memory-efficient forward pass implementation.
|
|
167
|
-
|
|
168
|
-
Args:
|
|
169
|
-
encoded_text (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
|
|
170
|
-
additional_inputs (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
torch.Tensor: Model output scores for each class
|
|
174
|
-
"""
|
|
175
|
-
batch_size = encoded_text.size(0)
|
|
176
|
-
|
|
177
|
-
# Ensure correct dtype and device once
|
|
178
|
-
if encoded_text.dtype != torch.long:
|
|
179
|
-
encoded_text = encoded_text.to(torch.long)
|
|
180
|
-
|
|
181
|
-
# Compute text embeddings
|
|
182
|
-
if self.direct_bagging:
|
|
183
|
-
x_text = self.embeddings(encoded_text) # (batch_size, embedding_dim)
|
|
184
|
-
else:
|
|
185
|
-
# Compute embeddings and averaging in a memory-efficient way
|
|
186
|
-
x_text = self.embeddings(encoded_text) # (batch_size, seq_len, embedding_dim)
|
|
187
|
-
# Calculate non-zero tokens mask once
|
|
188
|
-
non_zero_mask = (x_text.sum(-1) != 0).float() # (batch_size, seq_len)
|
|
189
|
-
token_counts = non_zero_mask.sum(-1, keepdim=True) # (batch_size, 1)
|
|
190
|
-
|
|
191
|
-
# Sum and average in place
|
|
192
|
-
x_text = (x_text * non_zero_mask.unsqueeze(-1)).sum(
|
|
193
|
-
dim=1
|
|
194
|
-
) # (batch_size, embedding_dim)
|
|
195
|
-
x_text = torch.div(x_text, token_counts.clamp(min=1.0))
|
|
196
|
-
x_text = torch.nan_to_num(x_text, 0.0)
|
|
197
|
-
|
|
198
|
-
# Handle categorical variables efficiently
|
|
199
|
-
if not self.no_cat_var and additional_inputs.numel() > 0:
|
|
200
|
-
cat_embeds = []
|
|
201
|
-
# Process categorical embeddings in batch
|
|
202
|
-
for i, (_, embed_layer) in enumerate(self.categorical_embedding_layers.items()):
|
|
203
|
-
cat_input = additional_inputs[:, i].long()
|
|
204
|
-
|
|
205
|
-
# Check if categorical values are within valid range and clamp if needed
|
|
206
|
-
vocab_size = embed_layer.num_embeddings
|
|
207
|
-
max_val = cat_input.max().item()
|
|
208
|
-
min_val = cat_input.min().item()
|
|
209
|
-
|
|
210
|
-
if max_val >= vocab_size or min_val < 0:
|
|
211
|
-
logger.warning(f"Categorical feature {i}: values range [{min_val}, {max_val}] exceed vocabulary size {vocab_size}. Clamping to valid range [0, {vocab_size - 1}]")
|
|
212
|
-
# Clamp values to valid range
|
|
213
|
-
cat_input = torch.clamp(cat_input, 0, vocab_size - 1)
|
|
214
|
-
|
|
215
|
-
cat_embed = embed_layer(cat_input)
|
|
216
|
-
if cat_embed.dim() > 2:
|
|
217
|
-
cat_embed = cat_embed.squeeze(1)
|
|
218
|
-
cat_embeds.append(cat_embed)
|
|
219
|
-
|
|
220
|
-
if cat_embeds: # If we have categorical embeddings
|
|
221
|
-
if self.categorical_embedding_dims is not None:
|
|
222
|
-
if self.average_cat_embed:
|
|
223
|
-
# Stack and average in one operation
|
|
224
|
-
x_cat = torch.stack(cat_embeds, dim=0).mean(dim=0)
|
|
225
|
-
x_combined = torch.cat([x_text, x_cat], dim=1)
|
|
226
|
-
else:
|
|
227
|
-
# Optimize concatenation
|
|
228
|
-
x_combined = torch.cat([x_text] + cat_embeds, dim=1)
|
|
229
|
-
else:
|
|
230
|
-
# Sum embeddings efficiently
|
|
231
|
-
x_combined = x_text + torch.stack(cat_embeds, dim=0).sum(dim=0)
|
|
232
|
-
else:
|
|
233
|
-
x_combined = x_text
|
|
234
|
-
else:
|
|
235
|
-
x_combined = x_text
|
|
236
|
-
|
|
237
|
-
# Final linear layer
|
|
238
|
-
return self.fc(x_combined)
|
|
239
|
-
|
|
240
|
-
def predict(
|
|
241
|
-
self,
|
|
242
|
-
text: List[str],
|
|
243
|
-
categorical_variables: List[List[int]],
|
|
244
|
-
top_k=1,
|
|
245
|
-
explain=False,
|
|
246
|
-
preprocess=True,
|
|
247
|
-
):
|
|
248
|
-
"""
|
|
249
|
-
Args:
|
|
250
|
-
text (List[str]): A list of text observations.
|
|
251
|
-
params (Optional[Dict[str, Any]]): Additional parameters to
|
|
252
|
-
pass to the model for inference.
|
|
253
|
-
top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
|
|
254
|
-
explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
|
|
255
|
-
preprocess (bool): If True, preprocess text. Needs unidecode library.
|
|
256
|
-
|
|
257
|
-
Returns:
|
|
258
|
-
if explain is False:
|
|
259
|
-
predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
|
|
260
|
-
confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores.
|
|
261
|
-
if explain is True:
|
|
262
|
-
predictions (torch.Tensor, shape (len(text), top_k)): Containing the top_k most likely codes to the query.
|
|
263
|
-
confidence (torch.Tensor, shape (len(text), top_k)): Corresponding confidence scores.
|
|
264
|
-
all_attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
|
|
265
|
-
x (torch.Tensor): A tensor containing the token indices of the text.
|
|
266
|
-
id_to_token_dicts (List[Dict[int, str]]): A list of dictionaries mapping token indices to tokens (one for each sentence).
|
|
267
|
-
token_to_id_dicts (List[Dict[str, int]]): A list of dictionaries mapping tokens to token indices: the reverse of those in id_to_token_dicts.
|
|
268
|
-
text (list[str]): A plist containing the preprocessed text (one line for each sentence).
|
|
269
|
-
"""
|
|
270
|
-
|
|
271
|
-
flag_change_embed = False
|
|
272
|
-
if explain:
|
|
273
|
-
if not HAS_CAPTUM:
|
|
274
|
-
raise ImportError(
|
|
275
|
-
"Captum is not installed and is required for explainability. Run 'pip install torchFastText[explainability]'."
|
|
276
|
-
)
|
|
277
|
-
if self.direct_bagging:
|
|
278
|
-
# Get back the classical embedding layer for explainability
|
|
279
|
-
new_embed_layer = nn.Embedding(
|
|
280
|
-
embedding_dim=self.embedding_dim,
|
|
281
|
-
num_embeddings=self.num_rows,
|
|
282
|
-
padding_idx=self.padding_idx,
|
|
283
|
-
sparse=self.sparse,
|
|
284
|
-
)
|
|
285
|
-
new_embed_layer.load_state_dict(
|
|
286
|
-
self.embeddings.state_dict()
|
|
287
|
-
) # No issues, as exactly the same parameters
|
|
288
|
-
self.embeddings = new_embed_layer
|
|
289
|
-
self.direct_bagging = (
|
|
290
|
-
False # To inform the forward pass that we are not using EmbeddingBag anymore
|
|
291
|
-
)
|
|
292
|
-
flag_change_embed = True
|
|
293
|
-
|
|
294
|
-
lig = LayerIntegratedGradients(
|
|
295
|
-
self, self.embeddings
|
|
296
|
-
) # initialize a Captum layer gradient integrator
|
|
297
|
-
|
|
298
|
-
self.eval()
|
|
299
|
-
batch_size = len(text)
|
|
300
|
-
|
|
301
|
-
indices_batch, id_to_token_dicts, token_to_id_dicts = self.tokenizer.tokenize(
|
|
302
|
-
text, text_tokens=False, preprocess=preprocess
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
padding_index = (
|
|
306
|
-
self.tokenizer.get_buckets() + self.tokenizer.get_nwords()
|
|
307
|
-
) # padding index, the integer value of the padding token
|
|
308
|
-
|
|
309
|
-
padded_batch = torch.nn.utils.rnn.pad_sequence(
|
|
310
|
-
indices_batch,
|
|
311
|
-
batch_first=True,
|
|
312
|
-
padding_value=padding_index,
|
|
313
|
-
) # (batch_size, seq_len) - Tokenized (int) + padded text
|
|
314
|
-
|
|
315
|
-
x = padded_batch
|
|
316
|
-
|
|
317
|
-
if not self.no_cat_var:
|
|
318
|
-
other_features = []
|
|
319
|
-
# Transpose categorical_variables to iterate over features instead of samples
|
|
320
|
-
categorical_variables_transposed = categorical_variables.T
|
|
321
|
-
for i, categorical_variable in enumerate(categorical_variables_transposed):
|
|
322
|
-
other_features.append(
|
|
323
|
-
torch.tensor(categorical_variable).reshape(batch_size, -1).to(torch.int64)
|
|
324
|
-
)
|
|
325
|
-
|
|
326
|
-
other_features = torch.stack(other_features).reshape(batch_size, -1).long()
|
|
327
|
-
else:
|
|
328
|
-
other_features = torch.empty(batch_size)
|
|
329
|
-
|
|
330
|
-
pred = self(
|
|
331
|
-
x, other_features
|
|
332
|
-
) # forward pass, contains the prediction scores (len(text), num_classes)
|
|
333
|
-
label_scores = pred.detach().cpu()
|
|
334
|
-
label_scores_topk = torch.topk(label_scores, k=top_k, dim=1)
|
|
335
|
-
|
|
336
|
-
predictions = label_scores_topk.indices # get the top_k most likely predictions
|
|
337
|
-
confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
|
|
338
|
-
|
|
339
|
-
if explain:
|
|
340
|
-
assert not self.direct_bagging, "Direct bagging should be False for explainability"
|
|
341
|
-
all_attributions = []
|
|
342
|
-
for k in range(top_k):
|
|
343
|
-
attributions = lig.attribute(
|
|
344
|
-
(x, other_features), target=torch.Tensor(predictions[:, k]).long()
|
|
345
|
-
) # (batch_size, seq_len)
|
|
346
|
-
attributions = attributions.sum(dim=-1)
|
|
347
|
-
all_attributions.append(attributions.detach().cpu())
|
|
348
|
-
|
|
349
|
-
all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
|
|
350
|
-
|
|
351
|
-
# Get back to initial embedding layer:
|
|
352
|
-
# EmbeddingBag -> Embedding -> EmbeddingBag
|
|
353
|
-
# or keep Embedding with no change
|
|
354
|
-
if flag_change_embed:
|
|
355
|
-
new_embed_layer = nn.EmbeddingBag(
|
|
356
|
-
embedding_dim=self.embedding_dim,
|
|
357
|
-
num_embeddings=self.num_rows,
|
|
358
|
-
padding_idx=self.padding_idx,
|
|
359
|
-
sparse=self.sparse,
|
|
360
|
-
)
|
|
361
|
-
new_embed_layer.load_state_dict(
|
|
362
|
-
self.embeddings.state_dict()
|
|
363
|
-
) # No issues, as exactly the same parameters
|
|
364
|
-
self.embeddings = new_embed_layer
|
|
365
|
-
self.direct_bagging = True
|
|
366
|
-
return (
|
|
367
|
-
predictions,
|
|
368
|
-
confidence,
|
|
369
|
-
all_attributions,
|
|
370
|
-
x,
|
|
371
|
-
id_to_token_dicts,
|
|
372
|
-
token_to_id_dicts,
|
|
373
|
-
text,
|
|
374
|
-
)
|
|
375
|
-
else:
|
|
376
|
-
return predictions, confidence
|
|
377
|
-
|
|
378
|
-
def predict_and_explain(self, text, categorical_variables, top_k=1, n=5, cutoff=0.65):
|
|
379
|
-
"""
|
|
380
|
-
Args:
|
|
381
|
-
text (List[str]): A list of sentences.
|
|
382
|
-
params (Optional[Dict[str, Any]]): Additional parameters to
|
|
383
|
-
pass to the model for inference.
|
|
384
|
-
top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
|
|
385
|
-
n (int): mapping processed to original words: max number of candidate processed words to consider per original word (default: 5)
|
|
386
|
-
cutoff (float): mapping processed to original words: minimum similarity score to consider a candidate processed word (default: 0.75)
|
|
387
|
-
|
|
388
|
-
Returns:
|
|
389
|
-
predictions (torch.Tensor, shape (len(text), top_k)): Containing the top_k most likely codes to the query.
|
|
390
|
-
confidence (torch.Tensor, shape (len(text), top_k)): Corresponding confidence scores.
|
|
391
|
-
all_scores (List[List[List[float]]]): For each sentence, list of the top_k lists of attributions for each word in the sentence (one for each pred).
|
|
392
|
-
"""
|
|
393
|
-
|
|
394
|
-
# Step 1: Get the predictions, confidence scores and attributions at token level
|
|
395
|
-
(
|
|
396
|
-
pred,
|
|
397
|
-
confidence,
|
|
398
|
-
all_attr,
|
|
399
|
-
tokenized_text,
|
|
400
|
-
id_to_token_dicts,
|
|
401
|
-
token_to_id_dicts,
|
|
402
|
-
processed_text,
|
|
403
|
-
) = self.predict(
|
|
404
|
-
text=text, categorical_variables=categorical_variables, top_k=top_k, explain=True
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
tokenized_text_tokens = self.tokenizer._tokenized_text_in_tokens(
|
|
408
|
-
tokenized_text, id_to_token_dicts
|
|
409
|
-
)
|
|
410
|
-
|
|
411
|
-
# Step 2: Map the attributions at token level to the processed words
|
|
412
|
-
processed_word_to_score_dicts, processed_word_to_token_idx_dicts = (
|
|
413
|
-
compute_preprocessed_word_score(
|
|
414
|
-
processed_text,
|
|
415
|
-
tokenized_text_tokens,
|
|
416
|
-
all_attr,
|
|
417
|
-
id_to_token_dicts,
|
|
418
|
-
token_to_id_dicts,
|
|
419
|
-
min_n=self.tokenizer.min_n,
|
|
420
|
-
padding_index=self.padding_idx,
|
|
421
|
-
end_of_string_index=0,
|
|
422
|
-
)
|
|
423
|
-
)
|
|
424
|
-
|
|
425
|
-
# Step 3: Map the processed words to the original words
|
|
426
|
-
all_scores, orig_to_processed_mappings = compute_word_score(
|
|
427
|
-
processed_word_to_score_dicts, text, n=n, cutoff=cutoff
|
|
428
|
-
)
|
|
429
|
-
|
|
430
|
-
# Step 2bis: Get the attributions at letter level
|
|
431
|
-
all_scores_letters = explain_continuous(
|
|
432
|
-
text,
|
|
433
|
-
processed_text,
|
|
434
|
-
tokenized_text_tokens,
|
|
435
|
-
orig_to_processed_mappings,
|
|
436
|
-
processed_word_to_token_idx_dicts,
|
|
437
|
-
all_attr,
|
|
438
|
-
top_k,
|
|
439
|
-
)
|
|
440
|
-
|
|
441
|
-
return pred, confidence, all_scores, all_scores_letters
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
# ============================================================================
|
|
445
|
-
# PyTorch Lightning Module
|
|
446
|
-
# ============================================================================
|
|
447
|
-
|
|
448
|
-
class FastTextModule(pl.LightningModule):
|
|
449
|
-
"""Pytorch Lightning Module for FastTextModel."""
|
|
450
|
-
|
|
451
|
-
def __init__(
|
|
452
|
-
self,
|
|
453
|
-
model: FastTextModel,
|
|
454
|
-
loss,
|
|
455
|
-
optimizer,
|
|
456
|
-
optimizer_params,
|
|
457
|
-
scheduler,
|
|
458
|
-
scheduler_params,
|
|
459
|
-
scheduler_interval="epoch",
|
|
460
|
-
**kwargs,
|
|
461
|
-
):
|
|
462
|
-
"""
|
|
463
|
-
Initialize FastTextModule.
|
|
464
|
-
|
|
465
|
-
Args:
|
|
466
|
-
model: Model.
|
|
467
|
-
loss: Loss
|
|
468
|
-
optimizer: Optimizer
|
|
469
|
-
optimizer_params: Optimizer parameters.
|
|
470
|
-
scheduler: Scheduler.
|
|
471
|
-
scheduler_params: Scheduler parameters.
|
|
472
|
-
scheduler_interval: Scheduler interval.
|
|
473
|
-
"""
|
|
474
|
-
super().__init__()
|
|
475
|
-
self.save_hyperparameters(ignore=["model", "loss"])
|
|
476
|
-
|
|
477
|
-
self.model = model
|
|
478
|
-
self.loss = loss
|
|
479
|
-
self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes)
|
|
480
|
-
self.optimizer = optimizer
|
|
481
|
-
self.optimizer_params = optimizer_params
|
|
482
|
-
self.scheduler = scheduler
|
|
483
|
-
self.scheduler_params = scheduler_params
|
|
484
|
-
self.scheduler_interval = scheduler_interval
|
|
485
|
-
|
|
486
|
-
def forward(self, inputs) -> torch.Tensor:
|
|
487
|
-
"""
|
|
488
|
-
Perform forward-pass.
|
|
489
|
-
|
|
490
|
-
Args:
|
|
491
|
-
batch (List[torch.LongTensor]): Batch to perform forward-pass on.
|
|
492
|
-
|
|
493
|
-
Returns (torch.Tensor): Prediction.
|
|
494
|
-
"""
|
|
495
|
-
return self.model(inputs[0], inputs[1])
|
|
496
|
-
|
|
497
|
-
def training_step(self, batch, batch_idx: int) -> torch.Tensor:
|
|
498
|
-
"""
|
|
499
|
-
Training step.
|
|
500
|
-
|
|
501
|
-
Args:
|
|
502
|
-
batch (List[torch.LongTensor]): Training batch.
|
|
503
|
-
batch_idx (int): Batch index.
|
|
504
|
-
|
|
505
|
-
Returns (torch.Tensor): Loss tensor.
|
|
506
|
-
"""
|
|
507
|
-
|
|
508
|
-
inputs, targets = batch[:-1], batch[-1]
|
|
509
|
-
outputs = self.forward(inputs)
|
|
510
|
-
loss = self.loss(outputs, targets)
|
|
511
|
-
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
|
|
512
|
-
accuracy = self.accuracy_fn(outputs, targets)
|
|
513
|
-
self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
|
|
514
|
-
|
|
515
|
-
torch.cuda.empty_cache()
|
|
516
|
-
|
|
517
|
-
return loss
|
|
518
|
-
|
|
519
|
-
def validation_step(self, batch, batch_idx: int):
|
|
520
|
-
"""
|
|
521
|
-
Validation step.
|
|
522
|
-
|
|
523
|
-
Args:
|
|
524
|
-
batch (List[torch.LongTensor]): Validation batch.
|
|
525
|
-
batch_idx (int): Batch index.
|
|
526
|
-
|
|
527
|
-
Returns (torch.Tensor): Loss tensor.
|
|
528
|
-
"""
|
|
529
|
-
inputs, targets = batch[:-1], batch[-1]
|
|
530
|
-
outputs = self.forward(inputs)
|
|
531
|
-
loss = self.loss(outputs, targets)
|
|
532
|
-
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
|
|
533
|
-
|
|
534
|
-
accuracy = self.accuracy_fn(outputs, targets)
|
|
535
|
-
self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
|
|
536
|
-
return loss
|
|
537
|
-
|
|
538
|
-
def test_step(self, batch, batch_idx: int):
|
|
539
|
-
"""
|
|
540
|
-
Test step.
|
|
541
|
-
|
|
542
|
-
Args:
|
|
543
|
-
batch (List[torch.LongTensor]): Test batch.
|
|
544
|
-
batch_idx (int): Batch index.
|
|
545
|
-
|
|
546
|
-
Returns (torch.Tensor): Loss tensor.
|
|
547
|
-
"""
|
|
548
|
-
inputs, targets = batch[:-1], batch[-1]
|
|
549
|
-
outputs = self.forward(inputs)
|
|
550
|
-
loss = self.loss(outputs, targets)
|
|
551
|
-
|
|
552
|
-
accuracy = self.accuracy_fn(outputs, targets)
|
|
553
|
-
|
|
554
|
-
return loss, accuracy
|
|
555
|
-
|
|
556
|
-
def configure_optimizers(self):
|
|
557
|
-
"""
|
|
558
|
-
Configure optimizer for Pytorch lighting.
|
|
559
|
-
|
|
560
|
-
Returns: Optimizer and scheduler for pytorch lighting.
|
|
561
|
-
"""
|
|
562
|
-
optimizer = self.optimizer(self.parameters(), **self.optimizer_params)
|
|
563
|
-
|
|
564
|
-
# Only use scheduler if it's not ReduceLROnPlateau or if we can ensure val_loss is available
|
|
565
|
-
# For complex training setups, sometimes val_loss is not available every epoch
|
|
566
|
-
if hasattr(self.scheduler, '__name__') and 'ReduceLROnPlateau' in self.scheduler.__name__:
|
|
567
|
-
# For ReduceLROnPlateau, use train_loss as it's always available
|
|
568
|
-
scheduler = self.scheduler(optimizer, **self.scheduler_params)
|
|
569
|
-
scheduler_config = {
|
|
570
|
-
"scheduler": scheduler,
|
|
571
|
-
"monitor": "train_loss",
|
|
572
|
-
"interval": self.scheduler_interval,
|
|
573
|
-
}
|
|
574
|
-
return [optimizer], [scheduler_config]
|
|
575
|
-
else:
|
|
576
|
-
# For other schedulers (StepLR, etc.), no monitoring needed
|
|
577
|
-
scheduler = self.scheduler(optimizer, **self.scheduler_params)
|
|
578
|
-
return [optimizer], [scheduler]
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
# ============================================================================
|
|
582
|
-
# Dataset
|
|
583
|
-
# ============================================================================
|
|
584
|
-
|
|
585
|
-
class FastTextModelDataset(torch.utils.data.Dataset):
|
|
586
|
-
"""FastTextModelDataset class."""
|
|
587
|
-
|
|
588
|
-
def __init__(
|
|
589
|
-
self,
|
|
590
|
-
categorical_variables: List[List[int]],
|
|
591
|
-
texts: List[str],
|
|
592
|
-
tokenizer, # NGramTokenizer
|
|
593
|
-
outputs: List[int] = None,
|
|
594
|
-
**kwargs,
|
|
595
|
-
):
|
|
596
|
-
"""
|
|
597
|
-
Constructor for the TorchDataset class.
|
|
598
|
-
|
|
599
|
-
Args:
|
|
600
|
-
categorical_variables (List[List[int]]): The elements of this list
|
|
601
|
-
are the values of each categorical variable across the dataset.
|
|
602
|
-
text (List[str]): List of text descriptions.
|
|
603
|
-
y (List[int]): List of outcomes.
|
|
604
|
-
tokenizer (Tokenizer): Tokenizer.
|
|
605
|
-
"""
|
|
606
|
-
|
|
607
|
-
if categorical_variables is not None and len(categorical_variables) != len(texts):
|
|
608
|
-
raise ValueError("Categorical variables and texts must have the same length.")
|
|
609
|
-
|
|
610
|
-
if outputs is not None and len(outputs) != len(texts):
|
|
611
|
-
raise ValueError("Outputs and texts must have the same length.")
|
|
612
|
-
|
|
613
|
-
self.categorical_variables = categorical_variables
|
|
614
|
-
self.texts = texts
|
|
615
|
-
self.outputs = outputs
|
|
616
|
-
self.tokenizer = tokenizer
|
|
617
|
-
|
|
618
|
-
def __len__(self) -> int:
|
|
619
|
-
"""
|
|
620
|
-
Returns length of the data.
|
|
621
|
-
|
|
622
|
-
Returns:
|
|
623
|
-
int: Number of observations.
|
|
624
|
-
"""
|
|
625
|
-
return len(self.texts)
|
|
626
|
-
|
|
627
|
-
def __str__(self) -> str:
|
|
628
|
-
"""
|
|
629
|
-
Returns description of the Dataset.
|
|
630
|
-
|
|
631
|
-
Returns:
|
|
632
|
-
str: Description.
|
|
633
|
-
"""
|
|
634
|
-
return f"<FastTextModelDataset(N={len(self)})>"
|
|
635
|
-
|
|
636
|
-
def __getitem__(self, index: int) -> List:
|
|
637
|
-
"""
|
|
638
|
-
Returns observation for a given index.
|
|
639
|
-
|
|
640
|
-
Args:
|
|
641
|
-
index (int): Index.
|
|
642
|
-
|
|
643
|
-
Returns:
|
|
644
|
-
List[int, str]: Observation with given index.
|
|
645
|
-
"""
|
|
646
|
-
categorical_variables = (
|
|
647
|
-
self.categorical_variables[index] if self.categorical_variables is not None else None
|
|
648
|
-
)
|
|
649
|
-
text = self.texts[index]
|
|
650
|
-
|
|
651
|
-
if self.outputs is not None:
|
|
652
|
-
y = self.outputs[index]
|
|
653
|
-
return text, categorical_variables, y
|
|
654
|
-
else:
|
|
655
|
-
return text, categorical_variables
|
|
656
|
-
|
|
657
|
-
def collate_fn(self, batch):
|
|
658
|
-
"""
|
|
659
|
-
Efficient batch processing without explicit loops.
|
|
660
|
-
|
|
661
|
-
Args:
|
|
662
|
-
batch: Data batch.
|
|
663
|
-
|
|
664
|
-
Returns:
|
|
665
|
-
Tuple[torch.LongTensor]: Observation with given index.
|
|
666
|
-
"""
|
|
667
|
-
|
|
668
|
-
# Unzip the batch in one go using zip(*batch)
|
|
669
|
-
if self.outputs is not None:
|
|
670
|
-
text, *categorical_vars, y = zip(*batch)
|
|
671
|
-
else:
|
|
672
|
-
text, *categorical_vars = zip(*batch)
|
|
673
|
-
|
|
674
|
-
# Convert text to indices in parallel using map
|
|
675
|
-
indices_batch = list(map(lambda x: self.tokenizer.indices_matrix(x)[0], text))
|
|
676
|
-
|
|
677
|
-
# Get padding index once
|
|
678
|
-
padding_index = self.tokenizer.get_buckets() + self.tokenizer.get_nwords()
|
|
679
|
-
|
|
680
|
-
# Pad sequences efficiently
|
|
681
|
-
padded_batch = torch.nn.utils.rnn.pad_sequence(
|
|
682
|
-
indices_batch,
|
|
683
|
-
batch_first=True,
|
|
684
|
-
padding_value=padding_index,
|
|
685
|
-
)
|
|
686
|
-
|
|
687
|
-
# Handle categorical variables efficiently
|
|
688
|
-
if self.categorical_variables is not None:
|
|
689
|
-
categorical_tensors = torch.stack(
|
|
690
|
-
[
|
|
691
|
-
torch.tensor(cat_var, dtype=torch.float32)
|
|
692
|
-
for cat_var in categorical_vars[
|
|
693
|
-
0
|
|
694
|
-
] # Access first element since zip returns tuple
|
|
695
|
-
]
|
|
696
|
-
)
|
|
697
|
-
else:
|
|
698
|
-
categorical_tensors = torch.empty(
|
|
699
|
-
padded_batch.shape[0], 1, dtype=torch.float32, device=padded_batch.device
|
|
700
|
-
)
|
|
701
|
-
|
|
702
|
-
if self.outputs is not None:
|
|
703
|
-
# Convert labels to tensor in one go
|
|
704
|
-
y = torch.tensor(y, dtype=torch.long)
|
|
705
|
-
return (padded_batch, categorical_tensors, y)
|
|
706
|
-
else:
|
|
707
|
-
return (padded_batch, categorical_tensors)
|
|
708
|
-
|
|
709
|
-
def create_dataloader(
|
|
710
|
-
self,
|
|
711
|
-
batch_size: int,
|
|
712
|
-
shuffle: bool = False,
|
|
713
|
-
drop_last: bool = False,
|
|
714
|
-
num_workers: int = os.cpu_count() - 1,
|
|
715
|
-
pin_memory: bool = True,
|
|
716
|
-
persistent_workers: bool = True,
|
|
717
|
-
**kwargs,
|
|
718
|
-
) -> torch.utils.data.DataLoader:
|
|
719
|
-
"""
|
|
720
|
-
Creates a Dataloader from the FastTextModelDataset.
|
|
721
|
-
Use collate_fn() to tokenize and pad the sequences.
|
|
722
|
-
|
|
723
|
-
Args:
|
|
724
|
-
batch_size (int): Batch size.
|
|
725
|
-
shuffle (bool, optional): Shuffle option. Defaults to False.
|
|
726
|
-
drop_last (bool, optional): Drop last option. Defaults to False.
|
|
727
|
-
num_workers (int, optional): Number of workers. Defaults to os.cpu_count() - 1.
|
|
728
|
-
pin_memory (bool, optional): Set True if working on GPU, False if CPU. Defaults to True.
|
|
729
|
-
persistent_workers (bool, optional): Set True for training, False for inference. Defaults to True.
|
|
730
|
-
**kwargs: Additional arguments for PyTorch DataLoader.
|
|
731
|
-
|
|
732
|
-
Returns:
|
|
733
|
-
torch.utils.data.DataLoader: Dataloader.
|
|
734
|
-
"""
|
|
735
|
-
|
|
736
|
-
logger.info(f"Creating DataLoader with {num_workers} workers.")
|
|
737
|
-
|
|
738
|
-
# persistent_workers requires num_workers > 0
|
|
739
|
-
if num_workers == 0:
|
|
740
|
-
persistent_workers = False
|
|
741
|
-
|
|
742
|
-
return torch.utils.data.DataLoader(
|
|
743
|
-
dataset=self,
|
|
744
|
-
batch_size=batch_size,
|
|
745
|
-
collate_fn=self.collate_fn,
|
|
746
|
-
shuffle=shuffle,
|
|
747
|
-
drop_last=drop_last,
|
|
748
|
-
pin_memory=pin_memory,
|
|
749
|
-
num_workers=num_workers,
|
|
750
|
-
persistent_workers=persistent_workers,
|
|
751
|
-
**kwargs,
|
|
752
|
-
)
|