torchtextclassifiers 1.0.2__tar.gz → 1.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/PKG-INFO +2 -2
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/README.md +1 -1
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/pyproject.toml +1 -1
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/__init__.py +1 -0
- torchtextclassifiers-1.0.4/torchTextClassifiers/model/components/text_embedder.py +401 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/lightning.py +1 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/model.py +52 -11
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/torchTextClassifiers.py +59 -23
- torchtextclassifiers-1.0.2/torchTextClassifiers/model/components/text_embedder.py +0 -223
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/__init__.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/dataset/__init__.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/dataset/dataset.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/__init__.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/attention.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/categorical_var_net.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/classification_head.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/WordPiece.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/__init__.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/base.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/ngram.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/utilities/__init__.py +0 -0
- {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/utilities/plot_explainability.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: torchtextclassifiers
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.4
|
|
4
4
|
Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
|
|
5
5
|
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
6
|
Author: Cédric Couralet, Meilame Tayebjee
|
|
@@ -49,7 +49,7 @@ A unified, extensible framework for text classification with categorical variabl
|
|
|
49
49
|
```bash
|
|
50
50
|
# Clone the repository
|
|
51
51
|
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
52
|
-
cd
|
|
52
|
+
cd torchTextClassifiers
|
|
53
53
|
|
|
54
54
|
# Install with uv (recommended)
|
|
55
55
|
uv sync
|
|
@@ -23,7 +23,7 @@ A unified, extensible framework for text classification with categorical variabl
|
|
|
23
23
|
```bash
|
|
24
24
|
# Clone the repository
|
|
25
25
|
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
26
|
-
cd
|
|
26
|
+
cd torchTextClassifiers
|
|
27
27
|
|
|
28
28
|
# Install with uv (recommended)
|
|
29
29
|
uv sync
|
|
@@ -8,5 +8,6 @@ from .categorical_var_net import (
|
|
|
8
8
|
CategoricalVariableNet as CategoricalVariableNet,
|
|
9
9
|
)
|
|
10
10
|
from .classification_head import ClassificationHead as ClassificationHead
|
|
11
|
+
from .text_embedder import LabelAttentionConfig as LabelAttentionConfig
|
|
11
12
|
from .text_embedder import TextEmbedder as TextEmbedder
|
|
12
13
|
from .text_embedder import TextEmbedderConfig as TextEmbedderConfig
|
|
@@ -0,0 +1,401 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Dict, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
from torch.nn import functional as F
|
|
8
|
+
|
|
9
|
+
from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class LabelAttentionConfig:
|
|
14
|
+
n_head: int
|
|
15
|
+
num_classes: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class TextEmbedderConfig:
|
|
20
|
+
vocab_size: int
|
|
21
|
+
embedding_dim: int
|
|
22
|
+
padding_idx: int
|
|
23
|
+
attention_config: Optional[AttentionConfig] = None
|
|
24
|
+
label_attention_config: Optional[LabelAttentionConfig] = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TextEmbedder(nn.Module):
|
|
28
|
+
def __init__(self, text_embedder_config: TextEmbedderConfig):
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
self.config = text_embedder_config
|
|
32
|
+
|
|
33
|
+
self.attention_config = text_embedder_config.attention_config
|
|
34
|
+
if isinstance(self.attention_config, dict):
|
|
35
|
+
self.attention_config = AttentionConfig(**self.attention_config)
|
|
36
|
+
|
|
37
|
+
# Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig
|
|
38
|
+
self.label_attention_config = text_embedder_config.label_attention_config
|
|
39
|
+
if isinstance(self.label_attention_config, dict):
|
|
40
|
+
self.label_attention_config = LabelAttentionConfig(**self.label_attention_config)
|
|
41
|
+
# Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier)
|
|
42
|
+
# always see a LabelAttentionConfig instance rather than a raw dict.
|
|
43
|
+
self.config.label_attention_config = self.label_attention_config
|
|
44
|
+
|
|
45
|
+
self.enable_label_attention = self.label_attention_config is not None
|
|
46
|
+
if self.enable_label_attention:
|
|
47
|
+
self.label_attention_module = LabelAttentionClassifier(self.config)
|
|
48
|
+
|
|
49
|
+
self.vocab_size = text_embedder_config.vocab_size
|
|
50
|
+
self.embedding_dim = text_embedder_config.embedding_dim
|
|
51
|
+
self.padding_idx = text_embedder_config.padding_idx
|
|
52
|
+
|
|
53
|
+
self.embedding_layer = nn.Embedding(
|
|
54
|
+
embedding_dim=self.embedding_dim,
|
|
55
|
+
num_embeddings=self.vocab_size,
|
|
56
|
+
padding_idx=self.padding_idx,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
if self.attention_config is not None:
|
|
60
|
+
self.attention_config.n_embd = text_embedder_config.embedding_dim
|
|
61
|
+
self.transformer = nn.ModuleDict(
|
|
62
|
+
{
|
|
63
|
+
"h": nn.ModuleList(
|
|
64
|
+
[
|
|
65
|
+
Block(self.attention_config, layer_idx)
|
|
66
|
+
for layer_idx in range(self.attention_config.n_layers)
|
|
67
|
+
]
|
|
68
|
+
),
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
head_dim = self.attention_config.n_embd // self.attention_config.n_head
|
|
73
|
+
|
|
74
|
+
if head_dim * self.attention_config.n_head != self.attention_config.n_embd:
|
|
75
|
+
raise ValueError("embedding_dim must be divisible by n_head.")
|
|
76
|
+
|
|
77
|
+
if self.attention_config.positional_encoding:
|
|
78
|
+
if head_dim % 2 != 0:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
"embedding_dim / n_head must be even for rotary positional embeddings."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if self.attention_config.sequence_len is None:
|
|
84
|
+
raise ValueError(
|
|
85
|
+
"sequence_len must be specified in AttentionConfig when positional_encoding is True."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
self.rotary_seq_len = self.attention_config.sequence_len * 10
|
|
89
|
+
cos, sin = self._precompute_rotary_embeddings(
|
|
90
|
+
seq_len=self.rotary_seq_len, head_dim=head_dim
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
self.register_buffer(
|
|
94
|
+
"cos", cos, persistent=False
|
|
95
|
+
) # persistent=False means it's not saved to the checkpoint
|
|
96
|
+
self.register_buffer("sin", sin, persistent=False)
|
|
97
|
+
|
|
98
|
+
def init_weights(self):
|
|
99
|
+
self.apply(self._init_weights)
|
|
100
|
+
|
|
101
|
+
# zero out c_proj weights in all blocks
|
|
102
|
+
if self.attention_config is not None:
|
|
103
|
+
for block in self.transformer.h:
|
|
104
|
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
|
105
|
+
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
|
106
|
+
# init the rotary embeddings
|
|
107
|
+
head_dim = self.attention_config.n_embd // self.attention_config.n_head
|
|
108
|
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
109
|
+
self.cos, self.sin = cos, sin
|
|
110
|
+
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
|
111
|
+
if self.embedding_layer.weight.device.type == "cuda":
|
|
112
|
+
self.embedding_layer.to(dtype=torch.bfloat16)
|
|
113
|
+
|
|
114
|
+
def _init_weights(self, module):
|
|
115
|
+
if isinstance(module, nn.Linear):
|
|
116
|
+
# https://arxiv.org/pdf/2310.17813
|
|
117
|
+
fan_out = module.weight.size(0)
|
|
118
|
+
fan_in = module.weight.size(1)
|
|
119
|
+
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
|
120
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
121
|
+
if module.bias is not None:
|
|
122
|
+
torch.nn.init.zeros_(module.bias)
|
|
123
|
+
elif isinstance(module, nn.Embedding):
|
|
124
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
|
125
|
+
|
|
126
|
+
def forward(
|
|
127
|
+
self,
|
|
128
|
+
input_ids: torch.Tensor,
|
|
129
|
+
attention_mask: torch.Tensor,
|
|
130
|
+
return_label_attention_matrix: bool = False,
|
|
131
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
132
|
+
"""Converts input token IDs to their corresponding embeddings.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
|
|
136
|
+
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
|
|
137
|
+
return_label_attention_matrix (bool): Whether to return the label attention matrix.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
dict: A dictionary with the following keys:
|
|
141
|
+
|
|
142
|
+
- "sentence_embedding" (torch.Tensor): Text embeddings of shape
|
|
143
|
+
(batch_size, embedding_dim) if ``self.enable_label_attention`` is False,
|
|
144
|
+
else (batch_size, num_classes, embedding_dim), where ``num_classes``
|
|
145
|
+
is the number of label classes.
|
|
146
|
+
|
|
147
|
+
- "label_attention_matrix" (Optional[torch.Tensor]): Label attention
|
|
148
|
+
matrix of shape (batch_size, n_head, num_classes, seq_len) if
|
|
149
|
+
``return_label_attention_matrix`` is True and label attention is
|
|
150
|
+
enabled, otherwise ``None``. The dimensions correspond to
|
|
151
|
+
(batch_size, attention heads, label classes, sequence length).
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
encoded_text = input_ids # clearer name
|
|
155
|
+
if encoded_text.dtype != torch.long:
|
|
156
|
+
encoded_text = encoded_text.to(torch.long)
|
|
157
|
+
|
|
158
|
+
batch_size, seq_len = encoded_text.shape
|
|
159
|
+
batch_size_check, seq_len_check = attention_mask.shape
|
|
160
|
+
|
|
161
|
+
if batch_size != batch_size_check or seq_len != seq_len_check:
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"Input IDs and attention mask must have the same batch size and sequence length. "
|
|
164
|
+
f"Got input_ids shape {encoded_text.shape} and attention_mask shape {attention_mask.shape}."
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
token_embeddings = self.embedding_layer(
|
|
168
|
+
encoded_text
|
|
169
|
+
) # (batch_size, seq_len, embedding_dim)
|
|
170
|
+
|
|
171
|
+
token_embeddings = norm(token_embeddings)
|
|
172
|
+
|
|
173
|
+
if self.attention_config is not None:
|
|
174
|
+
if self.attention_config.positional_encoding:
|
|
175
|
+
cos_sin = self.cos[:, :seq_len], self.sin[:, :seq_len]
|
|
176
|
+
else:
|
|
177
|
+
cos_sin = None
|
|
178
|
+
|
|
179
|
+
for block in self.transformer.h:
|
|
180
|
+
token_embeddings = block(token_embeddings, cos_sin)
|
|
181
|
+
|
|
182
|
+
token_embeddings = norm(token_embeddings)
|
|
183
|
+
|
|
184
|
+
out = self._get_sentence_embedding(
|
|
185
|
+
token_embeddings=token_embeddings,
|
|
186
|
+
attention_mask=attention_mask,
|
|
187
|
+
return_label_attention_matrix=return_label_attention_matrix,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
text_embedding = out["sentence_embedding"]
|
|
191
|
+
label_attention_matrix = out["label_attention_matrix"]
|
|
192
|
+
return {
|
|
193
|
+
"sentence_embedding": text_embedding,
|
|
194
|
+
"label_attention_matrix": label_attention_matrix,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
def _get_sentence_embedding(
|
|
198
|
+
self,
|
|
199
|
+
token_embeddings: torch.Tensor,
|
|
200
|
+
attention_mask: torch.Tensor,
|
|
201
|
+
return_label_attention_matrix: bool = False,
|
|
202
|
+
) -> Dict[str, Optional[torch.Tensor]]:
|
|
203
|
+
"""
|
|
204
|
+
Compute sentence embedding from embedded tokens - "remove" second dimension.
|
|
205
|
+
|
|
206
|
+
Args (output from dataset collate_fn):
|
|
207
|
+
token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
|
|
208
|
+
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
|
|
209
|
+
return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
|
|
210
|
+
Returns:
|
|
211
|
+
Dict[str, Optional[torch.Tensor]]: A dictionary containing:
|
|
212
|
+
- 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
|
|
213
|
+
- 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
# average over non-pad token embeddings
|
|
217
|
+
# attention mask has 1 for non-pad tokens and 0 for pad token positions
|
|
218
|
+
|
|
219
|
+
# mask pad-tokens
|
|
220
|
+
|
|
221
|
+
if self.attention_config is not None:
|
|
222
|
+
if self.attention_config.aggregation_method is not None: # default is "mean"
|
|
223
|
+
if self.attention_config.aggregation_method == "first":
|
|
224
|
+
return {
|
|
225
|
+
"sentence_embedding": token_embeddings[:, 0, :],
|
|
226
|
+
"label_attention_matrix": None,
|
|
227
|
+
}
|
|
228
|
+
elif self.attention_config.aggregation_method == "last":
|
|
229
|
+
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
|
|
230
|
+
return {
|
|
231
|
+
"sentence_embedding": token_embeddings[
|
|
232
|
+
torch.arange(token_embeddings.size(0)),
|
|
233
|
+
lengths - 1,
|
|
234
|
+
:,
|
|
235
|
+
],
|
|
236
|
+
"label_attention_matrix": None,
|
|
237
|
+
}
|
|
238
|
+
else:
|
|
239
|
+
if self.attention_config.aggregation_method != "mean":
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
|
|
245
|
+
|
|
246
|
+
if self.enable_label_attention:
|
|
247
|
+
label_attention_result = self.label_attention_module(
|
|
248
|
+
token_embeddings,
|
|
249
|
+
attention_mask=attention_mask,
|
|
250
|
+
compute_attention_matrix=return_label_attention_matrix,
|
|
251
|
+
)
|
|
252
|
+
sentence_embedding = label_attention_result[
|
|
253
|
+
"sentence_embedding"
|
|
254
|
+
] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
|
|
255
|
+
label_attention_matrix = label_attention_result["attention_matrix"]
|
|
256
|
+
|
|
257
|
+
else: # sentence embedding = mean of (non-pad) token embeddings
|
|
258
|
+
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
|
|
259
|
+
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
|
|
260
|
+
sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
|
|
261
|
+
min=1.0
|
|
262
|
+
) # avoid division by zero
|
|
263
|
+
|
|
264
|
+
sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
|
|
265
|
+
label_attention_matrix = None
|
|
266
|
+
|
|
267
|
+
return {
|
|
268
|
+
"sentence_embedding": sentence_embedding,
|
|
269
|
+
"label_attention_matrix": label_attention_matrix,
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
|
273
|
+
# autodetect the device from model embeddings
|
|
274
|
+
if device is None:
|
|
275
|
+
device = next(self.parameters()).device
|
|
276
|
+
|
|
277
|
+
# stride the channels
|
|
278
|
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
279
|
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
280
|
+
# stride the time steps
|
|
281
|
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
282
|
+
# calculate the rotation frequencies at each (time, channel) pair
|
|
283
|
+
freqs = torch.outer(t, inv_freq)
|
|
284
|
+
cos, sin = freqs.cos(), freqs.sin()
|
|
285
|
+
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
|
286
|
+
cos, sin = (
|
|
287
|
+
cos[None, :, None, :],
|
|
288
|
+
sin[None, :, None, :],
|
|
289
|
+
) # add batch and head dims for later broadcasting
|
|
290
|
+
|
|
291
|
+
return cos, sin
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class LabelAttentionClassifier(nn.Module):
|
|
295
|
+
"""
|
|
296
|
+
A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism.
|
|
297
|
+
Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings.
|
|
298
|
+
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
def __init__(self, config: TextEmbedderConfig):
|
|
302
|
+
super().__init__()
|
|
303
|
+
|
|
304
|
+
label_attention_config = config.label_attention_config
|
|
305
|
+
self.embedding_dim = config.embedding_dim
|
|
306
|
+
self.num_classes = label_attention_config.num_classes
|
|
307
|
+
self.n_head = label_attention_config.n_head
|
|
308
|
+
|
|
309
|
+
# Validate head configuration
|
|
310
|
+
self.head_dim = self.embedding_dim // self.n_head
|
|
311
|
+
|
|
312
|
+
if self.head_dim * self.n_head != self.embedding_dim:
|
|
313
|
+
raise ValueError(
|
|
314
|
+
f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). "
|
|
315
|
+
f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim)
|
|
319
|
+
|
|
320
|
+
self.c_q = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
|
|
321
|
+
self.c_k = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
|
|
322
|
+
self.c_v = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
|
|
323
|
+
self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
|
|
324
|
+
|
|
325
|
+
def forward(
|
|
326
|
+
self,
|
|
327
|
+
token_embeddings,
|
|
328
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
329
|
+
compute_attention_matrix: Optional[bool] = False,
|
|
330
|
+
):
|
|
331
|
+
"""
|
|
332
|
+
Args:
|
|
333
|
+
token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input.
|
|
334
|
+
attention_mask (torch.Tensor, optional), shape (batch, seq_len): Attention mask indicating non-pad tokens (1 for real tokens, 0 for padding).
|
|
335
|
+
compute_attention_matrix (bool): Whether to compute and return the attention matrix.
|
|
336
|
+
Returns:
|
|
337
|
+
dict: {
|
|
338
|
+
"sentence_embedding": torch.Tensor, shape (batch, num_classes, d_model): Label-specific sentence embeddings.
|
|
339
|
+
"attention_matrix": Optional[torch.Tensor], shape (batch, n_head, num_classes, seq_len): Attention weights if compute_attention_matrix is True, else None.
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
"""
|
|
343
|
+
B, T, C = token_embeddings.size()
|
|
344
|
+
if isinstance(compute_attention_matrix, torch.Tensor):
|
|
345
|
+
compute_attention_matrix = compute_attention_matrix[0].item()
|
|
346
|
+
compute_attention_matrix = bool(compute_attention_matrix)
|
|
347
|
+
|
|
348
|
+
# 1. Create label indices [0, 1, ..., C-1] for the whole batch
|
|
349
|
+
label_indices = torch.arange(
|
|
350
|
+
self.num_classes, dtype=torch.long, device=token_embeddings.device
|
|
351
|
+
).expand(B, -1)
|
|
352
|
+
|
|
353
|
+
all_label_embeddings = self.label_embeds(
|
|
354
|
+
label_indices
|
|
355
|
+
) # Shape: [batch, num_classes, d_model]
|
|
356
|
+
all_label_embeddings = norm(all_label_embeddings)
|
|
357
|
+
|
|
358
|
+
q = self.c_q(all_label_embeddings).view(B, self.num_classes, self.n_head, self.head_dim)
|
|
359
|
+
k = self.c_k(token_embeddings).view(B, T, self.n_head, self.head_dim)
|
|
360
|
+
v = self.c_v(token_embeddings).view(B, T, self.n_head, self.head_dim)
|
|
361
|
+
|
|
362
|
+
q, k = norm(q), norm(k) # QK norm
|
|
363
|
+
q, k, v = (
|
|
364
|
+
q.transpose(1, 2),
|
|
365
|
+
k.transpose(1, 2),
|
|
366
|
+
v.transpose(1, 2),
|
|
367
|
+
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
|
368
|
+
|
|
369
|
+
# Prepare attention mask for scaled_dot_product_attention
|
|
370
|
+
# attention_mask: (B, T) with 1 for real tokens, 0 for padding
|
|
371
|
+
# scaled_dot_product_attention expects attn_mask: (B, H, Q, K) or broadcastable shape
|
|
372
|
+
# where True means "mask out" (ignore), False means "attend to"
|
|
373
|
+
attn_mask = None
|
|
374
|
+
if attention_mask is not None:
|
|
375
|
+
# Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to)
|
|
376
|
+
attn_mask = attention_mask == 0 # (B, T)
|
|
377
|
+
# Expand to (B, 1, 1, T) for broadcasting across heads and queries
|
|
378
|
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
|
|
379
|
+
|
|
380
|
+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
|
|
381
|
+
|
|
382
|
+
# Re-assemble the heads side by side and project back to residual stream
|
|
383
|
+
y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model)
|
|
384
|
+
y = self.c_proj(y)
|
|
385
|
+
|
|
386
|
+
attention_matrix = None
|
|
387
|
+
if compute_attention_matrix:
|
|
388
|
+
# Compute attention scores
|
|
389
|
+
# size (B, n_head, n_labels, seq_len)
|
|
390
|
+
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
|
|
391
|
+
|
|
392
|
+
# Apply mask to attention scores before softmax
|
|
393
|
+
if attention_mask is not None:
|
|
394
|
+
# attn_mask is already in the right shape: (B, 1, 1, T)
|
|
395
|
+
# We need to apply it to scores of shape (B, n_head, n_labels, T)
|
|
396
|
+
# Set masked positions to -inf so they become 0 after softmax
|
|
397
|
+
attention_scores = attention_scores.masked_fill(attn_mask, float("-inf"))
|
|
398
|
+
|
|
399
|
+
attention_matrix = torch.softmax(attention_scores, dim=-1)
|
|
400
|
+
|
|
401
|
+
return {"sentence_embedding": y, "attention_matrix": attention_matrix}
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/lightning.py
RENAMED
|
@@ -102,6 +102,7 @@ class TextClassificationModule(pl.LightningModule):
|
|
|
102
102
|
targets = batch["labels"]
|
|
103
103
|
|
|
104
104
|
outputs = self.forward(batch)
|
|
105
|
+
|
|
105
106
|
loss = self.loss(outputs, targets)
|
|
106
107
|
self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
|
|
107
108
|
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/model.py
RENAMED
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""TextClassification model components.
|
|
2
2
|
|
|
3
3
|
This module contains the PyTorch model, Lightning module, and dataset classes
|
|
4
|
-
for
|
|
4
|
+
for text classification. Consolidates what was previously in pytorch_model.py,
|
|
5
5
|
lightning_module.py, and dataset.py.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import logging
|
|
9
|
-
from typing import Annotated, Optional
|
|
9
|
+
from typing import Annotated, Optional, Union
|
|
10
10
|
|
|
11
11
|
import torch
|
|
12
12
|
from torch import nn
|
|
@@ -17,6 +17,7 @@ from torchTextClassifiers.model.components import (
|
|
|
17
17
|
ClassificationHead,
|
|
18
18
|
TextEmbedder,
|
|
19
19
|
)
|
|
20
|
+
from torchTextClassifiers.model.components.attention import norm
|
|
20
21
|
|
|
21
22
|
logger = logging.getLogger(__name__)
|
|
22
23
|
|
|
@@ -67,8 +68,6 @@ class TextClassificationModel(nn.Module):
|
|
|
67
68
|
|
|
68
69
|
self._validate_component_connections()
|
|
69
70
|
|
|
70
|
-
self.num_classes = self.classification_head.num_classes
|
|
71
|
-
|
|
72
71
|
torch.nn.init.zeros_(self.classification_head.net.weight)
|
|
73
72
|
if self.text_embedder is not None:
|
|
74
73
|
self.text_embedder.init_weights()
|
|
@@ -98,6 +97,17 @@ class TextClassificationModel(nn.Module):
|
|
|
98
97
|
raise ValueError(
|
|
99
98
|
"Classification head input dimension does not match expected dimension from text embedder and categorical variable net."
|
|
100
99
|
)
|
|
100
|
+
if self.text_embedder.enable_label_attention:
|
|
101
|
+
self.enable_label_attention = True
|
|
102
|
+
if self.classification_head.num_classes != 1:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"Label attention is enabled. TextEmbedder outputs a (num_classes, embedding_dim) tensor, so the ClassificationHead should have an output dimension of 1."
|
|
105
|
+
)
|
|
106
|
+
# if enable_label_attention is True, label_attention_config exists - and contains num_classes necessarily
|
|
107
|
+
self.num_classes = self.text_embedder.config.label_attention_config.num_classes
|
|
108
|
+
else:
|
|
109
|
+
self.enable_label_attention = False
|
|
110
|
+
self.num_classes = self.classification_head.num_classes
|
|
101
111
|
else:
|
|
102
112
|
logger.warning(
|
|
103
113
|
"⚠️ No text embedder provided; assuming input text is already embedded or vectorized. Take care that the classification head input dimension matches the input text dimension."
|
|
@@ -108,8 +118,9 @@ class TextClassificationModel(nn.Module):
|
|
|
108
118
|
input_ids: Annotated[torch.Tensor, "batch seq_len"],
|
|
109
119
|
attention_mask: Annotated[torch.Tensor, "batch seq_len"],
|
|
110
120
|
categorical_vars: Annotated[torch.Tensor, "batch num_cats"],
|
|
121
|
+
return_label_attention_matrix: bool = False,
|
|
111
122
|
**kwargs,
|
|
112
|
-
) -> torch.Tensor:
|
|
123
|
+
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
|
113
124
|
"""
|
|
114
125
|
Memory-efficient forward pass implementation.
|
|
115
126
|
|
|
@@ -117,35 +128,65 @@ class TextClassificationModel(nn.Module):
|
|
|
117
128
|
input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
|
|
118
129
|
attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
|
|
119
130
|
categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
|
|
131
|
+
return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix
|
|
120
132
|
|
|
121
133
|
Returns:
|
|
122
|
-
torch.Tensor
|
|
123
|
-
|
|
134
|
+
Union[torch.Tensor, dict[str, torch.Tensor]]:
|
|
135
|
+
- If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
|
|
136
|
+
containing raw logits (not softmaxed)
|
|
137
|
+
- If return_label_attention_matrix is True: dict with keys:
|
|
138
|
+
- "logits": torch.Tensor of shape (batch_size, num_classes)
|
|
139
|
+
- "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len)
|
|
124
140
|
"""
|
|
125
141
|
encoded_text = input_ids # clearer name
|
|
142
|
+
label_attention_matrix = None
|
|
126
143
|
if self.text_embedder is None:
|
|
127
144
|
x_text = encoded_text.float()
|
|
145
|
+
if return_label_attention_matrix:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"return_label_attention_matrix=True requires a text_embedder with label attention enabled"
|
|
148
|
+
)
|
|
128
149
|
else:
|
|
129
|
-
|
|
150
|
+
text_embed_output = self.text_embedder(
|
|
151
|
+
input_ids=encoded_text,
|
|
152
|
+
attention_mask=attention_mask,
|
|
153
|
+
return_label_attention_matrix=return_label_attention_matrix,
|
|
154
|
+
)
|
|
155
|
+
x_text = text_embed_output["sentence_embedding"]
|
|
156
|
+
if isinstance(return_label_attention_matrix, torch.Tensor):
|
|
157
|
+
return_label_attention_matrix = return_label_attention_matrix[0].item()
|
|
158
|
+
if return_label_attention_matrix:
|
|
159
|
+
label_attention_matrix = text_embed_output["label_attention_matrix"]
|
|
130
160
|
|
|
131
161
|
if self.categorical_variable_net:
|
|
132
162
|
x_cat = self.categorical_variable_net(categorical_vars)
|
|
133
163
|
|
|
164
|
+
if self.enable_label_attention:
|
|
165
|
+
# x_text is (batch_size, num_classes, embedding_dim)
|
|
166
|
+
# x_cat is (batch_size, cat_embedding_dim)
|
|
167
|
+
# We need to expand x_cat to (batch_size, num_classes, cat_embedding_dim)
|
|
168
|
+
# x_cat will be appended to x_text along the last dimension for each class
|
|
169
|
+
x_cat = x_cat.unsqueeze(1).expand(-1, self.num_classes, -1)
|
|
170
|
+
|
|
134
171
|
if (
|
|
135
172
|
self.categorical_variable_net.forward_type
|
|
136
173
|
== CategoricalForwardType.AVERAGE_AND_CONCAT
|
|
137
174
|
or self.categorical_variable_net.forward_type
|
|
138
175
|
== CategoricalForwardType.CONCATENATE_ALL
|
|
139
176
|
):
|
|
140
|
-
x_combined = torch.cat((x_text, x_cat), dim
|
|
177
|
+
x_combined = torch.cat((x_text, x_cat), dim=-1)
|
|
141
178
|
else:
|
|
142
179
|
assert (
|
|
143
180
|
self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT
|
|
144
181
|
)
|
|
182
|
+
|
|
145
183
|
x_combined = x_text + x_cat
|
|
146
184
|
else:
|
|
147
185
|
x_combined = x_text
|
|
148
186
|
|
|
149
|
-
logits = self.classification_head(x_combined)
|
|
187
|
+
logits = self.classification_head(norm(x_combined)).squeeze(-1)
|
|
188
|
+
|
|
189
|
+
if return_label_attention_matrix:
|
|
190
|
+
return {"logits": logits, "label_attention_matrix": label_attention_matrix}
|
|
150
191
|
|
|
151
192
|
return logits
|
|
@@ -29,6 +29,7 @@ from torchTextClassifiers.model.components import (
|
|
|
29
29
|
CategoricalForwardType,
|
|
30
30
|
CategoricalVariableNet,
|
|
31
31
|
ClassificationHead,
|
|
32
|
+
LabelAttentionConfig,
|
|
32
33
|
TextEmbedder,
|
|
33
34
|
TextEmbedderConfig,
|
|
34
35
|
)
|
|
@@ -53,6 +54,7 @@ class ModelConfig:
|
|
|
53
54
|
categorical_embedding_dims: Optional[Union[List[int], int]] = None
|
|
54
55
|
num_classes: Optional[int] = None
|
|
55
56
|
attention_config: Optional[AttentionConfig] = None
|
|
57
|
+
label_attention_config: Optional[LabelAttentionConfig] = None
|
|
56
58
|
|
|
57
59
|
def to_dict(self) -> Dict[str, Any]:
|
|
58
60
|
return asdict(self)
|
|
@@ -140,6 +142,7 @@ class torchTextClassifiers:
|
|
|
140
142
|
self.embedding_dim = model_config.embedding_dim
|
|
141
143
|
self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
|
|
142
144
|
self.num_classes = model_config.num_classes
|
|
145
|
+
self.enable_label_attention = model_config.label_attention_config is not None
|
|
143
146
|
|
|
144
147
|
if self.tokenizer.output_vectorized:
|
|
145
148
|
self.text_embedder = None
|
|
@@ -153,6 +156,7 @@ class torchTextClassifiers:
|
|
|
153
156
|
embedding_dim=self.embedding_dim,
|
|
154
157
|
padding_idx=tokenizer.padding_idx,
|
|
155
158
|
attention_config=model_config.attention_config,
|
|
159
|
+
label_attention_config=model_config.label_attention_config,
|
|
156
160
|
)
|
|
157
161
|
self.text_embedder = TextEmbedder(
|
|
158
162
|
text_embedder_config=text_embedder_config,
|
|
@@ -174,7 +178,9 @@ class torchTextClassifiers:
|
|
|
174
178
|
|
|
175
179
|
self.classification_head = ClassificationHead(
|
|
176
180
|
input_dim=classif_head_input_dim,
|
|
177
|
-
num_classes=
|
|
181
|
+
num_classes=1
|
|
182
|
+
if self.enable_label_attention
|
|
183
|
+
else model_config.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim)
|
|
178
184
|
)
|
|
179
185
|
|
|
180
186
|
self.pytorch_model = TextClassificationModel(
|
|
@@ -486,13 +492,15 @@ class torchTextClassifiers:
|
|
|
486
492
|
self,
|
|
487
493
|
X_test: np.ndarray,
|
|
488
494
|
top_k=1,
|
|
489
|
-
|
|
495
|
+
explain_with_label_attention: bool = False,
|
|
496
|
+
explain_with_captum=False,
|
|
490
497
|
):
|
|
491
498
|
"""
|
|
492
499
|
Args:
|
|
493
500
|
X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
|
|
494
501
|
top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
|
|
495
|
-
|
|
502
|
+
explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False)
|
|
503
|
+
explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False)
|
|
496
504
|
|
|
497
505
|
Returns: A dictionary containing the following fields:
|
|
498
506
|
- predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
|
|
@@ -501,6 +509,7 @@ class torchTextClassifiers:
|
|
|
501
509
|
- attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
|
|
502
510
|
"""
|
|
503
511
|
|
|
512
|
+
explain = explain_with_label_attention or explain_with_captum
|
|
504
513
|
if explain:
|
|
505
514
|
return_offsets_mapping = True # to be passed to the tokenizer
|
|
506
515
|
return_word_ids = True
|
|
@@ -509,13 +518,19 @@ class torchTextClassifiers:
|
|
|
509
518
|
"Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
|
|
510
519
|
)
|
|
511
520
|
else:
|
|
512
|
-
if
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
521
|
+
if explain_with_captum:
|
|
522
|
+
if not HAS_CAPTUM:
|
|
523
|
+
raise ImportError(
|
|
524
|
+
"Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
|
|
525
|
+
)
|
|
526
|
+
lig = LayerIntegratedGradients(
|
|
527
|
+
self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
|
|
528
|
+
) # initialize a Captum layer gradient integrator
|
|
529
|
+
if explain_with_label_attention:
|
|
530
|
+
if not self.enable_label_attention:
|
|
531
|
+
raise RuntimeError(
|
|
532
|
+
"Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain."
|
|
533
|
+
)
|
|
519
534
|
else:
|
|
520
535
|
return_offsets_mapping = False
|
|
521
536
|
return_word_ids = False
|
|
@@ -547,9 +562,19 @@ class torchTextClassifiers:
|
|
|
547
562
|
else:
|
|
548
563
|
categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32)
|
|
549
564
|
|
|
550
|
-
|
|
551
|
-
encoded_text,
|
|
565
|
+
model_output = self.pytorch_model(
|
|
566
|
+
encoded_text,
|
|
567
|
+
attention_mask,
|
|
568
|
+
categorical_vars,
|
|
569
|
+
return_label_attention_matrix=explain_with_label_attention,
|
|
552
570
|
) # forward pass, contains the prediction scores (len(text), num_classes)
|
|
571
|
+
pred = (
|
|
572
|
+
model_output["logits"] if explain_with_label_attention else model_output
|
|
573
|
+
) # (batch_size, num_classes)
|
|
574
|
+
|
|
575
|
+
label_attention_matrix = (
|
|
576
|
+
model_output["label_attention_matrix"] if explain_with_label_attention else None
|
|
577
|
+
)
|
|
553
578
|
|
|
554
579
|
label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities
|
|
555
580
|
|
|
@@ -559,21 +584,28 @@ class torchTextClassifiers:
|
|
|
559
584
|
confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
|
|
560
585
|
|
|
561
586
|
if explain:
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
587
|
+
if explain_with_captum:
|
|
588
|
+
# Captum explanations
|
|
589
|
+
captum_attributions = []
|
|
590
|
+
for k in range(top_k):
|
|
591
|
+
attributions = lig.attribute(
|
|
592
|
+
(encoded_text, attention_mask, categorical_vars),
|
|
593
|
+
target=torch.Tensor(predictions[:, k]).long(),
|
|
594
|
+
) # (batch_size, seq_len)
|
|
595
|
+
attributions = attributions.sum(dim=-1)
|
|
596
|
+
captum_attributions.append(attributions.detach().cpu())
|
|
597
|
+
|
|
598
|
+
captum_attributions = torch.stack(
|
|
599
|
+
captum_attributions, dim=1
|
|
600
|
+
) # (batch_size, top_k, seq_len)
|
|
601
|
+
else:
|
|
602
|
+
captum_attributions = None
|
|
572
603
|
|
|
573
604
|
return {
|
|
574
605
|
"prediction": predictions,
|
|
575
606
|
"confidence": confidence,
|
|
576
|
-
"
|
|
607
|
+
"captum_attributions": captum_attributions,
|
|
608
|
+
"label_attention_attributions": label_attention_matrix,
|
|
577
609
|
"offset_mapping": tokenize_output.offset_mapping,
|
|
578
610
|
"word_ids": tokenize_output.word_ids,
|
|
579
611
|
}
|
|
@@ -665,6 +697,10 @@ class torchTextClassifiers:
|
|
|
665
697
|
|
|
666
698
|
# Reconstruct model_config
|
|
667
699
|
model_config = ModelConfig.from_dict(metadata["model_config"])
|
|
700
|
+
if isinstance(model_config.label_attention_config, dict):
|
|
701
|
+
model_config.label_attention_config = LabelAttentionConfig(
|
|
702
|
+
**model_config.label_attention_config
|
|
703
|
+
)
|
|
668
704
|
|
|
669
705
|
# Create instance
|
|
670
706
|
instance = cls(
|
|
@@ -1,223 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from torch import nn
|
|
7
|
-
|
|
8
|
-
from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@dataclass
|
|
12
|
-
class TextEmbedderConfig:
|
|
13
|
-
vocab_size: int
|
|
14
|
-
embedding_dim: int
|
|
15
|
-
padding_idx: int
|
|
16
|
-
attention_config: Optional[AttentionConfig] = None
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class TextEmbedder(nn.Module):
|
|
20
|
-
def __init__(self, text_embedder_config: TextEmbedderConfig):
|
|
21
|
-
super().__init__()
|
|
22
|
-
|
|
23
|
-
self.config = text_embedder_config
|
|
24
|
-
|
|
25
|
-
self.attention_config = text_embedder_config.attention_config
|
|
26
|
-
if isinstance(self.attention_config, dict):
|
|
27
|
-
self.attention_config = AttentionConfig(**self.attention_config)
|
|
28
|
-
|
|
29
|
-
if self.attention_config is not None:
|
|
30
|
-
self.attention_config.n_embd = text_embedder_config.embedding_dim
|
|
31
|
-
|
|
32
|
-
self.vocab_size = text_embedder_config.vocab_size
|
|
33
|
-
self.embedding_dim = text_embedder_config.embedding_dim
|
|
34
|
-
self.padding_idx = text_embedder_config.padding_idx
|
|
35
|
-
|
|
36
|
-
self.embedding_layer = nn.Embedding(
|
|
37
|
-
embedding_dim=self.embedding_dim,
|
|
38
|
-
num_embeddings=self.vocab_size,
|
|
39
|
-
padding_idx=self.padding_idx,
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
if self.attention_config is not None:
|
|
43
|
-
self.transformer = nn.ModuleDict(
|
|
44
|
-
{
|
|
45
|
-
"h": nn.ModuleList(
|
|
46
|
-
[
|
|
47
|
-
Block(self.attention_config, layer_idx)
|
|
48
|
-
for layer_idx in range(self.attention_config.n_layers)
|
|
49
|
-
]
|
|
50
|
-
),
|
|
51
|
-
}
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
head_dim = self.attention_config.n_embd // self.attention_config.n_head
|
|
55
|
-
|
|
56
|
-
if head_dim * self.attention_config.n_head != self.attention_config.n_embd:
|
|
57
|
-
raise ValueError("embedding_dim must be divisible by n_head.")
|
|
58
|
-
|
|
59
|
-
if self.attention_config.positional_encoding:
|
|
60
|
-
if head_dim % 2 != 0:
|
|
61
|
-
raise ValueError(
|
|
62
|
-
"embedding_dim / n_head must be even for rotary positional embeddings."
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
if self.attention_config.sequence_len is None:
|
|
66
|
-
raise ValueError(
|
|
67
|
-
"sequence_len must be specified in AttentionConfig when positional_encoding is True."
|
|
68
|
-
)
|
|
69
|
-
|
|
70
|
-
self.rotary_seq_len = self.attention_config.sequence_len * 10
|
|
71
|
-
cos, sin = self._precompute_rotary_embeddings(
|
|
72
|
-
seq_len=self.rotary_seq_len, head_dim=head_dim
|
|
73
|
-
)
|
|
74
|
-
|
|
75
|
-
self.register_buffer(
|
|
76
|
-
"cos", cos, persistent=False
|
|
77
|
-
) # persistent=False means it's not saved to the checkpoint
|
|
78
|
-
self.register_buffer("sin", sin, persistent=False)
|
|
79
|
-
|
|
80
|
-
def init_weights(self):
|
|
81
|
-
self.apply(self._init_weights)
|
|
82
|
-
|
|
83
|
-
# zero out c_proj weights in all blocks
|
|
84
|
-
if self.attention_config is not None:
|
|
85
|
-
for block in self.transformer.h:
|
|
86
|
-
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
|
87
|
-
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
|
88
|
-
# init the rotary embeddings
|
|
89
|
-
head_dim = self.attention_config.n_embd // self.attention_config.n_head
|
|
90
|
-
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
|
91
|
-
self.cos, self.sin = cos, sin
|
|
92
|
-
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
|
93
|
-
if self.embedding_layer.weight.device.type == "cuda":
|
|
94
|
-
self.embedding_layer.to(dtype=torch.bfloat16)
|
|
95
|
-
|
|
96
|
-
def _init_weights(self, module):
|
|
97
|
-
if isinstance(module, nn.Linear):
|
|
98
|
-
# https://arxiv.org/pdf/2310.17813
|
|
99
|
-
fan_out = module.weight.size(0)
|
|
100
|
-
fan_in = module.weight.size(1)
|
|
101
|
-
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
|
102
|
-
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
103
|
-
if module.bias is not None:
|
|
104
|
-
torch.nn.init.zeros_(module.bias)
|
|
105
|
-
elif isinstance(module, nn.Embedding):
|
|
106
|
-
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
|
107
|
-
|
|
108
|
-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
109
|
-
"""Converts input token IDs to their corresponding embeddings."""
|
|
110
|
-
|
|
111
|
-
encoded_text = input_ids # clearer name
|
|
112
|
-
if encoded_text.dtype != torch.long:
|
|
113
|
-
encoded_text = encoded_text.to(torch.long)
|
|
114
|
-
|
|
115
|
-
batch_size, seq_len = encoded_text.shape
|
|
116
|
-
batch_size_check, seq_len_check = attention_mask.shape
|
|
117
|
-
|
|
118
|
-
if batch_size != batch_size_check or seq_len != seq_len_check:
|
|
119
|
-
raise ValueError(
|
|
120
|
-
f"Input IDs and attention mask must have the same batch size and sequence length. "
|
|
121
|
-
f"Got input_ids shape {encoded_text.shape} and attention_mask shape {attention_mask.shape}."
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
token_embeddings = self.embedding_layer(
|
|
125
|
-
encoded_text
|
|
126
|
-
) # (batch_size, seq_len, embedding_dim)
|
|
127
|
-
|
|
128
|
-
token_embeddings = norm(token_embeddings)
|
|
129
|
-
|
|
130
|
-
if self.attention_config is not None:
|
|
131
|
-
if self.attention_config.positional_encoding:
|
|
132
|
-
cos_sin = self.cos[:, :seq_len], self.sin[:, :seq_len]
|
|
133
|
-
else:
|
|
134
|
-
cos_sin = None
|
|
135
|
-
|
|
136
|
-
for block in self.transformer.h:
|
|
137
|
-
token_embeddings = block(token_embeddings, cos_sin)
|
|
138
|
-
|
|
139
|
-
token_embeddings = norm(token_embeddings)
|
|
140
|
-
|
|
141
|
-
text_embedding = self._get_sentence_embedding(
|
|
142
|
-
token_embeddings=token_embeddings, attention_mask=attention_mask
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
return text_embedding
|
|
146
|
-
|
|
147
|
-
def _get_sentence_embedding(
|
|
148
|
-
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
149
|
-
) -> torch.Tensor:
|
|
150
|
-
"""
|
|
151
|
-
Compute sentence embedding from embedded tokens - "remove" second dimension.
|
|
152
|
-
|
|
153
|
-
Args (output from dataset collate_fn):
|
|
154
|
-
token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
|
|
155
|
-
attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
|
|
156
|
-
Returns:
|
|
157
|
-
torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim)
|
|
158
|
-
"""
|
|
159
|
-
|
|
160
|
-
# average over non-pad token embeddings
|
|
161
|
-
# attention mask has 1 for non-pad tokens and 0 for pad token positions
|
|
162
|
-
|
|
163
|
-
# mask pad-tokens
|
|
164
|
-
|
|
165
|
-
if self.attention_config is not None:
|
|
166
|
-
if self.attention_config.aggregation_method is not None:
|
|
167
|
-
if self.attention_config.aggregation_method == "first":
|
|
168
|
-
return token_embeddings[:, 0, :]
|
|
169
|
-
elif self.attention_config.aggregation_method == "last":
|
|
170
|
-
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
|
|
171
|
-
return token_embeddings[
|
|
172
|
-
torch.arange(token_embeddings.size(0)),
|
|
173
|
-
lengths - 1,
|
|
174
|
-
:,
|
|
175
|
-
]
|
|
176
|
-
else:
|
|
177
|
-
if self.attention_config.aggregation_method != "mean":
|
|
178
|
-
raise ValueError(
|
|
179
|
-
f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
|
|
183
|
-
|
|
184
|
-
mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
|
|
185
|
-
masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
|
|
186
|
-
|
|
187
|
-
sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
|
|
188
|
-
min=1.0
|
|
189
|
-
) # avoid division by zero
|
|
190
|
-
|
|
191
|
-
sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
|
|
192
|
-
|
|
193
|
-
return sentence_embedding
|
|
194
|
-
|
|
195
|
-
def __call__(self, *args, **kwargs):
|
|
196
|
-
out = super().__call__(*args, **kwargs)
|
|
197
|
-
if out.dim() != 2:
|
|
198
|
-
raise ValueError(
|
|
199
|
-
f"Output of {self.__class__.__name__}.forward must be 2D "
|
|
200
|
-
f"(got shape {tuple(out.shape)})"
|
|
201
|
-
)
|
|
202
|
-
return out
|
|
203
|
-
|
|
204
|
-
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
|
205
|
-
# autodetect the device from model embeddings
|
|
206
|
-
if device is None:
|
|
207
|
-
device = next(self.parameters()).device
|
|
208
|
-
|
|
209
|
-
# stride the channels
|
|
210
|
-
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
|
211
|
-
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
|
212
|
-
# stride the time steps
|
|
213
|
-
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
|
214
|
-
# calculate the rotation frequencies at each (time, channel) pair
|
|
215
|
-
freqs = torch.outer(t, inv_freq)
|
|
216
|
-
cos, sin = freqs.cos(), freqs.sin()
|
|
217
|
-
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
|
218
|
-
cos, sin = (
|
|
219
|
-
cos[None, :, None, :],
|
|
220
|
-
sin[None, :, None, :],
|
|
221
|
-
) # add batch and head dims for later broadcasting
|
|
222
|
-
|
|
223
|
-
return cos, sin
|
|
File without changes
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/dataset/__init__.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/dataset/dataset.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/base.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/ngram.py
RENAMED
|
File without changes
|
{torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/utilities/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|