torchtextclassifiers 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchTextClassifiers/__init__.py +12 -48
- torchTextClassifiers/dataset/__init__.py +1 -0
- torchTextClassifiers/dataset/dataset.py +114 -0
- torchTextClassifiers/model/__init__.py +2 -0
- torchTextClassifiers/model/components/__init__.py +12 -0
- torchTextClassifiers/model/components/attention.py +126 -0
- torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchTextClassifiers/model/components/classification_head.py +43 -0
- torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchTextClassifiers/model/lightning.py +166 -0
- torchTextClassifiers/model/model.py +151 -0
- torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchTextClassifiers/tokenizers/base.py +205 -0
- torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchTextClassifiers/torchTextClassifiers.py +463 -405
- torchTextClassifiers/utilities/__init__.py +0 -3
- torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-0.1.0.dist-info/METADATA +73 -0
- torchtextclassifiers-0.1.0.dist-info/RECORD +21 -0
- {torchtextclassifiers-0.0.1.dist-info → torchtextclassifiers-0.1.0.dist-info}/WHEEL +1 -1
- torchTextClassifiers/classifiers/base.py +0 -83
- torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchTextClassifiers/factories.py +0 -34
- torchTextClassifiers/utilities/checkers.py +0 -108
- torchTextClassifiers/utilities/preprocess.py +0 -82
- torchTextClassifiers/utilities/utils.py +0 -346
- torchtextclassifiers-0.0.1.dist-info/METADATA +0 -187
- torchtextclassifiers-0.0.1.dist-info/RECORD +0 -17
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class TextEmbedderConfig:
|
|
13
|
+
vocab_size: int
|
|
14
|
+
embedding_dim: int
|
|
15
|
+
padding_idx: int
|
|
16
|
+
attention_config: Optional[AttentionConfig] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TextEmbedder(nn.Module):
|
|
20
|
+
def __init__(self, text_embedder_config: TextEmbedderConfig):
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
self.config = text_embedder_config
|
|
24
|
+
|
|
25
|
+
self.attention_config = text_embedder_config.attention_config
|
|
26
|
+
if self.attention_config is not None:
|
|
27
|
+
self.attention_config.n_embd = text_embedder_config.embedding_dim
|
|
28
|
+
|
|
29
|
+
self.vocab_size = text_embedder_config.vocab_size
|
|
30
|
+
self.embedding_dim = text_embedder_config.embedding_dim
|
|
31
|
+
self.padding_idx = text_embedder_config.padding_idx
|
|
32
|
+
|
|
33
|
+
self.embedding_layer = nn.Embedding(
|
|
34
|
+
embedding_dim=self.embedding_dim,
|
|
35
|
+
num_embeddings=self.vocab_size,
|
|
36
|
+
padding_idx=self.padding_idx,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if self.attention_config is not None:
|
|
40
|
+
self.transformer = nn.ModuleDict(
|
|
41
|
+
{
|
|
42
|
+
"h": nn.ModuleList(
|
|
43
|
+
[
|
|
44
|
+
Block(self.attention_config, layer_idx)
|
|
45
|
+
for layer_idx in range(self.attention_config.n_layers)
|
|
46
|
+
]
|
|
47
|
+
),
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
head_dim = self.attention_config.n_embd // self.attention_config.n_head
|
|
52
|
+
|
|
53
|
+
if head_dim * self.attention_config.n_head != self.attention_config.n_embd:
|
|
54
|
+
raise ValueError("embedding_dim must be divisible by n_head.")
|
|
55
|
+
|
|
56
|
+
if self.attention_config.positional_encoding:
|
|
57
|
+
if head_dim % 2 != 0:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"embedding_dim / n_head must be even for rotary positional embeddings."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if self.attention_config.sequence_len is None:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
"sequence_len must be specified in AttentionConfig when positional_encoding is True."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.rotary_seq_len = self.attention_config.sequence_len * 10
|
|
68
|
+
cos, sin = self._precompute_rotary_embeddings(
|
|
69
|
+
seq_len=self.rotary_seq_len, head_dim=head_dim
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
self.register_buffer(
|
|
73
|
+
"cos", cos, persistent=False
|
|
74
|
+
) # persistent=False means it's not saved to the checkpoint
|
|
75
|
+
self.register_buffer("sin", sin, persistent=False)
|
|
76
|
+
|
|
77
|
+
def init_weights(self):
|
|
78
|
+
self.apply(self._init_weights)
|
|
79
|
+
|
|
80
|
+
# zero out c_proj weights in all blocks
|
|
81
|
+
if self.attention_config is not None:
|
|
82
|
+
for block in self.transformer.h:
|
|
83
|
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
|
84
|
+
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
|
85
|
+
# init the rotary embeddings
|
|
86
|
+
head_dim = self.attention_config.n_embd // self.attention_config.n_head
|
|
87
|
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
88
|
+
self.cos, self.sin = cos, sin
|
|
89
|
+
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
|
90
|
+
if self.embedding_layer.weight.device.type == "cuda":
|
|
91
|
+
self.embedding_layer.to(dtype=torch.bfloat16)
|
|
92
|
+
|
|
93
|
+
def _init_weights(self, module):
|
|
94
|
+
if isinstance(module, nn.Linear):
|
|
95
|
+
# https://arxiv.org/pdf/2310.17813
|
|
96
|
+
fan_out = module.weight.size(0)
|
|
97
|
+
fan_in = module.weight.size(1)
|
|
98
|
+
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
|
99
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
100
|
+
if module.bias is not None:
|
|
101
|
+
torch.nn.init.zeros_(module.bias)
|
|
102
|
+
elif isinstance(module, nn.Embedding):
|
|
103
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
|
104
|
+
|
|
105
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
106
|
+
"""Converts input token IDs to their corresponding embeddings."""
|
|
107
|
+
|
|
108
|
+
encoded_text = input_ids # clearer name
|
|
109
|
+
if encoded_text.dtype != torch.long:
|
|
110
|
+
encoded_text = encoded_text.to(torch.long)
|
|
111
|
+
|
|
112
|
+
batch_size, seq_len = encoded_text.shape
|
|
113
|
+
batch_size_check, seq_len_check = attention_mask.shape
|
|
114
|
+
|
|
115
|
+
if batch_size != batch_size_check or seq_len != seq_len_check:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Input IDs and attention mask must have the same batch size and sequence length. "
|
|
118
|
+
f"Got input_ids shape {encoded_text.shape} and attention_mask shape {attention_mask.shape}."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
token_embeddings = self.embedding_layer(
|
|
122
|
+
encoded_text
|
|
123
|
+
) # (batch_size, seq_len, embedding_dim)
|
|
124
|
+
|
|
125
|
+
token_embeddings = norm(token_embeddings)
|
|
126
|
+
|
|
127
|
+
if self.attention_config is not None:
|
|
128
|
+
if self.attention_config.positional_encoding:
|
|
129
|
+
cos_sin = self.cos[:, :seq_len], self.sin[:, :seq_len]
|
|
130
|
+
else:
|
|
131
|
+
cos_sin = None
|
|
132
|
+
|
|
133
|
+
for block in self.transformer.h:
|
|
134
|
+
token_embeddings = block(token_embeddings, cos_sin)
|
|
135
|
+
|
|
136
|
+
token_embeddings = norm(token_embeddings)
|
|
137
|
+
|
|
138
|
+
text_embedding = self._get_sentence_embedding(
|
|
139
|
+
token_embeddings=token_embeddings, attention_mask=attention_mask
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return text_embedding
|
|
143
|
+
|
|
144
|
+
def _get_sentence_embedding(
|
|
145
|
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
146
|
+
) -> torch.Tensor:
|
|
147
|
+
"""
|
|
148
|
+
Compute sentence embedding from embedded tokens - "remove" second dimension.
|
|
149
|
+
|
|
150
|
+
Args (output from dataset collate_fn):
|
|
151
|
+
token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
|
|
152
|
+
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
|
|
153
|
+
Returns:
|
|
154
|
+
torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim)
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
# average over non-pad token embeddings
|
|
158
|
+
# attention mask has 1 for non-pad tokens and 0 for pad token positions
|
|
159
|
+
|
|
160
|
+
# mask pad-tokens
|
|
161
|
+
|
|
162
|
+
if self.attention_config is not None:
|
|
163
|
+
if self.attention_config.aggregation_method is not None:
|
|
164
|
+
if self.attention_config.aggregation_method == "first":
|
|
165
|
+
return token_embeddings[:, 0, :]
|
|
166
|
+
elif self.attention_config.aggregation_method == "last":
|
|
167
|
+
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
|
|
168
|
+
return token_embeddings[
|
|
169
|
+
torch.arange(token_embeddings.size(0)),
|
|
170
|
+
lengths - 1,
|
|
171
|
+
:,
|
|
172
|
+
]
|
|
173
|
+
else:
|
|
174
|
+
if self.attention_config.aggregation_method != "mean":
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
|
|
180
|
+
|
|
181
|
+
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
|
|
182
|
+
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
|
|
183
|
+
|
|
184
|
+
sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
|
|
185
|
+
min=1.0
|
|
186
|
+
) # avoid division by zero
|
|
187
|
+
|
|
188
|
+
sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
|
|
189
|
+
|
|
190
|
+
return sentence_embedding
|
|
191
|
+
|
|
192
|
+
def __call__(self, *args, **kwargs):
|
|
193
|
+
out = super().__call__(*args, **kwargs)
|
|
194
|
+
if out.dim() != 2:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"Output of {self.__class__.__name__}.forward must be 2D "
|
|
197
|
+
f"(got shape {tuple(out.shape)})"
|
|
198
|
+
)
|
|
199
|
+
return out
|
|
200
|
+
|
|
201
|
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
|
202
|
+
# autodetect the device from model embeddings
|
|
203
|
+
if device is None:
|
|
204
|
+
device = next(self.parameters()).device
|
|
205
|
+
|
|
206
|
+
# stride the channels
|
|
207
|
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
208
|
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
209
|
+
# stride the time steps
|
|
210
|
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
211
|
+
# calculate the rotation frequencies at each (time, channel) pair
|
|
212
|
+
freqs = torch.outer(t, inv_freq)
|
|
213
|
+
cos, sin = freqs.cos(), freqs.sin()
|
|
214
|
+
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
|
215
|
+
cos, sin = (
|
|
216
|
+
cos[None, :, None, :],
|
|
217
|
+
sin[None, :, None, :],
|
|
218
|
+
) # add batch and head dims for later broadcasting
|
|
219
|
+
|
|
220
|
+
return cos, sin
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
import pytorch_lightning as pl
|
|
2
|
+
import torch
|
|
3
|
+
from torchmetrics import Accuracy
|
|
4
|
+
|
|
5
|
+
from .model import TextClassificationModel
|
|
6
|
+
|
|
7
|
+
# ============================================================================
|
|
8
|
+
# PyTorch Lightning Module
|
|
9
|
+
# ============================================================================
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TextClassificationModule(pl.LightningModule):
|
|
13
|
+
"""Pytorch Lightning Module for FastTextModel."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
model: TextClassificationModel,
|
|
18
|
+
loss,
|
|
19
|
+
optimizer,
|
|
20
|
+
optimizer_params,
|
|
21
|
+
scheduler,
|
|
22
|
+
scheduler_params,
|
|
23
|
+
scheduler_interval="epoch",
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize FastTextModule.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model: Model.
|
|
31
|
+
loss: Loss
|
|
32
|
+
optimizer: Optimizer
|
|
33
|
+
optimizer_params: Optimizer parameters.
|
|
34
|
+
scheduler: Scheduler.
|
|
35
|
+
scheduler_params: Scheduler parameters.
|
|
36
|
+
scheduler_interval: Scheduler interval.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.save_hyperparameters(ignore=["model", "loss"])
|
|
40
|
+
|
|
41
|
+
self.model = model
|
|
42
|
+
self.loss = loss
|
|
43
|
+
self.accuracy_fn = Accuracy(task="multiclass", num_classes=self.model.num_classes)
|
|
44
|
+
self.optimizer = optimizer
|
|
45
|
+
self.optimizer_params = optimizer_params
|
|
46
|
+
self.scheduler = scheduler
|
|
47
|
+
self.scheduler_params = scheduler_params
|
|
48
|
+
self.scheduler_interval = scheduler_interval
|
|
49
|
+
|
|
50
|
+
def forward(self, batch) -> torch.Tensor:
|
|
51
|
+
"""
|
|
52
|
+
Perform forward-pass.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
batch (List[torch.LongTensor]): Batch to perform forward-pass on.
|
|
56
|
+
|
|
57
|
+
Returns (torch.Tensor): Prediction.
|
|
58
|
+
"""
|
|
59
|
+
return self.model(
|
|
60
|
+
input_ids=batch["input_ids"],
|
|
61
|
+
attention_mask=batch["attention_mask"],
|
|
62
|
+
categorical_vars=batch.get("categorical_vars", None),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def training_step(self, batch, batch_idx: int) -> torch.Tensor:
|
|
66
|
+
"""
|
|
67
|
+
Training step.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
batch (List[torch.LongTensor]): Training batch.
|
|
71
|
+
batch_idx (int): Batch index.
|
|
72
|
+
|
|
73
|
+
Returns (torch.Tensor): Loss tensor.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
targets = batch["labels"]
|
|
77
|
+
|
|
78
|
+
outputs = self.forward(batch)
|
|
79
|
+
loss = self.loss(outputs, targets)
|
|
80
|
+
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
|
|
81
|
+
accuracy = self.accuracy_fn(outputs, targets)
|
|
82
|
+
self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
|
|
83
|
+
|
|
84
|
+
torch.cuda.empty_cache()
|
|
85
|
+
|
|
86
|
+
return loss
|
|
87
|
+
|
|
88
|
+
def validation_step(self, batch, batch_idx: int):
|
|
89
|
+
"""
|
|
90
|
+
Validation step.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
batch (List[torch.LongTensor]): Validation batch.
|
|
94
|
+
batch_idx (int): Batch index.
|
|
95
|
+
|
|
96
|
+
Returns (torch.Tensor): Loss tensor.
|
|
97
|
+
"""
|
|
98
|
+
targets = batch["labels"]
|
|
99
|
+
|
|
100
|
+
outputs = self.forward(batch)
|
|
101
|
+
loss = self.loss(outputs, targets)
|
|
102
|
+
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
|
|
103
|
+
|
|
104
|
+
accuracy = self.accuracy_fn(outputs, targets)
|
|
105
|
+
self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True)
|
|
106
|
+
return loss
|
|
107
|
+
|
|
108
|
+
def test_step(self, batch, batch_idx: int):
|
|
109
|
+
"""
|
|
110
|
+
Test step.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
batch (List[torch.LongTensor]): Test batch.
|
|
114
|
+
batch_idx (int): Batch index.
|
|
115
|
+
|
|
116
|
+
Returns (torch.Tensor): Loss tensor.
|
|
117
|
+
"""
|
|
118
|
+
targets = batch["labels"]
|
|
119
|
+
|
|
120
|
+
outputs = self.forward(batch)
|
|
121
|
+
loss = self.loss(outputs, targets)
|
|
122
|
+
|
|
123
|
+
accuracy = self.accuracy_fn(outputs, targets)
|
|
124
|
+
|
|
125
|
+
return loss, accuracy
|
|
126
|
+
|
|
127
|
+
def predict_step(self, batch, batch_idx: int = 0, dataloader_idx: int = 0):
|
|
128
|
+
"""
|
|
129
|
+
Prediction step.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
batch (List[torch.LongTensor]): Prediction batch.
|
|
133
|
+
batch_idx (int): Batch index.
|
|
134
|
+
dataloader_idx (int): Dataloader index.
|
|
135
|
+
|
|
136
|
+
Returns (torch.Tensor): Predictions.
|
|
137
|
+
"""
|
|
138
|
+
outputs = self.forward(batch)
|
|
139
|
+
return outputs
|
|
140
|
+
|
|
141
|
+
def configure_optimizers(self):
|
|
142
|
+
"""
|
|
143
|
+
Configure optimizer for Pytorch lighting.
|
|
144
|
+
|
|
145
|
+
Returns: Optimizer and scheduler for pytorch lighting.
|
|
146
|
+
"""
|
|
147
|
+
optimizer = self.optimizer(self.parameters(), **self.optimizer_params)
|
|
148
|
+
|
|
149
|
+
if self.scheduler is None:
|
|
150
|
+
return optimizer
|
|
151
|
+
|
|
152
|
+
# Only use scheduler if it's not ReduceLROnPlateau or if we can ensure val_loss is available
|
|
153
|
+
# For complex training setups, sometimes val_loss is not available every epoch
|
|
154
|
+
if hasattr(self.scheduler, "__name__") and "ReduceLROnPlateau" in self.scheduler.__name__:
|
|
155
|
+
# For ReduceLROnPlateau, use train_loss as it's always available
|
|
156
|
+
scheduler = self.scheduler(optimizer, **self.scheduler_params)
|
|
157
|
+
scheduler_config = {
|
|
158
|
+
"scheduler": scheduler,
|
|
159
|
+
"monitor": "train_loss",
|
|
160
|
+
"interval": self.scheduler_interval,
|
|
161
|
+
}
|
|
162
|
+
return [optimizer], [scheduler_config]
|
|
163
|
+
else:
|
|
164
|
+
# For other schedulers (StepLR, etc.), no monitoring needed
|
|
165
|
+
scheduler = self.scheduler(optimizer, **self.scheduler_params)
|
|
166
|
+
return [optimizer], [scheduler]
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""FastText model components.
|
|
2
|
+
|
|
3
|
+
This module contains the PyTorch model, Lightning module, and dataset classes
|
|
4
|
+
for FastText classification. Consolidates what was previously in pytorch_model.py,
|
|
5
|
+
lightning_module.py, and dataset.py.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Annotated, Optional
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch import nn
|
|
13
|
+
|
|
14
|
+
from torchTextClassifiers.model.components import (
|
|
15
|
+
CategoricalForwardType,
|
|
16
|
+
CategoricalVariableNet,
|
|
17
|
+
ClassificationHead,
|
|
18
|
+
TextEmbedder,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
logging.basicConfig(
|
|
24
|
+
level=logging.INFO,
|
|
25
|
+
format="%(asctime)s - %(name)s - %(message)s",
|
|
26
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
27
|
+
handlers=[logging.StreamHandler()],
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ============================================================================
|
|
32
|
+
# PyTorch Model
|
|
33
|
+
|
|
34
|
+
# It takes PyTorch tensors as input (not raw text!),
|
|
35
|
+
# and it outputs raw not-softmaxed logits, not predictions
|
|
36
|
+
# ============================================================================
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TextClassificationModel(nn.Module):
|
|
40
|
+
"""FastText Pytorch Model."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
classification_head: ClassificationHead,
|
|
45
|
+
text_embedder: Optional[TextEmbedder] = None,
|
|
46
|
+
categorical_variable_net: Optional[CategoricalVariableNet] = None,
|
|
47
|
+
):
|
|
48
|
+
"""
|
|
49
|
+
Constructor for the FastTextModel class.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
classification_head (ClassificationHead): The classification head module.
|
|
53
|
+
text_embedder (Optional[TextEmbedder]): The text embedding module.
|
|
54
|
+
If not provided, assumes that input text is already embedded (as tensors) and directly passed to the classification head.
|
|
55
|
+
categorical_variable_net (Optional[CategoricalVariableNet]): The categorical variable network module.
|
|
56
|
+
If not provided, assumes no categorical variables are used.
|
|
57
|
+
"""
|
|
58
|
+
super().__init__()
|
|
59
|
+
|
|
60
|
+
self.text_embedder = text_embedder
|
|
61
|
+
|
|
62
|
+
self.categorical_variable_net = categorical_variable_net
|
|
63
|
+
if not self.categorical_variable_net:
|
|
64
|
+
logger.info("🔹 No categorical variable network provided; using only text embeddings.")
|
|
65
|
+
|
|
66
|
+
self.classification_head = classification_head
|
|
67
|
+
|
|
68
|
+
self._validate_component_connections()
|
|
69
|
+
|
|
70
|
+
self.num_classes = self.classification_head.num_classes
|
|
71
|
+
|
|
72
|
+
torch.nn.init.zeros_(self.classification_head.net.weight)
|
|
73
|
+
if self.text_embedder is not None:
|
|
74
|
+
self.text_embedder.init_weights()
|
|
75
|
+
|
|
76
|
+
def _validate_component_connections(self):
|
|
77
|
+
def _check_text_categorical_connection(self, text_embedder, cat_var_net):
|
|
78
|
+
if cat_var_net.forward_type == CategoricalForwardType.SUM_TO_TEXT:
|
|
79
|
+
if text_embedder.embedding_dim != cat_var_net.output_dim:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"Text embedding dimension must match categorical variable embedding dimension."
|
|
82
|
+
)
|
|
83
|
+
self.expected_classification_head_input_dim = text_embedder.embedding_dim
|
|
84
|
+
else:
|
|
85
|
+
self.expected_classification_head_input_dim = (
|
|
86
|
+
text_embedder.embedding_dim + cat_var_net.output_dim
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if self.text_embedder:
|
|
90
|
+
if self.categorical_variable_net:
|
|
91
|
+
_check_text_categorical_connection(
|
|
92
|
+
self, self.text_embedder, self.categorical_variable_net
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
self.expected_classification_head_input_dim = self.text_embedder.embedding_dim
|
|
96
|
+
|
|
97
|
+
if self.expected_classification_head_input_dim != self.classification_head.input_dim:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
"Classification head input dimension does not match expected dimension from text embedder and categorical variable net."
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
logger.warning(
|
|
103
|
+
"⚠️ No text embedder provided; assuming input text is already embedded or vectorized. Take care that the classification head input dimension matches the input text dimension."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def forward(
|
|
107
|
+
self,
|
|
108
|
+
input_ids: Annotated[torch.Tensor, "batch seq_len"],
|
|
109
|
+
attention_mask: Annotated[torch.Tensor, "batch seq_len"],
|
|
110
|
+
categorical_vars: Annotated[torch.Tensor, "batch num_cats"],
|
|
111
|
+
**kwargs,
|
|
112
|
+
) -> torch.Tensor:
|
|
113
|
+
"""
|
|
114
|
+
Memory-efficient forward pass implementation.
|
|
115
|
+
|
|
116
|
+
Args: output from dataset collate_fn
|
|
117
|
+
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
|
|
118
|
+
attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
|
|
119
|
+
categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
torch.Tensor: Model output scores for each class - shape (batch_size, num_classes)
|
|
123
|
+
Raw, not softmaxed.
|
|
124
|
+
"""
|
|
125
|
+
encoded_text = input_ids # clearer name
|
|
126
|
+
if self.text_embedder is None:
|
|
127
|
+
x_text = encoded_text.float()
|
|
128
|
+
else:
|
|
129
|
+
x_text = self.text_embedder(input_ids=encoded_text, attention_mask=attention_mask)
|
|
130
|
+
|
|
131
|
+
if self.categorical_variable_net:
|
|
132
|
+
x_cat = self.categorical_variable_net(categorical_vars)
|
|
133
|
+
|
|
134
|
+
if (
|
|
135
|
+
self.categorical_variable_net.forward_type
|
|
136
|
+
== CategoricalForwardType.AVERAGE_AND_CONCAT
|
|
137
|
+
or self.categorical_variable_net.forward_type
|
|
138
|
+
== CategoricalForwardType.CONCATENATE_ALL
|
|
139
|
+
):
|
|
140
|
+
x_combined = torch.cat((x_text, x_cat), dim=1)
|
|
141
|
+
else:
|
|
142
|
+
assert (
|
|
143
|
+
self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT
|
|
144
|
+
)
|
|
145
|
+
x_combined = x_text + x_cat
|
|
146
|
+
else:
|
|
147
|
+
x_combined = x_text
|
|
148
|
+
|
|
149
|
+
logits = self.classification_head(x_combined)
|
|
150
|
+
|
|
151
|
+
return logits
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
|
|
5
|
+
from torchTextClassifiers.tokenizers import HAS_HF, HuggingFaceTokenizer
|
|
6
|
+
|
|
7
|
+
if not HAS_HF:
|
|
8
|
+
raise ImportError(
|
|
9
|
+
"The HuggingFace dependencies are needed to use this tokenizer. Please run 'uv add torchTextClassifiers --extra huggingface."
|
|
10
|
+
)
|
|
11
|
+
else:
|
|
12
|
+
from tokenizers import (
|
|
13
|
+
Tokenizer,
|
|
14
|
+
decoders,
|
|
15
|
+
models,
|
|
16
|
+
normalizers,
|
|
17
|
+
pre_tokenizers,
|
|
18
|
+
processors,
|
|
19
|
+
trainers,
|
|
20
|
+
)
|
|
21
|
+
from transformers import PreTrainedTokenizerFast
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class WordPieceTokenizer(HuggingFaceTokenizer):
|
|
27
|
+
def __init__(self, vocab_size: int, trained: bool = False, output_dim: Optional[int] = None):
|
|
28
|
+
"""Largely inspired by https://huggingface.co/learn/llm-course/chapter6/8"""
|
|
29
|
+
|
|
30
|
+
super().__init__(vocab_size=vocab_size, output_dim=output_dim)
|
|
31
|
+
|
|
32
|
+
self.unk_token = "[UNK]"
|
|
33
|
+
self.pad_token = "[PAD]"
|
|
34
|
+
self.cls_token = "[CLS]"
|
|
35
|
+
self.sep_token = "[SEP]"
|
|
36
|
+
self.special_tokens = [
|
|
37
|
+
self.unk_token,
|
|
38
|
+
self.pad_token,
|
|
39
|
+
self.cls_token,
|
|
40
|
+
self.sep_token,
|
|
41
|
+
]
|
|
42
|
+
self.vocab_size = vocab_size
|
|
43
|
+
self.context_size = output_dim
|
|
44
|
+
|
|
45
|
+
self.tokenizer = Tokenizer(models.WordPiece(unk_token=self.unk_token))
|
|
46
|
+
|
|
47
|
+
self.tokenizer.normalizer = normalizers.BertNormalizer(
|
|
48
|
+
lowercase=True
|
|
49
|
+
) # NFD, lowercase, strip accents - BERT style
|
|
50
|
+
|
|
51
|
+
self.tokenizer.pre_tokenizer = (
|
|
52
|
+
pre_tokenizers.BertPreTokenizer()
|
|
53
|
+
) # split on whitespace and punctuation - BERT style
|
|
54
|
+
self.trained = trained
|
|
55
|
+
|
|
56
|
+
def _post_training(self):
|
|
57
|
+
if not self.trained:
|
|
58
|
+
raise RuntimeError(
|
|
59
|
+
"Tokenizer must be trained before applying post-training configurations."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
self.tokenizer.post_processor = processors.BertProcessing(
|
|
63
|
+
(self.cls_token, self.tokenizer.token_to_id(self.cls_token)),
|
|
64
|
+
(self.sep_token, self.tokenizer.token_to_id(self.sep_token)),
|
|
65
|
+
)
|
|
66
|
+
self.tokenizer.decoder = decoders.WordPiece(prefix="##")
|
|
67
|
+
self.padding_idx = self.tokenizer.token_to_id("[PAD]")
|
|
68
|
+
self.tokenizer.enable_padding(pad_id=self.padding_idx, pad_token="[PAD]")
|
|
69
|
+
|
|
70
|
+
self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer)
|
|
71
|
+
self.vocab_size = len(self.tokenizer)
|
|
72
|
+
|
|
73
|
+
def train(
|
|
74
|
+
self, training_corpus: List[str], save_path: str = None, filesystem=None, s3_save_path=None
|
|
75
|
+
):
|
|
76
|
+
trainer = trainers.WordPieceTrainer(
|
|
77
|
+
vocab_size=self.vocab_size,
|
|
78
|
+
special_tokens=self.special_tokens,
|
|
79
|
+
)
|
|
80
|
+
self.tokenizer.train_from_iterator(training_corpus, trainer=trainer)
|
|
81
|
+
self.trained = True
|
|
82
|
+
self._post_training()
|
|
83
|
+
|
|
84
|
+
if save_path:
|
|
85
|
+
self.tokenizer.save(save_path)
|
|
86
|
+
logger.info(f"💾 Tokenizer saved at {save_path}")
|
|
87
|
+
if filesystem and s3_save_path:
|
|
88
|
+
parent_dir = os.path.dirname(save_path)
|
|
89
|
+
if not filesystem.exists(parent_dir):
|
|
90
|
+
filesystem.mkdirs(parent_dir)
|
|
91
|
+
filesystem.put(save_path, s3_save_path)
|
|
92
|
+
logger.info(f"💾 Tokenizer uploaded to S3 at {s3_save_path}")
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .base import (
|
|
2
|
+
HAS_HF as HAS_HF,
|
|
3
|
+
)
|
|
4
|
+
from .base import BaseTokenizer as BaseTokenizer
|
|
5
|
+
from .base import (
|
|
6
|
+
HuggingFaceTokenizer as HuggingFaceTokenizer,
|
|
7
|
+
)
|
|
8
|
+
from .base import TokenizerOutput as TokenizerOutput
|
|
9
|
+
from .ngram import NGramTokenizer as NGramTokenizer
|
|
10
|
+
from .WordPiece import WordPieceTokenizer as WordPieceTokenizer
|