torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +114 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +43 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +166 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +463 -405
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
- torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -1,346 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Utility functions.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import warnings
|
|
6
|
-
import difflib
|
|
7
|
-
from difflib import SequenceMatcher
|
|
8
|
-
|
|
9
|
-
import torch
|
|
10
|
-
import torch.nn.functional as F
|
|
11
|
-
|
|
12
|
-
from .preprocess import clean_text_feature
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def preprocess_token(token):
|
|
16
|
-
preprocessed_token = token.replace("</s>", "")
|
|
17
|
-
preprocessed_token = preprocessed_token.replace("<", "")
|
|
18
|
-
preprocessed_token = preprocessed_token.replace(">", "")
|
|
19
|
-
|
|
20
|
-
preprocessed_token = preprocessed_token.split()
|
|
21
|
-
|
|
22
|
-
return preprocessed_token
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def map_processed_to_original(processed_words, original_words, n=1, cutoff=0.9):
|
|
26
|
-
"""
|
|
27
|
-
Map processed words to original words based on similarity scores.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
processed_words (List[str]): List of processed words.
|
|
31
|
-
original_words (List[str]): List of original words.
|
|
32
|
-
n (int): Number of closest processed words to consider for a given original word.
|
|
33
|
-
cutoff (float): Minimum similarity score for a match.
|
|
34
|
-
|
|
35
|
-
Returns:
|
|
36
|
-
Dict[str, str]: Mapping from original word to the corresponding closest processed word.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
# For each word in the original list, find the n closest matching processed words
|
|
40
|
-
word_mapping = {}
|
|
41
|
-
|
|
42
|
-
for original_word in original_words:
|
|
43
|
-
original_word_prepro = clean_text_feature([original_word], remove_stop_words=False)[
|
|
44
|
-
0
|
|
45
|
-
] # Preprocess the original word
|
|
46
|
-
|
|
47
|
-
if original_word_prepro == "":
|
|
48
|
-
continue
|
|
49
|
-
|
|
50
|
-
max_similarity_score = 0
|
|
51
|
-
best_processed_word = None
|
|
52
|
-
# Calculate the similarity score for each processed word with the current original word
|
|
53
|
-
for processed_word in processed_words:
|
|
54
|
-
similarity_score = difflib.SequenceMatcher(
|
|
55
|
-
None, processed_word, original_word_prepro
|
|
56
|
-
).ratio() # Ratcliff-Obershelp algorithm
|
|
57
|
-
|
|
58
|
-
# Only consider matches with similarity above the cutoff
|
|
59
|
-
if similarity_score > max_similarity_score and similarity_score >= cutoff:
|
|
60
|
-
max_similarity_score = similarity_score
|
|
61
|
-
best_processed_word = processed_word
|
|
62
|
-
|
|
63
|
-
if best_processed_word is not None:
|
|
64
|
-
# original_word = original_word.replace(',', '')
|
|
65
|
-
# Add the tuple (list of closest words, list of similarity scores) to the mapping
|
|
66
|
-
word_mapping[original_word] = best_processed_word
|
|
67
|
-
|
|
68
|
-
return word_mapping
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def test_end_of_word(all_processed_words, word, target_token, next_token, min_n):
|
|
72
|
-
flag = False
|
|
73
|
-
if target_token[-1] == ">":
|
|
74
|
-
if next_token[0] == "<":
|
|
75
|
-
if word in target_token:
|
|
76
|
-
flag = True
|
|
77
|
-
if word in next_token:
|
|
78
|
-
flag = False
|
|
79
|
-
if next_token[1] != word[0]:
|
|
80
|
-
flag = True
|
|
81
|
-
if len(next_token) == min_n:
|
|
82
|
-
flag = True
|
|
83
|
-
if next_token in all_processed_words:
|
|
84
|
-
flag = True
|
|
85
|
-
|
|
86
|
-
return flag
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
def match_word_to_token_indexes(sentence, tokenized_sentence_tokens, min_n):
|
|
90
|
-
"""
|
|
91
|
-
Match words to token indexes in a sentence.
|
|
92
|
-
|
|
93
|
-
Args:
|
|
94
|
-
sentence (str): Preprocessed sentence.
|
|
95
|
-
tokenized_sentence_tokens (List[str]): List of tokenized sentence tokens.
|
|
96
|
-
|
|
97
|
-
Returns:
|
|
98
|
-
Dict[str, List[int]]: Mapping from word to list of token indexes.
|
|
99
|
-
|
|
100
|
-
"""
|
|
101
|
-
|
|
102
|
-
pointer_token = 0
|
|
103
|
-
res = {}
|
|
104
|
-
processed_sentence = clean_text_feature([sentence], remove_stop_words=False)[0]
|
|
105
|
-
processed_words = processed_sentence.split()
|
|
106
|
-
# we know the tokens are in the right order
|
|
107
|
-
for index_word, word in enumerate(processed_words):
|
|
108
|
-
if word not in res:
|
|
109
|
-
res[word] = []
|
|
110
|
-
|
|
111
|
-
start = pointer_token
|
|
112
|
-
|
|
113
|
-
# while we don't reach the end of the word, get going
|
|
114
|
-
while not test_end_of_word(
|
|
115
|
-
processed_words,
|
|
116
|
-
word,
|
|
117
|
-
tokenized_sentence_tokens[pointer_token],
|
|
118
|
-
tokenized_sentence_tokens[pointer_token + 1],
|
|
119
|
-
min_n=min_n,
|
|
120
|
-
):
|
|
121
|
-
pointer_token += 1
|
|
122
|
-
if pointer_token == len(tokenized_sentence_tokens) - 1:
|
|
123
|
-
warnings.warn("Error in the tokenization of the sentence")
|
|
124
|
-
# workaround to avoid error: each word is asociated to regular ranges
|
|
125
|
-
chunck = len(tokenized_sentence_tokens) // len(processed_words)
|
|
126
|
-
for idx, word in enumerate(processed_words):
|
|
127
|
-
res[word] = range(
|
|
128
|
-
idx * chunck, min((idx + 1) * chunck, len(tokenized_sentence_tokens))
|
|
129
|
-
)
|
|
130
|
-
return res
|
|
131
|
-
|
|
132
|
-
pointer_token += 1
|
|
133
|
-
end = pointer_token
|
|
134
|
-
|
|
135
|
-
res[word] += list(range(start, end))
|
|
136
|
-
|
|
137
|
-
# here we arrive at the end of the sentence
|
|
138
|
-
assert tokenized_sentence_tokens[pointer_token] == "</s>"
|
|
139
|
-
end_of_string_position = pointer_token
|
|
140
|
-
|
|
141
|
-
# starting word n_gram
|
|
142
|
-
pointer_token += 1
|
|
143
|
-
while pointer_token < len(tokenized_sentence_tokens):
|
|
144
|
-
token = tokenized_sentence_tokens[pointer_token]
|
|
145
|
-
for index_word, word in enumerate(processed_sentence.split()):
|
|
146
|
-
# now, the condition of matching changes: we need to find the word in the token
|
|
147
|
-
if word in token:
|
|
148
|
-
res[word].append(pointer_token)
|
|
149
|
-
pointer_token += 1
|
|
150
|
-
|
|
151
|
-
assert pointer_token == len(tokenized_sentence_tokens)
|
|
152
|
-
assert set(sum([v for v in res.values()], [end_of_string_position])) == set(
|
|
153
|
-
range(len(tokenized_sentence_tokens))
|
|
154
|
-
), print(
|
|
155
|
-
set(range(len(tokenized_sentence_tokens)))
|
|
156
|
-
- set(sum([v for v in res.values()], [end_of_string_position]))
|
|
157
|
-
) # verify if all tokens are used
|
|
158
|
-
|
|
159
|
-
return res
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
# at text level
|
|
163
|
-
def compute_preprocessed_word_score(
|
|
164
|
-
preprocessed_text,
|
|
165
|
-
tokenized_text_tokens,
|
|
166
|
-
scores,
|
|
167
|
-
id_to_token_dicts,
|
|
168
|
-
token_to_id_dicts,
|
|
169
|
-
min_n,
|
|
170
|
-
padding_index=2009603,
|
|
171
|
-
end_of_string_index=0,
|
|
172
|
-
):
|
|
173
|
-
"""
|
|
174
|
-
Compute preprocessed word scores based on token scores.
|
|
175
|
-
|
|
176
|
-
Args:
|
|
177
|
-
preprocessed_text (List[str]): List of preprocessed sentences.
|
|
178
|
-
tokenized_text (List[List[int]]): For each sentence, list of token IDs.
|
|
179
|
-
scores (List[torch.Tensor]): For each sentence, list of token scores.
|
|
180
|
-
id_to_token_dicts (List[Dict[int, str]]): For each sentence, mapping from token ID to token in string form.
|
|
181
|
-
token_to_id_dicts (List[Dict[str, int]]): For each sentence, mapping from token (string) to token ID.
|
|
182
|
-
padding_index (int): Index of padding token.
|
|
183
|
-
end_of_string_index (int): Index of end of string token.
|
|
184
|
-
aggregate (bool): Whether to aggregate scores at word level (if False, stay at token level).
|
|
185
|
-
|
|
186
|
-
Returns:
|
|
187
|
-
List[Dict[str, float]]: For each sentence, mapping from preprocessed word to score.
|
|
188
|
-
"""
|
|
189
|
-
|
|
190
|
-
word_to_score_dicts = []
|
|
191
|
-
word_to_token_idx_dicts = []
|
|
192
|
-
|
|
193
|
-
for idx, sentence in enumerate(preprocessed_text):
|
|
194
|
-
tokenized_sentence_tokens = tokenized_text_tokens[idx] # sentence level, List[str]
|
|
195
|
-
word_to_token_idx = match_word_to_token_indexes(sentence, tokenized_sentence_tokens, min_n)
|
|
196
|
-
score_sentence_topk = scores[idx] # torch.Tensor, token scores, (top_k, seq_len)
|
|
197
|
-
|
|
198
|
-
# Calculate the score for each token and map to words
|
|
199
|
-
word_to_score_topk = []
|
|
200
|
-
for k in range(len(score_sentence_topk)):
|
|
201
|
-
# Initialize word-to-score dictionary with zero values
|
|
202
|
-
word_to_score = {word: 0 for word in sentence.split()}
|
|
203
|
-
|
|
204
|
-
score_sentence = score_sentence_topk[k]
|
|
205
|
-
for word, associated_token_idx in word_to_token_idx.items():
|
|
206
|
-
associated_token_idx = torch.tensor(associated_token_idx).int()
|
|
207
|
-
word_to_score[word] = torch.sum(score_sentence[associated_token_idx]).item()
|
|
208
|
-
|
|
209
|
-
word_to_score_topk.append(word_to_score.copy())
|
|
210
|
-
|
|
211
|
-
word_to_score_dicts.append(word_to_score_topk)
|
|
212
|
-
word_to_token_idx_dicts.append(word_to_token_idx)
|
|
213
|
-
|
|
214
|
-
return word_to_score_dicts, word_to_token_idx_dicts
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
def compute_word_score(word_to_score_dicts, text, n=5, cutoff=0.75):
|
|
218
|
-
"""
|
|
219
|
-
Compute word scores based on preprocessed word scores.
|
|
220
|
-
|
|
221
|
-
Args:
|
|
222
|
-
word_to_score_dicts (List[List[Dict[str, float]]]): For each sentence, list of top_k mappings from preprocessed word to score.
|
|
223
|
-
text (List[str]): List of sentences.
|
|
224
|
-
n (int): Number of closest preprocessed words to consider for a given original word.
|
|
225
|
-
cutoff (float): Minimum similarity score for a match.
|
|
226
|
-
|
|
227
|
-
Returns:
|
|
228
|
-
List[List[List[float]]]: For each sentence, list of top-k scores for each word.
|
|
229
|
-
"""
|
|
230
|
-
|
|
231
|
-
all_scores_text = []
|
|
232
|
-
mappings = []
|
|
233
|
-
for idx, word_to_score_topk in enumerate(word_to_score_dicts): # iteration over sentences
|
|
234
|
-
all_scores_topk = []
|
|
235
|
-
processed_words = list(word_to_score_topk[0].keys())
|
|
236
|
-
original_words = text[idx].split()
|
|
237
|
-
original_words = list(filter(lambda x: x != ",", original_words))
|
|
238
|
-
mapping = map_processed_to_original(
|
|
239
|
-
processed_words, original_words, n=n, cutoff=cutoff
|
|
240
|
-
) # Dict[str, Tuple[List[str], List[float]]]
|
|
241
|
-
mappings.append(mapping)
|
|
242
|
-
for word_to_score in word_to_score_topk: # iteration over top_k (the preds)
|
|
243
|
-
scores = []
|
|
244
|
-
stopwords_idx = []
|
|
245
|
-
for pos_word, word in enumerate(original_words):
|
|
246
|
-
if word not in mapping:
|
|
247
|
-
scores.append(0)
|
|
248
|
-
stopwords_idx.append(pos_word)
|
|
249
|
-
continue
|
|
250
|
-
matching_processed_word = mapping[word]
|
|
251
|
-
word_score = word_to_score[matching_processed_word]
|
|
252
|
-
scores.append(word_score)
|
|
253
|
-
|
|
254
|
-
scores = torch.tensor(scores)
|
|
255
|
-
scores = F.softmax(
|
|
256
|
-
scores, dim=-1
|
|
257
|
-
) # softmax normalization. Length = len(original_words)
|
|
258
|
-
scores[stopwords_idx] = 0
|
|
259
|
-
|
|
260
|
-
all_scores_topk.append(scores) # length top_k
|
|
261
|
-
|
|
262
|
-
all_scores_text.append(all_scores_topk) # length = len(text)
|
|
263
|
-
|
|
264
|
-
return all_scores_text, mappings
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
def explain_continuous(
|
|
268
|
-
text, processed_text, tokenized_text_tokens, mappings, word_to_token_idx_dicts, all_attr, top_k
|
|
269
|
-
):
|
|
270
|
-
"""
|
|
271
|
-
Score explanation at letter level.
|
|
272
|
-
|
|
273
|
-
Args:
|
|
274
|
-
text (List[str]): List of original sentences.
|
|
275
|
-
processed_text (List[str]): List of preprocessed sentences.
|
|
276
|
-
tokenized_text_tokens (List[List[str]]): List of tokenized sentences.
|
|
277
|
-
mappings (List[Dict[str, str]]): List of mappings from original word to preprocessed word.
|
|
278
|
-
word_to_token_idx_dicts (List[Dict[str, List[int]]]): List of mappings from preprocessed word to token indexes.
|
|
279
|
-
all_attr (torch.Tensor): Tensor of token scores.
|
|
280
|
-
top_k (int): Number of top tokens to consider.
|
|
281
|
-
|
|
282
|
-
Returns:
|
|
283
|
-
List[torch.Tensor]: List of letter scores for each sentence.
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
"""
|
|
287
|
-
all_scores_text = []
|
|
288
|
-
for idx, processed_sentence in enumerate(processed_text):
|
|
289
|
-
tokenized_sentence_tokens = tokenized_text_tokens[idx]
|
|
290
|
-
mapping = mappings[idx]
|
|
291
|
-
word_to_token_idx = word_to_token_idx_dicts[idx]
|
|
292
|
-
original_words = text[idx].split()
|
|
293
|
-
original_words = list(filter(lambda x: x != ",", original_words))
|
|
294
|
-
|
|
295
|
-
original_to_token = {}
|
|
296
|
-
original_to_token_idxs = {}
|
|
297
|
-
|
|
298
|
-
for original in original_words:
|
|
299
|
-
# original = original.replace(',', '')
|
|
300
|
-
if original not in mapping:
|
|
301
|
-
continue
|
|
302
|
-
|
|
303
|
-
matching_processed_word = mapping[original]
|
|
304
|
-
associated_token_idx = word_to_token_idx[matching_processed_word]
|
|
305
|
-
original_to_token[original] = [
|
|
306
|
-
tokenized_sentence_tokens[token_idx] for token_idx in associated_token_idx
|
|
307
|
-
]
|
|
308
|
-
original_to_token_idxs[original] = associated_token_idx
|
|
309
|
-
|
|
310
|
-
scores_for_k = []
|
|
311
|
-
for k in range(top_k):
|
|
312
|
-
scores_for_words = []
|
|
313
|
-
for xxx, original_word in enumerate(original_words):
|
|
314
|
-
original_word_prepro = clean_text_feature([original_word], remove_stop_words=False)[
|
|
315
|
-
0
|
|
316
|
-
]
|
|
317
|
-
|
|
318
|
-
letters = list(original_word)
|
|
319
|
-
scores_letter = torch.zeros(len(letters), dtype=torch.float32)
|
|
320
|
-
|
|
321
|
-
if original_word not in original_to_token: # if stopword, 0
|
|
322
|
-
scores_for_words.append(scores_letter)
|
|
323
|
-
continue
|
|
324
|
-
|
|
325
|
-
for pos, token in enumerate(original_to_token[original_word]):
|
|
326
|
-
pos_token = original_to_token_idxs[original_word][pos]
|
|
327
|
-
# tok = preprocess_token(token)[0]
|
|
328
|
-
tok = preprocess_token(token)
|
|
329
|
-
score_token = all_attr[idx, k, pos_token].item()
|
|
330
|
-
|
|
331
|
-
# Embed the token at the right indexes of the word
|
|
332
|
-
sm = SequenceMatcher(None, original_word_prepro, tok)
|
|
333
|
-
a, _, size = sm.find_longest_match()
|
|
334
|
-
scores_letter[a : a + size] += score_token
|
|
335
|
-
|
|
336
|
-
scores_for_words.append(scores_letter)
|
|
337
|
-
|
|
338
|
-
all_scores_letter = torch.cat(scores_for_words)
|
|
339
|
-
scores = F.softmax(all_scores_letter, dim=-1)
|
|
340
|
-
scores[all_scores_letter == 0] = 0
|
|
341
|
-
scores_for_k.append(scores)
|
|
342
|
-
|
|
343
|
-
scores_for_sentence = torch.stack(scores_for_k)
|
|
344
|
-
all_scores_text.append(scores_for_sentence)
|
|
345
|
-
|
|
346
|
-
return torch.stack(all_scores_text)
|
|
@@ -1,187 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.3
|
|
2
|
-
Name: torchtextclassifiers
|
|
3
|
-
Version: 0.0.1
|
|
4
|
-
Summary: An implementation of the https://github.com/facebookresearch/fastText supervised learning algorithm for text classification using Pytorch.
|
|
5
|
-
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
|
-
Author: Tom Seimandi, Julien Pramil, Meilame Tayebjee, Cédric Couralet
|
|
7
|
-
Author-email: Tom Seimandi <tom.seimandi@gmail.com>, Julien Pramil <julien.pramil@insee.fr>, Meilame Tayebjee <meilame.tayebjee@insee.fr>, Cédric Couralet <cedric.couralet@insee.fr>
|
|
8
|
-
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
-
Classifier: Operating System :: OS Independent
|
|
11
|
-
Requires-Dist: numpy>=1.26.4
|
|
12
|
-
Requires-Dist: pytorch-lightning>=2.4.0
|
|
13
|
-
Requires-Dist: unidecode ; extra == 'explainability'
|
|
14
|
-
Requires-Dist: nltk ; extra == 'explainability'
|
|
15
|
-
Requires-Dist: captum ; extra == 'explainability'
|
|
16
|
-
Requires-Dist: unidecode ; extra == 'preprocess'
|
|
17
|
-
Requires-Dist: nltk ; extra == 'preprocess'
|
|
18
|
-
Requires-Python: >=3.11
|
|
19
|
-
Provides-Extra: explainability
|
|
20
|
-
Provides-Extra: preprocess
|
|
21
|
-
Description-Content-Type: text/markdown
|
|
22
|
-
|
|
23
|
-
# torchTextClassifiers
|
|
24
|
-
|
|
25
|
-
A unified, extensible framework for text classification built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
26
|
-
|
|
27
|
-
## 🚀 Features
|
|
28
|
-
|
|
29
|
-
- **Unified API**: Consistent interface for different classifier wrappers
|
|
30
|
-
- **Extensible**: Easy to add new classifier implementations through wrapper pattern
|
|
31
|
-
- **FastText Support**: Built-in FastText classifier with n-gram tokenization
|
|
32
|
-
- **Flexible Preprocessing**: Each classifier can implement its own text preprocessing approach
|
|
33
|
-
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
## 📦 Installation
|
|
37
|
-
|
|
38
|
-
```bash
|
|
39
|
-
# Clone the repository
|
|
40
|
-
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
41
|
-
cd torchtextClassifiers
|
|
42
|
-
|
|
43
|
-
# Install with uv (recommended)
|
|
44
|
-
uv sync
|
|
45
|
-
|
|
46
|
-
# Or install with pip
|
|
47
|
-
pip install -e .
|
|
48
|
-
```
|
|
49
|
-
|
|
50
|
-
## 🎯 Quick Start
|
|
51
|
-
|
|
52
|
-
### Basic FastText Classification
|
|
53
|
-
|
|
54
|
-
```python
|
|
55
|
-
import numpy as np
|
|
56
|
-
from torchTextClassifiers import create_fasttext
|
|
57
|
-
|
|
58
|
-
# Create a FastText classifier
|
|
59
|
-
classifier = create_fasttext(
|
|
60
|
-
embedding_dim=100,
|
|
61
|
-
sparse=False,
|
|
62
|
-
num_tokens=10000,
|
|
63
|
-
min_count=2,
|
|
64
|
-
min_n=3,
|
|
65
|
-
max_n=6,
|
|
66
|
-
len_word_ngrams=2,
|
|
67
|
-
num_classes=2
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
# Prepare your data
|
|
71
|
-
X_train = np.array([
|
|
72
|
-
"This is a positive example",
|
|
73
|
-
"This is a negative example",
|
|
74
|
-
"Another positive case",
|
|
75
|
-
"Another negative case"
|
|
76
|
-
])
|
|
77
|
-
y_train = np.array([1, 0, 1, 0])
|
|
78
|
-
|
|
79
|
-
X_val = np.array([
|
|
80
|
-
"Validation positive",
|
|
81
|
-
"Validation negative"
|
|
82
|
-
])
|
|
83
|
-
y_val = np.array([1, 0])
|
|
84
|
-
|
|
85
|
-
# Build the model
|
|
86
|
-
classifier.build(X_train, y_train)
|
|
87
|
-
|
|
88
|
-
# Train the model
|
|
89
|
-
classifier.train(
|
|
90
|
-
X_train, y_train, X_val, y_val,
|
|
91
|
-
num_epochs=50,
|
|
92
|
-
batch_size=32,
|
|
93
|
-
patience_train=5,
|
|
94
|
-
verbose=True
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
# Make predictions
|
|
98
|
-
X_test = np.array(["This is a test sentence"])
|
|
99
|
-
predictions = classifier.predict(X_test)
|
|
100
|
-
print(f"Predictions: {predictions}")
|
|
101
|
-
|
|
102
|
-
# Validate on test set
|
|
103
|
-
accuracy = classifier.validate(X_test, np.array([1]))
|
|
104
|
-
print(f"Accuracy: {accuracy:.3f}")
|
|
105
|
-
```
|
|
106
|
-
|
|
107
|
-
### Custom Classifier Implementation
|
|
108
|
-
|
|
109
|
-
```python
|
|
110
|
-
import numpy as np
|
|
111
|
-
from torchTextClassifiers import torchTextClassifiers
|
|
112
|
-
from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextWrapper, SimpleTextConfig
|
|
113
|
-
|
|
114
|
-
# Example: TF-IDF based classifier (alternative to tokenization)
|
|
115
|
-
config = SimpleTextConfig(
|
|
116
|
-
hidden_dim=128,
|
|
117
|
-
num_classes=2,
|
|
118
|
-
max_features=5000,
|
|
119
|
-
learning_rate=1e-3,
|
|
120
|
-
dropout_rate=0.2
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
# Create classifier with TF-IDF preprocessing
|
|
124
|
-
wrapper = SimpleTextWrapper(config)
|
|
125
|
-
classifier = torchTextClassifiers(wrapper)
|
|
126
|
-
|
|
127
|
-
# Text data
|
|
128
|
-
X_train = np.array(["Great product!", "Terrible service", "Love it!"])
|
|
129
|
-
y_train = np.array([1, 0, 1])
|
|
130
|
-
|
|
131
|
-
# Build and train
|
|
132
|
-
classifier.build(X_train, y_train)
|
|
133
|
-
# ... continue with training
|
|
134
|
-
```
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
### Training Customization
|
|
138
|
-
|
|
139
|
-
```python
|
|
140
|
-
# Custom PyTorch Lightning trainer parameters
|
|
141
|
-
trainer_params = {
|
|
142
|
-
'accelerator': 'gpu',
|
|
143
|
-
'devices': 1,
|
|
144
|
-
'precision': 16, # Mixed precision training
|
|
145
|
-
'gradient_clip_val': 1.0,
|
|
146
|
-
}
|
|
147
|
-
|
|
148
|
-
classifier.train(
|
|
149
|
-
X_train, y_train, X_val, y_val,
|
|
150
|
-
num_epochs=100,
|
|
151
|
-
batch_size=64,
|
|
152
|
-
patience_train=10,
|
|
153
|
-
trainer_params=trainer_params,
|
|
154
|
-
verbose=True
|
|
155
|
-
)
|
|
156
|
-
```
|
|
157
|
-
|
|
158
|
-
## 🔬 Testing
|
|
159
|
-
|
|
160
|
-
Run the test suite:
|
|
161
|
-
|
|
162
|
-
```bash
|
|
163
|
-
# Run all tests
|
|
164
|
-
uv run pytest
|
|
165
|
-
|
|
166
|
-
# Run with coverage
|
|
167
|
-
uv run pytest --cov=torchTextClassifiers
|
|
168
|
-
|
|
169
|
-
# Run specific test file
|
|
170
|
-
uv run pytest tests/test_torchTextClassifiers.py -v
|
|
171
|
-
```
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
## 📚 Examples
|
|
175
|
-
|
|
176
|
-
See the [examples/](examples/) directory for:
|
|
177
|
-
- Basic text classification
|
|
178
|
-
- Multi-class classification
|
|
179
|
-
- Mixed features (text + categorical)
|
|
180
|
-
- Custom classifier implementation
|
|
181
|
-
- Advanced training configurations
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
## 📄 License
|
|
186
|
-
|
|
187
|
-
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
torchTextClassifiers/__init__.py,sha256=dc77f92c57d9a0782777f83e955157be26ab0bce60434877a7361d1492978279,2228
|
|
2
|
-
torchTextClassifiers/classifiers/base.py,sha256=549669aca59fcdbca53d6c240e40e1f282d71dd99d9eb18010d37ae2a5843ce6,2796
|
|
3
|
-
torchTextClassifiers/classifiers/fasttext/__init__.py,sha256=e326a8f1f6018ea57715f94b5d14c1b18254115088911bb4e7c4f472d2ec6044,778
|
|
4
|
-
torchTextClassifiers/classifiers/fasttext/core.py,sha256=0b9d27c67f8eedbf6e9425943b10404bb6763709190351df01667ce3fc32f7f6,9943
|
|
5
|
-
torchTextClassifiers/classifiers/fasttext/model.py,sha256=4a3cd5b5403c5437e5c7d953dbc0a44b8e57ce5918b32b8b50227a8449c441b2,29858
|
|
6
|
-
torchTextClassifiers/classifiers/fasttext/tokenizer.py,sha256=d58c1ac0cbf7e62d21f3277a5fcb77fe9c7e74551df600843ce82fab5ad5664b,11422
|
|
7
|
-
torchTextClassifiers/classifiers/fasttext/wrapper.py,sha256=372903cb9313f8f79791ea4664226c10cffc4d2ec41f657153645de6339cbbfb,8816
|
|
8
|
-
torchTextClassifiers/classifiers/simple_text_classifier.py,sha256=d81afd256d451de212646bc99f8d8f790fb9e144c8fd93f44085acaed8c68be3,6725
|
|
9
|
-
torchTextClassifiers/factories.py,sha256=608d545d55be38ecbd89e80ff655140e4d7b3ae1696d6c1d3812fea2dddde88d,1296
|
|
10
|
-
torchTextClassifiers/torchTextClassifiers.py,sha256=fca4f7ca881d9d76711892c38ac6548f38d8376ad05878fabfbe9b08ca49090d,20496
|
|
11
|
-
torchTextClassifiers/utilities/__init__.py,sha256=17df83700c131f2f4b5acc619ccafa0dcb55139f2a27cf00f6c682880a2b3746,21
|
|
12
|
-
torchTextClassifiers/utilities/checkers.py,sha256=53494be4b95691090f70fda5498cc11f05adac042617d5da114ea60ea3e35444,3733
|
|
13
|
-
torchTextClassifiers/utilities/preprocess.py,sha256=bba939a19a82e5ebc49509f2c8c5716b71975d502babbe89b236470655295390,2230
|
|
14
|
-
torchTextClassifiers/utilities/utils.py,sha256=81ff0aeee829c0729d9eb1b37d7bc6e37d4bec0e65dbd199482e8da9584663ac,13567
|
|
15
|
-
torchtextclassifiers-0.0.1.dist-info/WHEEL,sha256=b70116f4076fa664af162441d2ba3754dbb4ec63e09d563bdc1e9ab023cce400,78
|
|
16
|
-
torchtextclassifiers-0.0.1.dist-info/METADATA,sha256=48862621e58dace60467867aef55399bda366a33ccd0c1bc7080a7ac60d05a39,4990
|
|
17
|
-
torchtextclassifiers-0.0.1.dist-info/RECORD,,
|