torchtextclassifiers 0.0.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +152 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +61 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +170 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +500 -413
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-1.0.0.dist-info/METADATA +87 -0
- torchtextclassifiers-1.0.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-1.0.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -1,346 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
NGramTokenizer class.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import ctypes
|
|
6
|
-
import json
|
|
7
|
-
from typing import List, Tuple, Type, Dict
|
|
8
|
-
|
|
9
|
-
import numpy as np
|
|
10
|
-
import torch
|
|
11
|
-
from torch import Tensor
|
|
12
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
-
from dataclasses import dataclass
|
|
14
|
-
from queue import Queue
|
|
15
|
-
import multiprocessing
|
|
16
|
-
|
|
17
|
-
from ...utilities.preprocess import clean_text_feature
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class NGramTokenizer:
|
|
21
|
-
"""
|
|
22
|
-
NGramTokenizer class.
|
|
23
|
-
"""
|
|
24
|
-
|
|
25
|
-
def __init__(
|
|
26
|
-
self,
|
|
27
|
-
min_count: int,
|
|
28
|
-
min_n: int,
|
|
29
|
-
max_n: int,
|
|
30
|
-
num_tokens: int,
|
|
31
|
-
len_word_ngrams: int,
|
|
32
|
-
training_text: List[str],
|
|
33
|
-
**kwargs,
|
|
34
|
-
):
|
|
35
|
-
"""
|
|
36
|
-
Constructor for the NGramTokenizer class.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
min_count (int): Minimum number of times a word has to be
|
|
40
|
-
in the training data to be given an embedding.
|
|
41
|
-
min_n (int): Minimum length of character n-grams.
|
|
42
|
-
max_n (int): Maximum length of character n-grams.
|
|
43
|
-
num_tokens (int): Number of rows in the embedding matrix.
|
|
44
|
-
word_ngrams (int): Maximum length of word n-grams.
|
|
45
|
-
training_text (List[str]): List of training texts.
|
|
46
|
-
|
|
47
|
-
Raises:
|
|
48
|
-
ValueError: If `min_n` is 1 or smaller.
|
|
49
|
-
ValueError: If `max_n` is 7 or higher.
|
|
50
|
-
"""
|
|
51
|
-
if min_n < 2:
|
|
52
|
-
raise ValueError("`min_n` parameter must be greater than 1.")
|
|
53
|
-
if max_n > 6:
|
|
54
|
-
raise ValueError("`max_n` parameter must be smaller than 7.")
|
|
55
|
-
|
|
56
|
-
self.min_count = min_count
|
|
57
|
-
self.min_n = min_n
|
|
58
|
-
self.max_n = max_n
|
|
59
|
-
self.num_tokens = num_tokens
|
|
60
|
-
self.word_ngrams = len_word_ngrams
|
|
61
|
-
|
|
62
|
-
word_counts = {}
|
|
63
|
-
for sentence in training_text:
|
|
64
|
-
for word in sentence.split(" "):
|
|
65
|
-
word_counts[word] = word_counts.setdefault(word, 0) + 1
|
|
66
|
-
|
|
67
|
-
self.word_id_mapping = {}
|
|
68
|
-
i = 1
|
|
69
|
-
for word, counts in word_counts.items():
|
|
70
|
-
if word_counts[word] >= min_count:
|
|
71
|
-
self.word_id_mapping[word] = i
|
|
72
|
-
i += 1
|
|
73
|
-
self.nwords = len(self.word_id_mapping)
|
|
74
|
-
|
|
75
|
-
self.padding_index = self.num_tokens + self.get_nwords()
|
|
76
|
-
|
|
77
|
-
def __str__(self) -> str:
|
|
78
|
-
"""
|
|
79
|
-
Returns description of the NGramTokenizer.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
str: Description.
|
|
83
|
-
"""
|
|
84
|
-
return f"<NGramTokenizer(min_n={self.min_n}, max_n={self.max_n}, num_tokens={self.num_tokens}, word_ngrams={self.word_ngrams}, nwords={self.nwords})>"
|
|
85
|
-
|
|
86
|
-
def get_nwords(self) -> int:
|
|
87
|
-
"""
|
|
88
|
-
Return number of words kept in training data.
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
int: Number of words.
|
|
92
|
-
"""
|
|
93
|
-
return self.nwords
|
|
94
|
-
|
|
95
|
-
def get_buckets(self) -> int:
|
|
96
|
-
"""
|
|
97
|
-
Return number of buckets for tokenizer.
|
|
98
|
-
|
|
99
|
-
Returns:
|
|
100
|
-
int: Number of buckets.
|
|
101
|
-
"""
|
|
102
|
-
return self.num_tokens
|
|
103
|
-
|
|
104
|
-
@staticmethod
|
|
105
|
-
def get_ngram_list(word: str, n: int) -> List[str]:
|
|
106
|
-
"""
|
|
107
|
-
Return the list of character n-grams for a word with a
|
|
108
|
-
given n.
|
|
109
|
-
|
|
110
|
-
Args:
|
|
111
|
-
word (str): Word.
|
|
112
|
-
n (int): Length of the n-grams.
|
|
113
|
-
|
|
114
|
-
Returns:
|
|
115
|
-
List[str]: List of character n-grams.
|
|
116
|
-
"""
|
|
117
|
-
return [word[i : i + n] for i in range(len(word) - n + 1)]
|
|
118
|
-
|
|
119
|
-
@staticmethod
|
|
120
|
-
def get_hash(subword: str) -> int:
|
|
121
|
-
"""
|
|
122
|
-
Return hash for a given subword.
|
|
123
|
-
|
|
124
|
-
Args:
|
|
125
|
-
subword (str): Character n-gram.
|
|
126
|
-
|
|
127
|
-
Returns:
|
|
128
|
-
int: Corresponding hash.
|
|
129
|
-
"""
|
|
130
|
-
h = ctypes.c_uint32(2166136261).value
|
|
131
|
-
for c in subword:
|
|
132
|
-
c = ctypes.c_int8(ord(c)).value
|
|
133
|
-
h = ctypes.c_uint32(h ^ c).value
|
|
134
|
-
h = ctypes.c_uint32(h * 16777619).value
|
|
135
|
-
return h
|
|
136
|
-
|
|
137
|
-
@staticmethod
|
|
138
|
-
def get_word_ngram_id(hashes: Tuple[int], bucket: int, nwords: int) -> int:
|
|
139
|
-
"""
|
|
140
|
-
Get word ngram index in the embedding matrix.
|
|
141
|
-
|
|
142
|
-
Args:
|
|
143
|
-
hashes (Tuple[int]): Word hashes.
|
|
144
|
-
bucket (int): Number of rows in embedding matrix.
|
|
145
|
-
nwords (int): Number of words in the vocabulary.
|
|
146
|
-
|
|
147
|
-
Returns:
|
|
148
|
-
int: Word ngram hash.
|
|
149
|
-
"""
|
|
150
|
-
hashes = [ctypes.c_int32(hash_value).value for hash_value in hashes]
|
|
151
|
-
h = ctypes.c_uint64(hashes[0]).value
|
|
152
|
-
for j in range(1, len(hashes)):
|
|
153
|
-
h = ctypes.c_uint64((h * 116049371)).value
|
|
154
|
-
h = ctypes.c_uint64(h + hashes[j]).value
|
|
155
|
-
return h % bucket + nwords
|
|
156
|
-
|
|
157
|
-
def get_subword_index(self, subword: str) -> int:
|
|
158
|
-
"""
|
|
159
|
-
Return the row index from the embedding matrix which
|
|
160
|
-
corresponds to a character n-gram.
|
|
161
|
-
|
|
162
|
-
Args:
|
|
163
|
-
subword (str): Character n-gram.
|
|
164
|
-
|
|
165
|
-
Returns:
|
|
166
|
-
int: Index.
|
|
167
|
-
"""
|
|
168
|
-
return self.get_hash(subword) % self.num_tokens + self.nwords
|
|
169
|
-
|
|
170
|
-
def get_word_index(self, word: str) -> int:
|
|
171
|
-
"""
|
|
172
|
-
Return the row index from the embedding matrix which
|
|
173
|
-
corresponds to a word.
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
word (str): Word.
|
|
177
|
-
|
|
178
|
-
Returns:
|
|
179
|
-
int: Index.
|
|
180
|
-
"""
|
|
181
|
-
return self.word_id_mapping[word]
|
|
182
|
-
|
|
183
|
-
def get_subwords(self, word: str) -> Tuple[List[str], List[int]]:
|
|
184
|
-
"""
|
|
185
|
-
Return all subwords tokens and indices for a given word.
|
|
186
|
-
Also adds the whole word token and indice if the word is in word_id_mapping
|
|
187
|
-
(==> the word is in initial vocabulary + seen at least MIN_COUNT times).
|
|
188
|
-
Adds tags "<" and ">" to the word.
|
|
189
|
-
|
|
190
|
-
Args:
|
|
191
|
-
word (str): Word.
|
|
192
|
-
|
|
193
|
-
Returns:
|
|
194
|
-
Tuple[List[str], List[int]]: Tuple of tokens and indices.
|
|
195
|
-
"""
|
|
196
|
-
tokens = []
|
|
197
|
-
word_with_tags = "<" + word + ">"
|
|
198
|
-
|
|
199
|
-
# Get subwords and associated indices WITHOUT the whole word
|
|
200
|
-
for n in range(self.min_n, self.max_n + 1):
|
|
201
|
-
ngrams = self.get_ngram_list(word_with_tags, n)
|
|
202
|
-
tokens += [
|
|
203
|
-
ngram for ngram in ngrams if ngram != word_with_tags and ngram != word
|
|
204
|
-
] # Exclude the full word
|
|
205
|
-
|
|
206
|
-
indices = [self.get_subword_index(token) for token in tokens]
|
|
207
|
-
assert word not in tokens
|
|
208
|
-
|
|
209
|
-
# Add word token and indice only if the word is in word_id_mapping
|
|
210
|
-
if word in self.word_id_mapping.keys():
|
|
211
|
-
self.get_word_index(word)
|
|
212
|
-
tokens = [word] + tokens
|
|
213
|
-
indices = [self.get_word_index(word)] + indices
|
|
214
|
-
|
|
215
|
-
return (tokens, indices)
|
|
216
|
-
|
|
217
|
-
def indices_matrix(self, sentence: str) -> tuple[torch.Tensor, dict, dict]:
|
|
218
|
-
"""
|
|
219
|
-
Returns an array of token indices for a text description.
|
|
220
|
-
|
|
221
|
-
Args:
|
|
222
|
-
sentence (str): Text description.
|
|
223
|
-
|
|
224
|
-
Returns:
|
|
225
|
-
tuple: (torch.Tensor of indices, id_to_token dict, token_to_id dict)
|
|
226
|
-
"""
|
|
227
|
-
# Pre-split the sentence once
|
|
228
|
-
words = sentence.split()
|
|
229
|
-
words.append("</s>") # Add end of string token
|
|
230
|
-
|
|
231
|
-
indices = []
|
|
232
|
-
all_tokens_id = {}
|
|
233
|
-
|
|
234
|
-
# Process subwords in one batch
|
|
235
|
-
for word in words[:-1]: # Exclude </s> from subword processing
|
|
236
|
-
tokens, ind = self.get_subwords(word)
|
|
237
|
-
indices.extend(ind)
|
|
238
|
-
# Update dictionary with zip for efficiency
|
|
239
|
-
all_tokens_id.update(zip(tokens, ind))
|
|
240
|
-
|
|
241
|
-
# Add </s> token
|
|
242
|
-
indices.append(0)
|
|
243
|
-
all_tokens_id["</s>"] = 0
|
|
244
|
-
|
|
245
|
-
# Compute word n-grams more efficiently
|
|
246
|
-
if self.word_ngrams > 1:
|
|
247
|
-
# Pre-compute hashes for all words to avoid repeated computation
|
|
248
|
-
word_hashes = [self.get_hash(word) for word in words]
|
|
249
|
-
|
|
250
|
-
# Generate n-grams using sliding window
|
|
251
|
-
word_ngram_ids = []
|
|
252
|
-
for n in range(2, self.word_ngrams + 1):
|
|
253
|
-
for i in range(len(words) - n + 1):
|
|
254
|
-
# Get slice of hashes for current n-gram
|
|
255
|
-
gram_hashes = tuple(word_hashes[i : i + n])
|
|
256
|
-
|
|
257
|
-
# Compute n-gram ID
|
|
258
|
-
word_ngram_id = int(
|
|
259
|
-
self.get_word_ngram_id(gram_hashes, self.num_tokens, self.nwords)
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# Store gram and its ID
|
|
263
|
-
gram = " ".join(words[i : i + n])
|
|
264
|
-
all_tokens_id[gram] = word_ngram_id
|
|
265
|
-
word_ngram_ids.append(word_ngram_id)
|
|
266
|
-
|
|
267
|
-
# Extend indices with n-gram IDs
|
|
268
|
-
indices.extend(word_ngram_ids)
|
|
269
|
-
|
|
270
|
-
# Create reverse mapping once at the end
|
|
271
|
-
id_to_token = {v: k for k, v in all_tokens_id.items()}
|
|
272
|
-
|
|
273
|
-
# Convert to tensor directly
|
|
274
|
-
return torch.tensor(indices, dtype=torch.long), id_to_token, all_tokens_id
|
|
275
|
-
|
|
276
|
-
def tokenize(self, text: list[str], text_tokens=True, preprocess=False):
|
|
277
|
-
"""
|
|
278
|
-
Tokenize a list of sentences.
|
|
279
|
-
|
|
280
|
-
Args:
|
|
281
|
-
text (list[str]): List of sentences.
|
|
282
|
-
text_tokens (bool): If True, return tokenized text in tokens.
|
|
283
|
-
preprocess (bool): If True, preprocess text. Needs unidecode library.
|
|
284
|
-
|
|
285
|
-
Returns:
|
|
286
|
-
np.array: Array of indices.
|
|
287
|
-
"""
|
|
288
|
-
|
|
289
|
-
if preprocess:
|
|
290
|
-
text = clean_text_feature(text)
|
|
291
|
-
|
|
292
|
-
tokenized_text = []
|
|
293
|
-
id_to_token_dicts = []
|
|
294
|
-
token_to_id_dicts = []
|
|
295
|
-
for sentence in text:
|
|
296
|
-
all_ind, id_to_token, token_to_id = self.indices_matrix(
|
|
297
|
-
sentence
|
|
298
|
-
) # tokenize and convert to token indices
|
|
299
|
-
tokenized_text.append(all_ind)
|
|
300
|
-
id_to_token_dicts.append(id_to_token)
|
|
301
|
-
token_to_id_dicts.append(token_to_id)
|
|
302
|
-
|
|
303
|
-
if text_tokens:
|
|
304
|
-
tokenized_text_tokens = self._tokenized_text_in_tokens(
|
|
305
|
-
tokenized_text, id_to_token_dicts
|
|
306
|
-
)
|
|
307
|
-
return tokenized_text_tokens, tokenized_text, id_to_token_dicts, token_to_id_dicts
|
|
308
|
-
else:
|
|
309
|
-
return tokenized_text, id_to_token_dicts, token_to_id_dicts
|
|
310
|
-
|
|
311
|
-
def _tokenized_text_in_tokens(self, tokenized_text, id_to_token_dicts):
|
|
312
|
-
"""
|
|
313
|
-
Convert tokenized text in int format to tokens in str format (given a mapping dictionary).
|
|
314
|
-
Private method. Used in tokenizer.tokenize and pytorch_model.predict()
|
|
315
|
-
|
|
316
|
-
Args:
|
|
317
|
-
tokenized_text (list): List of tokenized text in int format.
|
|
318
|
-
id_to_token_dicts (list[Dict]): List of dictionaries mapping token indices to tokens.
|
|
319
|
-
|
|
320
|
-
Both lists have the same length (number of sentences).
|
|
321
|
-
|
|
322
|
-
Returns:
|
|
323
|
-
list[list[str]]: List of tokenized text in str format.
|
|
324
|
-
|
|
325
|
-
"""
|
|
326
|
-
|
|
327
|
-
return [
|
|
328
|
-
[
|
|
329
|
-
id_to_token_dicts[i][token_id.item()]
|
|
330
|
-
for token_id in tokenized_sentence
|
|
331
|
-
if token_id.item() not in {self.padding_index}
|
|
332
|
-
]
|
|
333
|
-
for i, tokenized_sentence in enumerate(tokenized_text)
|
|
334
|
-
]
|
|
335
|
-
|
|
336
|
-
def get_vocab(self):
|
|
337
|
-
return self.word_id_mapping
|
|
338
|
-
|
|
339
|
-
@classmethod
|
|
340
|
-
def from_json(cls: Type["NGramTokenizer"], filepath: str, training_text) -> "NGramTokenizer":
|
|
341
|
-
"""
|
|
342
|
-
Load a dataclass instance from a JSON file.
|
|
343
|
-
"""
|
|
344
|
-
with open(filepath, "r") as f:
|
|
345
|
-
data = json.load(f)
|
|
346
|
-
return cls(**data, training_text=training_text)
|
|
@@ -1,216 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
from ..base import BaseClassifierWrapper
|
|
3
|
-
from .core import FastTextConfig
|
|
4
|
-
from .tokenizer import NGramTokenizer
|
|
5
|
-
from .model import FastTextModel, FastTextModule, FastTextModelDataset
|
|
6
|
-
from ...utilities.checkers import check_X, check_Y
|
|
7
|
-
import logging
|
|
8
|
-
import numpy as np
|
|
9
|
-
import torch
|
|
10
|
-
from torch.optim import SGD, Adam
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger()
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class FastTextWrapper(BaseClassifierWrapper):
|
|
16
|
-
"""Wrapper for FastText classifier."""
|
|
17
|
-
|
|
18
|
-
def __init__(self, config: FastTextConfig):
|
|
19
|
-
super().__init__(config)
|
|
20
|
-
self.config: FastTextConfig = config
|
|
21
|
-
self.tokenizer: Optional[NGramTokenizer] = None # FastText-specific tokenizer
|
|
22
|
-
|
|
23
|
-
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
24
|
-
"""Build NGram tokenizer for FastText."""
|
|
25
|
-
self.tokenizer = NGramTokenizer(
|
|
26
|
-
self.config.min_count,
|
|
27
|
-
self.config.min_n,
|
|
28
|
-
self.config.max_n,
|
|
29
|
-
self.config.num_tokens,
|
|
30
|
-
self.config.len_word_ngrams,
|
|
31
|
-
training_text,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
def build_tokenizer(self, training_text: np.ndarray) -> None:
|
|
35
|
-
"""Legacy method for backward compatibility."""
|
|
36
|
-
self.prepare_text_features(training_text)
|
|
37
|
-
|
|
38
|
-
def _build_pytorch_model(self) -> None:
|
|
39
|
-
"""Build FastText PyTorch model."""
|
|
40
|
-
if self.config.num_rows is None:
|
|
41
|
-
if self.tokenizer is None:
|
|
42
|
-
raise ValueError(
|
|
43
|
-
"Please provide a tokenizer or num_rows."
|
|
44
|
-
)
|
|
45
|
-
else:
|
|
46
|
-
self.config.num_rows = self.tokenizer.padding_index + 1
|
|
47
|
-
else:
|
|
48
|
-
if self.tokenizer is not None:
|
|
49
|
-
if self.config.num_rows != self.tokenizer.padding_index + 1:
|
|
50
|
-
logger.warning(
|
|
51
|
-
f"Divergent values for num_rows: {self.config.num_rows} and {self.tokenizer.padding_index + 1}. "
|
|
52
|
-
f"Using max value."
|
|
53
|
-
)
|
|
54
|
-
self.config.num_rows = max(self.config.num_rows, self.tokenizer.padding_index + 1)
|
|
55
|
-
|
|
56
|
-
self.padding_idx = self.config.num_rows - 1
|
|
57
|
-
|
|
58
|
-
# Update tokenizer padding index if necessary
|
|
59
|
-
if self.tokenizer is not None and self.padding_idx != self.tokenizer.padding_index:
|
|
60
|
-
self.tokenizer.padding_index = self.padding_idx
|
|
61
|
-
|
|
62
|
-
self.pytorch_model = FastTextModel(
|
|
63
|
-
tokenizer=self.tokenizer,
|
|
64
|
-
embedding_dim=self.config.embedding_dim,
|
|
65
|
-
num_rows=self.config.num_rows,
|
|
66
|
-
num_classes=self.config.num_classes,
|
|
67
|
-
categorical_vocabulary_sizes=self.config.categorical_vocabulary_sizes,
|
|
68
|
-
categorical_embedding_dims=self.config.categorical_embedding_dims,
|
|
69
|
-
padding_idx=self.padding_idx,
|
|
70
|
-
sparse=self.config.sparse,
|
|
71
|
-
direct_bagging=self.config.direct_bagging,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
def _check_and_init_lightning(
|
|
75
|
-
self,
|
|
76
|
-
optimizer=None,
|
|
77
|
-
optimizer_params=None,
|
|
78
|
-
lr=None,
|
|
79
|
-
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
80
|
-
scheduler_params=None,
|
|
81
|
-
patience_scheduler=3,
|
|
82
|
-
loss=torch.nn.CrossEntropyLoss(),
|
|
83
|
-
) -> None:
|
|
84
|
-
"""Initialize Lightning module for FastText."""
|
|
85
|
-
if optimizer is None:
|
|
86
|
-
if lr is None:
|
|
87
|
-
lr = getattr(self.config, 'learning_rate', 4e-3) # Use config or default
|
|
88
|
-
self.optimizer = SGD if self.config.sparse else Adam
|
|
89
|
-
self.optimizer_params = {"lr": lr}
|
|
90
|
-
else:
|
|
91
|
-
self.optimizer = optimizer
|
|
92
|
-
if optimizer_params is None:
|
|
93
|
-
if lr is not None:
|
|
94
|
-
self.optimizer_params = {"lr": lr}
|
|
95
|
-
else:
|
|
96
|
-
logger.warning("No optimizer parameters provided. Using defaults.")
|
|
97
|
-
self.optimizer_params = {}
|
|
98
|
-
|
|
99
|
-
self.scheduler = scheduler
|
|
100
|
-
|
|
101
|
-
if scheduler_params is None:
|
|
102
|
-
logger.warning("No scheduler parameters provided. Using defaults.")
|
|
103
|
-
self.scheduler_params = {
|
|
104
|
-
"mode": "min",
|
|
105
|
-
"patience": patience_scheduler,
|
|
106
|
-
}
|
|
107
|
-
else:
|
|
108
|
-
self.scheduler_params = scheduler_params
|
|
109
|
-
|
|
110
|
-
self.loss = loss
|
|
111
|
-
|
|
112
|
-
self.lightning_module = FastTextModule(
|
|
113
|
-
model=self.pytorch_model,
|
|
114
|
-
loss=self.loss,
|
|
115
|
-
optimizer=self.optimizer,
|
|
116
|
-
optimizer_params=self.optimizer_params,
|
|
117
|
-
scheduler=self.scheduler,
|
|
118
|
-
scheduler_params=self.scheduler_params,
|
|
119
|
-
scheduler_interval="epoch",
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
def predict(self, X: np.ndarray, top_k=1, preprocess=False, verbose=False) -> np.ndarray:
|
|
123
|
-
"""Make predictions with FastText model."""
|
|
124
|
-
if not self.trained:
|
|
125
|
-
raise Exception("Model must be trained first.")
|
|
126
|
-
|
|
127
|
-
text, categorical_variables, no_cat_var = check_X(X)
|
|
128
|
-
if categorical_variables is not None:
|
|
129
|
-
if categorical_variables.shape[1] != self.config.num_categorical_features:
|
|
130
|
-
raise Exception(
|
|
131
|
-
f"X must have the same number of categorical variables as training data."
|
|
132
|
-
)
|
|
133
|
-
else:
|
|
134
|
-
assert self.pytorch_model.no_cat_var == True
|
|
135
|
-
|
|
136
|
-
predictions, confidence = self.pytorch_model.predict(
|
|
137
|
-
text, categorical_variables, top_k=top_k, preprocess=preprocess
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
# Return just predictions, squeeze out the top_k dimension if top_k=1
|
|
141
|
-
if top_k == 1:
|
|
142
|
-
predictions = predictions.squeeze(-1)
|
|
143
|
-
|
|
144
|
-
# Convert to numpy array for consistency
|
|
145
|
-
if hasattr(predictions, 'numpy'):
|
|
146
|
-
predictions = predictions.numpy()
|
|
147
|
-
|
|
148
|
-
return predictions
|
|
149
|
-
|
|
150
|
-
def validate(self, X: np.ndarray, Y: np.ndarray, batch_size=256, num_workers=12) -> float:
|
|
151
|
-
"""Validate FastText model."""
|
|
152
|
-
if not self.trained:
|
|
153
|
-
raise Exception("Model must be trained first.")
|
|
154
|
-
|
|
155
|
-
# Use predict method which handles input validation and returns just predictions
|
|
156
|
-
predictions = self.predict(X)
|
|
157
|
-
y = check_Y(Y)
|
|
158
|
-
|
|
159
|
-
# Convert predictions to numpy if it's a tensor
|
|
160
|
-
if hasattr(predictions, 'numpy'):
|
|
161
|
-
predictions = predictions.numpy()
|
|
162
|
-
|
|
163
|
-
# Calculate accuracy
|
|
164
|
-
accuracy = (predictions == y).mean()
|
|
165
|
-
return float(accuracy)
|
|
166
|
-
|
|
167
|
-
def predict_and_explain(self, X: np.ndarray, top_k=1):
|
|
168
|
-
"""Predict and explain with FastText model."""
|
|
169
|
-
if not self.trained:
|
|
170
|
-
raise Exception("Model must be trained first.")
|
|
171
|
-
|
|
172
|
-
text, categorical_variables, no_cat_var = check_X(X)
|
|
173
|
-
if categorical_variables is not None:
|
|
174
|
-
if categorical_variables.shape[1] != self.config.num_categorical_features:
|
|
175
|
-
raise Exception(
|
|
176
|
-
f"X must have the same number of categorical variables as training data ({self.config.num_categorical_features})."
|
|
177
|
-
)
|
|
178
|
-
else:
|
|
179
|
-
assert self.pytorch_model.no_cat_var == True
|
|
180
|
-
|
|
181
|
-
return self.pytorch_model.predict_and_explain(text, categorical_variables, top_k=top_k)
|
|
182
|
-
|
|
183
|
-
def create_dataset(self, texts: np.ndarray, labels: np.ndarray, categorical_variables: np.ndarray = None):
|
|
184
|
-
"""Create FastText dataset."""
|
|
185
|
-
return FastTextModelDataset(
|
|
186
|
-
categorical_variables=categorical_variables,
|
|
187
|
-
texts=texts,
|
|
188
|
-
outputs=labels,
|
|
189
|
-
tokenizer=self.tokenizer,
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuffle: bool = True):
|
|
193
|
-
"""Create FastText dataloader."""
|
|
194
|
-
return dataset.create_dataloader(batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
|
|
195
|
-
|
|
196
|
-
def load_best_model(self, checkpoint_path: str) -> None:
|
|
197
|
-
"""Load best FastText model from checkpoint."""
|
|
198
|
-
self.lightning_module = FastTextModule.load_from_checkpoint(
|
|
199
|
-
checkpoint_path,
|
|
200
|
-
model=self.pytorch_model,
|
|
201
|
-
loss=self.loss,
|
|
202
|
-
optimizer=self.optimizer,
|
|
203
|
-
optimizer_params=self.optimizer_params,
|
|
204
|
-
scheduler=self.scheduler,
|
|
205
|
-
scheduler_params=self.scheduler_params,
|
|
206
|
-
scheduler_interval="epoch",
|
|
207
|
-
)
|
|
208
|
-
self.pytorch_model = self.lightning_module.model.to("cpu")
|
|
209
|
-
self.trained = True
|
|
210
|
-
self.pytorch_model.eval()
|
|
211
|
-
|
|
212
|
-
@classmethod
|
|
213
|
-
def get_config_class(cls):
|
|
214
|
-
"""Return the configuration class for FastText wrapper."""
|
|
215
|
-
return FastTextConfig
|
|
216
|
-
|