torchtextclassifiers 1.0.0__tar.gz → 1.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/PKG-INFO +1 -1
  2. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/pyproject.toml +1 -1
  3. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/text_embedder.py +3 -0
  4. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/lightning.py +1 -1
  5. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/__init__.py +3 -1
  6. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/torchTextClassifiers.py +130 -0
  7. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/utilities/plot_explainability.py +17 -7
  8. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/README.md +0 -0
  9. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/__init__.py +0 -0
  10. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/dataset/__init__.py +0 -0
  11. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/dataset/dataset.py +0 -0
  12. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/__init__.py +0 -0
  13. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/__init__.py +0 -0
  14. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/attention.py +0 -0
  15. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/categorical_var_net.py +0 -0
  16. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/components/classification_head.py +0 -0
  17. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/model/model.py +0 -0
  18. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/WordPiece.py +0 -0
  19. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/base.py +0 -0
  20. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/tokenizers/ngram.py +0 -0
  21. {torchtextclassifiers-1.0.0 → torchtextclassifiers-1.0.2}/torchTextClassifiers/utilities/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: torchtextclassifiers
3
- Version: 1.0.0
3
+ Version: 1.0.2
4
4
  Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
5
5
  Keywords: fastText,text classification,NLP,automatic coding,deep learning
6
6
  Author: Cédric Couralet, Meilame Tayebjee
@@ -18,7 +18,7 @@ dependencies = [
18
18
  "pytorch-lightning>=2.4.0",
19
19
  ]
20
20
  requires-python = ">=3.11"
21
- version="1.0.0"
21
+ version="1.0.2"
22
22
 
23
23
 
24
24
  [dependency-groups]
@@ -23,6 +23,9 @@ class TextEmbedder(nn.Module):
23
23
  self.config = text_embedder_config
24
24
 
25
25
  self.attention_config = text_embedder_config.attention_config
26
+ if isinstance(self.attention_config, dict):
27
+ self.attention_config = AttentionConfig(**self.attention_config)
28
+
26
29
  if self.attention_config is not None:
27
30
  self.attention_config.n_embd = text_embedder_config.embedding_dim
28
31
 
@@ -36,7 +36,7 @@ class TextClassificationModule(pl.LightningModule):
36
36
  scheduler_interval: Scheduler interval.
