torchtextclassifiers 1.0.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/model/components/text_embedder.py +3 -0
- torchTextClassifiers/model/lightning.py +1 -1
- torchTextClassifiers/tokenizers/__init__.py +3 -1
- torchTextClassifiers/torchTextClassifiers.py +123 -0
- {torchtextclassifiers-1.0.1.dist-info → torchtextclassifiers-1.0.2.dist-info}/METADATA +1 -1
- {torchtextclassifiers-1.0.1.dist-info → torchtextclassifiers-1.0.2.dist-info}/RECORD +7 -7
- {torchtextclassifiers-1.0.1.dist-info → torchtextclassifiers-1.0.2.dist-info}/WHEEL +0 -0
|
@@ -23,6 +23,9 @@ class TextEmbedder(nn.Module):
|
|
|
23
23
|
self.config = text_embedder_config
|
|
24
24
|
|
|
25
25
|
self.attention_config = text_embedder_config.attention_config
|
|
26
|
+
if isinstance(self.attention_config, dict):
|
|
27
|
+
self.attention_config = AttentionConfig(**self.attention_config)
|
|
28
|
+
|
|
26
29
|
if self.attention_config is not None:
|
|
27
30
|
self.attention_config.n_embd = text_embedder_config.embedding_dim
|
|
28
31
|
|
|
@@ -36,7 +36,7 @@ class TextClassificationModule(pl.LightningModule):
|
|
|
36
36
|
scheduler_interval: Scheduler interval.
|
|
37
37
|
"""
|
|
38
38
|
super().__init__()
|
|
39
|
-
self.save_hyperparameters(ignore=["model"
|
|
39
|
+
self.save_hyperparameters(ignore=["model"])
|
|
40
40
|
|
|
41
41
|
self.model = model
|
|
42
42
|
self.loss = loss
|
|
@@ -7,4 +7,6 @@ from .base import (
|
|
|
7
7
|
)
|
|
8
8
|
from .base import TokenizerOutput as TokenizerOutput
|
|
9
9
|
from .ngram import NGramTokenizer as NGramTokenizer
|
|
10
|
-
|
|
10
|
+
|
|
11
|
+
if HAS_HF:
|
|
12
|
+
from .WordPiece import WordPieceTokenizer as WordPieceTokenizer
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import pickle
|
|
2
3
|
import time
|
|
3
4
|
from dataclasses import asdict, dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
4
6
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
5
7
|
|
|
6
8
|
try:
|
|
@@ -75,6 +77,7 @@ class TrainingConfig:
|
|
|
75
77
|
trainer_params: Optional[dict] = None
|
|
76
78
|
optimizer_params: Optional[dict] = None
|
|
77
79
|
scheduler_params: Optional[dict] = None
|
|
80
|
+
save_path: Optional[str] = "my_ttc"
|
|
78
81
|
|
|
79
82
|
def to_dict(self) -> Dict[str, Any]:
|
|
80
83
|
data = asdict(self)
|
|
@@ -362,6 +365,7 @@ class torchTextClassifiers:
|
|
|
362
365
|
logger.info(f"Training completed in {end - start:.2f} seconds.")
|
|
363
366
|
|
|
364
367
|
best_model_path = trainer.checkpoint_callback.best_model_path
|
|
368
|
+
self.checkpoint_path = best_model_path
|
|
365
369
|
|
|
366
370
|
self.lightning_module = TextClassificationModule.load_from_checkpoint(
|
|
367
371
|
best_model_path,
|
|
@@ -372,6 +376,9 @@ class torchTextClassifiers:
|
|
|
372
376
|
|
|
373
377
|
self.pytorch_model = self.lightning_module.model.to(self.device)
|
|
374
378
|
|
|
379
|
+
self.save_path = training_config.save_path
|
|
380
|
+
self.save(self.save_path)
|
|
381
|
+
|
|
375
382
|
self.lightning_module.eval()
|
|
376
383
|
|
|
377
384
|
def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
@@ -576,6 +583,122 @@ class torchTextClassifiers:
|
|
|
576
583
|
"confidence": confidence,
|
|
577
584
|
}
|
|
578
585
|
|
|
586
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
587
|
+
"""Save the complete torchTextClassifiers instance to disk.
|
|
588
|
+
|
|
589
|
+
This saves:
|
|
590
|
+
- Model configuration
|
|
591
|
+
- Tokenizer state
|
|
592
|
+
- PyTorch Lightning checkpoint (if trained)
|
|
593
|
+
- All other instance attributes
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
path: Directory path where the model will be saved
|
|
597
|
+
|
|
598
|
+
Example:
|
|
599
|
+
>>> ttc = torchTextClassifiers(tokenizer, model_config)
|
|
600
|
+
>>> ttc.train(X_train, y_train, training_config)
|
|
601
|
+
>>> ttc.save("my_model")
|
|
602
|
+
"""
|
|
603
|
+
path = Path(path)
|
|
604
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
605
|
+
|
|
606
|
+
# Save the checkpoint if model has been trained
|
|
607
|
+
checkpoint_path = None
|
|
608
|
+
if hasattr(self, "lightning_module"):
|
|
609
|
+
checkpoint_path = path / "model_checkpoint.ckpt"
|
|
610
|
+
# Save the current state as a checkpoint
|
|
611
|
+
trainer = pl.Trainer()
|
|
612
|
+
trainer.strategy.connect(self.lightning_module)
|
|
613
|
+
trainer.save_checkpoint(checkpoint_path)
|
|
614
|
+
|
|
615
|
+
# Prepare metadata to save
|
|
616
|
+
metadata = {
|
|
617
|
+
"model_config": self.model_config.to_dict(),
|
|
618
|
+
"ragged_multilabel": self.ragged_multilabel,
|
|
619
|
+
"vocab_size": self.vocab_size,
|
|
620
|
+
"embedding_dim": self.embedding_dim,
|
|
621
|
+
"categorical_vocabulary_sizes": self.categorical_vocabulary_sizes,
|
|
622
|
+
"num_classes": self.num_classes,
|
|
623
|
+
"checkpoint_path": str(checkpoint_path) if checkpoint_path else None,
|
|
624
|
+
"device": str(self.device) if hasattr(self, "device") else None,
|
|
625
|
+
}
|
|
626
|
+
|
|
627
|
+
# Save metadata
|
|
628
|
+
with open(path / "metadata.pkl", "wb") as f:
|
|
629
|
+
pickle.dump(metadata, f)
|
|
630
|
+
|
|
631
|
+
# Save tokenizer
|
|
632
|
+
tokenizer_path = path / "tokenizer.pkl"
|
|
633
|
+
with open(tokenizer_path, "wb") as f:
|
|
634
|
+
pickle.dump(self.tokenizer, f)
|
|
635
|
+
|
|
636
|
+
logger.info(f"Model saved successfully to {path}")
|
|
637
|
+
|
|
638
|
+
@classmethod
|
|
639
|
+
def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers":
|
|
640
|
+
"""Load a torchTextClassifiers instance from disk.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
path: Directory path where the model was saved
|
|
644
|
+
device: Device to load the model on ('auto', 'cpu', 'cuda', etc.)
|
|
645
|
+
|
|
646
|
+
Returns:
|
|
647
|
+
Loaded torchTextClassifiers instance
|
|
648
|
+
|
|
649
|
+
Example:
|
|
650
|
+
>>> loaded_ttc = torchTextClassifiers.load("my_model")
|
|
651
|
+
>>> predictions = loaded_ttc.predict(X_test)
|
|
652
|
+
"""
|
|
653
|
+
path = Path(path)
|
|
654
|
+
|
|
655
|
+
if not path.exists():
|
|
656
|
+
raise FileNotFoundError(f"Model directory not found: {path}")
|
|
657
|
+
|
|
658
|
+
# Load metadata
|
|
659
|
+
with open(path / "metadata.pkl", "rb") as f:
|
|
660
|
+
metadata = pickle.load(f)
|
|
661
|
+
|
|
662
|
+
# Load tokenizer
|
|
663
|
+
with open(path / "tokenizer.pkl", "rb") as f:
|
|
664
|
+
tokenizer = pickle.load(f)
|
|
665
|
+
|
|
666
|
+
# Reconstruct model_config
|
|
667
|
+
model_config = ModelConfig.from_dict(metadata["model_config"])
|
|
668
|
+
|
|
669
|
+
# Create instance
|
|
670
|
+
instance = cls(
|
|
671
|
+
tokenizer=tokenizer,
|
|
672
|
+
model_config=model_config,
|
|
673
|
+
ragged_multilabel=metadata["ragged_multilabel"],
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
# Set device
|
|
677
|
+
if device == "auto":
|
|
678
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
679
|
+
else:
|
|
680
|
+
device = torch.device(device)
|
|
681
|
+
instance.device = device
|
|
682
|
+
|
|
683
|
+
# Load checkpoint if it exists
|
|
684
|
+
if metadata["checkpoint_path"]:
|
|
685
|
+
checkpoint_path = path / "model_checkpoint.ckpt"
|
|
686
|
+
if checkpoint_path.exists():
|
|
687
|
+
# Load the checkpoint with weights_only=False since it's our own trusted checkpoint
|
|
688
|
+
instance.lightning_module = TextClassificationModule.load_from_checkpoint(
|
|
689
|
+
str(checkpoint_path),
|
|
690
|
+
model=instance.pytorch_model,
|
|
691
|
+
weights_only=False,
|
|
692
|
+
)
|
|
693
|
+
instance.pytorch_model = instance.lightning_module.model.to(device)
|
|
694
|
+
instance.checkpoint_path = str(checkpoint_path)
|
|
695
|
+
logger.info(f"Model checkpoint loaded from {checkpoint_path}")
|
|
696
|
+
else:
|
|
697
|
+
logger.warning(f"Checkpoint file not found at {checkpoint_path}")
|
|
698
|
+
|
|
699
|
+
logger.info(f"Model loaded successfully from {path}")
|
|
700
|
+
return instance
|
|
701
|
+
|
|
579
702
|
def __repr__(self):
|
|
580
703
|
model_type = (
|
|
581
704
|
self.lightning_module.__repr__()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: torchtextclassifiers
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.2
|
|
4
4
|
Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
|
|
5
5
|
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
6
|
Author: Cédric Couralet, Meilame Tayebjee
|
|
@@ -6,16 +6,16 @@ torchTextClassifiers/model/components/__init__.py,sha256=-IT_6fCHZkRw6Hu7GdVeCt6
|
|
|
6
6
|
torchTextClassifiers/model/components/attention.py,sha256=hhSMh_CvpR-hiP8hoCg4Fr_TovGlJpC_RHs3iW-Pnpc,4199
|
|
7
7
|
torchTextClassifiers/model/components/categorical_var_net.py,sha256=no0QDidKCw1rlbJzD7S-Srhzn5P6vETGRT5Er-gzMnM,5699
|
|
8
8
|
torchTextClassifiers/model/components/classification_head.py,sha256=myuEc5wFQ5gw_f519cUZ1Z7AMuQF7Vshq_B3aRt5xRE,2501
|
|
9
|
-
torchTextClassifiers/model/components/text_embedder.py,sha256=
|
|
10
|
-
torchTextClassifiers/model/lightning.py,sha256=
|
|
9
|
+
torchTextClassifiers/model/components/text_embedder.py,sha256=qInHVQfjxN1zBGSNNv_9Ku4EwjntWLazjasoHhFn_yI,9188
|
|
10
|
+
torchTextClassifiers/model/lightning.py,sha256=dJEH_cPPh089v4hwLuyZuXe2QxIwWOqecsXqEYrsIHU,5359
|
|
11
11
|
torchTextClassifiers/model/model.py,sha256=jjGjvK7C2Wly0e4S6gTC8Ty8y-o8reU-aniBqYS73Cc,6100
|
|
12
12
|
torchTextClassifiers/tokenizers/WordPiece.py,sha256=HMHYV2SiwShlhWMQ6LXH4MtZE5GSsaNA2DlD340ABGE,3289
|
|
13
|
-
torchTextClassifiers/tokenizers/__init__.py,sha256=
|
|
13
|
+
torchTextClassifiers/tokenizers/__init__.py,sha256=rWWIDIQnAL9vS33ygNlZju3A6lpzC8zDiL1GBT_2TWc,350
|
|
14
14
|
torchTextClassifiers/tokenizers/base.py,sha256=OY6GIhI4KTdvvKq3VZowf64H7lAmdQyq4scZ10HxP3A,7570
|
|
15
15
|
torchTextClassifiers/tokenizers/ngram.py,sha256=lHI8dtuCGWh0o7V58TJx_mTVIHm8udl6XuWccxgJPew,16375
|
|
16
|
-
torchTextClassifiers/torchTextClassifiers.py,sha256=
|
|
16
|
+
torchTextClassifiers/torchTextClassifiers.py,sha256=_2PpE9OEuNNskwJwMc1Dqu_DP5yp6T-H-C2VOKoKn2I,27683
|
|
17
17
|
torchTextClassifiers/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
18
|
torchTextClassifiers/utilities/plot_explainability.py,sha256=uSN6NbbVnnCd7Zy7zCDVM0iBbhx03tXlON6TlNk0tNU,7248
|
|
19
|
-
torchtextclassifiers-1.0.
|
|
20
|
-
torchtextclassifiers-1.0.
|
|
21
|
-
torchtextclassifiers-1.0.
|
|
19
|
+
torchtextclassifiers-1.0.2.dist-info/WHEEL,sha256=xDCZ-UyfvkGuEHPeI7BcJzYKIZzdqN8A8o1M5Om8IyA,79
|
|
20
|
+
torchtextclassifiers-1.0.2.dist-info/METADATA,sha256=ztc5fj_-smNTKq6j8CeLU39QRdk8Li8CzgxX1snispU,3666
|
|
21
|
+
torchtextclassifiers-1.0.2.dist-info/RECORD,,
|
|
File without changes
|