torchtextclassifiers 0.0.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +152 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +61 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +170 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +500 -413
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-1.0.0.dist-info/METADATA +87 -0
- torchtextclassifiers-1.0.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-1.0.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import time
|
|
3
|
-
import
|
|
4
|
-
from typing import
|
|
3
|
+
from dataclasses import asdict, dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
from captum.attr import LayerIntegratedGradients
|
|
8
|
+
|
|
9
|
+
HAS_CAPTUM = True
|
|
10
|
+
except ImportError:
|
|
11
|
+
HAS_CAPTUM = False
|
|
12
|
+
|
|
5
13
|
|
|
6
14
|
import numpy as np
|
|
7
15
|
import pytorch_lightning as pl
|
|
@@ -12,9 +20,17 @@ from pytorch_lightning.callbacks import (
|
|
|
12
20
|
ModelCheckpoint,
|
|
13
21
|
)
|
|
14
22
|
|
|
15
|
-
from .
|
|
16
|
-
from .
|
|
17
|
-
|
|
23
|
+
from torchTextClassifiers.dataset import TextClassificationDataset
|
|
24
|
+
from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule
|
|
25
|
+
from torchTextClassifiers.model.components import (
|
|
26
|
+
AttentionConfig,
|
|
27
|
+
CategoricalForwardType,
|
|
28
|
+
CategoricalVariableNet,
|
|
29
|
+
ClassificationHead,
|
|
30
|
+
TextEmbedder,
|
|
31
|
+
TextEmbedderConfig,
|
|
32
|
+
)
|
|
33
|
+
from torchTextClassifiers.tokenizers import BaseTokenizer, TokenizerOutput
|
|
18
34
|
|
|
19
35
|
logger = logging.getLogger(__name__)
|
|
20
36
|
|
|
@@ -26,484 +42,555 @@ logging.basicConfig(
|
|
|
26
42
|
)
|
|
27
43
|
|
|
28
44
|
|
|
45
|
+
@dataclass
|
|
46
|
+
class ModelConfig:
|
|
47
|
+
"""Base configuration class for text classifiers."""
|
|
48
|
+
|
|
49
|
+
embedding_dim: int
|
|
50
|
+
categorical_vocabulary_sizes: Optional[List[int]] = None
|
|
51
|
+
categorical_embedding_dims: Optional[Union[List[int], int]] = None
|
|
52
|
+
num_classes: Optional[int] = None
|
|
53
|
+
attention_config: Optional[AttentionConfig] = None
|
|
54
|
+
|
|
55
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
56
|
+
return asdict(self)
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ModelConfig":
|
|
60
|
+
return cls(**data)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class TrainingConfig:
|
|
65
|
+
num_epochs: int
|
|
66
|
+
batch_size: int
|
|
67
|
+
lr: float
|
|
68
|
+
loss: torch.nn.Module = field(default_factory=lambda: torch.nn.CrossEntropyLoss())
|
|
69
|
+
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
|
70
|
+
scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None
|
|
71
|
+
accelerator: str = "auto"
|
|
72
|
+
num_workers: int = 12
|
|
73
|
+
patience_early_stopping: int = 3
|
|
74
|
+
dataloader_params: Optional[dict] = None
|
|
75
|
+
trainer_params: Optional[dict] = None
|
|
76
|
+
optimizer_params: Optional[dict] = None
|
|
77
|
+
scheduler_params: Optional[dict] = None
|
|
78
|
+
|
|
79
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
80
|
+
data = asdict(self)
|
|
81
|
+
# Serialize loss and scheduler as their class names
|
|
82
|
+
data["loss"] = self.loss.__class__.__name__
|
|
83
|
+
if self.scheduler is not None:
|
|
84
|
+
data["scheduler"] = self.scheduler.__name__
|
|
85
|
+
return data
|
|
29
86
|
|
|
30
87
|
|
|
31
88
|
class torchTextClassifiers:
|
|
32
89
|
"""Generic text classifier framework supporting multiple architectures.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
The
|
|
39
|
-
|
|
40
|
-
- Model training with validation
|
|
41
|
-
- Prediction and evaluation
|
|
42
|
-
- Model serialization and loading
|
|
43
|
-
|
|
44
|
-
Attributes:
|
|
45
|
-
config: Configuration object specific to the classifier type
|
|
46
|
-
classifier: The underlying classifier implementation
|
|
47
|
-
|
|
48
|
-
Example:
|
|
49
|
-
>>> from torchTextClassifiers import torchTextClassifiers
|
|
50
|
-
>>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
|
|
51
|
-
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
52
|
-
>>>
|
|
53
|
-
>>> # Create configuration
|
|
54
|
-
>>> config = FastTextConfig(
|
|
55
|
-
... embedding_dim=100,
|
|
56
|
-
... num_tokens=10000,
|
|
57
|
-
... min_count=1,
|
|
58
|
-
... min_n=3,
|
|
59
|
-
... max_n=6,
|
|
60
|
-
... len_word_ngrams=2,
|
|
61
|
-
... num_classes=2
|
|
62
|
-
... )
|
|
63
|
-
>>>
|
|
64
|
-
>>> # Initialize classifier with wrapper
|
|
65
|
-
>>> wrapper = FastTextWrapper(config)
|
|
66
|
-
>>> classifier = torchTextClassifiers(wrapper)
|
|
67
|
-
>>>
|
|
68
|
-
>>> # Build and train
|
|
69
|
-
>>> classifier.build(X_train, y_train)
|
|
70
|
-
>>> classifier.train(X_train, y_train, X_val, y_val, num_epochs=10, batch_size=32)
|
|
71
|
-
>>>
|
|
72
|
-
>>> # Predict
|
|
73
|
-
>>> predictions = classifier.predict(X_test)
|
|
90
|
+
|
|
91
|
+
Given a tokenizer and model configuration, this class initializes:
|
|
92
|
+
- Text embedding layer (if needed)
|
|
93
|
+
- Categorical variable embedding network (if categorical variables are provided)
|
|
94
|
+
- Classification head
|
|
95
|
+
The resulting model can be trained using PyTorch Lightning and used for predictions.
|
|
96
|
+
|
|
74
97
|
"""
|
|
75
|
-
|
|
76
|
-
def __init__(
|
|
77
|
-
"""Initialize the torchTextClassifiers instance.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
classifier: An instance of a classifier wrapper that implements BaseClassifierWrapper
|
|
81
|
-
|
|
82
|
-
Example:
|
|
83
|
-
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
84
|
-
>>> from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
|
|
85
|
-
>>> config = FastTextConfig(embedding_dim=50, num_tokens=5000)
|
|
86
|
-
>>> wrapper = FastTextWrapper(config)
|
|
87
|
-
>>> classifier = torchTextClassifiers(wrapper)
|
|
88
|
-
"""
|
|
89
|
-
self.classifier = classifier
|
|
90
|
-
self.config = classifier.config
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def build_tokenizer(self, training_text: np.ndarray) -> None:
|
|
94
|
-
"""Build tokenizer from training text data.
|
|
95
|
-
|
|
96
|
-
This method is kept for backward compatibility. It delegates to
|
|
97
|
-
prepare_text_features which handles the actual text preprocessing.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
training_text: Array of text strings to build the tokenizer from
|
|
101
|
-
|
|
102
|
-
Example:
|
|
103
|
-
>>> import numpy as np
|
|
104
|
-
>>> texts = np.array(["Hello world", "This is a test", "Another example"])
|
|
105
|
-
>>> classifier.build_tokenizer(texts)
|
|
106
|
-
"""
|
|
107
|
-
self.classifier.prepare_text_features(training_text)
|
|
108
|
-
|
|
109
|
-
def prepare_text_features(self, training_text: np.ndarray) -> None:
|
|
110
|
-
"""Prepare text features for the classifier.
|
|
111
|
-
|
|
112
|
-
This method handles text preprocessing which could involve tokenization,
|
|
113
|
-
vectorization, or other approaches depending on the classifier type.
|
|
114
|
-
|
|
115
|
-
Args:
|
|
116
|
-
training_text: Array of text strings to prepare features from
|
|
117
|
-
|
|
118
|
-
Example:
|
|
119
|
-
>>> import numpy as np
|
|
120
|
-
>>> texts = np.array(["Hello world", "This is a test", "Another example"])
|
|
121
|
-
>>> classifier.prepare_text_features(texts)
|
|
122
|
-
"""
|
|
123
|
-
self.classifier.prepare_text_features(training_text)
|
|
124
|
-
|
|
125
|
-
def build(
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
126
100
|
self,
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
This method handles the full model building process including:
|
|
135
|
-
- Input validation and preprocessing
|
|
136
|
-
- Tokenizer creation from training text
|
|
137
|
-
- Model architecture initialization
|
|
138
|
-
- Lightning module setup (if enabled)
|
|
139
|
-
|
|
101
|
+
tokenizer: BaseTokenizer,
|
|
102
|
+
model_config: ModelConfig,
|
|
103
|
+
ragged_multilabel: bool = False,
|
|
104
|
+
):
|
|
105
|
+
"""Initialize the torchTextClassifiers instance.
|
|
106
|
+
|
|
140
107
|
Args:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
**kwargs: Additional arguments passed to Lightning initialization
|
|
145
|
-
|
|
146
|
-
Raises:
|
|
147
|
-
ValueError: If y_train is None and num_classes is not set in config
|
|
148
|
-
ValueError: If label values are outside expected range
|
|
149
|
-
|
|
108
|
+
tokenizer: A tokenizer instance for text preprocessing
|
|
109
|
+
model_config: Configuration parameters for the text classification model
|
|
110
|
+
|
|
150
111
|
Example:
|
|
151
|
-
>>>
|
|
152
|
-
>>>
|
|
153
|
-
>>>
|
|
112
|
+
>>> from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
|
|
113
|
+
>>> # Assume tokenizer is a trained BaseTokenizer instance
|
|
114
|
+
>>> model_config = ModelConfig(
|
|
115
|
+
... embedding_dim=10,
|
|
116
|
+
... categorical_vocabulary_sizes=[30, 25],
|
|
117
|
+
... categorical_embedding_dims=[10, 5],
|
|
118
|
+
... num_classes=10,
|
|
119
|
+
... )
|
|
120
|
+
>>> ttc = torchTextClassifiers(
|
|
121
|
+
... tokenizer=tokenizer,
|
|
122
|
+
... model_config=model_config,
|
|
123
|
+
... )
|
|
154
124
|
"""
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
y_train = check_Y(y_train)
|
|
165
|
-
self.config.num_classes = len(np.unique(y_train))
|
|
166
|
-
|
|
167
|
-
if np.max(y_train) >= self.config.num_classes:
|
|
168
|
-
raise ValueError(
|
|
169
|
-
"y_train must contain values between 0 and num_classes-1"
|
|
125
|
+
|
|
126
|
+
self.model_config = model_config
|
|
127
|
+
self.tokenizer = tokenizer
|
|
128
|
+
self.ragged_multilabel = ragged_multilabel
|
|
129
|
+
|
|
130
|
+
if hasattr(self.tokenizer, "trained"):
|
|
131
|
+
if not self.tokenizer.trained:
|
|
132
|
+
raise RuntimeError(
|
|
133
|
+
f"Tokenizer {type(self.tokenizer)} must be trained before initializing the classifier."
|
|
170
134
|
)
|
|
135
|
+
|
|
136
|
+
self.vocab_size = tokenizer.vocab_size
|
|
137
|
+
self.embedding_dim = model_config.embedding_dim
|
|
138
|
+
self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
|
|
139
|
+
self.num_classes = model_config.num_classes
|
|
140
|
+
|
|
141
|
+
if self.tokenizer.output_vectorized:
|
|
142
|
+
self.text_embedder = None
|
|
143
|
+
logger.info(
|
|
144
|
+
"Tokenizer outputs vectorized tokens; skipping TextEmbedder initialization."
|
|
145
|
+
)
|
|
146
|
+
self.embedding_dim = self.tokenizer.output_dim
|
|
171
147
|
else:
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
if
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
self.
|
|
203
|
-
|
|
148
|
+
text_embedder_config = TextEmbedderConfig(
|
|
149
|
+
vocab_size=self.vocab_size,
|
|
150
|
+
embedding_dim=self.embedding_dim,
|
|
151
|
+
padding_idx=tokenizer.padding_idx,
|
|
152
|
+
attention_config=model_config.attention_config,
|
|
153
|
+
)
|
|
154
|
+
self.text_embedder = TextEmbedder(
|
|
155
|
+
text_embedder_config=text_embedder_config,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
classif_head_input_dim = self.embedding_dim
|
|
159
|
+
if self.categorical_vocabulary_sizes:
|
|
160
|
+
self.categorical_var_net = CategoricalVariableNet(
|
|
161
|
+
categorical_vocabulary_sizes=self.categorical_vocabulary_sizes,
|
|
162
|
+
categorical_embedding_dims=model_config.categorical_embedding_dims,
|
|
163
|
+
text_embedding_dim=self.embedding_dim,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if self.categorical_var_net.forward_type != CategoricalForwardType.SUM_TO_TEXT:
|
|
167
|
+
classif_head_input_dim += self.categorical_var_net.output_dim
|
|
168
|
+
|
|
169
|
+
else:
|
|
170
|
+
self.categorical_var_net = None
|
|
171
|
+
|
|
172
|
+
self.classification_head = ClassificationHead(
|
|
173
|
+
input_dim=classif_head_input_dim,
|
|
174
|
+
num_classes=model_config.num_classes,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
self.pytorch_model = TextClassificationModel(
|
|
178
|
+
text_embedder=self.text_embedder,
|
|
179
|
+
categorical_variable_net=self.categorical_var_net,
|
|
180
|
+
classification_head=self.classification_head,
|
|
181
|
+
)
|
|
182
|
+
|
|
204
183
|
def train(
|
|
205
184
|
self,
|
|
206
185
|
X_train: np.ndarray,
|
|
207
186
|
y_train: np.ndarray,
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
batch_size: int,
|
|
212
|
-
cpu_run: bool = False,
|
|
213
|
-
num_workers: int = 12,
|
|
214
|
-
patience_train: int = 3,
|
|
187
|
+
training_config: TrainingConfig,
|
|
188
|
+
X_val: Optional[np.ndarray] = None,
|
|
189
|
+
y_val: Optional[np.ndarray] = None,
|
|
215
190
|
verbose: bool = False,
|
|
216
|
-
trainer_params: Optional[dict] = None,
|
|
217
|
-
**kwargs
|
|
218
191
|
) -> None:
|
|
219
192
|
"""Train the classifier using PyTorch Lightning.
|
|
220
|
-
|
|
193
|
+
|
|
221
194
|
This method handles the complete training process including:
|
|
222
195
|
- Data validation and preprocessing
|
|
223
196
|
- Dataset and DataLoader creation
|
|
224
197
|
- PyTorch Lightning trainer setup with callbacks
|
|
225
198
|
- Model training with early stopping
|
|
226
199
|
- Best model loading after training
|
|
227
|
-
|
|
200
|
+
|
|
228
201
|
Args:
|
|
229
202
|
X_train: Training input data
|
|
230
203
|
y_train: Training labels
|
|
231
204
|
X_val: Validation input data
|
|
232
205
|
y_val: Validation labels
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
patience_train: Number of epochs to wait for improvement before early stopping
|
|
238
|
-
verbose: If True, print detailed training progress
|
|
239
|
-
trainer_params: Additional parameters to pass to PyTorch Lightning Trainer
|
|
240
|
-
**kwargs: Additional arguments passed to the build method
|
|
241
|
-
|
|
206
|
+
training_config: Configuration parameters for training
|
|
207
|
+
verbose: Whether to print training progress information
|
|
208
|
+
|
|
209
|
+
|
|
242
210
|
Example:
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
211
|
+
|
|
212
|
+
>>> training_config = TrainingConfig(
|
|
213
|
+
... lr=1e-3,
|
|
214
|
+
... batch_size=4,
|
|
215
|
+
... num_epochs=1,
|
|
216
|
+
... )
|
|
217
|
+
>>> ttc.train(
|
|
218
|
+
... X_train=X,
|
|
219
|
+
... y_train=Y,
|
|
220
|
+
... X_val=X,
|
|
221
|
+
... y_val=Y,
|
|
222
|
+
... training_config=training_config,
|
|
223
|
+
... )
|
|
250
224
|
"""
|
|
251
225
|
# Input validation
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
226
|
+
X_train, y_train = self._check_XY(X_train, y_train)
|
|
227
|
+
|
|
228
|
+
if X_val is not None:
|
|
229
|
+
assert y_val is not None, "y_val must be provided if X_val is provided."
|
|
230
|
+
if y_val is not None:
|
|
231
|
+
assert X_val is not None, "X_val must be provided if y_val is provided."
|
|
232
|
+
|
|
233
|
+
if X_val is not None and y_val is not None:
|
|
234
|
+
X_val, y_val = self._check_XY(X_val, y_val)
|
|
235
|
+
|
|
236
|
+
if (
|
|
237
|
+
X_train["categorical_variables"] is not None
|
|
238
|
+
and X_val["categorical_variables"] is not None
|
|
239
|
+
):
|
|
240
|
+
assert (
|
|
241
|
+
X_train["categorical_variables"].ndim > 1
|
|
242
|
+
and X_train["categorical_variables"].shape[1]
|
|
243
|
+
== X_val["categorical_variables"].shape[1]
|
|
244
|
+
or X_val["categorical_variables"].ndim == 1
|
|
245
|
+
), "X_train and X_val must have the same number of columns."
|
|
246
|
+
|
|
268
247
|
if verbose:
|
|
269
248
|
logger.info("Starting training process...")
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
if cpu_run:
|
|
273
|
-
device = torch.device("cpu")
|
|
274
|
-
else:
|
|
249
|
+
|
|
250
|
+
if training_config.accelerator == "auto":
|
|
275
251
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
252
|
+
else:
|
|
253
|
+
device = torch.device(training_config.accelerator)
|
|
254
|
+
|
|
255
|
+
self.device = device
|
|
256
|
+
|
|
257
|
+
optimizer_params = {"lr": training_config.lr}
|
|
258
|
+
if training_config.optimizer_params is not None:
|
|
259
|
+
optimizer_params.update(training_config.optimizer_params)
|
|
260
|
+
|
|
261
|
+
if training_config.loss is torch.nn.CrossEntropyLoss and self.ragged_multilabel:
|
|
262
|
+
logger.warning(
|
|
263
|
+
"⚠️ You have set ragged_multilabel to True but are using CrossEntropyLoss. We would recommend to use torch.nn.BCEWithLogitsLoss for multilabel classification tasks."
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
self.lightning_module = TextClassificationModule(
|
|
267
|
+
model=self.pytorch_model,
|
|
268
|
+
loss=training_config.loss,
|
|
269
|
+
optimizer=training_config.optimizer,
|
|
270
|
+
optimizer_params=optimizer_params,
|
|
271
|
+
scheduler=training_config.scheduler,
|
|
272
|
+
scheduler_params=training_config.scheduler_params
|
|
273
|
+
if training_config.scheduler_params
|
|
274
|
+
else {},
|
|
275
|
+
scheduler_interval="epoch",
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
self.pytorch_model.to(self.device)
|
|
279
|
+
|
|
279
280
|
if verbose:
|
|
280
281
|
logger.info(f"Running on: {device}")
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
if
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
self.
|
|
288
|
-
if verbose:
|
|
289
|
-
end = time.time()
|
|
290
|
-
logger.info(f"Model built in {end - start:.2f} seconds.")
|
|
291
|
-
|
|
292
|
-
self.classifier.pytorch_model = self.classifier.pytorch_model.to(device)
|
|
293
|
-
|
|
294
|
-
# Create datasets and dataloaders using wrapper methods
|
|
295
|
-
train_dataset = self.classifier.create_dataset(
|
|
296
|
-
texts=training_text,
|
|
297
|
-
labels=y_train,
|
|
298
|
-
categorical_variables=train_categorical_variables,
|
|
299
|
-
)
|
|
300
|
-
val_dataset = self.classifier.create_dataset(
|
|
301
|
-
texts=val_text,
|
|
302
|
-
labels=y_val,
|
|
303
|
-
categorical_variables=val_categorical_variables,
|
|
304
|
-
)
|
|
305
|
-
|
|
306
|
-
train_dataloader = self.classifier.create_dataloader(
|
|
307
|
-
dataset=train_dataset,
|
|
308
|
-
batch_size=batch_size,
|
|
309
|
-
num_workers=num_workers,
|
|
310
|
-
shuffle=True
|
|
282
|
+
|
|
283
|
+
train_dataset = TextClassificationDataset(
|
|
284
|
+
texts=X_train["text"],
|
|
285
|
+
categorical_variables=X_train["categorical_variables"], # None if no cat vars
|
|
286
|
+
tokenizer=self.tokenizer,
|
|
287
|
+
labels=y_train.tolist(),
|
|
288
|
+
ragged_multilabel=self.ragged_multilabel,
|
|
311
289
|
)
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
290
|
+
train_dataloader = train_dataset.create_dataloader(
|
|
291
|
+
batch_size=training_config.batch_size,
|
|
292
|
+
num_workers=training_config.num_workers,
|
|
293
|
+
shuffle=True,
|
|
294
|
+
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
317
295
|
)
|
|
318
|
-
|
|
296
|
+
|
|
297
|
+
if X_val is not None and y_val is not None:
|
|
298
|
+
val_dataset = TextClassificationDataset(
|
|
299
|
+
texts=X_val["text"],
|
|
300
|
+
categorical_variables=X_val["categorical_variables"], # None if no cat vars
|
|
301
|
+
tokenizer=self.tokenizer,
|
|
302
|
+
labels=y_val,
|
|
303
|
+
ragged_multilabel=self.ragged_multilabel,
|
|
304
|
+
)
|
|
305
|
+
val_dataloader = val_dataset.create_dataloader(
|
|
306
|
+
batch_size=training_config.batch_size,
|
|
307
|
+
num_workers=training_config.num_workers,
|
|
308
|
+
shuffle=False,
|
|
309
|
+
**training_config.dataloader_params if training_config.dataloader_params else {},
|
|
310
|
+
)
|
|
311
|
+
else:
|
|
312
|
+
val_dataloader = None
|
|
313
|
+
|
|
319
314
|
# Setup trainer
|
|
320
315
|
callbacks = [
|
|
321
316
|
ModelCheckpoint(
|
|
322
|
-
monitor="val_loss",
|
|
317
|
+
monitor="val_loss" if val_dataloader is not None else "train_loss",
|
|
323
318
|
save_top_k=1,
|
|
324
319
|
save_last=False,
|
|
325
320
|
mode="min",
|
|
326
321
|
),
|
|
327
322
|
EarlyStopping(
|
|
328
|
-
monitor="val_loss",
|
|
329
|
-
patience=
|
|
323
|
+
monitor="val_loss" if val_dataloader is not None else "train_loss",
|
|
324
|
+
patience=training_config.patience_early_stopping,
|
|
330
325
|
mode="min",
|
|
331
326
|
),
|
|
332
327
|
LearningRateMonitor(logging_interval="step"),
|
|
333
328
|
]
|
|
334
|
-
|
|
335
|
-
|
|
329
|
+
|
|
330
|
+
trainer_params = {
|
|
331
|
+
"accelerator": training_config.accelerator,
|
|
336
332
|
"callbacks": callbacks,
|
|
337
|
-
"max_epochs": num_epochs,
|
|
333
|
+
"max_epochs": training_config.num_epochs,
|
|
338
334
|
"num_sanity_val_steps": 2,
|
|
339
335
|
"strategy": "auto",
|
|
340
336
|
"log_every_n_steps": 1,
|
|
341
337
|
"enable_progress_bar": True,
|
|
342
338
|
}
|
|
343
|
-
|
|
344
|
-
if trainer_params is not None:
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
trainer = pl.Trainer(**
|
|
348
|
-
|
|
339
|
+
|
|
340
|
+
if training_config.trainer_params is not None:
|
|
341
|
+
trainer_params.update(training_config.trainer_params)
|
|
342
|
+
|
|
343
|
+
trainer = pl.Trainer(**trainer_params)
|
|
344
|
+
|
|
349
345
|
torch.cuda.empty_cache()
|
|
350
346
|
torch.set_float32_matmul_precision("medium")
|
|
351
|
-
|
|
347
|
+
|
|
352
348
|
if verbose:
|
|
353
349
|
logger.info("Launching training...")
|
|
354
350
|
start = time.time()
|
|
355
|
-
|
|
356
|
-
trainer.fit(self.
|
|
357
|
-
|
|
351
|
+
|
|
352
|
+
trainer.fit(self.lightning_module, train_dataloader, val_dataloader)
|
|
353
|
+
|
|
358
354
|
if verbose:
|
|
359
355
|
end = time.time()
|
|
360
356
|
logger.info(f"Training completed in {end - start:.2f} seconds.")
|
|
361
|
-
|
|
362
|
-
# Load best model using wrapper method
|
|
357
|
+
|
|
363
358
|
best_model_path = trainer.checkpoint_callback.best_model_path
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
X
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
return
|
|
399
|
-
|
|
400
|
-
def
|
|
401
|
-
"""
|
|
402
|
-
|
|
403
|
-
This method provides both predictions and explanations for the model's
|
|
404
|
-
decisions. Availability depends on the specific classifier implementation.
|
|
405
|
-
|
|
359
|
+
|
|
360
|
+
self.lightning_module = TextClassificationModule.load_from_checkpoint(
|
|
361
|
+
best_model_path,
|
|
362
|
+
model=self.pytorch_model,
|
|
363
|
+
loss=training_config.loss,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
self.pytorch_model = self.lightning_module.model.to(self.device)
|
|
367
|
+
|
|
368
|
+
self.lightning_module.eval()
|
|
369
|
+
|
|
370
|
+
def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
371
|
+
X = self._check_X(X)
|
|
372
|
+
Y = self._check_Y(Y)
|
|
373
|
+
|
|
374
|
+
if X["text"].shape[0] != len(Y):
|
|
375
|
+
raise ValueError("X_train and y_train must have the same number of observations.")
|
|
376
|
+
|
|
377
|
+
return X, Y
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
def _check_text_col(X):
|
|
381
|
+
assert isinstance(
|
|
382
|
+
X, np.ndarray
|
|
383
|
+
), "X must be a numpy array of shape (N,d), with the first column being the text and the rest being the categorical variables."
|
|
384
|
+
|
|
385
|
+
try:
|
|
386
|
+
if X.ndim > 1:
|
|
387
|
+
text = X[:, 0].astype(str)
|
|
388
|
+
else:
|
|
389
|
+
text = X[:].astype(str)
|
|
390
|
+
except ValueError:
|
|
391
|
+
logger.error("The first column of X must be castable in string format.")
|
|
392
|
+
|
|
393
|
+
return text
|
|
394
|
+
|
|
395
|
+
def _check_categorical_variables(self, X: np.ndarray) -> None:
|
|
396
|
+
"""Check if categorical variables in X match training configuration.
|
|
397
|
+
|
|
406
398
|
Args:
|
|
407
|
-
X: Input data
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
Returns:
|
|
411
|
-
tuple: (predictions, explanations) where explanations format depends
|
|
412
|
-
on the classifier type
|
|
413
|
-
|
|
399
|
+
X: Input data to check
|
|
400
|
+
|
|
414
401
|
Raises:
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
Example:
|
|
418
|
-
>>> predictions, explanations = classifier.predict_and_explain(X_test)
|
|
419
|
-
>>> print(f"Predictions: {predictions}")
|
|
420
|
-
>>> print(f"Explanations: {explanations}")
|
|
402
|
+
ValueError: If the number of categorical variables does not match
|
|
403
|
+
the training configuration
|
|
421
404
|
"""
|
|
422
|
-
|
|
423
|
-
|
|
405
|
+
|
|
406
|
+
assert self.categorical_var_net is not None
|
|
407
|
+
|
|
408
|
+
if X.ndim > 1:
|
|
409
|
+
num_cat_vars = X.shape[1] - 1
|
|
424
410
|
else:
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
411
|
+
num_cat_vars = 0
|
|
412
|
+
|
|
413
|
+
if num_cat_vars != self.categorical_var_net.num_categorical_features:
|
|
414
|
+
raise ValueError(
|
|
415
|
+
f"X must have the same number of categorical variables as the number of embedding layers in the categorical net: ({self.categorical_var_net.num_categorical_features})."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
categorical_variables = X[:, 1:].astype(int)
|
|
420
|
+
except ValueError:
|
|
421
|
+
logger.error(
|
|
422
|
+
f"Columns {1} to {X.shape[1] - 1} of X_train must be castable in integer format."
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
for j in range(X.shape[1] - 1):
|
|
426
|
+
max_cat_value = categorical_variables[:, j].max()
|
|
427
|
+
if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]:
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Categorical variable at index {j} has value {max_cat_value} which exceeds the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
return categorical_variables
|
|
433
|
+
|
|
434
|
+
def _check_X(self, X: np.ndarray) -> np.ndarray:
|
|
435
|
+
text = self._check_text_col(X)
|
|
436
|
+
|
|
437
|
+
categorical_variables = None
|
|
438
|
+
if self.categorical_var_net is not None:
|
|
439
|
+
categorical_variables = self._check_categorical_variables(X)
|
|
440
|
+
|
|
441
|
+
return {"text": text, "categorical_variables": categorical_variables}
|
|
442
|
+
|
|
443
|
+
def _check_Y(self, Y):
|
|
444
|
+
if self.ragged_multilabel:
|
|
445
|
+
assert isinstance(
|
|
446
|
+
Y, list
|
|
447
|
+
), "Y must be a list of lists for ragged multilabel classification."
|
|
448
|
+
for row in Y:
|
|
449
|
+
assert isinstance(row, list), "Each element of Y must be a list of labels."
|
|
450
|
+
|
|
451
|
+
return Y
|
|
452
|
+
|
|
453
|
+
else:
|
|
454
|
+
assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
|
|
455
|
+
assert (
|
|
456
|
+
len(Y.shape) == 1 or len(Y.shape) == 2
|
|
457
|
+
), "Y must be a numpy array of shape (N,) or (N, num_labels)."
|
|
458
|
+
|
|
459
|
+
try:
|
|
460
|
+
Y = Y.astype(int)
|
|
461
|
+
except ValueError:
|
|
462
|
+
logger.error("Y must be castable in integer format.")
|
|
463
|
+
|
|
464
|
+
if Y.max() >= self.num_classes or Y.min() < 0:
|
|
465
|
+
raise ValueError(
|
|
466
|
+
f"Y contains class labels outside the range [0, {self.num_classes - 1}]."
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
return Y
|
|
470
|
+
|
|
471
|
+
def predict(
|
|
472
|
+
self,
|
|
473
|
+
X_test: np.ndarray,
|
|
474
|
+
top_k=1,
|
|
475
|
+
explain=False,
|
|
476
|
+
):
|
|
440
477
|
"""
|
|
441
|
-
with open(filepath, "w") as f:
|
|
442
|
-
data = {
|
|
443
|
-
"config": self.config.to_dict(),
|
|
444
|
-
}
|
|
445
|
-
|
|
446
|
-
# Try to get wrapper class info for reconstruction
|
|
447
|
-
if hasattr(self.classifier.__class__, 'get_wrapper_class_info'):
|
|
448
|
-
data["wrapper_class_info"] = self.classifier.__class__.get_wrapper_class_info()
|
|
449
|
-
else:
|
|
450
|
-
# Fallback: store module and class name
|
|
451
|
-
data["wrapper_class_info"] = {
|
|
452
|
-
"module": self.classifier.__class__.__module__,
|
|
453
|
-
"class_name": self.classifier.__class__.__name__
|
|
454
|
-
}
|
|
455
|
-
|
|
456
|
-
json.dump(data, f, cls=NumpyJSONEncoder, indent=4)
|
|
457
|
-
|
|
458
|
-
@classmethod
|
|
459
|
-
def from_json(cls, filepath: str, wrapper_class: Optional[Type[BaseClassifierWrapper]] = None) -> "torchTextClassifiers":
|
|
460
|
-
"""Load classifier configuration from JSON file.
|
|
461
|
-
|
|
462
|
-
This method creates a new classifier instance from a previously saved
|
|
463
|
-
configuration file. The classifier will need to be built and trained again.
|
|
464
|
-
|
|
465
478
|
Args:
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
Returns:
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
FileNotFoundError: If the configuration file doesn't exist
|
|
476
|
-
|
|
477
|
-
Example:
|
|
478
|
-
>>> # Using saved wrapper class info
|
|
479
|
-
>>> classifier = torchTextClassifiers.from_json('my_classifier_config.json')
|
|
480
|
-
>>>
|
|
481
|
-
>>> # Or providing wrapper class explicitly
|
|
482
|
-
>>> from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
|
|
483
|
-
>>> classifier = torchTextClassifiers.from_json('config.json', FastTextWrapper)
|
|
479
|
+
X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
|
|
480
|
+
top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
|
|
481
|
+
explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
|
|
482
|
+
|
|
483
|
+
Returns: A dictionary containing the following fields:
|
|
484
|
+
- predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
|
|
485
|
+
- confidence (torch.Tensor, shape (len(text), top_k)): A tensor array containing the corresponding confidence scores.
|
|
486
|
+
- if explain is True:
|
|
487
|
+
- attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
|
|
484
488
|
"""
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
489
|
+
|
|
490
|
+
if explain:
|
|
491
|
+
return_offsets_mapping = True # to be passed to the tokenizer
|
|
492
|
+
return_word_ids = True
|
|
493
|
+
if self.pytorch_model.text_embedder is None:
|
|
494
|
+
raise RuntimeError(
|
|
495
|
+
"Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
|
|
496
|
+
)
|
|
497
|
+
else:
|
|
498
|
+
if not HAS_CAPTUM:
|
|
499
|
+
raise ImportError(
|
|
500
|
+
"Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
|
|
501
|
+
)
|
|
502
|
+
lig = LayerIntegratedGradients(
|
|
503
|
+
self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
|
|
504
|
+
) # initialize a Captum layer gradient integrator
|
|
505
|
+
else:
|
|
506
|
+
return_offsets_mapping = False
|
|
507
|
+
return_word_ids = False
|
|
508
|
+
|
|
509
|
+
X_test = self._check_X(X_test)
|
|
510
|
+
text = X_test["text"]
|
|
511
|
+
categorical_variables = X_test["categorical_variables"]
|
|
512
|
+
|
|
513
|
+
self.pytorch_model.eval().cpu()
|
|
514
|
+
|
|
515
|
+
tokenize_output = self.tokenizer.tokenize(
|
|
516
|
+
text.tolist(),
|
|
517
|
+
return_offsets_mapping=return_offsets_mapping,
|
|
518
|
+
return_word_ids=return_word_ids,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
if not isinstance(tokenize_output, TokenizerOutput):
|
|
522
|
+
raise TypeError(
|
|
523
|
+
f"Expected TokenizerOutput, got {type(tokenize_output)} from tokenizer.tokenize method."
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
encoded_text = tokenize_output.input_ids # (batch_size, seq_len)
|
|
527
|
+
attention_mask = tokenize_output.attention_mask # (batch_size, seq_len)
|
|
528
|
+
|
|
529
|
+
if categorical_variables is not None:
|
|
530
|
+
categorical_vars = torch.tensor(
|
|
531
|
+
categorical_variables, dtype=torch.float32
|
|
532
|
+
) # (batch_size, num_categorical_features)
|
|
533
|
+
else:
|
|
534
|
+
categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32)
|
|
535
|
+
|
|
536
|
+
pred = self.pytorch_model(
|
|
537
|
+
encoded_text, attention_mask, categorical_vars
|
|
538
|
+
) # forward pass, contains the prediction scores (len(text), num_classes)
|
|
539
|
+
|
|
540
|
+
label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities
|
|
541
|
+
|
|
542
|
+
label_scores_topk = torch.topk(label_scores, k=top_k, dim=1)
|
|
543
|
+
|
|
544
|
+
predictions = label_scores_topk.indices # get the top_k most likely predictions
|
|
545
|
+
confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
|
|
546
|
+
|
|
547
|
+
if explain:
|
|
548
|
+
all_attributions = []
|
|
549
|
+
for k in range(top_k):
|
|
550
|
+
attributions = lig.attribute(
|
|
551
|
+
(encoded_text, attention_mask, categorical_vars),
|
|
552
|
+
target=torch.Tensor(predictions[:, k]).long(),
|
|
553
|
+
) # (batch_size, seq_len)
|
|
554
|
+
attributions = attributions.sum(dim=-1)
|
|
555
|
+
all_attributions.append(attributions.detach().cpu())
|
|
556
|
+
|
|
557
|
+
all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
|
|
558
|
+
|
|
559
|
+
return {
|
|
560
|
+
"prediction": predictions,
|
|
561
|
+
"confidence": confidence,
|
|
562
|
+
"attributions": all_attributions,
|
|
563
|
+
"offset_mapping": tokenize_output.offset_mapping,
|
|
564
|
+
"word_ids": tokenize_output.word_ids,
|
|
565
|
+
}
|
|
566
|
+
else:
|
|
567
|
+
return {
|
|
568
|
+
"prediction": predictions,
|
|
569
|
+
"confidence": confidence,
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
def __repr__(self):
|
|
573
|
+
model_type = (
|
|
574
|
+
self.lightning_module.__repr__()
|
|
575
|
+
if hasattr(self, "lightning_module")
|
|
576
|
+
else self.pytorch_model.__repr__()
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
tokenizer_info = self.tokenizer.__repr__()
|
|
580
|
+
|
|
581
|
+
cat_forward_type = (
|
|
582
|
+
self.categorical_var_net.forward_type.name
|
|
583
|
+
if self.categorical_var_net is not None
|
|
584
|
+
else "None"
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
lines = [
|
|
588
|
+
"torchTextClassifiers(",
|
|
589
|
+
f" tokenizer = {tokenizer_info},",
|
|
590
|
+
f" model = {model_type},",
|
|
591
|
+
f" categorical_forward_type = {cat_forward_type},",
|
|
592
|
+
f" num_classes = {self.model_config.num_classes},",
|
|
593
|
+
f" embedding_dim = {self.embedding_dim},",
|
|
594
|
+
")",
|
|
595
|
+
]
|
|
596
|
+
return "\n".join(lines)
|