37
37
  """
38
38
  super().__init__()
39
- self.save_hyperparameters(ignore=["model", "loss"])
39
+ self.save_hyperparameters(ignore=["model"])
40
40
 
41
41
  self.model = model
42
42
  self.loss = loss
@@ -7,4 +7,6 @@ from .base import (
7
7
  )
8
8
  from .base import TokenizerOutput as TokenizerOutput
9
9
  from .ngram import NGramTokenizer as NGramTokenizer
10
- from .WordPiece import WordPieceTokenizer as WordPieceTokenizer
10
+
11
+ if HAS_HF:
12
+ from .WordPiece import WordPieceTokenizer as WordPieceTokenizer
@@ -1,6 +1,8 @@
1
1
  import logging
2
+ import pickle
2
3
  import time
3
4
  from dataclasses import asdict, dataclass, field
5
+ from pathlib import Path
4
6
  from typing import Any, Dict, List, Optional, Tuple, Type, Union
5
7
 
6
8
  try:
@@ -75,6 +77,7 @@ class TrainingConfig:
75
77
  trainer_params: Optional[dict] = None
76
78
  optimizer_params: Optional[dict] = None
77
79
  scheduler_params: Optional[dict] = None
80
+ save_path: Optional[str] = "my_ttc"
78
81
 
79
82
  def to_dict(self) -> Dict[str, Any]:
80
83
  data = asdict(self)
@@ -198,6 +201,12 @@ class torchTextClassifiers:
198
201
  - Model training with early stopping
199
202
  - Best model loading after training
200
203
 
204
+ Note on Checkpoints:
205
+ After training, the best model checkpoint is automatically loaded.
206
+ This checkpoint contains the full training state (model weights,
207
+ optimizer, and scheduler state). Loading uses weights_only=False
208
+ as the checkpoint is self-generated and trusted.
209
+
201
210
  Args:
202
211
  X_train: Training input data
203
212
  y_train: Training labels
@@ -356,15 +365,20 @@ class torchTextClassifiers:
356
365
  logger.info(f"Training completed in {end - start:.2f} seconds.")
357
366
 
358
367
  best_model_path = trainer.checkpoint_callback.best_model_path
368
+ self.checkpoint_path = best_model_path
359
369
 
360
370
  self.lightning_module = TextClassificationModule.load_from_checkpoint(
361
371
  best_model_path,
362
372
  model=self.pytorch_model,
363
373
  loss=training_config.loss,
374
+ weights_only=False, # Required: checkpoint contains optimizer/scheduler state
364
375
  )
365
376
 
366
377
  self.pytorch_model = self.lightning_module.model.to(self.device)
367
378
 
379
+ self.save_path = training_config.save_path
380
+ self.save(self.save_path)
381
+
368
382
  self.lightning_module.eval()
369
383
 
370
384
  def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
@@ -569,6 +583,122 @@ class torchTextClassifiers:
569
583
  "confidence": confidence,
570
584
  }
571
585
 
586
+ def save(self, path: Union[str, Path]) -> None:
587
+ """Save the complete torchTextClassifiers instance to disk.
588
+
589
+ This saves:
590
+ - Model configuration
591
+ - Tokenizer state
592
+ - PyTorch Lightning checkpoint (if trained)
593
+ - All other instance attributes
594
+
595
+ Args:
596
+ path: Directory path where the model will be saved
597
+
598
+ Example:
599
+ >>> ttc = torchTextClassifiers(tokenizer, model_config)
600
+ >>> ttc.train(X_train, y_train, training_config)
601
+ >>> ttc.save("my_model")
602
+ """
603
+ path = Path(path)
604
+ path.mkdir(parents=True, exist_ok=True)
605
+
606
+ # Save the checkpoint if model has been trained
607
+ checkpoint_path = None
608
+ if hasattr(self, "lightning_module"):
609
+ checkpoint_path = path / "model_checkpoint.ckpt"
610
+ # Save the current state as a checkpoint
611
+ trainer = pl.Trainer()
612
+ trainer.strategy.connect(self.lightning_module)
613
+ trainer.save_checkpoint(checkpoint_path)
614
+
615
+ # Prepare metadata to save
616
+ metadata = {
617
+ "model_config": self.model_config.to_dict(),
618
+ "ragged_multilabel": self.ragged_multilabel,
619
+ "vocab_size": self.vocab_size,
620
+ "embedding_dim": self.embedding_dim,
621
+ "categorical_vocabulary_sizes": self.categorical_vocabulary_sizes,
622
+ "num_classes": self.num_classes,
623
+ "checkpoint_path": str(checkpoint_path) if checkpoint_path else None,
624
+ "device": str(self.device) if hasattr(self, "device") else None,
625
+ }
626
+
627
+ # Save metadata
628
+ with open(path / "metadata.pkl", "wb") as f:
629
+ pickle.dump(metadata, f)
630
+
631
+ # Save tokenizer
632
+ tokenizer_path = path / "tokenizer.pkl"
633
+ with open(tokenizer_path, "wb") as f:
634
+ pickle.dump(self.tokenizer, f)
635
+
636
+ logger.info(f"Model saved successfully to {path}")
637
+
638
+ @classmethod
639
+ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers":
640
+ """Load a torchTextClassifiers instance from disk.
641
+
642
+ Args:
643
+ path: Directory path where the model was saved
644
+ device: Device to load the model on ('auto', 'cpu', 'cuda', etc.)
645
+
646
+ Returns:
647
+ Loaded torchTextClassifiers instance
648
+
649
+ Example:
650
+ >>> loaded_ttc = torchTextClassifiers.load("my_model")
651
+ >>> predictions = loaded_ttc.predict(X_test)
652
+ """
653
+ path = Path(path)
654
+
655
+ if not path.exists():
656
+ raise FileNotFoundError(f"Model directory not found: {path}")
657
+
658
+ # Load metadata
659
+ with open(path / "metadata.pkl", "rb") as f:
660
+ metadata = pickle.load(f)
661
+
662
+ # Load tokenizer
663
+ with open(path / "tokenizer.pkl", "rb") as f:
664
+ tokenizer = pickle.load(f)
665
+
666
+ # Reconstruct model_config
667
+ model_config = ModelConfig.from_dict(metadata["model_config"])
668
+
669
+ # Create instance
670
+ instance = cls(
671
+ tokenizer=tokenizer,
672
+ model_config=model_config,
673
+ ragged_multilabel=metadata["ragged_multilabel"],
674
+ )
675
+
676
+ # Set device
677
+ if device == "auto":
678
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
679
+ else:
680
+ device = torch.device(device)
681
+ instance.device = device
682
+
683
+ # Load checkpoint if it exists
684
+ if metadata["checkpoint_path"]:
685
+ checkpoint_path = path / "model_checkpoint.ckpt"
686
+ if checkpoint_path.exists():
687
+ # Load the checkpoint with weights_only=False since it's our own trusted checkpoint
688
+ instance.lightning_module = TextClassificationModule.load_from_checkpoint(
689
+ str(checkpoint_path),
690
+ model=instance.pytorch_model,
691
+ weights_only=False,
692
+ )
693
+ instance.pytorch_model = instance.lightning_module.model.to(device)
694
+ instance.checkpoint_path = str(checkpoint_path)
695
+ logger.info(f"Model checkpoint loaded from {checkpoint_path}")
696
+ else:
697
+ logger.warning(f"Checkpoint file not found at {checkpoint_path}")
698
+
699
+ logger.info(f"Model loaded successfully from {path}")
700
+ return instance
701
+
572
702
  def __repr__(self):
573
703
  model_type = (
574
704
  self.lightning_module.__repr__()
@@ -53,8 +53,18 @@ def map_attributions_to_char(attributions, offsets, text):
53
53
  np.exp(attributions_per_char), axis=1, keepdims=True
54
54
  ) # softmax normalization
55
55
 
56
+ def get_id_to_word(text, word_ids, offsets):
57
+ words = {}
58
+ for idx, word_id in enumerate(word_ids):
59
+ if word_id is None:
60
+ continue
61
+ start, end = offsets[idx]
62
+ words[int(word_id)] = text[start:end]
63
+
64
+ return words
65
+
56
66
 
57
- def map_attributions_to_word(attributions, word_ids):
67
+ def map_attributions_to_word(attributions, text, word_ids, offsets):
58
68
  """
