torchtextclassifiers 1.0.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,6 +23,9 @@ class TextEmbedder(nn.Module):
23
23
  self.config = text_embedder_config
24
24
 
25
25
  self.attention_config = text_embedder_config.attention_config
26
+ if isinstance(self.attention_config, dict):
27
+ self.attention_config = AttentionConfig(**self.attention_config)
28
+
26
29
  if self.attention_config is not None:
27
30
  self.attention_config.n_embd = text_embedder_config.embedding_dim
28
31
 
@@ -36,7 +36,7 @@ class TextClassificationModule(pl.LightningModule):
36
36
  scheduler_interval: Scheduler interval.
37
37
  """
38
38
  super().__init__()
39
- self.save_hyperparameters(ignore=["model", "loss"])
39
+ self.save_hyperparameters(ignore=["model"])
40
40
 
41
41
  self.model = model
42
42
  self.loss = loss
@@ -7,4 +7,6 @@ from .base import (
7
7
  )
8
8
  from .base import TokenizerOutput as TokenizerOutput
9
9
  from .ngram import NGramTokenizer as NGramTokenizer
10
- from .WordPiece import WordPieceTokenizer as WordPieceTokenizer
10
+
11
+ if HAS_HF:
12
+ from .WordPiece import WordPieceTokenizer as WordPieceTokenizer
@@ -1,6 +1,8 @@
1
1
  import logging
2
+ import pickle
2
3
  import time
3
4
  from dataclasses import asdict, dataclass, field
5
+ from pathlib import Path
4
6
  from typing import Any, Dict, List, Optional, Tuple, Type, Union
5
7
 
6
8
  try:
@@ -75,6 +77,7 @@ class TrainingConfig:
75
77
  trainer_params: Optional[dict] = None
76
78
  optimizer_params: Optional[dict] = None
77
79
  scheduler_params: Optional[dict] = None
80
+ save_path: Optional[str] = "my_ttc"
78
81
 
79
82
  def to_dict(self) -> Dict[str, Any]:
80
83
  data = asdict(self)
@@ -362,6 +365,7 @@ class torchTextClassifiers:
362
365
  logger.info(f"Training completed in {end - start:.2f} seconds.")
363
366
 
364
367
  best_model_path = trainer.checkpoint_callback.best_model_path
368
+ self.checkpoint_path = best_model_path
365
369
 
366
370
  self.lightning_module = TextClassificationModule.load_from_checkpoint(
367
371
  best_model_path,
@@ -372,6 +376,9 @@ class torchTextClassifiers:
372
376
 
373
377
  self.pytorch_model = self.lightning_module.model.to(self.device)
374
378
 
379
+ self.save_path = training_config.save_path
380
+ self.save(self.save_path)
381
+
375
382
  self.lightning_module.eval()
376
383
 
377
384
  def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
@@ -576,6 +583,122 @@ class torchTextClassifiers:
576
583
  "confidence": confidence,
577
584
  }
578
585
 
586
+ def save(self, path: Union[str, Path]) -> None:
587
+ """Save the complete torchTextClassifiers instance to disk.
588
+
589
+ This saves:
590
+ - Model configuration
591
+ - Tokenizer state
592
+ - PyTorch Lightning checkpoint (if trained)
593
+ - All other instance attributes
594
+
595
+ Args:
596
+ path: Directory path where the model will be saved
597
+
598
+ Example:
599
+ >>> ttc = torchTextClassifiers(tokenizer, model_config)
600
+ >>> ttc.train(X_train, y_train, training_config)
601
+ >>> ttc.save("my_model")
602
+ """
603
+ path = Path(path)
604
+ path.mkdir(parents=True, exist_ok=True)
605
+
606
+ # Save the checkpoint if model has been trained
607
+ checkpoint_path = None
608
+ if hasattr(self, "lightning_module"):
609
+ checkpoint_path = path / "model_checkpoint.ckpt"
610
+ # Save the current state as a checkpoint
611
+ trainer = pl.Trainer()
612
+ trainer.strategy.connect(self.lightning_module)
613
+ trainer.save_checkpoint(checkpoint_path)
614
+
615
+ # Prepare metadata to save
616
+ metadata = {
617
+ "model_config": self.model_config.to_dict(),
618
+ "ragged_multilabel": self.ragged_multilabel,
619
+ "vocab_size": self.vocab_size,
620
+ "embedding_dim": self.embedding_dim,
621
+ "categorical_vocabulary_sizes": self.categorical_vocabulary_sizes,
622
+ "num_classes": self.num_classes,
623
+ "checkpoint_path": str(checkpoint_path) if checkpoint_path else None,
624
+ "device": str(self.device) if hasattr(self, "device") else None,
625
+ }
626
+
627
+ # Save metadata
628
+ with open(path / "metadata.pkl", "wb") as f:
629
+ pickle.dump(metadata, f)
630
+
631
+ # Save tokenizer
632
+ tokenizer_path = path / "tokenizer.pkl"
633
+ with open(tokenizer_path, "wb") as f:
634
+ pickle.dump(self.tokenizer, f)
635
+
636
+ logger.info(f"Model saved successfully to {path}")
637
+
638
+ @classmethod
639
+ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers":
640
+ """Load a torchTextClassifiers instance from disk.
641
+
642
+ Args:
643
+ path: Directory path where the model was saved
644
+ device: Device to load the model on ('auto', 'cpu', 'cuda', etc.)
645
+
646
+ Returns:
647
+ Loaded torchTextClassifiers instance
648
+
649
+ Example:
650
+ >>> loaded_ttc = torchTextClassifiers.load("my_model")
651
+ >>> predictions = loaded_ttc.predict(X_test)
652
+ """
653
+ path = Path(path)
654
+
655
+ if not path.exists():
656
+ raise FileNotFoundError(f"Model directory not found: {path}")
657
+
658
+ # Load metadata
659
+ with open(path / "metadata.pkl", "rb") as f:
660
+ metadata = pickle.load(f)
661
+
662
+ # Load tokenizer
663
+ with open(path / "tokenizer.pkl", "rb") as f:
664
+ tokenizer = pickle.load(f)
665
+
666
+ # Reconstruct model_config
667
+ model_config = ModelConfig.from_dict(metadata["model_config"])
668
+
669
+ # Create instance
670
+ instance = cls(
671
+ tokenizer=tokenizer,
672
+ model_config=model_config,
673
+ ragged_multilabel=metadata["ragged_multilabel"],
674
+ )
675
+
676
+ # Set device
677
+ if device == "auto":
678
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
679
+ else:
680
+ device = torch.device(device)
681
+ instance.device = device
682
+
683
+ # Load checkpoint if it exists
684
+ if metadata["checkpoint_path"]:
685
+ checkpoint_path = path / "model_checkpoint.ckpt"
686
+ if checkpoint_path.exists():
687
+ # Load the checkpoint with weights_only=False since it's our own trusted checkpoint
688
+ instance.lightning_module = TextClassificationModule.load_from_checkpoint(
689
+ str(checkpoint_path),
690
+ model=instance.pytorch_model,
691
+ weights_only=False,
692
+ )
693
+ instance.pytorch_model = instance.lightning_module.model.to(device)
694
+ instance.checkpoint_path = str(checkpoint_path)
695
+ logger.info(f"Model checkpoint loaded from {checkpoint_path}")
696
+ else:
697
+ logger.warning(f"Checkpoint file not found at {checkpoint_path}")
698
+
699
+ logger.info(f"Model loaded successfully from {path}")
700
+ return instance
701
+
579
702
  def __repr__(self):
580
703
  model_type = (
581
704
  self.lightning_module.__repr__()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: torchtextclassifiers
3
- Version: 1.0.1
3
+ Version: 1.0.2
4
4
  Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
5
5
  Keywords: fastText,text classification,NLP,automatic coding,deep learning
6
6
  Author: Cédric Couralet, Meilame Tayebjee
@@ -6,16 +6,16 @@ torchTextClassifiers/model/components/__init__.py,sha256=-IT_6fCHZkRw6Hu7GdVeCt6
6
6
  torchTextClassifiers/model/components/attention.py,sha256=hhSMh_CvpR-hiP8hoCg4Fr_TovGlJpC_RHs3iW-Pnpc,4199
7
7
  torchTextClassifiers/model/components/categorical_var_net.py,sha256=no0QDidKCw1rlbJzD7S-Srhzn5P6vETGRT5Er-gzMnM,5699
8
8
  torchTextClassifiers/model/components/classification_head.py,sha256=myuEc5wFQ5gw_f519cUZ1Z7AMuQF7Vshq_B3aRt5xRE,2501
9
- torchTextClassifiers/model/components/text_embedder.py,sha256=tY2pXAt4IvayyvRpjiKGg5vGz_Q2-p_TOL6Jg2p8hYE,9058
10
- torchTextClassifiers/model/lightning.py,sha256=dOJzyGbqwFxriAtrIjC14E1f107YMtpiR65-OJy_Pc4,5367
9
+ torchTextClassifiers/model/components/text_embedder.py,sha256=qInHVQfjxN1zBGSNNv_9Ku4EwjntWLazjasoHhFn_yI,9188
10
+ torchTextClassifiers/model/lightning.py,sha256=dJEH_cPPh089v4hwLuyZuXe2QxIwWOqecsXqEYrsIHU,5359
11
11
  torchTextClassifiers/model/model.py,sha256=jjGjvK7C2Wly0e4S6gTC8Ty8y-o8reU-aniBqYS73Cc,6100
12
12
  torchTextClassifiers/tokenizers/WordPiece.py,sha256=HMHYV2SiwShlhWMQ6LXH4MtZE5GSsaNA2DlD340ABGE,3289
13
- torchTextClassifiers/tokenizers/__init__.py,sha256=I8IQ2-t85RVlZFwLjDFF_Te2S9uiwlymQDWx-3GeF-Y,334
13
+ torchTextClassifiers/tokenizers/__init__.py,sha256=rWWIDIQnAL9vS33ygNlZju3A6lpzC8zDiL1GBT_2TWc,350
14
14
  torchTextClassifiers/tokenizers/base.py,sha256=OY6GIhI4KTdvvKq3VZowf64H7lAmdQyq4scZ10HxP3A,7570
15
15
  torchTextClassifiers/tokenizers/ngram.py,sha256=lHI8dtuCGWh0o7V58TJx_mTVIHm8udl6XuWccxgJPew,16375
16
- torchTextClassifiers/torchTextClassifiers.py,sha256=ru1gAp3IaNNiV1aMzU_TYxfm81buJLu-NkvrRwUGbEU,23053
16
+ torchTextClassifiers/torchTextClassifiers.py,sha256=_2PpE9OEuNNskwJwMc1Dqu_DP5yp6T-H-C2VOKoKn2I,27683
17
17
  torchTextClassifiers/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  torchTextClassifiers/utilities/plot_explainability.py,sha256=uSN6NbbVnnCd7Zy7zCDVM0iBbhx03tXlON6TlNk0tNU,7248
19
- torchtextclassifiers-1.0.1.dist-info/WHEEL,sha256=xDCZ-UyfvkGuEHPeI7BcJzYKIZzdqN8A8o1M5Om8IyA,79
20
- torchtextclassifiers-1.0.1.dist-info/METADATA,sha256=Nwp2MD_jexz6zQdwPXIsiLO7GDwTL3qVYK6D57aYMF4,3666
21
- torchtextclassifiers-1.0.1.dist-info/RECORD,,
19
+ torchtextclassifiers-1.0.2.dist-info/WHEEL,sha256=xDCZ-UyfvkGuEHPeI7BcJzYKIZzdqN8A8o1M5Om8IyA,79
20
+ torchtextclassifiers-1.0.2.dist-info/METADATA,sha256=ztc5fj_-smNTKq6j8CeLU39QRdk8Li8CzgxX1snispU,3666
21
+ torchtextclassifiers-1.0.2.dist-info/RECORD,,