torchtextclassifiers 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +68 -0
- torchTextClassifiers/classifiers/base.py +83 -0
- torchTextClassifiers/classifiers/fasttext/__init__.py +25 -0
- torchTextClassifiers/classifiers/fasttext/core.py +269 -0
- torchTextClassifiers/classifiers/fasttext/model.py +752 -0
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +346 -0
- torchTextClassifiers/classifiers/fasttext/wrapper.py +216 -0
- torchTextClassifiers/classifiers/simple_text_classifier.py +191 -0
- torchTextClassifiers/factories.py +34 -0
- torchTextClassifiers/torchTextClassifiers.py +509 -0
- torchTextClassifiers/utilities/__init__.py +3 -0
- torchTextClassifiers/utilities/checkers.py +108 -0
- torchTextClassifiers/utilities/preprocess.py +82 -0
- torchTextClassifiers/utilities/utils.py +346 -0
- torchtextclassifiers-0.0.1.dist-info/METADATA +187 -0
- torchtextclassifiers-0.0.1.dist-info/RECORD +17 -0
- torchtextclassifiers-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""torchTextClassifiers: A unified framework for text classification.
|
|
2
|
+
|
|
3
|
+
This package provides a generic, extensible framework for building and training
|
|
4
|
+
different types of text classifiers. It currently supports FastText classifiers
|
|
5
|
+
with a clean API for building, training, and inference.
|
|
6
|
+
|
|
7
|
+
Key Features:
|
|
8
|
+
- Unified API for different classifier types
|
|
9
|
+
- Built-in support for FastText classifiers
|
|
10
|
+
- PyTorch Lightning integration for training
|
|
11
|
+
- Extensible architecture for adding new classifier types
|
|
12
|
+
- Support for both text-only and mixed text/categorical features
|
|
13
|
+
|
|
14
|
+
Quick Start:
|
|
15
|
+
>>> from torchTextClassifiers import create_fasttext
|
|
16
|
+
>>> import numpy as np
|
|
17
|
+
>>>
|
|
18
|
+
>>> # Create classifier
|
|
19
|
+
>>> classifier = create_fasttext(
|
|
20
|
+
... embedding_dim=100,
|
|
21
|
+
... sparse=False,
|
|
22
|
+
... num_tokens=10000,
|
|
23
|
+
... min_count=2,
|
|
24
|
+
... min_n=3,
|
|
25
|
+
... max_n=6,
|
|
26
|
+
... len_word_ngrams=2,
|
|
27
|
+
... num_classes=2
|
|
28
|
+
... )
|
|
29
|
+
>>>
|
|
30
|
+
>>> # Prepare data
|
|
31
|
+
>>> X_train = np.array(["positive text", "negative text"])
|
|
32
|
+
>>> y_train = np.array([1, 0])
|
|
33
|
+
>>> X_val = np.array(["validation text"])
|
|
34
|
+
>>> y_val = np.array([1])
|
|
35
|
+
>>>
|
|
36
|
+
>>> # Build and train
|
|
37
|
+
>>> classifier.build(X_train, y_train)
|
|
38
|
+
>>> classifier.train(X_train, y_train, X_val, y_val, num_epochs=10, batch_size=32)
|
|
39
|
+
>>>
|
|
40
|
+
>>> # Predict
|
|
41
|
+
>>> predictions = classifier.predict(np.array(["new text sample"]))
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
from .torchTextClassifiers import torchTextClassifiers
|
|
45
|
+
|
|
46
|
+
# Convenience imports for FastText
|
|
47
|
+
try:
|
|
48
|
+
from .classifiers.fasttext.core import FastTextFactory
|
|
49
|
+
|
|
50
|
+
# Expose FastText convenience methods at package level for easy access
|
|
51
|
+
create_fasttext = FastTextFactory.create_fasttext
|
|
52
|
+
build_fasttext_from_tokenizer = FastTextFactory.build_from_tokenizer
|
|
53
|
+
|
|
54
|
+
except ImportError:
|
|
55
|
+
# FastText module not available - define placeholder functions
|
|
56
|
+
def create_fasttext(*args, **kwargs):
|
|
57
|
+
raise ImportError("FastText module not available")
|
|
58
|
+
|
|
59
|
+
def build_fasttext_from_tokenizer(*args, **kwargs):
|
|
60
|
+
raise ImportError("FastText module not available")
|
|
61
|
+
|
|
62
|
+
__all__ = [
|
|
63
|
+
"torchTextClassifiers",
|
|
64
|
+
"create_fasttext",
|
|
65
|
+
"build_fasttext_from_tokenizer",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
__version__ = "1.0.0"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import Optional, Union, Type, List, Dict, Any
|
|
2
|
+
from dataclasses import dataclass, field, asdict
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
class BaseClassifierConfig(ABC):
|
|
7
|
+
"""Abstract base class for classifier configurations."""
|
|
8
|
+
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
11
|
+
"""Convert configuration to dictionary."""
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def from_dict(cls, data: Dict[str, Any]) -> "BaseClassifierConfig":
|
|
17
|
+
"""Create configuration from dictionary."""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
class BaseClassifierWrapper(ABC):
|
|
21
|
+
"""Abstract base class for classifier wrappers.
|
|
22
|
+
|
|
23
|
+
Each classifier wrapper is responsible for its own text processing approach.
|
|
24
|
+
Some may use tokenizers, others may use different preprocessing methods.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, config: BaseClassifierConfig):
|
|
28
|
+
self.config = config
|
|
29
|
+
self.pytorch_model = None
|
|
30
|
+
self.lightning_module = None
|
|
31
|
+
self.trained: bool = False
|
|
32
|
+
self.device = None
|
|
33
|
+
# Remove tokenizer from base class - it's now wrapper-specific
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
37
|
+
"""Prepare text features for the classifier.
|
|
38
|
+
|
|
39
|
+
This could involve tokenization, vectorization, or other preprocessing.
|
|
40
|
+
Each classifier wrapper implements this according to its needs.
|
|
41
|
+
"""
|
|
42
|
+
pass
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def _build_pytorch_model(self) -> None:
|
|
46
|
+
"""Build the PyTorch model."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def _check_and_init_lightning(self, **kwargs) -> None:
|
|
51
|
+
"""Initialize Lightning module."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
|
|
56
|
+
"""Make predictions."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
|
|
61
|
+
"""Validate the model."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def create_dataset(self, texts: np.ndarray, labels: np.ndarray, categorical_variables: Optional[np.ndarray] = None):
|
|
66
|
+
"""Create dataset for training/validation."""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@abstractmethod
|
|
70
|
+
def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuffle: bool = True):
|
|
71
|
+
"""Create dataloader from dataset."""
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
@abstractmethod
|
|
75
|
+
def load_best_model(self, checkpoint_path: str) -> None:
|
|
76
|
+
"""Load best model from checkpoint."""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
@abstractmethod
|
|
81
|
+
def get_config_class(cls) -> Type[BaseClassifierConfig]:
|
|
82
|
+
"""Return the configuration class for this wrapper."""
|
|
83
|
+
pass
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""FastText classifier package.
|
|
2
|
+
|
|
3
|
+
Provides FastText text classification with PyTorch Lightning integration.
|
|
4
|
+
This folder contains 4 main files:
|
|
5
|
+
- core.py: Configuration, losses, and factory methods
|
|
6
|
+
- tokenizer.py: NGramTokenizer implementation
|
|
7
|
+
- model.py: PyTorch model, Lightning module, and dataset
|
|
8
|
+
- wrapper.py: High-level wrapper interface
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .core import FastTextConfig, OneVsAllLoss, FastTextFactory
|
|
12
|
+
from .tokenizer import NGramTokenizer
|
|
13
|
+
from .model import FastTextModel, FastTextModule, FastTextModelDataset
|
|
14
|
+
from .wrapper import FastTextWrapper
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"FastTextConfig",
|
|
18
|
+
"OneVsAllLoss",
|
|
19
|
+
"FastTextFactory",
|
|
20
|
+
"NGramTokenizer",
|
|
21
|
+
"FastTextModel",
|
|
22
|
+
"FastTextModule",
|
|
23
|
+
"FastTextModelDataset",
|
|
24
|
+
"FastTextWrapper",
|
|
25
|
+
]
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
"""FastText classifier core components.
|
|
2
|
+
|
|
3
|
+
This module contains the core components for FastText classification:
|
|
4
|
+
- Configuration dataclass
|
|
5
|
+
- Loss functions
|
|
6
|
+
- Factory methods for creating classifiers
|
|
7
|
+
|
|
8
|
+
Consolidates what was previously in config.py, losses.py, and factory.py.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass, field, asdict
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from ..base import BaseClassifierConfig
|
|
14
|
+
from typing import Optional, List, TYPE_CHECKING, Union, Dict, Any
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from ...torchTextClassifiers import torchTextClassifiers
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ============================================================================
|
|
25
|
+
# Configuration
|
|
26
|
+
# ============================================================================
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class FastTextConfig(BaseClassifierConfig):
|
|
30
|
+
"""Configuration for FastText classifier."""
|
|
31
|
+
# Embedding matrix
|
|
32
|
+
embedding_dim: int
|
|
33
|
+
sparse: bool
|
|
34
|
+
|
|
35
|
+
# Tokenizer-related
|
|
36
|
+
num_tokens: int
|
|
37
|
+
min_count: int
|
|
38
|
+
min_n: int
|
|
39
|
+
max_n: int
|
|
40
|
+
len_word_ngrams: int
|
|
41
|
+
|
|
42
|
+
# Optional parameters
|
|
43
|
+
num_classes: Optional[int] = None
|
|
44
|
+
num_rows: Optional[int] = None
|
|
45
|
+
|
|
46
|
+
# Categorical variables
|
|
47
|
+
categorical_vocabulary_sizes: Optional[List[int]] = None
|
|
48
|
+
categorical_embedding_dims: Optional[Union[List[int], int]] = None
|
|
49
|
+
num_categorical_features: Optional[int] = None
|
|
50
|
+
|
|
51
|
+
# Model-specific parameters
|
|
52
|
+
direct_bagging: Optional[bool] = True
|
|
53
|
+
|
|
54
|
+
# Training parameters
|
|
55
|
+
learning_rate: float = 4e-3
|
|
56
|
+
|
|
57
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
58
|
+
return asdict(self)
|
|
59
|
+
|
|
60
|
+
@classmethod
|
|
61
|
+
def from_dict(cls, data: Dict[str, Any]) -> "FastTextConfig":
|
|
62
|
+
return cls(**data)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ============================================================================
|
|
66
|
+
# Loss Functions
|
|
67
|
+
# ============================================================================
|
|
68
|
+
|
|
69
|
+
class OneVsAllLoss(nn.Module):
|
|
70
|
+
def __init__(self):
|
|
71
|
+
super(OneVsAllLoss, self).__init__()
|
|
72
|
+
|
|
73
|
+
def forward(self, logits, targets):
|
|
74
|
+
"""
|
|
75
|
+
Compute One-vs-All loss
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
logits: Tensor of shape (batch_size, num_classes) containing classification scores
|
|
79
|
+
targets: Tensor of shape (batch_size) containing true class indices
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
loss: Mean loss value across the batch
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
num_classes = logits.size(1)
|
|
86
|
+
|
|
87
|
+
# Convert targets to one-hot encoding
|
|
88
|
+
targets_one_hot = F.one_hot(targets, num_classes=num_classes).float()
|
|
89
|
+
|
|
90
|
+
# For each sample, treat the true class as positive and all others as negative
|
|
91
|
+
# Using binary cross entropy for each class
|
|
92
|
+
loss = F.binary_cross_entropy_with_logits(
|
|
93
|
+
logits, # Raw logits
|
|
94
|
+
targets_one_hot, # Target probabilities
|
|
95
|
+
reduction="none", # Don't reduce yet to allow for custom weighting if needed
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Sum losses across all classes for each sample, then take mean across batch
|
|
99
|
+
return loss.sum(dim=1).mean()
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# ============================================================================
|
|
103
|
+
# Factory Methods
|
|
104
|
+
# ============================================================================
|
|
105
|
+
|
|
106
|
+
class FastTextFactory:
|
|
107
|
+
"""Factory class for creating FastText classifiers with convenience methods.
|
|
108
|
+
|
|
109
|
+
This factory provides static methods for creating FastText classifiers with
|
|
110
|
+
common configurations. It handles the complexities of configuration creation
|
|
111
|
+
and classifier initialization, offering a simplified API for users.
|
|
112
|
+
|
|
113
|
+
All methods return fully initialized torchTextClassifiers instances that are
|
|
114
|
+
ready for building and training.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def create_fasttext(
|
|
119
|
+
embedding_dim: int,
|
|
120
|
+
sparse: bool,
|
|
121
|
+
num_tokens: int,
|
|
122
|
+
min_count: int,
|
|
123
|
+
min_n: int,
|
|
124
|
+
max_n: int,
|
|
125
|
+
len_word_ngrams: int,
|
|
126
|
+
**kwargs
|
|
127
|
+
) -> "torchTextClassifiers":
|
|
128
|
+
"""Create a FastText classifier with the specified configuration.
|
|
129
|
+
|
|
130
|
+
This is the primary method for creating FastText classifiers. It creates
|
|
131
|
+
a configuration object with the provided parameters and initializes a
|
|
132
|
+
complete classifier instance.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
embedding_dim: Dimension of word embeddings
|
|
136
|
+
sparse: Whether to use sparse embeddings
|
|
137
|
+
num_tokens: Maximum number of tokens in vocabulary
|
|
138
|
+
min_count: Minimum count for tokens to be included in vocabulary
|
|
139
|
+
min_n: Minimum length of character n-grams
|
|
140
|
+
max_n: Maximum length of character n-grams
|
|
141
|
+
len_word_ngrams: Length of word n-grams to use
|
|
142
|
+
**kwargs: Additional configuration parameters (e.g., num_classes,
|
|
143
|
+
categorical_vocabulary_sizes, etc.)
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
torchTextClassifiers: Initialized FastText classifier instance
|
|
147
|
+
|
|
148
|
+
Example:
|
|
149
|
+
>>> from torchTextClassifiers.classifiers.fasttext.core import FastTextFactory
|
|
150
|
+
>>> classifier = FastTextFactory.create_fasttext(
|
|
151
|
+
... embedding_dim=100,
|
|
152
|
+
... sparse=False,
|
|
153
|
+
... num_tokens=10000,
|
|
154
|
+
... min_count=2,
|
|
155
|
+
... min_n=3,
|
|
156
|
+
... max_n=6,
|
|
157
|
+
... len_word_ngrams=2,
|
|
158
|
+
... num_classes=3
|
|
159
|
+
... )
|
|
160
|
+
"""
|
|
161
|
+
from ...torchTextClassifiers import torchTextClassifiers
|
|
162
|
+
from .wrapper import FastTextWrapper
|
|
163
|
+
|
|
164
|
+
config = FastTextConfig(
|
|
165
|
+
embedding_dim=embedding_dim,
|
|
166
|
+
sparse=sparse,
|
|
167
|
+
num_tokens=num_tokens,
|
|
168
|
+
min_count=min_count,
|
|
169
|
+
min_n=min_n,
|
|
170
|
+
max_n=max_n,
|
|
171
|
+
len_word_ngrams=len_word_ngrams,
|
|
172
|
+
**kwargs
|
|
173
|
+
)
|
|
174
|
+
wrapper = FastTextWrapper(config)
|
|
175
|
+
return torchTextClassifiers(wrapper)
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def build_from_tokenizer(
|
|
179
|
+
tokenizer, # NGramTokenizer
|
|
180
|
+
embedding_dim: int,
|
|
181
|
+
num_classes: Optional[int],
|
|
182
|
+
categorical_vocabulary_sizes: Optional[List[int]] = None,
|
|
183
|
+
sparse: bool = False,
|
|
184
|
+
**kwargs
|
|
185
|
+
) -> "torchTextClassifiers":
|
|
186
|
+
"""Create FastText classifier from an existing trained tokenizer.
|
|
187
|
+
|
|
188
|
+
This method is useful when you have a pre-trained tokenizer and want to
|
|
189
|
+
create a classifier that uses the same vocabulary and tokenization scheme.
|
|
190
|
+
The resulting classifier will have its tokenizer and model architecture
|
|
191
|
+
pre-built.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
tokenizer: Pre-trained NGramTokenizer instance
|
|
195
|
+
embedding_dim: Dimension of word embeddings
|
|
196
|
+
num_classes: Number of output classes
|
|
197
|
+
categorical_vocabulary_sizes: Sizes of categorical feature vocabularies
|
|
198
|
+
sparse: Whether to use sparse embeddings
|
|
199
|
+
**kwargs: Additional configuration parameters
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
torchTextClassifiers: Classifier with pre-built tokenizer and model
|
|
203
|
+
|
|
204
|
+
Raises:
|
|
205
|
+
ValueError: If the tokenizer is missing required attributes
|
|
206
|
+
|
|
207
|
+
Example:
|
|
208
|
+
>>> # Assume you have a pre-trained tokenizer
|
|
209
|
+
>>> classifier = FastTextFactory.build_from_tokenizer(
|
|
210
|
+
... tokenizer=my_tokenizer,
|
|
211
|
+
... embedding_dim=100,
|
|
212
|
+
... num_classes=2,
|
|
213
|
+
... sparse=False
|
|
214
|
+
... )
|
|
215
|
+
>>> # The classifier is ready for training without building
|
|
216
|
+
>>> classifier.train(X_train, y_train, X_val, y_val, ...)
|
|
217
|
+
"""
|
|
218
|
+
from ...torchTextClassifiers import torchTextClassifiers
|
|
219
|
+
from .wrapper import FastTextWrapper
|
|
220
|
+
|
|
221
|
+
# Ensure the tokenizer has required attributes
|
|
222
|
+
required_attrs = ["min_count", "min_n", "max_n", "num_tokens", "word_ngrams"]
|
|
223
|
+
if not all(hasattr(tokenizer, attr) for attr in required_attrs):
|
|
224
|
+
missing_attrs = [attr for attr in required_attrs if not hasattr(tokenizer, attr)]
|
|
225
|
+
raise ValueError(f"Missing attributes in tokenizer: {missing_attrs}")
|
|
226
|
+
|
|
227
|
+
config = FastTextConfig(
|
|
228
|
+
num_tokens=tokenizer.num_tokens,
|
|
229
|
+
embedding_dim=embedding_dim,
|
|
230
|
+
min_count=tokenizer.min_count,
|
|
231
|
+
min_n=tokenizer.min_n,
|
|
232
|
+
max_n=tokenizer.max_n,
|
|
233
|
+
len_word_ngrams=tokenizer.word_ngrams,
|
|
234
|
+
sparse=sparse,
|
|
235
|
+
num_classes=num_classes,
|
|
236
|
+
categorical_vocabulary_sizes=categorical_vocabulary_sizes,
|
|
237
|
+
**kwargs
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
wrapper = FastTextWrapper(config)
|
|
241
|
+
classifier = torchTextClassifiers(wrapper)
|
|
242
|
+
classifier.classifier.tokenizer = tokenizer
|
|
243
|
+
classifier.classifier._build_pytorch_model()
|
|
244
|
+
|
|
245
|
+
return classifier
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def from_dict(config_dict: dict) -> FastTextConfig:
|
|
249
|
+
"""Create FastText configuration from dictionary.
|
|
250
|
+
|
|
251
|
+
This method is used internally by the configuration factory system
|
|
252
|
+
to recreate FastText configurations from serialized data.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
config_dict: Dictionary containing configuration parameters
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
FastTextConfig: Reconstructed configuration object
|
|
259
|
+
|
|
260
|
+
Example:
|
|
261
|
+
>>> config_dict = {
|
|
262
|
+
... 'embedding_dim': 100,
|
|
263
|
+
... 'num_tokens': 5000,
|
|
264
|
+
... 'min_count': 1,
|
|
265
|
+
... # ... other parameters
|
|
266
|
+
... }
|
|
267
|
+
>>> config = FastTextFactory.from_dict(config_dict)
|
|
268
|
+
"""
|
|
269
|
+
return FastTextConfig.from_dict(config_dict)
|