59
69
  Maps token-level attributions to word-level attributions based on word IDs.
60
70
  Args:
@@ -69,8 +79,9 @@ def map_attributions_to_word(attributions, word_ids):
69
79
  np.ndarray: Array of shape (top_k, num_words) containing word-level attributions.
70
80
  num_words is the number of unique words in the original text.
71
81
  """
72
-
82
+
73
83
  word_ids = np.array(word_ids)
84
+ words = get_id_to_word(text, word_ids, offsets)
74
85
 
75
86
  # Convert None to -1 for easier processing (PAD tokens)
76
87
  word_ids_int = np.array([x if x is not None else -1 for x in word_ids], dtype=int)
@@ -99,7 +110,7 @@ def map_attributions_to_word(attributions, word_ids):
99
110
  ) # zero-out non-matching tokens and sum attributions for all tokens belonging to the same word
100
111
 
101
112
  # assert word_attributions.sum(axis=1) == attributions.sum(axis=1), "Sum of word attributions per top_k must equal sum of token attributions per top_k."
102
- return np.exp(word_attributions) / np.sum(
113
+ return words, np.exp(word_attributions) / np.sum(
103
114
  np.exp(word_attributions), axis=1, keepdims=True
104
115
  ) # softmax normalization
105
116
 
@@ -131,7 +142,7 @@ def plot_attributions_at_char(
131
142
  fig, ax = plt.subplots(figsize=figsize)
132
143
  ax.bar(range(len(text)), attributions_per_char[i])
133
144
  ax.set_xticks(np.arange(len(text)))
134
- ax.set_xticklabels(list(text), rotation=90)
145
+ ax.set_xticklabels(list(text), rotation=45)
135
146
  title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
136
147
  ax.set_title(title)
137
148
  ax.set_xlabel("Characters in Text")
@@ -142,7 +153,7 @@ def plot_attributions_at_char(
142
153
 
143
154
 
144
155
  def plot_attributions_at_word(
145
- text, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
156
+ text, words, attributions_per_word, figsize=(10, 2), titles: Optional[List[str]] = None
146
157
  ):
147
158
  """
148
159
  Plots word-level attributions as a heatmap.
@@ -159,14 +170,13 @@ def plot_attributions_at_word(
159
170
  "matplotlib is required for plotting. Please install it to use this function."
160
171
  )
161
172
 
162
- words = text.split()
163
173
  top_k = attributions_per_word.shape[0]
164
174
  all_plots = []
165
175
  for i in range(top_k):
166
176
  fig, ax = plt.subplots(figsize=figsize)
167
177
  ax.bar(range(len(words)), attributions_per_word[i])
168
178
  ax.set_xticks(np.arange(len(words)))
169
- ax.set_xticklabels(words, rotation=90)
179
+ ax.set_xticklabels(words, rotation=45)
170
180
  title = titles[i] if titles is not None else f"Attributions for Top {i+1} Prediction"
171
181
  ax.set_title(title)
172
182
  ax.set_xlabel("Words in Text")