torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. torchTextClassifiers/__init__.py +12 -48
  2. torchTextClassifiers/dataset/__init__.py +1 -0
  3. torchTextClassifiers/dataset/dataset.py +114 -0
  4. torchTextClassifiers/model/__init__.py +2 -0
  5. torchTextClassifiers/model/components/__init__.py +12 -0
  6. torchTextClassifiers/model/components/attention.py +126 -0
  7. torchTextClassifiers/model/components/categorical_var_net.py +128 -0
  8. torchTextClassifiers/model/components/classification_head.py +43 -0
  9. torchTextClassifiers/model/components/text_embedder.py +220 -0
  10. torchTextClassifiers/model/lightning.py +166 -0
  11. torchTextClassifiers/model/model.py +151 -0
  12. torchTextClassifiers/tokenizers/WordPiece.py +92 -0
  13. torchTextClassifiers/tokenizers/__init__.py +10 -0
  14. torchTextClassifiers/tokenizers/base.py +205 -0
  15. torchTextClassifiers/tokenizers/ngram.py +472 -0
  16. torchTextClassifiers/torchTextClassifiers.py +463 -405
  17. torchTextClassifiers/utilities/__init__.py +0 -3
  18. torchTextClassifiers/utilities/plot_explainability.py +184 -0
  19. torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
  20. torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
  21. {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
  22. torchTextClassifiers/classifiers/base.py +0 -83
  23. torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
  24. torchTextClassifiers/classifiers/fasttext/core.py +0 -269
  25. torchTextClassifiers/classifiers/fasttext/model.py +0 -752
  26. torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
  27. torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
  28. torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
  29. torchTextClassifiers/factories.py +0 -34
  30. torchTextClassifiers/utilities/checkers.py +0 -108
  31. torchTextClassifiers/utilities/preprocess.py +0 -82
  32. torchTextClassifiers/utilities/utils.py +0 -346
  33. torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
  34. torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
@@ -1,3 +0,0 @@
1
- """
2
- Init script.
3
- """
@@ -0,0 +1,184 @@
1
+ from typing import List, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ try:
7
+ from matplotlib import pyplot as plt
8
+
9
+ HAS_PYPLOT = True
10
+ except ImportError:
11
+ HAS_PYPLOT = False
12
+
13
+
14
+ def map_attributions_to_char(attributions, offsets, text):
15
+ """
16
+ Maps token-level attributions to character-level attributions based on token offsets.
17
+ Args:
18
+ attributions (np.ndarray): Array of shape (top_k, seq_len) or (seq_len,) containing token-level attributions.
19
+ Output from:
20
+ >>> ttc.predict(X, top_k=top_k, explain=True)["attributions"]
21
+ offsets (list of tuples): List of (start, end) offsets for each token in the original text.
22
+ Output from:
23
+ >>> ttc.predict(X, top_k=top_k, explain=True)["offset_mapping"]
24
+ Also from:
25
+ >>> ttc.tokenizer.tokenize(text, return_offsets_mapping=True)["offset_mapping"]
26
+ text (str): The original input text.
27
+
28
+ Returns:
29
+ np.ndarray: Array of shape (top_k, text_len) containing character-level attributions.
30
+ text_len is the number of characters in the original text.
31
+
32
+ """
33
+
34
+ if isinstance(text, list):
35
+ raise ValueError("text must be a single string, not a list of strings.")
36
+
37
+ assert isinstance(text, str), "text must be a string."
38
+
39
+ if isinstance(attributions, torch.Tensor):
40
+ attributions = attributions.cpu().numpy()
41
+
42
+ if attributions.ndim == 1:
43
+ attributions = attributions[None, :]
44
+
45
+ attributions_per_char = np.zeros((attributions.shape[0], len(text))) # top_k, text_len
46
+
47
+ for token_idx, (start, end) in enumerate(offsets):
48
+ if start == end: # skip special tokens
49
+ continue
50
+ attributions_per_char[:, start:end] = attributions[:, token_idx][:, None]
51
+
52
+ return np.exp(attributions_per_char) / np.sum(
53
+ np.exp(attributions_per_char), axis=1, keepdims=True
54
+ ) # softmax normalization
55
+
56
+
57
+ def map_attributions_to_word(attributions, word_ids):
58
+ """
59
+ Maps token-level attributions to word-level attributions based on word IDs.
60
+ Args:
61
+ attributions (np.ndarray): Array of shape (top_k, seq_len) or (seq_len,) containing token-level attributions.
62
+ Output from:
63
+ >>> ttc.predict(X, top_k=top_k, explain=True)["attributions"]
64
+ word_ids (list of int or None): List of word IDs for each token in the original text.
65
+ Output from:
66
+ >>> ttc.predict(X, top_k=top_k, explain=True)["word_ids"]
67
+
68
+ Returns:
69
+ np.ndarray: Array of shape (top_k, num_words) containing word-level attributions.
70
+ num_words is the number of unique words in the original text.
71
+ """
72
+
73
+ word_ids = np.array(word_ids)
74
+
75
+ # Convert None to -1 for easier processing (PAD tokens)
76
+ word_ids_int = np.array([x if x is not None else -1 for x in word_ids], dtype=int)
77
+
78
+ # Filter out PAD tokens from attributions and word_ids
79
+ attributions = attributions[
80
+ torch.arange(attributions.shape[0])[:, None],
81
+ torch.tensor(np.where(word_ids_int != -1)[0])[None, :],
82
+ ]
83
+ word_ids_int = word_ids_int[word_ids_int != -1]
84
+ unique_word_ids = np.unique(word_ids_int)
85
+ num_unique_words = len(unique_word_ids)
86
+
87
+ top_k = attributions.shape[0]
88
+ attr_with_word_id = np.concat(
89
+ (attributions[:, :, None], np.tile(word_ids_int[None, :], reps=(top_k, 1))[:, :, None]),
90
+ axis=-1,
91
+ ) # top_k, seq_len, 2
92
+ # last dim is 2: 0 is the attribution of the token, 1 is the word_id the token is associated to
93
+
94
+ word_attributions = np.zeros((top_k, num_unique_words))
95
+ for word_id in unique_word_ids:
96
+ mask = attr_with_word_id[:, :, 1] == word_id # top_k, seq_len
97
+ word_attributions[:, word_id] = (attr_with_word_id[:, :, 0] * mask).sum(
98
+ axis=1
99
+ ) # zero-out non-matching tokens and sum attributions for all tokens belonging to the same word
100
+
101
+ # assert word_attributions.sum(axis=1) == attributions.sum(axis=1), "Sum of word attributions per top_k must equal sum of token attributions per top_k."
102
+ return np.exp(word_attributions) / np.sum(
103
+ np.exp(word_attributions), axis=1, keepdims=True
104
+ ) # softmax normalization
105
+
106
+
107
+ def plot_attributions_at_char(
108
+ text: str,
109
+ attributions_per_char: np.ndarray,
110
+ figsize=(10, 2),
111
+ titles: Optional[List[str]] = None,
112
+ ):
113
+ """
114
+ Plots character-level attributions as a heatmap.
115
+ Args:
116
+ text (str): The original input text.
117
+ attributions_per_char (np.ndarray): Array of shape (top_k, text_len) containing character-level attributions.
118
+ Output from map_attributions_to_char function.
119
+ title (str): Title of the plot.
120
+ figsize (tuple): Figure size for the plot.
121
+ """
122
+
123
+ if not HAS_PYPLOT:
124
+ raise ImportError(
125
+ "matplotlib is required for plotting. Please install it to use this function."
126
+ )
127
+ top_k = attributions_per_char.shape[0]
128
+
129
+ all_plots = []
130
+ for i in range(top_k):
131
+ fig, ax = plt.subplots(figsize=figsize)
132
+ ax.bar(range(len(text)), attributions_per_char[i])
133
+ ax.set_xticks(np.arange(len(text)))
134
+ ax.set_xticklabels(list(text), rotation=90)
135
+ title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
136
+ ax.set_title(title)
137
+ ax.set_xlabel("Characters in Text")
138
+ ax.set_ylabel("Top Predictions")
139
+ all_plots.append(fig)
140
+
141
+ return all_plots
142
+
143
+
144
+ def plot_attributions_at_word(
145
+ text, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
146
+ ):
147
+ """
148
+ Plots word-level attributions as a heatmap.
149
+ Args:
150
+ text (str): The original input text.
151
+ attributions_per_word (np.ndarray): Array of shape (top_k, num_words) containing word-level attributions.
152
+ Output from map_attributions_to_word function.
153
+ title (str): Title of the plot.
154
+ figsize (tuple): Figure size for the plot.
155
+ """
156
+
157
+ if not HAS_PYPLOT:
158
+ raise ImportError(
159
+ "matplotlib is required for plotting. Please install it to use this function."
160
+ )
161
+
162
+ words = text.split()
163
+ top_k = attributions_per_word.shape[0]
164
+ all_plots = []
165
+ for i in range(top_k):
166
+ fig, ax = plt.subplots(figsize=figsize)
167
+ ax.bar(range(len(words)), attributions_per_word[i])
168
+ ax.set_xticks(np.arange(len(words)))
169
+ ax.set_xticklabels(words, rotation=90)
170
+ title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
171
+ ax.set_title(title)
172
+ ax.set_xlabel("Words in Text")
173
+ ax.set_ylabel("Attributions")
174
+ all_plots.append(fig)
175
+
176
+ return all_plots
177
+
178
+
179
+ def figshow(figure):
180
+ # https://stackoverflow.com/questions/53088212/create-multiple-figures-in-pyplot-but-only-show-one
181
+ for i in plt.get_fignums():
182
+ if figure != plt.figure(i):
183
+ plt.close(plt.figure(i))
184
+ plt.show()
@@ -0,0 +1,73 @@
1
+ Metadata-Version: 2.3
2
+ Name: torchtextclassifiers
3
+ Version: 0.1.0
4
+ Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
5
+ Keywords: fastText,text classification,NLP,automatic coding,deep learning
6
+ Author: Cédric Couralet, Meilame Tayebjee
7
+ Author-email: Cédric Couralet <cedric.couralet@insee.fr>, Meilame Tayebjee <meilame.tayebjee@insee.fr>
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Dist: numpy>=1.26.4
12
+ Requires-Dist: pytorch-lightning>=2.4.0
13
+ Requires-Dist: unidecode ; extra == 'explainability'
14
+ Requires-Dist: nltk ; extra == 'explainability'
15
+ Requires-Dist: captum ; extra == 'explainability'
16
+ Requires-Dist: tokenizers>=0.22.1 ; extra == 'huggingface'
17
+ Requires-Dist: transformers>=4.57.1 ; extra == 'huggingface'
18
+ Requires-Dist: datasets>=4.3.0 ; extra == 'huggingface'
19
+ Requires-Dist: unidecode ; extra == 'preprocess'
20
+ Requires-Dist: nltk ; extra == 'preprocess'
21
+ Requires-Python: >=3.11
22
+ Provides-Extra: explainability
23
+ Provides-Extra: huggingface
24
+ Provides-Extra: preprocess
25
+ Description-Content-Type: text/markdown
26
+
27
+ # torchTextClassifiers
28
+
29
+ A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
30
+
31
+ ## 🚀 Features
32
+
33
+ - **Mixed input support**: Handle text data alongside categorical variables seamlessly.
34
+ - **Unified yet highly customizable**:
35
+ - Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
36
+ - Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
37
+ - The `TextClassificationModel` class combines these components and can be extended for custom behavior.
38
+ - **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
39
+ - **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
40
+ - The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
41
+ - **Additional features**: explainability using Captum
42
+
43
+
44
+ ## 📦 Installation
45
+
46
+ ```bash
47
+ # Clone the repository
48
+ git clone https://github.com/InseeFrLab/torchTextClassifiers.git
49
+ cd torchtextClassifiers
50
+
51
+ # Install with uv (recommended)
52
+ uv sync
53
+
54
+ # Or install with pip
55
+ pip install -e .
56
+ ```
57
+
58
+ ## 📝 Usage
59
+
60
+ Checkout the [notebook](notebooks/example.ipynb) for a quick start.
61
+
62
+ ## 📚 Examples
63
+
64
+ See the [examples/](examples/) directory for:
65
+ - Basic text classification
66
+ - Multi-class classification
67
+ - Mixed features (text + categorical)
68
+ - Advanced training configurations
69
+ - Prediction and explainability
70
+
71
+ ## 📄 License
72
+
73
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
@@ -0,0 +1,21 @@
1
+ torchTextClassifiers/__init__.py,sha256=TM2AjZ4KDqpgwMKiT0X5daNZvLDj6WECz_OFf8M4lgA,906
2
+ torchTextClassifiers/dataset/__init__.py,sha256=dyCz48pO6zRC-2qh4753Hj70W2MZGXdX3RbgutvyOng,76
3
+ torchTextClassifiers/dataset/dataset.py,sha256=n7V4JNtcuqb2ugx7hxkAohEPHqEuxv46jYU47KiUbno,3295
4
+ torchTextClassifiers/model/__init__.py,sha256=lFY1Mb1J0tFhe4_PsDOEHhnVl3dXj59K4Zxnwy2KkS4,146
5
+ torchTextClassifiers/model/components/__init__.py,sha256=-IT_6fCHZkRw6Hu7GdVeCt685P4PuGaY6VdYQV5M8mE,447
6
+ torchTextClassifiers/model/components/attention.py,sha256=hhSMh_CvpR-hiP8hoCg4Fr_TovGlJpC_RHs3iW-Pnpc,4199
7
+ torchTextClassifiers/model/components/categorical_var_net.py,sha256=no0QDidKCw1rlbJzD7S-Srhzn5P6vETGRT5Er-gzMnM,5699
8
+ torchTextClassifiers/model/components/classification_head.py,sha256=lPndu5FPC-bOZ2H4Yq0EnzWrOzPFJdBb_KUx5wyZBb4,1445
9
+ torchTextClassifiers/model/components/text_embedder.py,sha256=tY2pXAt4IvayyvRpjiKGg5vGz_Q2-p_TOL6Jg2p8hYE,9058
10
+ torchTextClassifiers/model/lightning.py,sha256=z5mq10_hNp-UK66Aqpcablg3BDYnjF9Gch0HaGoJ6cM,5265
11
+ torchTextClassifiers/model/model.py,sha256=jjGjvK7C2Wly0e4S6gTC8Ty8y-o8reU-aniBqYS73Cc,6100
12
+ torchTextClassifiers/tokenizers/WordPiece.py,sha256=HMHYV2SiwShlhWMQ6LXH4MtZE5GSsaNA2DlD340ABGE,3289
13
+ torchTextClassifiers/tokenizers/__init__.py,sha256=I8IQ2-t85RVlZFwLjDFF_Te2S9uiwlymQDWx-3GeF-Y,334
14
+ torchTextClassifiers/tokenizers/base.py,sha256=OY6GIhI4KTdvvKq3VZowf64H7lAmdQyq4scZ10HxP3A,7570
15
+ torchTextClassifiers/tokenizers/ngram.py,sha256=lHI8dtuCGWh0o7V58TJx_mTVIHm8udl6XuWccxgJPew,16375
16
+ torchTextClassifiers/torchTextClassifiers.py,sha256=E2XVGAky_SMAw6BAMswA3c08rKyOpGEW_dv1BqQlJrU,21141
17
+ torchTextClassifiers/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ torchTextClassifiers/utilities/plot_explainability.py,sha256=8YhyiMupdiIZp4jT7uvlcJNf69Fyr9HXfjUiNyMSYxE,6931
19
+ torchtextclassifiers-0.1.0.dist-info/WHEEL,sha256=ELhySV62sOro8I5wRaLaF3TWxhBpkcDkdZUdAYLy_Hk,78
20
+ torchtextclassifiers-0.1.0.dist-info/METADATA,sha256=fvPTUIS-M4LgURVzC1CUTb8IrKyZiBzWRAE1heTafEE,2988
21
+ torchtextclassifiers-0.1.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.8.3
2
+ Generator: uv 0.9.3
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,83 +0,0 @@
1
- from typing import Optional, Union, Type, List, Dict, Any
2
- from dataclasses import dataclass, field, asdict
3
- from abc import ABC, abstractmethod
4
- import numpy as np
5
-
6
- class BaseClassifierConfig(ABC):
7
- """Abstract base class for classifier configurations."""
8
-
9
- @abstractmethod
10
- def to_dict(self) -> Dict[str, Any]:
11
- """Convert configuration to dictionary."""
12
- pass
13
-
14
- @classmethod
15
- @abstractmethod
16
- def from_dict(cls, data: Dict[str, Any]) -> "BaseClassifierConfig":
17
- """Create configuration from dictionary."""
18
- pass
19
-
20
- class BaseClassifierWrapper(ABC):
21
- """Abstract base class for classifier wrappers.
22
-
23
- Each classifier wrapper is responsible for its own text processing approach.
24
- Some may use tokenizers, others may use different preprocessing methods.
25
- """
26
-
27
- def __init__(self, config: BaseClassifierConfig):
28
- self.config = config
29
- self.pytorch_model = None
30
- self.lightning_module = None
31
- self.trained: bool = False
32
- self.device = None
33
- # Remove tokenizer from base class - it's now wrapper-specific
34
-
35
- @abstractmethod
36
- def prepare_text_features(self, training_text: np.ndarray) -> None:
37
- """Prepare text features for the classifier.
38
-
39
- This could involve tokenization, vectorization, or other preprocessing.
40
- Each classifier wrapper implements this according to its needs.
41
- """
42
- pass
43
-
44
- @abstractmethod
45
- def _build_pytorch_model(self) -> None:
46
- """Build the PyTorch model."""
47
- pass
48
-
49
- @abstractmethod
50
- def _check_and_init_lightning(self, **kwargs) -> None:
51
- """Initialize Lightning module."""
52
- pass
53
-
54
- @abstractmethod
55
- def predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
56
- """Make predictions."""
57
- pass
58
-
59
- @abstractmethod
60
- def validate(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
61
- """Validate the model."""
62
- pass
63
-
64
- @abstractmethod
65
- def create_dataset(self, texts: np.ndarray, labels: np.ndarray, categorical_variables: Optional[np.ndarray] = None):
66
- """Create dataset for training/validation."""
67
- pass
68
-
69
- @abstractmethod
70
- def create_dataloader(self, dataset, batch_size: int, num_workers: int = 0, shuffle: bool = True):
71
- """Create dataloader from dataset."""
72
- pass
73
-
74
- @abstractmethod
75
- def load_best_model(self, checkpoint_path: str) -> None:
76
- """Load best model from checkpoint."""
77
- pass
78
-
79
- @classmethod
80
- @abstractmethod
81
- def get_config_class(cls) -> Type[BaseClassifierConfig]:
82
- """Return the configuration class for this wrapper."""
83
- pass
@@ -1,25 +0,0 @@
1
- """FastText classifier package.
2
-
3
- Provides FastText text classification with PyTorch Lightning integration.
4
- This folder contains 4 main files:
5
- - core.py: Configuration, losses, and factory methods
6
- - tokenizer.py: NGramTokenizer implementation
7
- - model.py: PyTorch model, Lightning module, and dataset
8
- - wrapper.py: High-level wrapper interface
9
- """
10
-
11
- from .core import FastTextConfig, OneVsAllLoss, FastTextFactory
12
- from .tokenizer import NGramTokenizer
13
- from .model import FastTextModel, FastTextModule, FastTextModelDataset
14
- from .wrapper import FastTextWrapper
15
-
16
- __all__ = [
17
- "FastTextConfig",
18
- "OneVsAllLoss",
19
- "FastTextFactory",
20
- "NGramTokenizer",
21
- "FastTextModel",
22
- "FastTextModule",
23
- "FastTextModelDataset",
24
- "FastTextWrapper",
25
- ]
@@ -1,269 +0,0 @@
1
- """FastText classifier core components.
2
-
3
- This module contains the core components for FastText classification:
4
- - Configuration dataclass
5
- - Loss functions
6
- - Factory methods for creating classifiers
7
-
8
- Consolidates what was previously in config.py, losses.py, and factory.py.
9
- """
10
-
11
- from dataclasses import dataclass, field, asdict
12
- from abc import ABC, abstractmethod
13
- from ..base import BaseClassifierConfig
14
- from typing import Optional, List, TYPE_CHECKING, Union, Dict, Any
15
- import numpy as np
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- if TYPE_CHECKING:
21
- from ...torchTextClassifiers import torchTextClassifiers
22
-
23
-
24
- # ============================================================================
25
- # Configuration
26
- # ============================================================================
27
-
28
- @dataclass
29
- class FastTextConfig(BaseClassifierConfig):
30
- """Configuration for FastText classifier."""
31
- # Embedding matrix
32
- embedding_dim: int
33
- sparse: bool
34
-
35
- # Tokenizer-related
36
- num_tokens: int
37
- min_count: int
38
- min_n: int
39
- max_n: int
40
- len_word_ngrams: int
41
-
42
- # Optional parameters
43
- num_classes: Optional[int] = None
44
- num_rows: Optional[int] = None
45
-
46
- # Categorical variables
47
- categorical_vocabulary_sizes: Optional[List[int]] = None
48
- categorical_embedding_dims: Optional[Union[List[int], int]] = None
49
- num_categorical_features: Optional[int] = None
50
-
51
- # Model-specific parameters
52
- direct_bagging: Optional[bool] = True
53
-
54
- # Training parameters
55
- learning_rate: float = 4e-3
56
-
57
- def to_dict(self) -> Dict[str, Any]:
58
- return asdict(self)
59
-
60
- @classmethod
61
- def from_dict(cls, data: Dict[str, Any]) -> "FastTextConfig":
62
- return cls(**data)
63
-
64
-
65
- # ============================================================================
66
- # Loss Functions
67
- # ============================================================================
68
-
69
- class OneVsAllLoss(nn.Module):
70
- def __init__(self):
71
- super(OneVsAllLoss, self).__init__()
72
-
73
- def forward(self, logits, targets):
74
- """
75
- Compute One-vs-All loss
76
-
77
- Args:
78
- logits: Tensor of shape (batch_size, num_classes) containing classification scores
79
- targets: Tensor of shape (batch_size) containing true class indices
80
-
81
- Returns:
82
- loss: Mean loss value across the batch
83
- """
84
-
85
- num_classes = logits.size(1)
86
-
87
- # Convert targets to one-hot encoding
88
- targets_one_hot = F.one_hot(targets, num_classes=num_classes).float()
89
-
90
- # For each sample, treat the true class as positive and all others as negative
91
- # Using binary cross entropy for each class
92
- loss = F.binary_cross_entropy_with_logits(
93
- logits, # Raw logits
94
- targets_one_hot, # Target probabilities
95
- reduction="none", # Don't reduce yet to allow for custom weighting if needed
96
- )
97
-
98
- # Sum losses across all classes for each sample, then take mean across batch
99
- return loss.sum(dim=1).mean()
100
-
101
-
102
- # ============================================================================
103
- # Factory Methods
104
- # ============================================================================
105
-
106
- class FastTextFactory:
107
- """Factory class for creating FastText classifiers with convenience methods.
108
-
109
- This factory provides static methods for creating FastText classifiers with
110
- common configurations. It handles the complexities of configuration creation
111
- and classifier initialization, offering a simplified API for users.
112
-
113
- All methods return fully initialized torchTextClassifiers instances that are
114
- ready for building and training.
115
- """
116
-
117
- @staticmethod
118
- def create_fasttext(
119
- embedding_dim: int,
120
- sparse: bool,
121
- num_tokens: int,
122
- min_count: int,
123
- min_n: int,
124
- max_n: int,
125
- len_word_ngrams: int,
126
- **kwargs
127
- ) -> "torchTextClassifiers":
128
- """Create a FastText classifier with the specified configuration.
129
-
130
- This is the primary method for creating FastText classifiers. It creates
131
- a configuration object with the provided parameters and initializes a
132
- complete classifier instance.
133
-
134
- Args:
135
- embedding_dim: Dimension of word embeddings
136
- sparse: Whether to use sparse embeddings
137
- num_tokens: Maximum number of tokens in vocabulary
138
- min_count: Minimum count for tokens to be included in vocabulary
139
- min_n: Minimum length of character n-grams
140
- max_n: Maximum length of character n-grams
141
- len_word_ngrams: Length of word n-grams to use
142
- **kwargs: Additional configuration parameters (e.g., num_classes,
143
- categorical_vocabulary_sizes, etc.)
144
-
145
- Returns:
146
- torchTextClassifiers: Initialized FastText classifier instance
147
-
148
- Example:
149
- >>> from torchTextClassifiers.classifiers.fasttext.core import FastTextFactory
150
- >>> classifier = FastTextFactory.create_fasttext(
151
- ... embedding_dim=100,
152
- ... sparse=False,
153
- ... num_tokens=10000,
154
- ... min_count=2,
155
- ... min_n=3,
156
- ... max_n=6,
157
- ... len_word_ngrams=2,
158
- ... num_classes=3
159
- ... )
160
- """
161
- from ...torchTextClassifiers import torchTextClassifiers
162
- from .wrapper import FastTextWrapper
163
-
164
- config = FastTextConfig(
165
- embedding_dim=embedding_dim,
166
- sparse=sparse,
167
- num_tokens=num_tokens,
168
- min_count=min_count,
169
- min_n=min_n,
170
- max_n=max_n,
171
- len_word_ngrams=len_word_ngrams,
172
- **kwargs
173
- )
174
- wrapper = FastTextWrapper(config)
175
- return torchTextClassifiers(wrapper)
176
-
177
- @staticmethod
178
- def build_from_tokenizer(
179
- tokenizer, # NGramTokenizer
180
- embedding_dim: int,
181
- num_classes: Optional[int],
182
- categorical_vocabulary_sizes: Optional[List[int]] = None,
183
- sparse: bool = False,
184
- **kwargs
185
- ) -> "torchTextClassifiers":
186
- """Create FastText classifier from an existing trained tokenizer.
187
-
188
- This method is useful when you have a pre-trained tokenizer and want to
189
- create a classifier that uses the same vocabulary and tokenization scheme.
190
- The resulting classifier will have its tokenizer and model architecture
191
- pre-built.
192
-
193
- Args:
194
- tokenizer: Pre-trained NGramTokenizer instance
195
- embedding_dim: Dimension of word embeddings
196
- num_classes: Number of output classes
197
- categorical_vocabulary_sizes: Sizes of categorical feature vocabularies
198
- sparse: Whether to use sparse embeddings
199
- **kwargs: Additional configuration parameters
200
-
201
- Returns:
202
- torchTextClassifiers: Classifier with pre-built tokenizer and model
203
-
204
- Raises:
205
- ValueError: If the tokenizer is missing required attributes
206
-
207
- Example:
208
- >>> # Assume you have a pre-trained tokenizer
209
- >>> classifier = FastTextFactory.build_from_tokenizer(
210
- ... tokenizer=my_tokenizer,
211
- ... embedding_dim=100,
212
- ... num_classes=2,
213
- ... sparse=False
214
- ... )
215
- >>> # The classifier is ready for training without building
216
- >>> classifier.train(X_train, y_train, X_val, y_val, ...)
217
- """
218
- from ...torchTextClassifiers import torchTextClassifiers
219
- from .wrapper import FastTextWrapper
220
-
221
- # Ensure the tokenizer has required attributes
222
- required_attrs = ["min_count", "min_n", "max_n", "num_tokens", "word_ngrams"]
223
- if not all(hasattr(tokenizer, attr) for attr in required_attrs):
224
- missing_attrs = [attr for attr in required_attrs if not hasattr(tokenizer, attr)]
225
- raise ValueError(f"Missing attributes in tokenizer: {missing_attrs}")
226
-
227
- config = FastTextConfig(
228
- num_tokens=tokenizer.num_tokens,
229
- embedding_dim=embedding_dim,
230
- min_count=tokenizer.min_count,
231
- min_n=tokenizer.min_n,
232
- max_n=tokenizer.max_n,
233
- len_word_ngrams=tokenizer.word_ngrams,
234
- sparse=sparse,
235
- num_classes=num_classes,
236
- categorical_vocabulary_sizes=categorical_vocabulary_sizes,
237
- **kwargs
238
- )
239
-
240
- wrapper = FastTextWrapper(config)
241
- classifier = torchTextClassifiers(wrapper)
242
- classifier.classifier.tokenizer = tokenizer
243
- classifier.classifier._build_pytorch_model()
244
-
245
- return classifier
246
-
247
- @staticmethod
248
- def from_dict(config_dict: dict) -> FastTextConfig:
249
- """Create FastText configuration from dictionary.
250
-
251
- This method is used internally by the configuration factory system
252
- to recreate FastText configurations from serialized data.
253
-
254
- Args:
255
- config_dict: Dictionary containing configuration parameters
256
-
257
- Returns:
258
- FastTextConfig: Reconstructed configuration object
259
-
260
- Example:
261
- >>> config_dict = {
262
- ... 'embedding_dim': 100,
263
- ... 'num_tokens': 5000,
264
- ... 'min_count': 1,
265
- ... # ... other parameters
266
- ... }
267
- >>> config = FastTextFactory.from_dict(config_dict)
268
- """
269
- return FastTextConfig.from_dict(config_dict)