torchtextclassifiers 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,191 @@
1
+ """
2
+ Simple text classifier example that doesn't require a tokenizer.
3
+
4
+ This demonstrates how to create a classifier wrapper that uses
5
+ different text preprocessing approaches.
6
+ """
7
+
8
+ from typing import Optional, Dict, Any
9
+ from dataclasses import dataclass, asdict
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from sklearn.feature_extraction.text import TfidfVectorizer
14
+ from torch.utils.data import Dataset, DataLoader
15
+ import pytorch_lightning as pl
16
+ from torch.optim import Adam
17
+
18
+ from .base import BaseClassifierWrapper, BaseClassifierConfig
19
+
20
+
21
+ @dataclass
22
+ class SimpleTextConfig(BaseClassifierConfig):
23
+ """Configuration for simple text classifier using TF-IDF."""
24
+
25
+ hidden_dim: int = 128
26
+ num_classes: Optional[int] = None
27
+ max_features: int = 10000
28
+ learning_rate: float = 1e-3
29
+ dropout_rate: float = 0.1
30
+
31
+ def to_dict(self) -> Dict[str, Any]:
32
+ return asdict(self)
33
+
34
+ @classmethod
35
+ def from_dict(cls, data: Dict[str, Any]) -> "SimpleTextConfig":
36
+ return cls(**data)
37
+
38
+
39
+ class SimpleTextDataset(Dataset):
40
+ """Dataset for simple text classifier."""
41
+
42
+ def __init__(self, features: np.ndarray, labels: np.ndarray):
43
+ self.features = torch.FloatTensor(features)
44
+ self.labels = torch.LongTensor(labels)
45
+
46
+ def __len__(self):
47
+ return len(self.features)
48
+
49
+ def __getitem__(self, idx):
50
+ return self.features[idx], self.labels[idx]
51
+
52
+
53
+ class SimpleTextModel(nn.Module):
54
+ """Simple neural network for text classification using TF-IDF features."""
55
+
56
+ def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout_rate: float = 0.1):
57
+ super().__init__()
58
+
59
+ self.network = nn.Sequential(
60
+ nn.Linear(input_dim, hidden_dim),
61
+ nn.ReLU(),
62
+ nn.Dropout(dropout_rate),
63
+ nn.Linear(hidden_dim, hidden_dim // 2),
64
+ nn.ReLU(),
65
+ nn.Dropout(dropout_rate),
66
+ nn.Linear(hidden_dim // 2, num_classes)
67
+ )
68
+
69
+ def forward(self, x):
70
+ return self.network(x)
71
+
72
+
73
+ class SimpleTextModule(pl.LightningModule):
74
+ """Lightning module for simple text classifier."""
75
+
76
+ def __init__(self, model: nn.Module, learning_rate: float = 1e-3):
77
+ super().__init__()
78
+ self.model = model
79
+ self.learning_rate = learning_rate
80
+ self.loss_fn = nn.CrossEntropyLoss()
81
+
82
+ def forward(self, x):
83
+ return self.model(x)
84
+
85
+ def training_step(self, batch, batch_idx):
86
+ features, labels = batch
87
+ logits = self(features)
88
+ loss = self.loss_fn(logits, labels)
89
+ self.log('train_loss', loss)
90
+ return loss
91
+
92
+ def validation_step(self, batch, batch_idx):
93
+ features, labels = batch
94
+ logits = self(features)
95
+ loss = self.loss_fn(logits, labels)
96
+ self.log('val_loss', loss)
97
+ return loss
98
+
99
+ def configure_optimizers(self):
100
+ return Adam(self.parameters(), lr=self.learning_rate)
101
+
102
+
103
+ class SimpleTextWrapper(BaseClassifierWrapper):
104
+ """Wrapper for simple text classifier that uses TF-IDF instead of tokenization."""
105
+
106
+ def __init__(self, config: SimpleTextConfig):
107
+ super().__init__(config)
108
+ self.config: SimpleTextConfig = config
109
+ self.vectorizer: Optional[TfidfVectorizer] = None
110
+
111
+ def prepare_text_features(self, training_text: np.ndarray) -> None:
112
+ """Prepare TF-IDF vectorizer instead of tokenizer."""
113
+ self.vectorizer = TfidfVectorizer(
114
+ max_features=self.config.max_features,
115
+ lowercase=True,
116
+ stop_words='english'
117
+ )
118
+ # Fit the vectorizer on training text
119
+ self.vectorizer.fit(training_text)
120
+
121
+ def _build_pytorch_model(self) -> None:
122
+ """Build the PyTorch model."""
123
+ if self.vectorizer is None:
124
+ raise ValueError("Must call prepare_text_features first")
125
+
126
+ input_dim = len(self.vectorizer.get_feature_names_out())
127
+
128
+ self.pytorch_model = SimpleTextModel(
129
+ input_dim=input_dim,
130
+ hidden_dim=self.config.hidden_dim,
131
+ num_classes=self.config.num_classes,
132
+ dropout_rate=self.config.dropout_rate
133
+ )
134
+
135
+ def _check_and_init_lightning(self, **kwargs) -> None:
136
+ """Initialize Lightning module."""
137
+ self.lightning_module = SimpleTextModule(
138
+ model=self.pytorch_model,
139
+ learning_rate=self.config.learning_rate
140
+ )
141
+
142
+ def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
143
+ """Make predictions."""
144
+ if not self.trained:
145
+ raise Exception("Model must be trained first.")
146
+
147
+ # Extract text from X (assuming first column is text)
148
+ text_data = X[:, 0] if X.ndim > 1 else X
149
+
150
+ # Transform text to TF-IDF features
151
+ features = self.vectorizer.transform(text_data).toarray()
152
+ features_tensor = torch.FloatTensor(features)
153
+
154
+ self.pytorch_model.eval()
155
+ with torch.no_grad():
156
+ logits = self.pytorch_model(features_tensor)
157
+ predictions = torch.argmax(logits, dim=1)
158
+
159
+ return predictions.numpy()
160
+
161
+ def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
162
+ """Validate the model."""
163
+ predictions = self.predict(X)
164
+ accuracy = (predictions == Y).mean()
165
+ return float(accuracy)
166
+
167
+ def create_dataset(self, texts: np.ndarray, labels: np.ndarray, categorical_variables: Optional[np.ndarray] = None):
168
+ """Create dataset."""
169
+ # Transform text to TF-IDF features
170
+ features = self.vectorizer.transform(texts).toarray()
171
+ return SimpleTextDataset(features, labels)
172
+
173
+ def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuffle: bool = True):
174
+ """Create dataloader."""
175
+ return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
176
+
177
+ def load_best_model(self, checkpoint_path: str) -> None:
178
+ """Load best model from checkpoint."""
179
+ self.lightning_module = SimpleTextModule.load_from_checkpoint(
180
+ checkpoint_path,
181
+ model=self.pytorch_model,
182
+ learning_rate=self.config.learning_rate
183
+ )
184
+ self.pytorch_model = self.lightning_module.model
185
+ self.trained = True
186
+ self.pytorch_model.eval()
187
+
188
+ @classmethod
189
+ def get_config_class(cls):
190
+ """Return the configuration class."""
191
+ return SimpleTextConfig
@@ -0,0 +1,34 @@
1
+ """Generic factories for different classifier types."""
2
+
3
+ from typing import Dict, Any, Optional, Type, Callable
4
+ from .classifiers.base import BaseClassifierConfig
5
+
6
+ # Registry of config factories for different classifier types
7
+ CONFIG_FACTORIES: Dict[str, Callable[[dict], BaseClassifierConfig]] = {}
8
+
9
+
10
+ def register_config_factory(classifier_type: str, factory_func: Callable[[dict], BaseClassifierConfig]):
11
+ """Register a config factory for a classifier type."""
12
+ CONFIG_FACTORIES[classifier_type] = factory_func
13
+
14
+
15
+ def create_config_from_dict(classifier_type: str, config_dict: dict) -> BaseClassifierConfig:
16
+ """Create a config object from dictionary based on classifier type."""
17
+ if classifier_type not in CONFIG_FACTORIES:
18
+ raise ValueError(f"Unsupported classifier type: {classifier_type}")
19
+
20
+ return CONFIG_FACTORIES[classifier_type](config_dict)
21
+
22
+
23
+ # Register FastText factory
24
+ def _register_fasttext_factory():
25
+ """Register FastText config factory."""
26
+ try:
27
+ from .classifiers.fasttext.core import FastTextFactory
28
+ register_config_factory("fasttext", FastTextFactory.from_dict)
29
+ except ImportError:
30
+ pass # FastText module not available
31
+
32
+
33
+ # Auto-register available factories
34
+ _register_fasttext_factory()