torchtextclassifiers 1.0.0__tar.gz → 1.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/PKG-INFO +1 -1
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/pyproject.toml +1 -1
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/text_embedder.py +3 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/lightning.py +1 -1
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/__init__.py +3 -1
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/torchTextClassifiers.py +130 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/utilities/plot_explainability.py +17 -7
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/README.md +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/__init__.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/dataset/__init__.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/dataset/dataset.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/__init__.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/__init__.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/attention.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/categorical_var_net.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/classification_head.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/model.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/WordPiece.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/base.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/ngram.py +0 -0
- {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/utilities/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: torchtextclassifiers
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.2
|
|
4
4
|
Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
|
|
5
5
|
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
6
|
Author: Cédric Couralet, Meilame Tayebjee
|
|
@@ -23,6 +23,9 @@ class TextEmbedder(nn.Module):
|
|
|
23
23
|
self.config = text_embedder_config
|
|
24
24
|
|
|
25
25
|
self.attention_config = text_embedder_config.attention_config
|
|
26
|
+
if isinstance(self.attention_config, dict):
|
|
27
|
+
self.attention_config = AttentionConfig(**self.attention_config)
|
|
28
|
+
|
|
26
29
|
if self.attention_config is not None:
|
|
27
30
|
self.attention_config.n_embd = text_embedder_config.embedding_dim
|
|
28
31
|
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/lightning.py
RENAMED
|
@@ -36,7 +36,7 @@ class TextClassificationModule(pl.LightningModule):
|
|
|
36
36
|
scheduler_interval: Scheduler interval.
|
|
37
37
|
"""
|
|
38
38
|
super().__init__()
|
|
39
|
-
self.save_hyperparameters(ignore=["model"
|
|
39
|
+
self.save_hyperparameters(ignore=["model"])
|
|
40
40
|
|
|
41
41
|
self.model = model
|
|
42
42
|
self.loss = loss
|
|
@@ -7,4 +7,6 @@ from .base import (
|
|
|
7
7
|
)
|
|
8
8
|
from .base import TokenizerOutput as TokenizerOutput
|
|
9
9
|
from .ngram import NGramTokenizer as NGramTokenizer
|
|
10
|
-
|
|
10
|
+
|
|
11
|
+
if HAS_HF:
|
|
12
|
+
from .WordPiece import WordPieceTokenizer as WordPieceTokenizer
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import pickle
|
|
2
3
|
import time
|
|
3
4
|
from dataclasses import asdict, dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
4
6
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
5
7
|
|
|
6
8
|
try:
|
|
@@ -75,6 +77,7 @@ class TrainingConfig:
|
|
|
75
77
|
trainer_params: Optional[dict] = None
|
|
76
78
|
optimizer_params: Optional[dict] = None
|
|
77
79
|
scheduler_params: Optional[dict] = None
|
|
80
|
+
save_path: Optional[str] = "my_ttc"
|
|
78
81
|
|
|
79
82
|
def to_dict(self) -> Dict[str, Any]:
|
|
80
83
|
data = asdict(self)
|
|
@@ -198,6 +201,12 @@ class torchTextClassifiers:
|
|
|
198
201
|
- Model training with early stopping
|
|
199
202
|
- Best model loading after training
|
|
200
203
|
|
|
204
|
+
Note on Checkpoints:
|
|
205
|
+
After training, the best model checkpoint is automatically loaded.
|
|
206
|
+
This checkpoint contains the full training state (model weights,
|
|
207
|
+
optimizer, and scheduler state). Loading uses weights_only=False
|
|
208
|
+
as the checkpoint is self-generated and trusted.
|
|
209
|
+
|
|
201
210
|
Args:
|
|
202
211
|
X_train: Training input data
|
|
203
212
|
y_train: Training labels
|
|
@@ -356,15 +365,20 @@ class torchTextClassifiers:
|
|
|
356
365
|
logger.info(f"Training completed in {end - start:.2f} seconds.")
|
|
357
366
|
|
|
358
367
|
best_model_path = trainer.checkpoint_callback.best_model_path
|
|
368
|
+
self.checkpoint_path = best_model_path
|
|
359
369
|
|
|
360
370
|
self.lightning_module = TextClassificationModule.load_from_checkpoint(
|
|
361
371
|
best_model_path,
|
|
362
372
|
model=self.pytorch_model,
|
|
363
373
|
loss=training_config.loss,
|
|
374
|
+
weights_only=False, # Required: checkpoint contains optimizer/scheduler state
|
|
364
375
|
)
|
|
365
376
|
|
|
366
377
|
self.pytorch_model = self.lightning_module.model.to(self.device)
|
|
367
378
|
|
|
379
|
+
self.save_path = training_config.save_path
|
|
380
|
+
self.save(self.save_path)
|
|
381
|
+
|
|
368
382
|
self.lightning_module.eval()
|
|
369
383
|
|
|
370
384
|
def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
@@ -569,6 +583,122 @@ class torchTextClassifiers:
|
|
|
569
583
|
"confidence": confidence,
|
|
570
584
|
}
|
|
571
585
|
|
|
586
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
587
|
+
"""Save the complete torchTextClassifiers instance to disk.
|
|
588
|
+
|
|
589
|
+
This saves:
|
|
590
|
+
- Model configuration
|
|
591
|
+
- Tokenizer state
|
|
592
|
+
- PyTorch Lightning checkpoint (if trained)
|
|
593
|
+
- All other instance attributes
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
path: Directory path where the model will be saved
|
|
597
|
+
|
|
598
|
+
Example:
|
|
599
|
+
>>> ttc = torchTextClassifiers(tokenizer, model_config)
|
|
600
|
+
>>> ttc.train(X_train, y_train, training_config)
|
|
601
|
+
>>> ttc.save("my_model")
|
|
602
|
+
"""
|
|
603
|
+
path = Path(path)
|
|
604
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
605
|
+
|
|
606
|
+
# Save the checkpoint if model has been trained
|
|
607
|
+
checkpoint_path = None
|
|
608
|
+
if hasattr(self, "lightning_module"):
|
|
609
|
+
checkpoint_path = path / "model_checkpoint.ckpt"
|
|
610
|
+
# Save the current state as a checkpoint
|
|
611
|
+
trainer = pl.Trainer()
|
|
612
|
+
trainer.strategy.connect(self.lightning_module)
|
|
613
|
+
trainer.save_checkpoint(checkpoint_path)
|
|
614
|
+
|
|
615
|
+
# Prepare metadata to save
|
|
616
|
+
metadata = {
|
|
617
|
+
"model_config": self.model_config.to_dict(),
|
|
618
|
+
"ragged_multilabel": self.ragged_multilabel,
|
|
619
|
+
"vocab_size": self.vocab_size,
|
|
620
|
+
"embedding_dim": self.embedding_dim,
|
|
621
|
+
"categorical_vocabulary_sizes": self.categorical_vocabulary_sizes,
|
|
622
|
+
"num_classes": self.num_classes,
|
|
623
|
+
"checkpoint_path": str(checkpoint_path) if checkpoint_path else None,
|
|
624
|
+
"device": str(self.device) if hasattr(self, "device") else None,
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
# Save metadata
|
|
628
|
+
with open(path / "metadata.pkl", "wb") as f:
|
|
629
|
+
pickle.dump(metadata, f)
|
|
630
|
+
|
|
631
|
+
# Save tokenizer
|
|
632
|
+
tokenizer_path = path / "tokenizer.pkl"
|
|
633
|
+
with open(tokenizer_path, "wb") as f:
|
|
634
|
+
pickle.dump(self.tokenizer, f)
|
|
635
|
+
|
|
636
|
+
logger.info(f"Model saved successfully to {path}")
|
|
637
|
+
|
|
638
|
+
@classmethod
|
|
639
|
+
def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers":
|
|
640
|
+
"""Load a torchTextClassifiers instance from disk.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
path: Directory path where the model was saved
|
|
644
|
+
device: Device to load the model on ('auto', 'cpu', 'cuda', etc.)
|
|
645
|
+
|
|
646
|
+
Returns:
|
|
647
|
+
Loaded torchTextClassifiers instance
|
|
648
|
+
|
|
649
|
+
Example:
|
|
650
|
+
>>> loaded_ttc = torchTextClassifiers.load("my_model")
|
|
651
|
+
>>> predictions = loaded_ttc.predict(X_test)
|
|
652
|
+
"""
|
|
653
|
+
path = Path(path)
|
|
654
|
+
|
|
655
|
+
if not path.exists():
|
|
656
|
+
raise FileNotFoundError(f"Model directory not found: {path}")
|
|
657
|
+
|
|
658
|
+
# Load metadata
|
|
659
|
+
with open(path / "metadata.pkl", "rb") as f:
|
|
660
|
+
metadata = pickle.load(f)
|
|
661
|
+
|
|
662
|
+
# Load tokenizer
|
|
663
|
+
with open(path / "tokenizer.pkl", "rb") as f:
|
|
664
|
+
tokenizer = pickle.load(f)
|
|
665
|
+
|
|
666
|
+
# Reconstruct model_config
|
|
667
|
+
model_config = ModelConfig.from_dict(metadata["model_config"])
|
|
668
|
+
|
|
669
|
+
# Create instance
|
|
670
|
+
instance = cls(
|
|
671
|
+
tokenizer=tokenizer,
|
|
672
|
+
model_config=model_config,
|
|
673
|
+
ragged_multilabel=metadata["ragged_multilabel"],
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
# Set device
|
|
677
|
+
if device == "auto":
|
|
678
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
679
|
+
else:
|
|
680
|
+
device = torch.device(device)
|
|
681
|
+
instance.device = device
|
|
682
|
+
|
|
683
|
+
# Load checkpoint if it exists
|
|
684
|
+
if metadata["checkpoint_path"]:
|
|
685
|
+
checkpoint_path = path / "model_checkpoint.ckpt"
|
|
686
|
+
if checkpoint_path.exists():
|
|
687
|
+
# Load the checkpoint with weights_only=False since it's our own trusted checkpoint
|
|
688
|
+
instance.lightning_module = TextClassificationModule.load_from_checkpoint(
|
|
689
|
+
str(checkpoint_path),
|
|
690
|
+
model=instance.pytorch_model,
|
|
691
|
+
weights_only=False,
|
|
692
|
+
)
|
|
693
|
+
instance.pytorch_model = instance.lightning_module.model.to(device)
|
|
694
|
+
instance.checkpoint_path = str(checkpoint_path)
|
|
695
|
+
logger.info(f"Model checkpoint loaded from {checkpoint_path}")
|
|
696
|
+
else:
|
|
697
|
+
logger.warning(f"Checkpoint file not found at {checkpoint_path}")
|
|
698
|
+
|
|
699
|
+
logger.info(f"Model loaded successfully from {path}")
|
|
700
|
+
return instance
|
|
701
|
+
|
|
572
702
|
def __repr__(self):
|
|
573
703
|
model_type = (
|
|
574
704
|
self.lightning_module.__repr__()
|
|
@@ -53,8 +53,18 @@ def map_attributions_to_char(attributions, offsets, text):
|
|
|
53
53
|
np.exp(attributions_per_char), axis=1, keepdims=True
|
|
54
54
|
) # softmax normalization
|
|
55
55
|
|
|
56
|
+
def get_id_to_word(text, word_ids, offsets):
|
|
57
|
+
words = {}
|
|
58
|
+
for idx, word_id in enumerate(word_ids):
|
|
59
|
+
if word_id is None:
|
|
60
|
+
continue
|
|
61
|
+
start, end = offsets[idx]
|
|
62
|
+
words[int(word_id)] = text[start:end]
|
|
63
|
+
|
|
64
|
+
return words
|
|
65
|
+
|
|
56
66
|
|
|
57
|
-
def map_attributions_to_word(attributions, word_ids):
|
|
67
|
+
def map_attributions_to_word(attributions, text, word_ids, offsets):
|
|
58
68
|
"""
|
|
59
69
|
Maps token-level attributions to word-level attributions based on word IDs.
|
|
60
70
|
Args:
|
|
@@ -69,8 +79,9 @@ def map_attributions_to_word(attributions, word_ids):
|
|
|
69
79
|
np.ndarray: Array of shape (top_k, num_words) containing word-level attributions.
|
|
70
80
|
num_words is the number of unique words in the original text.
|
|
71
81
|
"""
|
|
72
|
-
|
|
82
|
+
|
|
73
83
|
word_ids = np.array(word_ids)
|
|
84
|
+
words = get_id_to_word(text, word_ids, offsets)
|
|
74
85
|
|
|
75
86
|
# Convert None to -1 for easier processing (PAD tokens)
|
|
76
87
|
word_ids_int = np.array([x if x is not None else -1 for x in word_ids], dtype=int)
|
|
@@ -99,7 +110,7 @@ def map_attributions_to_word(attributions, word_ids):
|
|
|
99
110
|
) # zero-out non-matching tokens and sum attributions for all tokens belonging to the same word
|
|
100
111
|
|
|
101
112
|
# assert word_attributions.sum(axis=1) == attributions.sum(axis=1), "Sum of word attributions per top_k must equal sum of token attributions per top_k."
|
|
102
|
-
return np.exp(word_attributions) / np.sum(
|
|
113
|
+
return words, np.exp(word_attributions) / np.sum(
|
|
103
114
|
np.exp(word_attributions), axis=1, keepdims=True
|
|
104
115
|
) # softmax normalization
|
|
105
116
|
|
|
@@ -131,7 +142,7 @@ def plot_attributions_at_char(
|
|
|
131
142
|
fig, ax = plt.subplots(figsize=figsize)
|
|
132
143
|
ax.bar(range(len(text)), attributions_per_char[i])
|
|
133
144
|
ax.set_xticks(np.arange(len(text)))
|
|
134
|
-
ax.set_xticklabels(list(text), rotation=
|
|
145
|
+
ax.set_xticklabels(list(text), rotation=45)
|
|
135
146
|
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
|
|
136
147
|
ax.set_title(title)
|
|
137
148
|
ax.set_xlabel("Characters in Text")
|
|
@@ -142,7 +153,7 @@ def plot_attributions_at_char(
|
|
|
142
153
|
|
|
143
154
|
|
|
144
155
|
def plot_attributions_at_word(
|
|
145
|
-
text, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
|
|
156
|
+
text, words, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
|
|
146
157
|
):
|
|
147
158
|
"""
|
|
148
159
|
Plots word-level attributions as a heatmap.
|
|
@@ -159,14 +170,13 @@ def plot_attributions_at_word(
|
|
|
159
170
|
"matplotlib is required for plotting. Please install it to use this function."
|
|
160
171
|
)
|
|
161
172
|
|
|
162
|
-
words = text.split()
|
|
163
173
|
top_k = attributions_per_word.shape[0]
|
|
164
174
|
all_plots = []
|
|
165
175
|
for i in range(top_k):
|
|
166
176
|
fig, ax = plt.subplots(figsize=figsize)
|
|
167
177
|
ax.bar(range(len(words)), attributions_per_word[i])
|
|
168
178
|
ax.set_xticks(np.arange(len(words)))
|
|
169
|
-
ax.set_xticklabels(words, rotation=
|
|
179
|
+
ax.set_xticklabels(words, rotation=45)
|
|
170
180
|
title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
|
|
171
181
|
ax.set_title(title)
|
|
172
182
|
ax.set_xlabel("Words in Text")
|
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/dataset/__init__.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/dataset/dataset.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/model.py
RENAMED
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/base.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/ngram.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/utilities/__init__.py
RENAMED
|
File without changes
|