torchtextclassifiers 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchtextclassifiers-0.0.1/PKG-INFO +187 -0
- torchtextclassifiers-0.0.1/README.md +165 -0
- torchtextclassifiers-0.0.1/pyproject.toml +63 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/__init__.py +68 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/base.py +83 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/__init__.py +25 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/core.py +269 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/model.py +752 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/tokenizer.py +346 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/wrapper.py +216 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/simple_text_classifier.py +191 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/factories.py +34 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/torchTextClassifiers.py +509 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/__init__.py +3 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/checkers.py +108 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/preprocess.py +82 -0
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/utils.py +346 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: torchtextclassifiers
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: An implementation of the https://github.com/facebookresearch/fastText supervised learning algorithm for text classification using Pytorch.
|
|
5
|
+
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
|
+
Author: Tom Seimandi, Julien Pramil, Meilame Tayebjee, Cédric Couralet
|
|
7
|
+
Author-email: Tom Seimandi <tom.seimandi@gmail.com>, Julien Pramil <julien.pramil@insee.fr>, Meilame Tayebjee <meilame.tayebjee@insee.fr>, Cédric Couralet <cedric.couralet@insee.fr>
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Dist: numpy>=1.26.4
|
|
12
|
+
Requires-Dist: pytorch-lightning>=2.4.0
|
|
13
|
+
Requires-Dist: unidecode ; extra == 'explainability'
|
|
14
|
+
Requires-Dist: nltk ; extra == 'explainability'
|
|
15
|
+
Requires-Dist: captum ; extra == 'explainability'
|
|
16
|
+
Requires-Dist: unidecode ; extra == 'preprocess'
|
|
17
|
+
Requires-Dist: nltk ; extra == 'preprocess'
|
|
18
|
+
Requires-Python: >=3.11
|
|
19
|
+
Provides-Extra: explainability
|
|
20
|
+
Provides-Extra: preprocess
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
|
|
23
|
+
# torchTextClassifiers
|
|
24
|
+
|
|
25
|
+
A unified, extensible framework for text classification built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
26
|
+
|
|
27
|
+
## 🚀 Features
|
|
28
|
+
|
|
29
|
+
- **Unified API**: Consistent interface for different classifier wrappers
|
|
30
|
+
- **Extensible**: Easy to add new classifier implementations through wrapper pattern
|
|
31
|
+
- **FastText Support**: Built-in FastText classifier with n-gram tokenization
|
|
32
|
+
- **Flexible Preprocessing**: Each classifier can implement its own text preprocessing approach
|
|
33
|
+
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
## 📦 Installation
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
# Clone the repository
|
|
40
|
+
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
41
|
+
cd torchtextClassifiers
|
|
42
|
+
|
|
43
|
+
# Install with uv (recommended)
|
|
44
|
+
uv sync
|
|
45
|
+
|
|
46
|
+
# Or install with pip
|
|
47
|
+
pip install -e .
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## 🎯 Quick Start
|
|
51
|
+
|
|
52
|
+
### Basic FastText Classification
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
import numpy as np
|
|
56
|
+
from torchTextClassifiers import create_fasttext
|
|
57
|
+
|
|
58
|
+
# Create a FastText classifier
|
|
59
|
+
classifier = create_fasttext(
|
|
60
|
+
embedding_dim=100,
|
|
61
|
+
sparse=False,
|
|
62
|
+
num_tokens=10000,
|
|
63
|
+
min_count=2,
|
|
64
|
+
min_n=3,
|
|
65
|
+
max_n=6,
|
|
66
|
+
len_word_ngrams=2,
|
|
67
|
+
num_classes=2
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Prepare your data
|
|
71
|
+
X_train = np.array([
|
|
72
|
+
"This is a positive example",
|
|
73
|
+
"This is a negative example",
|
|
74
|
+
"Another positive case",
|
|
75
|
+
"Another negative case"
|
|
76
|
+
])
|
|
77
|
+
y_train = np.array([1, 0, 1, 0])
|
|
78
|
+
|
|
79
|
+
X_val = np.array([
|
|
80
|
+
"Validation positive",
|
|
81
|
+
"Validation negative"
|
|
82
|
+
])
|
|
83
|
+
y_val = np.array([1, 0])
|
|
84
|
+
|
|
85
|
+
# Build the model
|
|
86
|
+
classifier.build(X_train, y_train)
|
|
87
|
+
|
|
88
|
+
# Train the model
|
|
89
|
+
classifier.train(
|
|
90
|
+
X_train, y_train, X_val, y_val,
|
|
91
|
+
num_epochs=50,
|
|
92
|
+
batch_size=32,
|
|
93
|
+
patience_train=5,
|
|
94
|
+
verbose=True
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Make predictions
|
|
98
|
+
X_test = np.array(["This is a test sentence"])
|
|
99
|
+
predictions = classifier.predict(X_test)
|
|
100
|
+
print(f"Predictions: {predictions}")
|
|
101
|
+
|
|
102
|
+
# Validate on test set
|
|
103
|
+
accuracy = classifier.validate(X_test, np.array([1]))
|
|
104
|
+
print(f"Accuracy: {accuracy:.3f}")
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
### Custom Classifier Implementation
|
|
108
|
+
|
|
109
|
+
```python
|
|
110
|
+
import numpy as np
|
|
111
|
+
from torchTextClassifiers import torchTextClassifiers
|
|
112
|
+
from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextWrapper, SimpleTextConfig
|
|
113
|
+
|
|
114
|
+
# Example: TF-IDF based classifier (alternative to tokenization)
|
|
115
|
+
config = SimpleTextConfig(
|
|
116
|
+
hidden_dim=128,
|
|
117
|
+
num_classes=2,
|
|
118
|
+
max_features=5000,
|
|
119
|
+
learning_rate=1e-3,
|
|
120
|
+
dropout_rate=0.2
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Create classifier with TF-IDF preprocessing
|
|
124
|
+
wrapper = SimpleTextWrapper(config)
|
|
125
|
+
classifier = torchTextClassifiers(wrapper)
|
|
126
|
+
|
|
127
|
+
# Text data
|
|
128
|
+
X_train = np.array(["Great product!", "Terrible service", "Love it!"])
|
|
129
|
+
y_train = np.array([1, 0, 1])
|
|
130
|
+
|
|
131
|
+
# Build and train
|
|
132
|
+
classifier.build(X_train, y_train)
|
|
133
|
+
# ... continue with training
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
### Training Customization
|
|
138
|
+
|
|
139
|
+
```python
|
|
140
|
+
# Custom PyTorch Lightning trainer parameters
|
|
141
|
+
trainer_params = {
|
|
142
|
+
'accelerator': 'gpu',
|
|
143
|
+
'devices': 1,
|
|
144
|
+
'precision': 16, # Mixed precision training
|
|
145
|
+
'gradient_clip_val': 1.0,
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
classifier.train(
|
|
149
|
+
X_train, y_train, X_val, y_val,
|
|
150
|
+
num_epochs=100,
|
|
151
|
+
batch_size=64,
|
|
152
|
+
patience_train=10,
|
|
153
|
+
trainer_params=trainer_params,
|
|
154
|
+
verbose=True
|
|
155
|
+
)
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
## 🔬 Testing
|
|
159
|
+
|
|
160
|
+
Run the test suite:
|
|
161
|
+
|
|
162
|
+
```bash
|
|
163
|
+
# Run all tests
|
|
164
|
+
uv run pytest
|
|
165
|
+
|
|
166
|
+
# Run with coverage
|
|
167
|
+
uv run pytest --cov=torchTextClassifiers
|
|
168
|
+
|
|
169
|
+
# Run specific test file
|
|
170
|
+
uv run pytest tests/test_torchTextClassifiers.py -v
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
## 📚 Examples
|
|
175
|
+
|
|
176
|
+
See the [examples/](examples/) directory for:
|
|
177
|
+
- Basic text classification
|
|
178
|
+
- Multi-class classification
|
|
179
|
+
- Mixed features (text + categorical)
|
|
180
|
+
- Custom classifier implementation
|
|
181
|
+
- Advanced training configurations
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
## 📄 License
|
|
186
|
+
|
|
187
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# torchTextClassifiers
|
|
2
|
+
|
|
3
|
+
A unified, extensible framework for text classification built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
4
|
+
|
|
5
|
+
## 🚀 Features
|
|
6
|
+
|
|
7
|
+
- **Unified API**: Consistent interface for different classifier wrappers
|
|
8
|
+
- **Extensible**: Easy to add new classifier implementations through wrapper pattern
|
|
9
|
+
- **FastText Support**: Built-in FastText classifier with n-gram tokenization
|
|
10
|
+
- **Flexible Preprocessing**: Each classifier can implement its own text preprocessing approach
|
|
11
|
+
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
## 📦 Installation
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
# Clone the repository
|
|
18
|
+
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
19
|
+
cd torchtextClassifiers
|
|
20
|
+
|
|
21
|
+
# Install with uv (recommended)
|
|
22
|
+
uv sync
|
|
23
|
+
|
|
24
|
+
# Or install with pip
|
|
25
|
+
pip install -e .
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## 🎯 Quick Start
|
|
29
|
+
|
|
30
|
+
### Basic FastText Classification
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
import numpy as np
|
|
34
|
+
from torchTextClassifiers import create_fasttext
|
|
35
|
+
|
|
36
|
+
# Create a FastText classifier
|
|
37
|
+
classifier = create_fasttext(
|
|
38
|
+
embedding_dim=100,
|
|
39
|
+
sparse=False,
|
|
40
|
+
num_tokens=10000,
|
|
41
|
+
min_count=2,
|
|
42
|
+
min_n=3,
|
|
43
|
+
max_n=6,
|
|
44
|
+
len_word_ngrams=2,
|
|
45
|
+
num_classes=2
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Prepare your data
|
|
49
|
+
X_train = np.array([
|
|
50
|
+
"This is a positive example",
|
|
51
|
+
"This is a negative example",
|
|
52
|
+
"Another positive case",
|
|
53
|
+
"Another negative case"
|
|
54
|
+
])
|
|
55
|
+
y_train = np.array([1, 0, 1, 0])
|
|
56
|
+
|
|
57
|
+
X_val = np.array([
|
|
58
|
+
"Validation positive",
|
|
59
|
+
"Validation negative"
|
|
60
|
+
])
|
|
61
|
+
y_val = np.array([1, 0])
|
|
62
|
+
|
|
63
|
+
# Build the model
|
|
64
|
+
classifier.build(X_train, y_train)
|
|
65
|
+
|
|
66
|
+
# Train the model
|
|
67
|
+
classifier.train(
|
|
68
|
+
X_train, y_train, X_val, y_val,
|
|
69
|
+
num_epochs=50,
|
|
70
|
+
batch_size=32,
|
|
71
|
+
patience_train=5,
|
|
72
|
+
verbose=True
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Make predictions
|
|
76
|
+
X_test = np.array(["This is a test sentence"])
|
|
77
|
+
predictions = classifier.predict(X_test)
|
|
78
|
+
print(f"Predictions: {predictions}")
|
|
79
|
+
|
|
80
|
+
# Validate on test set
|
|
81
|
+
accuracy = classifier.validate(X_test, np.array([1]))
|
|
82
|
+
print(f"Accuracy: {accuracy:.3f}")
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
### Custom Classifier Implementation
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
import numpy as np
|
|
89
|
+
from torchTextClassifiers import torchTextClassifiers
|
|
90
|
+
from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextWrapper, SimpleTextConfig
|
|
91
|
+
|
|
92
|
+
# Example: TF-IDF based classifier (alternative to tokenization)
|
|
93
|
+
config = SimpleTextConfig(
|
|
94
|
+
hidden_dim=128,
|
|
95
|
+
num_classes=2,
|
|
96
|
+
max_features=5000,
|
|
97
|
+
learning_rate=1e-3,
|
|
98
|
+
dropout_rate=0.2
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Create classifier with TF-IDF preprocessing
|
|
102
|
+
wrapper = SimpleTextWrapper(config)
|
|
103
|
+
classifier = torchTextClassifiers(wrapper)
|
|
104
|
+
|
|
105
|
+
# Text data
|
|
106
|
+
X_train = np.array(["Great product!", "Terrible service", "Love it!"])
|
|
107
|
+
y_train = np.array([1, 0, 1])
|
|
108
|
+
|
|
109
|
+
# Build and train
|
|
110
|
+
classifier.build(X_train, y_train)
|
|
111
|
+
# ... continue with training
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
### Training Customization
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
# Custom PyTorch Lightning trainer parameters
|
|
119
|
+
trainer_params = {
|
|
120
|
+
'accelerator': 'gpu',
|
|
121
|
+
'devices': 1,
|
|
122
|
+
'precision': 16, # Mixed precision training
|
|
123
|
+
'gradient_clip_val': 1.0,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
classifier.train(
|
|
127
|
+
X_train, y_train, X_val, y_val,
|
|
128
|
+
num_epochs=100,
|
|
129
|
+
batch_size=64,
|
|
130
|
+
patience_train=10,
|
|
131
|
+
trainer_params=trainer_params,
|
|
132
|
+
verbose=True
|
|
133
|
+
)
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
## 🔬 Testing
|
|
137
|
+
|
|
138
|
+
Run the test suite:
|
|
139
|
+
|
|
140
|
+
```bash
|
|
141
|
+
# Run all tests
|
|
142
|
+
uv run pytest
|
|
143
|
+
|
|
144
|
+
# Run with coverage
|
|
145
|
+
uv run pytest --cov=torchTextClassifiers
|
|
146
|
+
|
|
147
|
+
# Run specific test file
|
|
148
|
+
uv run pytest tests/test_torchTextClassifiers.py -v
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
## 📚 Examples
|
|
153
|
+
|
|
154
|
+
See the [examples/](examples/) directory for:
|
|
155
|
+
- Basic text classification
|
|
156
|
+
- Multi-class classification
|
|
157
|
+
- Mixed features (text + categorical)
|
|
158
|
+
- Custom classifier implementation
|
|
159
|
+
- Advanced training configurations
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
## 📄 License
|
|
164
|
+
|
|
165
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "torchtextclassifiers"
|
|
3
|
+
description = "An implementation of the https://github.com/facebookresearch/fastText supervised learning algorithm for text classification using Pytorch."
|
|
4
|
+
authors = [
|
|
5
|
+
{ name = "Tom Seimandi", email = "tom.seimandi@gmail.com" },
|
|
6
|
+
{ name = "Julien Pramil", email = "julien.pramil@insee.fr" },
|
|
7
|
+
{ name = "Meilame Tayebjee", email = "meilame.tayebjee@insee.fr" },
|
|
8
|
+
{ name = "Cédric Couralet", email = "cedric.couralet@insee.fr" },
|
|
9
|
+
]
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
repository = "https://github.com/InseeFrLab/torchTextClassifiers"
|
|
12
|
+
classifiers = [
|
|
13
|
+
"Programming Language :: Python :: 3",
|
|
14
|
+
"License :: OSI Approved :: MIT License",
|
|
15
|
+
"Operating System :: OS Independent",
|
|
16
|
+
]
|
|
17
|
+
keywords = ["fastText", "text classification", "NLP", "automatic coding", "deep learning"]
|
|
18
|
+
dependencies = [
|
|
19
|
+
"numpy>=1.26.4",
|
|
20
|
+
"pytorch-lightning>=2.4.0",
|
|
21
|
+
]
|
|
22
|
+
requires-python = ">=3.11"
|
|
23
|
+
version="0.0.1"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
[dependency-groups]
|
|
27
|
+
dev = [
|
|
28
|
+
"pytest >=8.1.1,<9",
|
|
29
|
+
"pandas",
|
|
30
|
+
"scikit-learn",
|
|
31
|
+
"nltk",
|
|
32
|
+
"unidecode",
|
|
33
|
+
"captum",
|
|
34
|
+
"pyarrow"
|
|
35
|
+
]
|
|
36
|
+
docs = [
|
|
37
|
+
"sphinx>=5.0.0",
|
|
38
|
+
"sphinx-rtd-theme>=1.2.0",
|
|
39
|
+
"sphinx-autodoc-typehints>=1.19.0",
|
|
40
|
+
"sphinxcontrib-napoleon>=0.7",
|
|
41
|
+
"sphinx-copybutton>=0.5.0",
|
|
42
|
+
"myst-parser>=0.18.0",
|
|
43
|
+
"sphinx-design>=0.3.0"
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
[project.optional-dependencies]
|
|
47
|
+
explainability = ["unidecode", "nltk", "captum"]
|
|
48
|
+
preprocess = ["unidecode", "nltk"]
|
|
49
|
+
|
|
50
|
+
[build-system]
|
|
51
|
+
requires = ["uv_build>=0.8.3,<0.9.0"]
|
|
52
|
+
build-backend = "uv_build"
|
|
53
|
+
|
|
54
|
+
[tool.ruff]
|
|
55
|
+
line-length = 100
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
[tool.uv.build-backend]
|
|
59
|
+
module-name="torchTextClassifiers"
|
|
60
|
+
module-root = ""
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""torchTextClassifiers: A unified framework for text classification.
|
|
2
|
+
|
|
3
|
+
This package provides a generic, extensible framework for building and training
|
|
4
|
+
different types of text classifiers. It currently supports FastText classifiers
|
|
5
|
+
with a clean API for building, training, and inference.
|
|
6
|
+
|
|
7
|
+
Key Features:
|
|
8
|
+
- Unified API for different classifier types
|
|
9
|
+
- Built-in support for FastText classifiers
|
|
10
|
+
- PyTorch Lightning integration for training
|
|
11
|
+
- Extensible architecture for adding new classifier types
|
|
12
|
+
- Support for both text-only and mixed text/categorical features
|
|
13
|
+
|
|
14
|
+
Quick Start:
|
|
15
|
+
>>> from torchTextClassifiers import create_fasttext
|
|
16
|
+
>>> import numpy as np
|
|
17
|
+
>>>
|
|
18
|
+
>>> # Create classifier
|
|
19
|
+
>>> classifier = create_fasttext(
|
|
20
|
+
... embedding_dim=100,
|
|
21
|
+
... sparse=False,
|
|
22
|
+
... num_tokens=10000,
|
|
23
|
+
... min_count=2,
|
|
24
|
+
... min_n=3,
|
|
25
|
+
... max_n=6,
|
|
26
|
+
... len_word_ngrams=2,
|
|
27
|
+
... num_classes=2
|
|
28
|
+
... )
|
|
29
|
+
>>>
|
|
30
|
+
>>> # Prepare data
|
|
31
|
+
>>> X_train = np.array(["positive text", "negative text"])
|
|
32
|
+
>>> y_train = np.array([1, 0])
|
|
33
|
+
>>> X_val = np.array(["validation text"])
|
|
34
|
+
>>> y_val = np.array([1])
|
|
35
|
+
>>>
|
|
36
|
+
>>> # Build and train
|
|
37
|
+
>>> classifier.build(X_train, y_train)
|
|
38
|
+
>>> classifier.train(X_train, y_train, X_val, y_val, num_epochs=10, batch_size=32)
|
|
39
|
+
>>>
|
|
40
|
+
>>> # Predict
|
|
41
|
+
>>> predictions = classifier.predict(np.array(["new text sample"]))
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
from .torchTextClassifiers import torchTextClassifiers
|
|
45
|
+
|
|
46
|
+
# Convenience imports for FastText
|
|
47
|
+
try:
|
|
48
|
+
from .classifiers.fasttext.core import FastTextFactory
|
|
49
|
+
|
|
50
|
+
# Expose FastText convenience methods at package level for easy access
|
|
51
|
+
create_fasttext = FastTextFactory.create_fasttext
|
|
52
|
+
build_fasttext_from_tokenizer = FastTextFactory.build_from_tokenizer
|
|
53
|
+
|
|
54
|
+
except ImportError:
|
|
55
|
+
# FastText module not available - define placeholder functions
|
|
56
|
+
def create_fasttext(*args, **kwargs):
|
|
57
|
+
raise ImportError("FastText module not available")
|
|
58
|
+
|
|
59
|
+
def build_fasttext_from_tokenizer(*args, **kwargs):
|
|
60
|
+
raise ImportError("FastText module not available")
|
|
61
|
+
|
|
62
|
+
__all__ = [
|
|
63
|
+
"torchTextClassifiers",
|
|
64
|
+
"create_fasttext",
|
|
65
|
+
"build_fasttext_from_tokenizer",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
__version__ = "1.0.0"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import Optional, Union, Type, List, Dict, Any
|
|
2
|
+
from dataclasses import dataclass, field, asdict
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
class BaseClassifierConfig(ABC):
|
|
7
|
+
"""Abstract base class for classifier configurations."""
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
11
|
+
"""Convert configuration to dictionary."""
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def from_dict(cls, data: Dict[str, Any]) -> "BaseClassifierConfig":
|
|
17
|
+
"""Create configuration from dictionary."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
class BaseClassifierWrapper(ABC):
|
|
21
|
+
"""Abstract base class for classifier wrappers.
|
|
22
|
+
|
|
23
|
+
Each classifier wrapper is responsible for its own text processing approach.
|
|
24
|
+
Some may use tokenizers, others may use different preprocessing methods.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: BaseClassifierConfig):
|
|
28
|
+
self.config = config
|
|
29
|
+
self.pytorch_model = None
|
|
30
|
+
self.lightning_module = None
|
|
31
|
+
self.trained: bool = False
|
|
32
|
+
self.device = None
|
|
33
|
+
# Remove tokenizer from base class - it's now wrapper-specific
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
37
|
+
"""Prepare text features for the classifier.
|
|
38
|
+
|
|
39
|
+
This could involve tokenization, vectorization, or other preprocessing.
|
|
40
|
+
Each classifier wrapper implements this according to its needs.
|
|
41
|
+
"""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def _build_pytorch_model(self) -> None:
|
|
46
|
+
"""Build the PyTorch model."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def _check_and_init_lightning(self, **kwargs) -> None:
|
|
51
|
+
"""Initialize Lightning module."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
|
|
56
|
+
"""Make predictions."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
|
61
|
+
"""Validate the model."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def create_dataset(self, texts: np.ndarray, labels: np.ndarray, categorical_variables: Optional[np.ndarray] = None):
|
|
66
|
+
"""Create dataset for training/validation."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuffle: bool = True):
|
|
71
|
+
"""Create dataloader from dataset."""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def load_best_model(self, checkpoint_path: str) -> None:
|
|
76
|
+
"""Load best model from checkpoint."""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def get_config_class(cls) -> Type[BaseClassifierConfig]:
|
|
82
|
+
"""Return the configuration class for this wrapper."""
|
|
83
|
+
pass
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""FastText classifier package.
|
|
2
|
+
|
|
3
|
+
Provides FastText text classification with PyTorch Lightning integration.
|
|
4
|
+
This folder contains 4 main files:
|
|
5
|
+
- core.py: Configuration, losses, and factory methods
|
|
6
|
+
- tokenizer.py: NGramTokenizer implementation
|
|
7
|
+
- model.py: PyTorch model, Lightning module, and dataset
|
|
8
|
+
- wrapper.py: High-level wrapper interface
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .core import FastTextConfig, OneVsAllLoss, FastTextFactory
|
|
12
|
+
from .tokenizer import NGramTokenizer
|
|
13
|
+
from .model import FastTextModel, FastTextModule, FastTextModelDataset
|
|
14
|
+
from .wrapper import FastTextWrapper
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"FastTextConfig",
|
|
18
|
+
"OneVsAllLoss",
|
|
19
|
+
"FastTextFactory",
|
|
20
|
+
"NGramTokenizer",
|
|
21
|
+
"FastTextModel",
|
|
22
|
+
"FastTextModule",
|
|
23
|
+
"FastTextModelDataset",
|
|
24
|
+
"FastTextWrapper",
|
|
25
|
+
]
|