torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +114 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +43 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +166 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +463 -405
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
- torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from typing import List, Optional
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from matplotlib import pyplot as plt
|
|
8
|
+
|
|
9
|
+
HAS_PYPLOT = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
HAS_PYPLOT = False
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def map_attributions_to_char(attributions, offsets, text):
|
|
15
|
+
"""
|
|
16
|
+
Maps token-level attributions to character-level attributions based on token offsets.
|
|
17
|
+
Args:
|
|
18
|
+
attributions (np.ndarray): Array of shape (top_k, seq_len) or (seq_len,) containing token-level attributions.
|
|
19
|
+
Output from:
|
|
20
|
+
>>> ttc.predict(X, top_k=top_k, explain=True)["attributions"]
|
|
21
|
+
offsets (list of tuples): List of (start, end) offsets for each token in the original text.
|
|
22
|
+
Output from:
|
|
23
|
+
>>> ttc.predict(X, top_k=top_k, explain=True)["offset_mapping"]
|
|
24
|
+
Also from:
|
|
25
|
+
>>> ttc.tokenizer.tokenize(text, return_offsets_mapping=True)["offset_mapping"]
|
|
26
|
+
text (str): The original input text.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
np.ndarray: Array of shape (top_k, text_len) containing character-level attributions.
|
|
30
|
+
text_len is the number of characters in the original text.
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
if isinstance(text, list):
|
|
35
|
+
raise ValueError("text must be a single string, not a list of strings.")
|
|
36
|
+
|
|
37
|
+
assert isinstance(text, str), "text must be a string."
|
|
38
|
+
|
|
39
|
+
if isinstance(attributions, torch.Tensor):
|
|
40
|
+
attributions = attributions.cpu().numpy()
|
|
41
|
+
|
|
42
|
+
if attributions.ndim == 1:
|
|
43
|
+
attributions = attributions[None, :]
|
|
44
|
+
|
|
45
|
+
attributions_per_char = np.zeros((attributions.shape[0], len(text))) # top_k, text_len
|
|
46
|
+
|
|
47
|
+
for token_idx, (start, end) in enumerate(offsets):
|
|
48
|
+
if start == end: # skip special tokens
|
|
49
|
+
continue
|
|
50
|
+
attributions_per_char[:, start:end] = attributions[:, token_idx][:, None]
|
|
51
|
+
|
|
52
|
+
return np.exp(attributions_per_char) / np.sum(
|
|
53
|
+
np.exp(attributions_per_char), axis=1, keepdims=True
|
|
54
|
+
) # softmax normalization
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def map_attributions_to_word(attributions, word_ids):
|
|
58
|
+
"""
|
|
59
|
+
Maps token-level attributions to word-level attributions based on word IDs.
|
|
60
|
+
Args:
|
|
61
|
+
attributions (np.ndarray): Array of shape (top_k, seq_len) or (seq_len,) containing token-level attributions.
|
|
62
|
+
Output from:
|
|
63
|
+
>>> ttc.predict(X, top_k=top_k, explain=True)["attributions"]
|
|
64
|
+
word_ids (list of int or None): List of word IDs for each token in the original text.
|
|
65
|
+
Output from:
|
|
66
|
+
>>> ttc.predict(X, top_k=top_k, explain=True)["word_ids"]
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
np.ndarray: Array of shape (top_k, num_words) containing word-level attributions.
|
|
70
|
+
num_words is the number of unique words in the original text.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
word_ids = np.array(word_ids)
|
|
74
|
+
|
|
75
|
+
# Convert None to -1 for easier processing (PAD tokens)
|
|
76
|
+
word_ids_int = np.array([x if x is not None else -1 for x in word_ids], dtype=int)
|
|
77
|
+
|
|
78
|
+
# Filter out PAD tokens from attributions and word_ids
|
|
79
|
+
attributions = attributions[
|
|
80
|
+
torch.arange(attributions.shape[0])[:, None],
|
|
81
|
+
torch.tensor(np.where(word_ids_int != -1)[0])[None, :],
|
|
82
|
+
]
|
|
83
|
+
word_ids_int = word_ids_int[word_ids_int != -1]
|
|
84
|
+
unique_word_ids = np.unique(word_ids_int)
|
|
85
|
+
num_unique_words = len(unique_word_ids)
|
|
86
|
+
|
|
87
|
+
top_k = attributions.shape[0]
|
|
88
|
+
attr_with_word_id = np.concat(
|
|
89
|
+
(attributions[:, :, None], np.tile(word_ids_int[None, :], reps=(top_k, 1))[:, :, None]),
|
|
90
|
+
axis=-1,
|
|
91
|
+
) # top_k, seq_len, 2
|
|
92
|
+
# last dim is 2: 0 is the attribution of the token, 1 is the word_id the token is associated to
|
|
93
|
+
|
|
94
|
+
word_attributions = np.zeros((top_k, num_unique_words))
|
|
95
|
+
for word_id in unique_word_ids:
|
|
96
|
+
mask = attr_with_word_id[:, :, 1] == word_id # top_k, seq_len
|
|
97
|
+
word_attributions[:, word_id] = (attr_with_word_id[:, :, 0] * mask).sum(
|
|
98
|
+
axis=1
|
|
99
|
+
) # zero-out non-matching tokens and sum attributions for all tokens belonging to the same word
|
|
100
|
+
|
|
101
|
+
# assert word_attributions.sum(axis=1) == attributions.sum(axis=1), "Sum of word attributions per top_k must equal sum of token attributions per top_k."
|
|
102
|
+
return np.exp(word_attributions) / np.sum(
|
|
103
|
+
np.exp(word_attributions), axis=1, keepdims=True
|
|
104
|
+
) # softmax normalization
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def plot_attributions_at_char(
|
|
108
|
+
text: str,
|
|
109
|
+
attributions_per_char: np.ndarray,
|
|
110
|
+
figsize=(10, 2),
|
|
111
|
+
titles: Optional[List[str]] = None,
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Plots character-level attributions as a heatmap.
|
|
115
|
+
Args:
|
|
116
|
+
text (str): The original input text.
|
|
117
|
+
attributions_per_char (np.ndarray): Array of shape (top_k, text_len) containing character-level attributions.
|
|
118
|
+
Output from map_attributions_to_char function.
|
|
119
|
+
title (str): Title of the plot.
|
|
120
|
+
figsize (tuple): Figure size for the plot.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
if not HAS_PYPLOT:
|
|
124
|
+
raise ImportError(
|
|
125
|
+
"matplotlib is required for plotting. Please install it to use this function."
|
|
126
|
+
)
|
|
127
|
+
top_k = attributions_per_char.shape[0]
|
|
128
|
+
|
|
129
|
+
all_plots = []
|
|
130
|
+
for i in range(top_k):
|
|
131
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
132
|
+
ax.bar(range(len(text)), attributions_per_char[i])
|
|
133
|
+
ax.set_xticks(np.arange(len(text)))
|
|
134
|
+
ax.set_xticklabels(list(text), rotation=90)
|
|
135
|
+
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
|
|
136
|
+
ax.set_title(title)
|
|
137
|
+
ax.set_xlabel("Characters in Text")
|
|
138
|
+
ax.set_ylabel("Top Predictions")
|
|
139
|
+
all_plots.append(fig)
|
|
140
|
+
|
|
141
|
+
return all_plots
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def plot_attributions_at_word(
|
|
145
|
+
text, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
|
|
146
|
+
):
|
|
147
|
+
"""
|
|
148
|
+
Plots word-level attributions as a heatmap.
|
|
149
|
+
Args:
|
|
150
|
+
text (str): The original input text.
|
|
151
|
+
attributions_per_word (np.ndarray): Array of shape (top_k, num_words) containing word-level attributions.
|
|
152
|
+
Output from map_attributions_to_word function.
|
|
153
|
+
title (str): Title of the plot.
|
|
154
|
+
figsize (tuple): Figure size for the plot.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
if not HAS_PYPLOT:
|
|
158
|
+
raise ImportError(
|
|
159
|
+
"matplotlib is required for plotting. Please install it to use this function."
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
words = text.split()
|
|
163
|
+
top_k = attributions_per_word.shape[0]
|
|
164
|
+
all_plots = []
|
|
165
|
+
for i in range(top_k):
|
|
166
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
167
|
+
ax.bar(range(len(words)), attributions_per_word[i])
|
|
168
|
+
ax.set_xticks(np.arange(len(words)))
|
|
169
|
+
ax.set_xticklabels(words, rotation=90)
|
|
170
|
+
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
|
|
171
|
+
ax.set_title(title)
|
|
172
|
+
ax.set_xlabel("Words in Text")
|
|
173
|
+
ax.set_ylabel("Attributions")
|
|
174
|
+
all_plots.append(fig)
|
|
175
|
+
|
|
176
|
+
return all_plots
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def figshow(figure):
|
|
180
|
+
# https://stackoverflow.com/questions/53088212/create-multiple-figures-in-pyplot-but-only-show-one
|
|
181
|
+
for i in plt.get_fignums():
|
|
182
|
+
if figure != plt.figure(i):
|
|
183
|
+
plt.close(plt.figure(i))
|
|
184
|
+
plt.show()
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: torchtextclassifiers
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
|
|
5
|
+
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
|
+
Author: Cédric Couralet, Meilame Tayebjee
|
|
7
|
+
Author-email: Cédric Couralet <cedric.couralet@insee.fr>, Meilame Tayebjee <meilame.tayebjee@insee.fr>
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Dist: numpy>=1.26.4
|
|
12
|
+
Requires-Dist: pytorch-lightning>=2.4.0
|
|
13
|
+
Requires-Dist: unidecode ; extra == 'explainability'
|
|
14
|
+
Requires-Dist: nltk ; extra == 'explainability'
|
|
15
|
+
Requires-Dist: captum ; extra == 'explainability'
|
|
16
|
+
Requires-Dist: tokenizers>=0.22.1 ; extra == 'huggingface'
|
|
17
|
+
Requires-Dist: transformers>=4.57.1 ; extra == 'huggingface'
|
|
18
|
+
Requires-Dist: datasets>=4.3.0 ; extra == 'huggingface'
|
|
19
|
+
Requires-Dist: unidecode ; extra == 'preprocess'
|
|
20
|
+
Requires-Dist: nltk ; extra == 'preprocess'
|
|
21
|
+
Requires-Python: >=3.11
|
|
22
|
+
Provides-Extra: explainability
|
|
23
|
+
Provides-Extra: huggingface
|
|
24
|
+
Provides-Extra: preprocess
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# torchTextClassifiers
|
|
28
|
+
|
|
29
|
+
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
30
|
+
|
|
31
|
+
## 🚀 Features
|
|
32
|
+
|
|
33
|
+
- **Mixed input support**: Handle text data alongside categorical variables seamlessly.
|
|
34
|
+
- **Unified yet highly customizable**:
|
|
35
|
+
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
|
|
36
|
+
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
|
|
37
|
+
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
|
|
38
|
+
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
39
|
+
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
|
|
40
|
+
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
|
|
41
|
+
- **Additional features**: explainability using Captum
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
## 📦 Installation
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
# Clone the repository
|
|
48
|
+
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
49
|
+
cd torchtextClassifiers
|
|
50
|
+
|
|
51
|
+
# Install with uv (recommended)
|
|
52
|
+
uv sync
|
|
53
|
+
|
|
54
|
+
# Or install with pip
|
|
55
|
+
pip install -e .
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## 📝 Usage
|
|
59
|
+
|
|
60
|
+
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
|
|
61
|
+
|
|
62
|
+
## 📚 Examples
|
|
63
|
+
|
|
64
|
+
See the [examples/](examples/) directory for:
|
|
65
|
+
- Basic text classification
|
|
66
|
+
- Multi-class classification
|
|
67
|
+
- Mixed features (text + categorical)
|
|
68
|
+
- Advanced training configurations
|
|
69
|
+
- Prediction and explainability
|
|
70
|
+
|
|
71
|
+
## 📄 License
|
|
72
|
+
|
|
73
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
torchTextClassifiers/__init__.py,sha256=TM2AjZ4KDqpgwMKiT0X5daNZvLDj6WECz_OFf8M4lgA,906
|
|
2
|
+
torchTextClassifiers/dataset/__init__.py,sha256=dyCz48pO6zRC-2qh4753Hj70W2MZGXdX3RbgutvyOng,76
|
|
3
|
+
torchTextClassifiers/dataset/dataset.py,sha256=n7V4JNtcuqb2ugx7hxkAohEPHqEuxv46jYU47KiUbno,3295
|
|
4
|
+
torchTextClassifiers/model/__init__.py,sha256=lFY1Mb1J0tFhe4_PsDOEHhnVl3dXj59K4Zxnwy2KkS4,146
|
|
5
|
+
torchTextClassifiers/model/components/__init__.py,sha256=-IT_6fCHZkRw6Hu7GdVeCt685P4PuGaY6VdYQV5M8mE,447
|
|
6
|
+
torchTextClassifiers/model/components/attention.py,sha256=hhSMh_CvpR-hiP8hoCg4Fr_TovGlJpC_RHs3iW-Pnpc,4199
|
|
7
|
+
torchTextClassifiers/model/components/categorical_var_net.py,sha256=no0QDidKCw1rlbJzD7S-Srhzn5P6vETGRT5Er-gzMnM,5699
|
|
8
|
+
torchTextClassifiers/model/components/classification_head.py,sha256=lPndu5FPC-bOZ2H4Yq0EnzWrOzPFJdBb_KUx5wyZBb4,1445
|
|
9
|
+
torchTextClassifiers/model/components/text_embedder.py,sha256=tY2pXAt4IvayyvRpjiKGg5vGz_Q2-p_TOL6Jg2p8hYE,9058
|
|
10
|
+
torchTextClassifiers/model/lightning.py,sha256=z5mq10_hNp-UK66Aqpcablg3BDYnjF9Gch0HaGoJ6cM,5265
|
|
11
|
+
torchTextClassifiers/model/model.py,sha256=jjGjvK7C2Wly0e4S6gTC8Ty8y-o8reU-aniBqYS73Cc,6100
|
|
12
|
+
torchTextClassifiers/tokenizers/WordPiece.py,sha256=HMHYV2SiwShlhWMQ6LXH4MtZE5GSsaNA2DlD340ABGE,3289
|
|
13
|
+
torchTextClassifiers/tokenizers/__init__.py,sha256=I8IQ2-t85RVlZFwLjDFF_Te2S9uiwlymQDWx-3GeF-Y,334
|
|
14
|
+
torchTextClassifiers/tokenizers/base.py,sha256=OY6GIhI4KTdvvKq3VZowf64H7lAmdQyq4scZ10HxP3A,7570
|
|
15
|
+
torchTextClassifiers/tokenizers/ngram.py,sha256=lHI8dtuCGWh0o7V58TJx_mTVIHm8udl6XuWccxgJPew,16375
|
|
16
|
+
torchTextClassifiers/torchTextClassifiers.py,sha256=E2XVGAky_SMAw6BAMswA3c08rKyOpGEW_dv1BqQlJrU,21141
|
|
17
|
+
torchTextClassifiers/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
+
torchTextClassifiers/utilities/plot_explainability.py,sha256=8YhyiMupdiIZp4jT7uvlcJNf69Fyr9HXfjUiNyMSYxE,6931
|
|
19
|
+
torchtextclassifiers-0.1.0.dist-info/WHEEL,sha256=ELhySV62sOro8I5wRaLaF3TWxhBpkcDkdZUdAYLy_Hk,78
|
|
20
|
+
torchtextclassifiers-0.1.0.dist-info/METADATA,sha256=fvPTUIS-M4LgURVzC1CUTb8IrKyZiBzWRAE1heTafEE,2988
|
|
21
|
+
torchtextclassifiers-0.1.0.dist-info/RECORD,,
|
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
from typing import Optional, Union, Type, List, Dict, Any
|
|
2
|
-
from dataclasses import dataclass, field, asdict
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
class BaseClassifierConfig(ABC):
|
|
7
|
-
"""Abstract base class for classifier configurations."""
|
|
8
|
-
|
|
9
|
-
@abstractmethod
|
|
10
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
11
|
-
"""Convert configuration to dictionary."""
|
|
12
|
-
pass
|
|
13
|
-
|
|
14
|
-
@classmethod
|
|
15
|
-
@abstractmethod
|
|
16
|
-
def from_dict(cls, data: Dict[str, Any]) -> "BaseClassifierConfig":
|
|
17
|
-
"""Create configuration from dictionary."""
|
|
18
|
-
pass
|
|
19
|
-
|
|
20
|
-
class BaseClassifierWrapper(ABC):
|
|
21
|
-
"""Abstract base class for classifier wrappers.
|
|
22
|
-
|
|
23
|
-
Each classifier wrapper is responsible for its own text processing approach.
|
|
24
|
-
Some may use tokenizers, others may use different preprocessing methods.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, config: BaseClassifierConfig):
|
|
28
|
-
self.config = config
|
|
29
|
-
self.pytorch_model = None
|
|
30
|
-
self.lightning_module = None
|
|
31
|
-
self.trained: bool = False
|
|
32
|
-
self.device = None
|
|
33
|
-
# Remove tokenizer from base class - it's now wrapper-specific
|
|
34
|
-
|
|
35
|
-
@abstractmethod
|
|
36
|
-
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
37
|
-
"""Prepare text features for the classifier.
|
|
38
|
-
|
|
39
|
-
This could involve tokenization, vectorization, or other preprocessing.
|
|
40
|
-
Each classifier wrapper implements this according to its needs.
|
|
41
|
-
"""
|
|
42
|
-
pass
|
|
43
|
-
|
|
44
|
-
@abstractmethod
|
|
45
|
-
def _build_pytorch_model(self) -> None:
|
|
46
|
-
"""Build the PyTorch model."""
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
@abstractmethod
|
|
50
|
-
def _check_and_init_lightning(self, **kwargs) -> None:
|
|
51
|
-
"""Initialize Lightning module."""
|
|
52
|
-
pass
|
|
53
|
-
|
|
54
|
-
@abstractmethod
|
|
55
|
-
def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
|
|
56
|
-
"""Make predictions."""
|
|
57
|
-
pass
|
|
58
|
-
|
|
59
|
-
@abstractmethod
|
|
60
|
-
def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
|
61
|
-
"""Validate the model."""
|
|
62
|
-
pass
|
|
63
|
-
|
|
64
|
-
@abstractmethod
|
|
65
|
-
def create_dataset(self, texts: np.ndarray, labels: np.ndarray, categorical_variables: Optional[np.ndarray] = None):
|
|
66
|
-
"""Create dataset for training/validation."""
|
|
67
|
-
pass
|
|
68
|
-
|
|
69
|
-
@abstractmethod
|
|
70
|
-
def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuffle: bool = True):
|
|
71
|
-
"""Create dataloader from dataset."""
|
|
72
|
-
pass
|
|
73
|
-
|
|
74
|
-
@abstractmethod
|
|
75
|
-
def load_best_model(self, checkpoint_path: str) -> None:
|
|
76
|
-
"""Load best model from checkpoint."""
|
|
77
|
-
pass
|
|
78
|
-
|
|
79
|
-
@classmethod
|
|
80
|
-
@abstractmethod
|
|
81
|
-
def get_config_class(cls) -> Type[BaseClassifierConfig]:
|
|
82
|
-
"""Return the configuration class for this wrapper."""
|
|
83
|
-
pass
|
|
@@ -1,25 +0,0 @@
|
|
|
1
|
-
"""FastText classifier package.
|
|
2
|
-
|
|
3
|
-
Provides FastText text classification with PyTorch Lightning integration.
|
|
4
|
-
This folder contains 4 main files:
|
|
5
|
-
- core.py: Configuration, losses, and factory methods
|
|
6
|
-
- tokenizer.py: NGramTokenizer implementation
|
|
7
|
-
- model.py: PyTorch model, Lightning module, and dataset
|
|
8
|
-
- wrapper.py: High-level wrapper interface
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
from .core import FastTextConfig, OneVsAllLoss, FastTextFactory
|
|
12
|
-
from .tokenizer import NGramTokenizer
|
|
13
|
-
from .model import FastTextModel, FastTextModule, FastTextModelDataset
|
|
14
|
-
from .wrapper import FastTextWrapper
|
|
15
|
-
|
|
16
|
-
__all__ = [
|
|
17
|
-
"FastTextConfig",
|
|
18
|
-
"OneVsAllLoss",
|
|
19
|
-
"FastTextFactory",
|
|
20
|
-
"NGramTokenizer",
|
|
21
|
-
"FastTextModel",
|
|
22
|
-
"FastTextModule",
|
|
23
|
-
"FastTextModelDataset",
|
|
24
|
-
"FastTextWrapper",
|
|
25
|
-
]
|
|
@@ -1,269 +0,0 @@
|
|
|
1
|
-
"""FastText classifier core components.
|
|
2
|
-
|
|
3
|
-
This module contains the core components for FastText classification:
|
|
4
|
-
- Configuration dataclass
|
|
5
|
-
- Loss functions
|
|
6
|
-
- Factory methods for creating classifiers
|
|
7
|
-
|
|
8
|
-
Consolidates what was previously in config.py, losses.py, and factory.py.
|
|
9
|
-
"""
|
|
10
|
-
|
|
11
|
-
from dataclasses import dataclass, field, asdict
|
|
12
|
-
from abc import ABC, abstractmethod
|
|
13
|
-
from ..base import BaseClassifierConfig
|
|
14
|
-
from typing import Optional, List, TYPE_CHECKING, Union, Dict, Any
|
|
15
|
-
import numpy as np
|
|
16
|
-
import torch
|
|
17
|
-
import torch.nn.functional as F
|
|
18
|
-
from torch import nn
|
|
19
|
-
|
|
20
|
-
if TYPE_CHECKING:
|
|
21
|
-
from ...torchTextClassifiers import torchTextClassifiers
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
# ============================================================================
|
|
25
|
-
# Configuration
|
|
26
|
-
# ============================================================================
|
|
27
|
-
|
|
28
|
-
@dataclass
|
|
29
|
-
class FastTextConfig(BaseClassifierConfig):
|
|
30
|
-
"""Configuration for FastText classifier."""
|
|
31
|
-
# Embedding matrix
|
|
32
|
-
embedding_dim: int
|
|
33
|
-
sparse: bool
|
|
34
|
-
|
|
35
|
-
# Tokenizer-related
|
|
36
|
-
num_tokens: int
|
|
37
|
-
min_count: int
|
|
38
|
-
min_n: int
|
|
39
|
-
max_n: int
|
|
40
|
-
len_word_ngrams: int
|
|
41
|
-
|
|
42
|
-
# Optional parameters
|
|
43
|
-
num_classes: Optional[int] = None
|
|
44
|
-
num_rows: Optional[int] = None
|
|
45
|
-
|
|
46
|
-
# Categorical variables
|
|
47
|
-
categorical_vocabulary_sizes: Optional[List[int]] = None
|
|
48
|
-
categorical_embedding_dims: Optional[Union[List[int], int]] = None
|
|
49
|
-
num_categorical_features: Optional[int] = None
|
|
50
|
-
|
|
51
|
-
# Model-specific parameters
|
|
52
|
-
direct_bagging: Optional[bool] = True
|
|
53
|
-
|
|
54
|
-
# Training parameters
|
|
55
|
-
learning_rate: float = 4e-3
|
|
56
|
-
|
|
57
|
-
def to_dict(self) -> Dict[str, Any]:
|
|
58
|
-
return asdict(self)
|
|
59
|
-
|
|
60
|
-
@classmethod
|
|
61
|
-
def from_dict(cls, data: Dict[str, Any]) -> "FastTextConfig":
|
|
62
|
-
return cls(**data)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
# ============================================================================
|
|
66
|
-
# Loss Functions
|
|
67
|
-
# ============================================================================
|
|
68
|
-
|
|
69
|
-
class OneVsAllLoss(nn.Module):
|
|
70
|
-
def __init__(self):
|
|
71
|
-
super(OneVsAllLoss, self).__init__()
|
|
72
|
-
|
|
73
|
-
def forward(self, logits, targets):
|
|
74
|
-
"""
|
|
75
|
-
Compute One-vs-All loss
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
logits: Tensor of shape (batch_size, num_classes) containing classification scores
|
|
79
|
-
targets: Tensor of shape (batch_size) containing true class indices
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
loss: Mean loss value across the batch
|
|
83
|
-
"""
|
|
84
|
-
|
|
85
|
-
num_classes = logits.size(1)
|
|
86
|
-
|
|
87
|
-
# Convert targets to one-hot encoding
|
|
88
|
-
targets_one_hot = F.one_hot(targets, num_classes=num_classes).float()
|
|
89
|
-
|
|
90
|
-
# For each sample, treat the true class as positive and all others as negative
|
|
91
|
-
# Using binary cross entropy for each class
|
|
92
|
-
loss = F.binary_cross_entropy_with_logits(
|
|
93
|
-
logits, # Raw logits
|
|
94
|
-
targets_one_hot, # Target probabilities
|
|
95
|
-
reduction="none", # Don't reduce yet to allow for custom weighting if needed
|
|
96
|
-
)
|
|
97
|
-
|
|
98
|
-
# Sum losses across all classes for each sample, then take mean across batch
|
|
99
|
-
return loss.sum(dim=1).mean()
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
# ============================================================================
|
|
103
|
-
# Factory Methods
|
|
104
|
-
# ============================================================================
|
|
105
|
-
|
|
106
|
-
class FastTextFactory:
|
|
107
|
-
"""Factory class for creating FastText classifiers with convenience methods.
|
|
108
|
-
|
|
109
|
-
This factory provides static methods for creating FastText classifiers with
|
|
110
|
-
common configurations. It handles the complexities of configuration creation
|
|
111
|
-
and classifier initialization, offering a simplified API for users.
|
|
112
|
-
|
|
113
|
-
All methods return fully initialized torchTextClassifiers instances that are
|
|
114
|
-
ready for building and training.
|
|
115
|
-
"""
|
|
116
|
-
|
|
117
|
-
@staticmethod
|
|
118
|
-
def create_fasttext(
|
|
119
|
-
embedding_dim: int,
|
|
120
|
-
sparse: bool,
|
|
121
|
-
num_tokens: int,
|
|
122
|
-
min_count: int,
|
|
123
|
-
min_n: int,
|
|
124
|
-
max_n: int,
|
|
125
|
-
len_word_ngrams: int,
|
|
126
|
-
**kwargs
|
|
127
|
-
) -> "torchTextClassifiers":
|
|
128
|
-
"""Create a FastText classifier with the specified configuration.
|
|
129
|
-
|
|
130
|
-
This is the primary method for creating FastText classifiers. It creates
|
|
131
|
-
a configuration object with the provided parameters and initializes a
|
|
132
|
-
complete classifier instance.
|
|
133
|
-
|
|
134
|
-
Args:
|
|
135
|
-
embedding_dim: Dimension of word embeddings
|
|
136
|
-
sparse: Whether to use sparse embeddings
|
|
137
|
-
num_tokens: Maximum number of tokens in vocabulary
|
|
138
|
-
min_count: Minimum count for tokens to be included in vocabulary
|
|
139
|
-
min_n: Minimum length of character n-grams
|
|
140
|
-
max_n: Maximum length of character n-grams
|
|
141
|
-
len_word_ngrams: Length of word n-grams to use
|
|
142
|
-
**kwargs: Additional configuration parameters (e.g., num_classes,
|
|
143
|
-
categorical_vocabulary_sizes, etc.)
|
|
144
|
-
|
|
145
|
-
Returns:
|
|
146
|
-
torchTextClassifiers: Initialized FastText classifier instance
|
|
147
|
-
|
|
148
|
-
Example:
|
|
149
|
-
>>> from torchTextClassifiers.classifiers.fasttext.core import FastTextFactory
|
|
150
|
-
>>> classifier = FastTextFactory.create_fasttext(
|
|
151
|
-
... embedding_dim=100,
|
|
152
|
-
... sparse=False,
|
|
153
|
-
... num_tokens=10000,
|
|
154
|
-
... min_count=2,
|
|
155
|
-
... min_n=3,
|
|
156
|
-
... max_n=6,
|
|
157
|
-
... len_word_ngrams=2,
|
|
158
|
-
... num_classes=3
|
|
159
|
-
... )
|
|
160
|
-
"""
|
|
161
|
-
from ...torchTextClassifiers import torchTextClassifiers
|
|
162
|
-
from .wrapper import FastTextWrapper
|
|
163
|
-
|
|
164
|
-
config = FastTextConfig(
|
|
165
|
-
embedding_dim=embedding_dim,
|
|
166
|
-
sparse=sparse,
|
|
167
|
-
num_tokens=num_tokens,
|
|
168
|
-
min_count=min_count,
|
|
169
|
-
min_n=min_n,
|
|
170
|
-
max_n=max_n,
|
|
171
|
-
len_word_ngrams=len_word_ngrams,
|
|
172
|
-
**kwargs
|
|
173
|
-
)
|
|
174
|
-
wrapper = FastTextWrapper(config)
|
|
175
|
-
return torchTextClassifiers(wrapper)
|
|
176
|
-
|
|
177
|
-
@staticmethod
|
|
178
|
-
def build_from_tokenizer(
|
|
179
|
-
tokenizer, # NGramTokenizer
|
|
180
|
-
embedding_dim: int,
|
|
181
|
-
num_classes: Optional[int],
|
|
182
|
-
categorical_vocabulary_sizes: Optional[List[int]] = None,
|
|
183
|
-
sparse: bool = False,
|
|
184
|
-
**kwargs
|
|
185
|
-
) -> "torchTextClassifiers":
|
|
186
|
-
"""Create FastText classifier from an existing trained tokenizer.
|
|
187
|
-
|
|
188
|
-
This method is useful when you have a pre-trained tokenizer and want to
|
|
189
|
-
create a classifier that uses the same vocabulary and tokenization scheme.
|
|
190
|
-
The resulting classifier will have its tokenizer and model architecture
|
|
191
|
-
pre-built.
|
|
192
|
-
|
|
193
|
-
Args:
|
|
194
|
-
tokenizer: Pre-trained NGramTokenizer instance
|
|
195
|
-
embedding_dim: Dimension of word embeddings
|
|
196
|
-
num_classes: Number of output classes
|
|
197
|
-
categorical_vocabulary_sizes: Sizes of categorical feature vocabularies
|
|
198
|
-
sparse: Whether to use sparse embeddings
|
|
199
|
-
**kwargs: Additional configuration parameters
|
|
200
|
-
|
|
201
|
-
Returns:
|
|
202
|
-
torchTextClassifiers: Classifier with pre-built tokenizer and model
|
|
203
|
-
|
|
204
|
-
Raises:
|
|
205
|
-
ValueError: If the tokenizer is missing required attributes
|
|
206
|
-
|
|
207
|
-
Example:
|
|
208
|
-
>>> # Assume you have a pre-trained tokenizer
|
|
209
|
-
>>> classifier = FastTextFactory.build_from_tokenizer(
|
|
210
|
-
... tokenizer=my_tokenizer,
|
|
211
|
-
... embedding_dim=100,
|
|
212
|
-
... num_classes=2,
|
|
213
|
-
... sparse=False
|
|
214
|
-
... )
|
|
215
|
-
>>> # The classifier is ready for training without building
|
|
216
|
-
>>> classifier.train(X_train, y_train, X_val, y_val, ...)
|
|
217
|
-
"""
|
|
218
|
-
from ...torchTextClassifiers import torchTextClassifiers
|
|
219
|
-
from .wrapper import FastTextWrapper
|
|
220
|
-
|
|
221
|
-
# Ensure the tokenizer has required attributes
|
|
222
|
-
required_attrs = ["min_count", "min_n", "max_n", "num_tokens", "word_ngrams"]
|
|
223
|
-
if not all(hasattr(tokenizer, attr) for attr in required_attrs):
|
|
224
|
-
missing_attrs = [attr for attr in required_attrs if not hasattr(tokenizer, attr)]
|
|
225
|
-
raise ValueError(f"Missing attributes in tokenizer: {missing_attrs}")
|
|
226
|
-
|
|
227
|
-
config = FastTextConfig(
|
|
228
|
-
num_tokens=tokenizer.num_tokens,
|
|
229
|
-
embedding_dim=embedding_dim,
|
|
230
|
-
min_count=tokenizer.min_count,
|
|
231
|
-
min_n=tokenizer.min_n,
|
|
232
|
-
max_n=tokenizer.max_n,
|
|
233
|
-
len_word_ngrams=tokenizer.word_ngrams,
|
|
234
|
-
sparse=sparse,
|
|
235
|
-
num_classes=num_classes,
|
|
236
|
-
categorical_vocabulary_sizes=categorical_vocabulary_sizes,
|
|
237
|
-
**kwargs
|
|
238
|
-
)
|
|
239
|
-
|
|
240
|
-
wrapper = FastTextWrapper(config)
|
|
241
|
-
classifier = torchTextClassifiers(wrapper)
|
|
242
|
-
classifier.classifier.tokenizer = tokenizer
|
|
243
|
-
classifier.classifier._build_pytorch_model()
|
|
244
|
-
|
|
245
|
-
return classifier
|
|
246
|
-
|
|
247
|
-
@staticmethod
|
|
248
|
-
def from_dict(config_dict: dict) -> FastTextConfig:
|
|
249
|
-
"""Create FastText configuration from dictionary.
|
|
250
|
-
|
|
251
|
-
This method is used internally by the configuration factory system
|
|
252
|
-
to recreate FastText configurations from serialized data.
|
|
253
|
-
|
|
254
|
-
Args:
|
|
255
|
-
config_dict: Dictionary containing configuration parameters
|
|
256
|
-
|
|
257
|
-
Returns:
|
|
258
|
-
FastTextConfig: Reconstructed configuration object
|
|
259
|
-
|
|
260
|
-
Example:
|
|
261
|
-
>>> config_dict = {
|
|
262
|
-
... 'embedding_dim': 100,
|
|
263
|
-
... 'num_tokens': 5000,
|
|
264
|
-
... 'min_count': 1,
|
|
265
|
-
... # ... other parameters
|
|
266
|
-
... }
|
|
267
|
-
>>> config = FastTextFactory.from_dict(config_dict)
|
|
268
|
-
"""
|
|
269
|
-
return FastTextConfig.from_dict(config_dict)
|